| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473 |
- """
- The RPyC protocol
- """
- import sys
- import select
- import weakref
- import itertools
- import cPickle as pickle
- from threading import Lock
- from rpyc.lib.colls import WeakValueDict, RefCountingColl
- from rpyc.core import consts, brine, vinegar, netref
- from rpyc.core.async import AsyncResult
- class PingError(Exception):
- pass
- DEFAULT_CONFIG = dict(
- # ATTRIBUTES
- allow_safe_attrs = True,
- allow_exposed_attrs = True,
- allow_public_attrs = False,
- allow_all_attrs = False,
- safe_attrs = set(['__abs__', '__add__', '__and__', '__cmp__', '__contains__',
- '__delitem__', '__delslice__', '__div__', '__divmod__', '__doc__',
- '__eq__', '__float__', '__floordiv__', '__ge__', '__getitem__',
- '__getslice__', '__gt__', '__hash__', '__hex__', '__iadd__', '__iand__',
- '__idiv__', '__ifloordiv__', '__ilshift__', '__imod__', '__imul__',
- '__index__', '__int__', '__invert__', '__ior__', '__ipow__', '__irshift__',
- '__isub__', '__iter__', '__itruediv__', '__ixor__', '__le__', '__len__',
- '__long__', '__lshift__', '__lt__', '__mod__', '__mul__', '__ne__',
- '__neg__', '__new__', '__nonzero__', '__oct__', '__or__', '__pos__',
- '__pow__', '__radd__', '__rand__', '__rdiv__', '__rdivmod__', '__repr__',
- '__rfloordiv__', '__rlshift__', '__rmod__', '__rmul__', '__ror__',
- '__rpow__', '__rrshift__', '__rshift__', '__rsub__', '__rtruediv__',
- '__rxor__', '__setitem__', '__setslice__', '__str__', '__sub__',
- '__truediv__', '__xor__', 'next', '__length_hint__', '__enter__',
- '__exit__', ]),
- exposed_prefix = "exposed_",
- allow_getattr = True,
- allow_setattr = False,
- allow_delattr = False,
- # EXCEPTIONS
- include_local_traceback = True,
- instantiate_custom_exceptions = False,
- import_custom_exceptions = False,
- instantiate_oldstyle_exceptions = False, # which don't derive from Exception
- propagate_SystemExit_locally = False, # whether to propagate SystemExit locally or to the other party
- # MISC
- allow_pickle = False,
- connid = None,
- credentials = None,
- )
- _connection_id_generator = itertools.count(1)
- class Connection(object):
- """The RPyC connection (also know as the RPyC protocol).
- * service: the service to expose
- * channel: the channcel over which messages are passed
- * config: this connection's config dict (overriding parameters from the
- default config dict)
- * _lazy: whether or not to initialize the service with the creation of the
- connection. default is True. if set to False, you will need to call
- _init_service manually later
- """
- def __init__(self, service, channel, config = {}, _lazy = False):
- self._closed = True
- self._config = DEFAULT_CONFIG.copy()
- self._config.update(config)
- if self._config["connid"] is None:
- self._config["connid"] = "conn%d" % (_connection_id_generator.next(),)
-
- self._channel = channel
- self._seqcounter = itertools.count()
- self._recvlock = Lock()
- self._sendlock = Lock()
- self._sync_replies = {}
- self._async_callbacks = {}
- self._local_objects = RefCountingColl()
- self._last_traceback = None
- self._proxy_cache = WeakValueDict()
- self._netref_classes_cache = {}
- self._remote_root = None
- self._local_root = service(weakref.proxy(self))
- if not _lazy:
- self._init_service()
- self._closed = False
- def _init_service(self):
- self._local_root.on_connect()
-
- def __del__(self):
- self.close()
- def __enter__(self):
- return self
- def __exit__(self, t, v, tb):
- self.close()
- def __repr__(self):
- a, b = object.__repr__(self).split(" object ")
- return "%s %r object %s" % (a, self._config["connid"], b)
-
- #
- # IO
- #
- def _cleanup(self, _anyway = True):
- if self._closed and not _anyway:
- return
- self._closed = True
- self._channel.close()
- self._local_root.on_disconnect()
- self._sync_replies.clear()
- self._async_callbacks.clear()
- self._local_objects.clear()
- self._proxy_cache.clear()
- self._netref_classes_cache.clear()
- self._last_traceback = None
- self._last_traceback = None
- self._remote_root = None
- self._local_root = None
- #self._seqcounter = None
- #self._config.clear()
- def close(self, _catchall = True):
- if self._closed:
- return
- self._closed = True
- try:
- try:
- self._async_request(consts.HANDLE_CLOSE)
- except EOFError:
- pass
- except Exception:
- if not _catchall:
- raise
- finally:
- self._cleanup(_anyway = True)
-
- @property
- def closed(self):
- return self._closed
- def fileno(self):
- return self._channel.fileno()
-
- def ping(self, data = "the world is a vampire!" * 20, timeout = 3):
- """assert that the other party is functioning properly"""
- res = self.async_request(consts.HANDLE_PING, data, timeout = timeout)
- if res.value != data:
- raise PingError("echo mismatches sent data")
-
- def _send(self, msg, seq, args):
- data = brine.dump((msg, seq, args))
- self._sendlock.acquire()
- try:
- self._channel.send(data)
- finally:
- self._sendlock.release()
- def _send_request(self, handler, args):
- seq = self._seqcounter.next()
- self._send(consts.MSG_REQUEST, seq, (handler, self._box(args)))
- return seq
- def _send_reply(self, seq, obj):
- self._send(consts.MSG_REPLY, seq, self._box(obj))
- def _send_exception(self, seq, exctype, excval, exctb):
- exc = vinegar.dump(exctype, excval, exctb,
- include_local_traceback = self._config["include_local_traceback"])
- self._send(consts.MSG_EXCEPTION, seq, exc)
-
- #
- # boxing
- #
- def _box(self, obj):
- """store a local object in such a way that it could be recreated on
- the remote party either by-value or by-reference"""
- if brine.dumpable(obj):
- return consts.LABEL_VALUE, obj
- if type(obj) is tuple:
- return consts.LABEL_TUPLE, tuple(self._box(item) for item in obj)
- elif isinstance(obj, netref.BaseNetref) and obj.____conn__() is self:
- return consts.LABEL_LOCAL_REF, obj.____oid__
- else:
- self._local_objects.add(obj)
- try:
- cls = obj.__class__
- except Exception:
- # see issue #16
- cls = type(obj)
- return consts.LABEL_REMOTE_REF, (id(obj), cls.__name__, cls.__module__)
-
- def _unbox(self, package):
- """recreate a local object representation of the remote object: if the
- object is passed by value, just return it; if the object is passed by
- reference, create a netref to it"""
- label, value = package
- if label == consts.LABEL_VALUE:
- return value
- if label == consts.LABEL_TUPLE:
- return tuple(self._unbox(item) for item in value)
- if label == consts.LABEL_LOCAL_REF:
- return self._local_objects[value]
- if label == consts.LABEL_REMOTE_REF:
- oid, clsname, modname = value
- if oid in self._proxy_cache:
- return self._proxy_cache[oid]
- proxy = self._netref_factory(oid, clsname, modname)
- self._proxy_cache[oid] = proxy
- return proxy
- raise ValueError("invalid label %r" % (label,))
-
- def _netref_factory(self, oid, clsname, modname):
- typeinfo = (clsname, modname)
- if typeinfo in self._netref_classes_cache:
- cls = self._netref_classes_cache[typeinfo]
- elif typeinfo in netref.builtin_classes_cache:
- cls = netref.builtin_classes_cache[typeinfo]
- else:
- info = self.sync_request(consts.HANDLE_INSPECT, oid)
- cls = netref.class_factory(clsname, modname, info)
- self._netref_classes_cache[typeinfo] = cls
- return cls(weakref.ref(self), oid)
-
- #
- # dispatching
- #
- def _dispatch_request(self, seq, raw_args):
- try:
- handler, args = raw_args
- args = self._unbox(args)
- res = self._HANDLERS[handler](self, *args)
- except KeyboardInterrupt:
- raise
- except:
- # need to catch old style exceptions too
- t, v, tb = sys.exc_info()
- self._last_traceback = tb
- if t is SystemExit and self._config["propagate_SystemExit_locally"]:
- raise
- self._send_exception(seq, t, v, tb)
- else:
- self._send_reply(seq, res)
-
- def _dispatch_reply(self, seq, raw):
- obj = self._unbox(raw)
- if seq in self._async_callbacks:
- self._async_callbacks.pop(seq)(False, obj)
- else:
- self._sync_replies[seq] = (False, obj)
-
- def _dispatch_exception(self, seq, raw):
- obj = vinegar.load(raw,
- import_custom_exceptions = self._config["import_custom_exceptions"],
- instantiate_custom_exceptions = self._config["instantiate_custom_exceptions"],
- instantiate_oldstyle_exceptions = self._config["instantiate_oldstyle_exceptions"])
- if seq in self._async_callbacks:
- self._async_callbacks.pop(seq)(True, obj)
- else:
- self._sync_replies[seq] = (True, obj)
-
- #
- # serving
- #
- def _recv(self, timeout, wait_for_lock):
- if not self._recvlock.acquire(wait_for_lock):
- return None
- try:
- try:
- if self._channel.poll(timeout):
- data = self._channel.recv()
- else:
- data = None
- except EOFError:
- self.close()
- raise
- finally:
- self._recvlock.release()
- return data
-
- def _dispatch(self, data):
- msg, seq, args = brine.load(data)
- if msg == consts.MSG_REQUEST:
- self._dispatch_request(seq, args)
- elif msg == consts.MSG_REPLY:
- self._dispatch_reply(seq, args)
- elif msg == consts.MSG_EXCEPTION:
- self._dispatch_exception(seq, args)
- else:
- raise ValueError("invalid message type: %r" % (msg,))
- def poll(self, timeout = 0):
- """serve a single transaction, should one arrives in the given
- interval. note that handling a request/reply may trigger nested
- requests, which are all part of the transaction.
-
- returns True if one was served, False otherwise"""
- data = self._recv(timeout, wait_for_lock = False)
- if not data:
- return False
- self._dispatch(data)
- return True
-
- def serve(self, timeout = 1):
- """serve a single request or reply that arrives within the given
- time frame (default is 1 sec). note that the dispatching of a request
- might trigger multiple (nested) requests, thus this function may be
- reentrant. returns True if a request or reply were received, False
- otherwise."""
-
- data = self._recv(timeout, wait_for_lock = True)
- if not data:
- return False
- self._dispatch(data)
- return True
-
- def serve_all(self):
- """serve all requests and replies while the connection is alive"""
- try:
- try:
- while True:
- self.serve(0.1)
- except select.error:
- if not self.closed:
- raise
- except EOFError:
- pass
- finally:
- self.close()
-
- def poll_all(self, timeout = 0):
- """serve all requests and replies that arrive within the given interval.
- returns True if at least one was served, False otherwise"""
- at_least_once = False
- try:
- while self.poll(timeout):
- at_least_once = True
- except EOFError:
- pass
- return at_least_once
-
- #
- # requests
- #
- def sync_request(self, handler, *args):
- """send a request and wait for the reply to arrive"""
- seq = self._send_request(handler, args)
- while seq not in self._sync_replies:
- self.serve(0.1)
- isexc, obj = self._sync_replies.pop(seq)
- if isexc:
- raise obj
- else:
- return obj
-
- def _async_request(self, handler, args = (), callback = (lambda a, b: None)):
- seq = self._send_request(handler, args)
- self._async_callbacks[seq] = callback
- def async_request(self, handler, *args, **kwargs):
- """send a request and return an AsyncResult object, which will
- eventually hold the reply"""
- timeout = kwargs.pop("timeout", None)
- if kwargs:
- raise TypeError("got unexpected keyword argument %r" % (kwargs.keys()[0],))
- res = AsyncResult(weakref.proxy(self))
- self._async_request(handler, args, res)
- if timeout is not None:
- res.set_expiry(timeout)
- return res
-
- @property
- def root(self):
- """fetch the root object of the other party"""
- if self._remote_root is None:
- self._remote_root = self.sync_request(consts.HANDLE_GETROOT)
- return self._remote_root
-
- #
- # attribute access
- #
- def _check_attr(self, obj, name):
- if self._config["allow_exposed_attrs"]:
- if name.startswith(self._config["exposed_prefix"]):
- name2 = name
- else:
- name2 = self._config["exposed_prefix"] + name
- if hasattr(obj, name2):
- return name2
- if self._config["allow_all_attrs"]:
- return name
- if self._config["allow_safe_attrs"] and name in self._config["safe_attrs"]:
- return name
- if self._config["allow_public_attrs"] and not name.startswith("_"):
- return name
- return False
-
- def _access_attr(self, oid, name, args, overrider, param, default):
- if type(name) is unicode:
- name = str(name) # IronPython issue #10
- elif type(name) is not str:
- raise TypeError("attr name must be a string")
- obj = self._local_objects[oid]
- accessor = getattr(type(obj), overrider, None)
- if accessor is None:
- name2 = self._check_attr(obj, name)
- if not self._config[param] or not name2:
- raise AttributeError("cannot access %r" % (name,))
- accessor = default
- name = name2
- return accessor(obj, name, *args)
-
- #
- # handlers
- #
- def _handle_ping(self, data):
- return data
- def _handle_close(self):
- self._cleanup()
- def _handle_getroot(self):
- return self._local_root
- def _handle_del(self, oid):
- self._local_objects.decref(oid)
- def _handle_repr(self, oid):
- return repr(self._local_objects[oid])
- def _handle_str(self, oid):
- return str(self._local_objects[oid])
- def _handle_cmp(self, oid, other):
- # cmp() might enter recursive resonance... yet another workaround
- #return cmp(self._local_objects[oid], other)
- obj = self._local_objects[oid]
- try:
- return type(obj).__cmp__(obj, other)
- except TypeError:
- return NotImplemented
- def _handle_hash(self, oid):
- return hash(self._local_objects[oid])
- def _handle_call(self, oid, args, kwargs=()):
- return self._local_objects[oid](*args, **dict(kwargs))
- def _handle_dir(self, oid):
- return tuple(dir(self._local_objects[oid]))
- def _handle_inspect(self, oid):
- return tuple(netref.inspect_methods(self._local_objects[oid]))
- def _handle_getattr(self, oid, name):
- return self._access_attr(oid, name, (), "_rpyc_getattr", "allow_getattr", getattr)
- def _handle_delattr(self, oid, name):
- return self._access_attr(oid, name, (), "_rpyc_delattr", "allow_delattr", delattr)
- def _handle_setattr(self, oid, name, value):
- return self._access_attr(oid, name, (value,), "_rpyc_setattr", "allow_setattr", setattr)
- def _handle_callattr(self, oid, name, args, kwargs):
- return self._handle_getattr(oid, name)(*args, **dict(kwargs))
- def _handle_pickle(self, oid, proto):
- if not self._config["allow_pickle"]:
- raise ValueError("pickling is disabled")
- return pickle.dumps(self._local_objects[oid], proto)
- def _handle_buffiter(self, oid, count):
- items = []
- obj = self._local_objects[oid]
- i = 0
- try:
- while i < count:
- items.append(obj.next())
- i += 1
- except StopIteration:
- pass
- return tuple(items)
-
- # collect handlers
- _HANDLERS = {}
- for name, obj in locals().items():
- if name.startswith("_handle_"):
- name2 = "HANDLE_" + name[8:].upper()
- if hasattr(consts, name2):
- _HANDLERS[getattr(consts, name2)] = obj
- else:
- raise NameError("no constant defined for %r", name)
- del name, name2, obj
|