from .setops import SetOps
from .tunnel import SSHTunnel, loop, PY2, ErrorResult, pickle, RemoteException
__metaclass__ = type
if not PY2:
basestring = str
[docs]class GroupResult(dict):
"""The results of a :meth:`Group.call()` operation.
GroupResult behaves as a dictionary of results, keyed by hostname, although
failures from individual hosts are represented as :class:`ErrorResult`
objects.
Methods are provided to easily process successes and failures separately.
"""
if not PY2:
def iteritems(self):
"""Implement iteritems() for Python 3."""
return self.items()
[docs] def successful(self):
"""Iterate over successful results as (host, value) pairs."""
for host, res in self.iteritems():
if isinstance(res, ErrorResult):
continue
yield host, res
[docs] def failures(self):
"""Iterate over failed results as (host, err) pairs."""
for host, res in self.iteritems():
if isinstance(res, ErrorResult):
yield host, res
[docs] def raise_failures(self):
"""Raise a RemoteException if there were any failures."""
failures = []
for host, err in self.failures():
failures.append(
'[%s] %s' % (host, err.msg)
)
if failures:
raise RemoteException(
'{}/{} hosts had failures:\n\n{}'.format(
len(failures),
len(self),
'\n'.join(failures)
)
)
def __repr__(self):
return '%s(%s)' % (
self.__class__.__name__,
super(GroupResult, self).__repr__()
)
class GroupOp:
"""An operation in progress on a group."""
def __init__(self, callback):
self.callback = callback
self.results = {}
self.waiting = 0
def make_callback(self, host):
"""Return a callback to store a result for the given host.
The callback will trigger the GroupOp's callback once all group
results have been received.
"""
def cb(ret):
self.results[host] = ret
self.waiting -= 1
if self.waiting <= 0:
self.callback(GroupResult(self.results))
self.waiting += 1
return cb
[docs]class Group(SetOps):
"""A group of hosts, for performing operations in parallel."""
[docs] def __init__(self, hosts):
"""Construct a group from a list of tunnels or hosts.
`hosts` may contain hostnames - in which case the connections will be
made via SSH using the default settings. Alternatively, it may contain
tunnel instances.
"""
self.tunnels = []
for h in hosts:
if isinstance(h, basestring):
h = SSHTunnel(h)
self.tunnels.append(h)
self.connection_errors = {}
def _new_op(self):
self.op = GroupOp(loop.stop)
self.op.results = self.connection_errors.copy()
def _parallel(self, tunnels, method, *args, **kwargs):
"""Helper to call a method on all tunnels."""
self._new_op()
for t in tunnels:
m = getattr(t, method)
m(self.op.make_callback(t.host), *args, **kwargs)
try:
return loop.run()
except:
self.close()
raise
def connect(self):
"""Connect all tunnels."""
self._connect(force=True)
def _connect(self, force=False):
"""Connect all disconnected tunnels.
Return a list of the tunnels we ended up connecting. Connection errors
are saved into self.connection_errors.
If force is False, don't attempt to reconnect tunnels that have
failed to connect already.
"""
all_tunnels = {}
connected_tunnels = []
disconnected_tunnels = []
for t in self.tunnels:
all_tunnels[t.host] = t
if t.connected:
connected_tunnels.append(t)
else:
if force or t.host not in self.connection_errors:
disconnected_tunnels.append(t)
if not disconnected_tunnels:
return connected_tunnels
result = self._parallel(disconnected_tunnels, '_connect_async')
self.connection_errors = {
host: err
for host, err in self.connection_errors.items()
if host in all_tunnels
}
pickle_versions = [pickle.HIGHEST_PROTOCOL]
for host, r in result.iteritems():
t = all_tunnels[host]
err = isinstance(r, ErrorResult)
t.connected = not err
if err:
self.connection_errors[host] = r
else:
pickle_versions.append(r)
# Use a common pickle version for all of these tunnels
pickle_version = min(pickle_versions)
for t in all_tunnels.values():
t.pickle_version = pickle_version
connected = set(all_tunnels) - set(self.connection_errors)
return [t for t in all_tunnels.values() if t.host in connected]
def close(self):
"""Close all tunnels."""
for t in self.tunnels:
t.close()
def __enter__(self):
"""Connect all tunnels."""
self.connect()
return self
def __exit__(self, *_):
self.close()
[docs] def call(self, callable, *args, **kwargs):
"""Call the given callable on all hosts in the group.
The given callable and parameters must be pickleable.
However, the callable's return value has a tighter restriction: it must
be serialisable as JSON, in order to ensure the orchestration host
cannot be compromised through pickle attacks.
The return value is a :class:`GroupResult`.
"""
tunnels = self._connect()
return self._parallel(tunnels, '_call_async', callable, *args, **kwargs)
@staticmethod
def _local_paths(tunnels, local_path):
if local_path is not None:
names = [local_path.format(host=t.host) for t in tunnels]
if len(set(names)) != len(tunnels):
raise ValueError(
'local_path template %s does not give unique paths' %
local_path
)
return zip(tunnels, names)
else:
return ((t, None) for t in tunnels)
[docs] def fetch(self, remote_path, local_path=None):
"""Fetch files from all remote hosts.
If `local_path` is given, it is a local path template, into which
the tunnel's ``host`` name will be substituted using ``str.format()``.
Hostnames generated in this way must be unique.
For example::
group.fetch('/etc/passwd', local_path='passwd-{host}')
If `local_path` is not given, a temporary file will be used for
each host.
Return a :class:`GroupResult` of dicts, each containing:
* ``local_path`` - the local path written to
* ``remote_path`` - the absolute remote path
* ``size`` - the number of bytes received
* ``sha1sum`` - a sha1 checksum of the file data
"""
tunnels = self._connect()
self._new_op()
for tun, local_path in self._local_paths(tunnels, local_path):
tun._fetch_async(
self.op.make_callback(tun.host),
remote_path, local_path
)
try:
return loop.run()
except:
self.close()
raise
[docs] def put(self, local_path, remote_path=None, mode=0o644):
"""Copy a file to all remote hosts.
If remote_path is given, it is the remote path to write to. Otherwise,
a temporary filename will be used (which will be different on each
host).
`mode` gives the permission bits of the files to create, or 0o644 if
unspecified.
This operation supports arbitarily large files (file data is streamed,
not buffered in memory).
Return a :class:`GroupResult` of dicts, each containing:
* ``remote_path`` - the absolute remote path
* ``size`` - the number of bytes received
* ``sha1sum`` - a sha1 checksum of the file data
"""
tunnels = self._connect()
return self._parallel(
tunnels, '_put_async',
local_path, remote_path, mode
)
def __repr__(self):
return '%s(%r)' % (type(self).__name__, self.tunnels)
def _as_group(self):
return self
[docs] def filter(self, predicate, exclude=False):
"""Return a Group of the tunnels for which `predicate` returns True.
`predicate` must be a no-argument callable that can be pickled.
If `exclude` is True, then return a Group that only contains tunnels
for which predicate returns False.
Raise RemoteException if any hosts could not be connected or fail to
evaluate the predicate.
"""
result = self.call(predicate)
result.raise_failures()
if exclude:
op = lambda x: not x
else:
op = bool
include = set(host for host, res in result.successful() if op(res))
cls = type(self)
return cls([t for t in self.tunnels if t.host in include])