Source code for chopsticks.tunnel

from __future__ import print_function
import subprocess
import sys
import os
import os.path
import pkgutil
import threading
import tempfile
import time
from hashlib import sha1
from base64 import b64encode
from contextlib import contextmanager

import chopsticks
from . import ioloop
from .setops import SetOps
from .serialise_main import prepare_callable

PY2 = sys.version_info < (3,)

if PY2:
    import cPickle as pickle
    import pickle

__metaclass__ = type

# One global loop for all tunnels
loop = ioloop.IOLoop()

# Another thread will output stderr
errloop = ioloop.IOLoop()

OP_RET = 1
OP_EXC = 2
OP_IMP = 3

CHOPSTICKS_PREFIX = 'chopsticks://'

def start_errloop():
    if errloop.running:  # FIXME: race condition - may be stopping
    t = threading.Thread(

[docs]class ErrorResult: """Indicates an error returned by the remote host. Because tracebacks or error types cannot be represented across hosts this will simply consist of a message. """ def __init__(self, msg, tb=None): self.msg = msg if tb: self.msg += '\n\n ' + '\n '.join(tb.splitlines()) def __repr__(self): return 'ErrorResult(%r)' % self.msg __str__ = __repr__ __unicode__ = __repr__
class RemoteException(Exception): """An exception from the remote agent.""" class DepthLimitExceeded(Exception): """The recursive tunnel depth limit was hit.""" bubble = pkgutil.get_data('chopsticks', '')
[docs]class BaseTunnel(SetOps): HIGHEST_PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL def __init__(self): self._reset() def _reset(self): self.req_id = 0 self.callbacks = {} self.connected = False self.pickle_version = self.HIGHEST_PICKLE_PROTOCOL def __eq__(self, ano): return == def __ne__(self, ano): return != def __hash__(self): return hash( def __repr__(self): return '%s(%r)' % (type(self).__name__, def _as_group(self): """Tunnels behave like groups of one tunnel.""" from import Group return Group([self]) def _run_loop(self): """Run the loop, but clean up after crashes.""" try: return except: self.close() raise def connect(self): if self.connected: return assert, "No host name received" self._connect_async(loop.stop) res = self._run_loop() if isinstance(res, ErrorResult): raise RemoteException(res.msg) def _connect_async(self, callback): """Connect the tunnel.""" raise NotImplementedError('Subclasses must implement _connect_async()') def write_msg(self, op, req_id, data=None, **kwargs): """Write one message to the subprocess. This uses a chunked JSON protocol. """ if data and kwargs: raise TypeError('Can only send kwargs or data') self.writer.write(op, req_id, data or kwargs) def _next_id(self): self.req_id += 1 return self.req_id @classmethod def _read_source(cls, file): with open(file, 'rb') as f: return def handle_imp(self, mod): key = mod fname = None if mod == '__main__': # Special-case main to find real main module main = sys.modules['__main__'] path = main.__file__ self.write_msg( OP_IMP, 0, mod=mod, exists=True, is_pkg=False, file=os.path.basename(path), source=self._read_source(path) ) return elif isinstance(mod, list): mod, fname = mod if not mod: mod, fname = fname.split('/', 1) stem = mod.replace('.', os.sep) paths = [ (True, os.path.join(stem, '')), (False, stem + '.py'), ] for root in sys.path: if root == CHOPSTICKS_PREFIX: importer = sys.path_importer_cache[root] if fname: req = (mod, fname) else: req = mod try: imp = importer._raw_get(req) except ImportError: continue else: self.write_msg( OP_IMP, 0, mod=key, exists=imp.exists, is_pkg=imp.is_pkg, file=imp.file, source=imp.source, ) return for is_pkg, rel in paths: path = os.path.join(root, rel) if os.path.exists(path): if fname is not None: path = os.path.join(root, stem, fname) if not os.path.exists(path): break rel = stem + '/' + fname is_pkg = False self.write_msg( OP_IMP, 0, mod=key, exists=True, is_pkg=is_pkg, file=rel, source=self._read_source(path) ) return self.write_msg( OP_IMP, 0, mod=key, exists=False, is_pkg=False, file=None, source='' ) def _get_callback(self, req_id, data): try: return self.callbacks[req_id] except KeyError: raise RuntimeError( 'response received for unknown req_id %d.' % req_id + ' data: %s' % data + ' callbacks: %s' % self.callbacks ) def _pop_callback(self, req_id, data): cb = self._get_callback(req_id, data) del self.callbacks[req_id] return cb def on_message(self, msg): """Pump messages until the given ID is received. The current thread will be blocked until the response is received. """ op, req_id, data = msg if op == OP_EXC: error = ErrorResult( 'Host %r raised exception; traceback follows' %, data['tb'] ) self._pop_callback(req_id, data)(error) elif op == OP_IMP: self.handle_imp(data['imp']) elif op == OP_RET: self.reader.stop() self._pop_callback(req_id, data)(data['ret']) elif op == OP_FETCH_DATA: self._get_callback(req_id, data).recv(data) else: self._warn('Unknown opcode received %r' % op) def _warn(self, msg): print('%s:' %, msg, file=sys.stderr)
[docs] def call(self, callable, *args, **kwargs): """Call the given callable on the remote host. The callable must return a value that can be serialised as JSON, but there is no such restriction on the parameters. """ self.connect() self._call_async(loop.stop, callable, *args, **kwargs) ret = self._run_loop() if isinstance(ret, ErrorResult): raise RemoteException(ret.msg) return ret
def _call_async(self, on_result, callable, *args, **kwargs): id = self._next_id() self.callbacks[id] = on_result params = prepare_callable(callable, args, kwargs) self.reader.start() self.write_msg( OP_CALL, req_id=id, data=pickle.dumps(params, self.pickle_version) )
[docs] def fetch(self, remote_path, local_path=None): """Fetch one file from the remote host. If local_path is given, it is the local path to write to. Otherwise, a temporary filename will be used. This operation supports arbitarily large files (file data is streamed, not buffered in memory). The return value is a dict containing: * ``local_path`` - the local path written to * ``remote_path`` - the absolute remote path * ``size`` - the number of bytes received * ``sha1sum`` - a sha1 checksum of the file data """ self.connect() self._fetch_async(loop.stop, remote_path, local_path) ret = self._run_loop() if isinstance(ret, ErrorResult): raise RemoteException(ret.msg) return ret
def _fetch_async(self, on_result, remote_path, local_path=None): id = self._next_id() fetch = Fetch(on_result, local_path) self.callbacks[id] = fetch self.reader.start() self.write_msg( OP_FETCH_BEGIN, req_id=id, path=remote_path, )
[docs] def put(self, local_path, remote_path=None, mode=0o644): """Copy a file to the remote host. If `remote_path` is given, it is the remote path to write to. Otherwise, a temporary filename will be used. `mode` gives is the permission bits of the file to create, or 0o644 if unspecified. This operation supports arbitarily large files (file data is streamed, not buffered in memory). The return value is a dict containing: * ``remote_path`` - the absolute remote path * ``size`` - the number of bytes received * ``sha1sum`` - a sha1 checksum of the file data """ self.connect() self._put_async(loop.stop, local_path, remote_path, mode) ret = self._run_loop() if isinstance(ret, ErrorResult): raise RemoteException(ret.msg) return ret
def _put_async( self, on_result, local_path, remote_path=None, mode=0o644): id = self._next_id() self.callbacks[id] = on_result self.reader.start() self.write_msg( OP_PUT_BEGIN, id, path=remote_path, mode=mode ) self.writer.write_iter( iter_chunks(id, local_path) )
[docs] def close(): """Disconnect the tunnel. Note that this will terminate the remote process and any state will be lost. This does not destroy the Tunnel object, which can be reconnected with :meth:`.connect()`. """ raise NotImplementedError()
def iter_chunks(req_id, path): """Iterate over chunks of the given file. Yields message suitable for writing to a stream. """ chksum = sha1() with open(path, 'rb') as f: while True: chunk = if not chunk: yield OP_PUT_END, req_id, {'sha1sum': chksum.hexdigest()} break chksum.update(chunk) yield OP_PUT_DATA, req_id, chunk class Fetch(object): def __init__(self, on_result, local_path=None): self.on_result = on_result if local_path: self.local_path = local_path self.file = open(local_path, 'wb') else: self.file = tempfile.NamedTemporaryFile('wb', delete=False) self.local_path = self.size = 0 self.chksum = sha1() def recv(self, data): self.chksum.update(data) self.file.write(data) self.size += len(data) def __call__(self, result): self.file.close() if not isinstance(result, ErrorResult): remote_chksum = result['sha1sum'] if remote_chksum != self.chksum.hexdigest(): result = ErrorResult('Fetch failed due to checksum mismatch') else: result['local_path'] = self.local_path result['size'] = self.size if isinstance(result, ErrorResult): os.unlink(self.local_path) self.on_result(result) class PipeTunnel(BaseTunnel): """A tunnel that connects via a pair of unidirectional pipes. Subclasses will need to implement ``connect_pipes()`` to create the ``self.wpipe`` and ``self.rpipe`` pipe attibutes. """ def _connect_async(self, callback): if self.connected: callback(None) return try: path = sys._chopsticks_path[:] except AttributeError: path = [] path.append( if len(path) > chopsticks.DEPTH_LIMIT: raise DepthLimitExceeded( 'Depth limit of %s exceeded at %s' % ( chopsticks.DEPTH_LIMIT, ' -> '.join(path) ) ) self.connect_pipes() self.reader = loop.reader(self.rpipe, self) self.writer = loop.writer(self.wpipe) def wrapped_callback(res): self.connected = not isinstance(res, ErrorResult) if self.connected: # Remote sends a pickle_version in response to OP_START self.pickle_version = min(self.HIGHEST_PICKLE_PROTOCOL, res) callback(res) self.callbacks[0] = wrapped_callback self.reader.start() self.writer.write_raw(bubble) self.write_msg( OP_START, req_id=0,, path=path, depthlimit=chopsticks.DEPTH_LIMIT, ) self.errreader = ioloop.StderrReader(errloop, self.epipe, start_errloop() def on_error(self, err): err = ErrorResult(err) for id in list(self.callbacks): self.callbacks.pop(id)(err) def _join(self, timeout=5): end = time.time() + timeout while time.time() < end: if self.proc.poll() is not None: return True time.sleep(0.01) return False def close(self): if not self.connected: return self.wpipe.close() # Terminate child self.reader.stop() self.writer.stop() self._reset() if self._join(timeout=1): return # Send TERM self.proc.terminate() # Wait for process to shut down cleanly if self._join(timeout=5): return # Process did not shut down cleanly; force kill it self.proc.kill() self._warn('Timeout expired waiting for pipe to close') def __del__(self): self.close() def __enter__(self): return self def __exit__(self, exc_type, exc_value, trackback): self.close() class SubprocessTunnel(PipeTunnel): """A tunnel that connects to a subprocess.""" #: These arguments are used for bootstrapping Python into out remote agent PYTHON_ARGS = [ '-usS', '-c', 'import sys, os; sys.stdin = os.fdopen(0, \'rb\', 0); ' + '__bubble =; ' % len(bubble) + 'exec(compile(__bubble, \'\', \'exec\'))' ] # Paths to the Python 2/3 binary on the remote host python2 = '/usr/bin/python2' python3 = '/usr/bin/python3' def connect_pipes(self): args = self.cmd_args() self.proc = subprocess.Popen( args, bufsize=0, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.PIPE, shell=False, preexec_fn=os.setpgrp ) self.wpipe = self.proc.stdin self.rpipe = self.proc.stdout self.epipe = self.proc.stderr def cmd_args(self): python = self.python2 if PY2 else self.python3 return [python] + self.PYTHON_ARGS
[docs]class Local(SubprocessTunnel): """A tunnel to a subprocess on the same host.""" def __init__(self, name='localhost'): = name super(Local, self).__init__()
[docs]class Sudo(SubprocessTunnel): """A tunnel to a process on the same host, launched with sudo.""" def __init__(self, user='root', name=None): self.user = user = name or user + '@localhost' super(Sudo, self).__init__() def cmd_args(self): args = [ 'sudo', '--non-interactive', '-u', self.user ] args += super(Sudo, self).cmd_args() return args def close(self): """Close the tunnel. Here we override the base class implementation which tries to kill the tunnel if it does not shut down in a timely fashion, because we cannot kill a root process. """ self.wpipe.close() self.proc.wait()
[docs]class Docker(SubprocessTunnel): """A tunnel connected to a throwaway Docker container. :param name: The name of the Docker instance to create. :param image: The Docker image to launch. By default, download and run an `official Docker Python image`__ corresponding to the running Python version. `Official images are curated by Docker`__. :param rm: If true, destroy the container when the tunnel is closed. .. __: .. __: """ #: For the standard Python docker images, Python is not installed as #: /usr/bin/python[23] python2 = 'python' python3 = 'python3' pyver = '{0}.{1}'.format(*sys.version_info) def __init__(self, name, image='python:' + pyver, rm=True): = name self.image = image self.rm = rm super(Docker, self).__init__() def cmd_args(self): base = super(Docker, self).cmd_args() args = [] if self.rm: args.append('--rm') return [ 'docker', 'run', '-i', '--name',, ] + args + [self.image] + base
[docs]class SSHTunnel(SubprocessTunnel): """A tunnel that connects to a remote host over SSH. :param host: The hostname to connect to, as would be specified on an ``ssh`` command line. :param user: The username to connect as. :param sudo: If true, use ``sudo`` on the remote end in order to run as the ``root`` user. Use this when you can ``sudo`` to root but not ``ssh`` directly as the root user. """ def __init__(self, host, user=None, sudo=False): = host self.user = user self.sudo = sudo super(SubprocessTunnel, self).__init__() def cmd_args(self): args = ['ssh', '-o', 'PasswordAuthentication=no'] if self.user: args.extend(['-l', self.user]) args.append( if self.sudo: args.append('sudo') args.extend(super(SSHTunnel, self).cmd_args()) return ['"%s"' % w if ' ' in w else w for w in args]
# An alias, because this is the default tunnel type Tunnel = SSHTunnel