protocol.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. """
  2. The RPyC protocol
  3. """
  4. import sys
  5. import select
  6. import weakref
  7. import itertools
  8. import cPickle as pickle
  9. from threading import Lock
  10. from rpyc.lib.colls import WeakValueDict, RefCountingColl
  11. from rpyc.core import consts, brine, vinegar, netref
  12. from rpyc.core.async import AsyncResult
  13. class PingError(Exception):
  14. pass
  15. DEFAULT_CONFIG = dict(
  16. # ATTRIBUTES
  17. allow_safe_attrs = True,
  18. allow_exposed_attrs = True,
  19. allow_public_attrs = False,
  20. allow_all_attrs = False,
  21. safe_attrs = set(['__abs__', '__add__', '__and__', '__cmp__', '__contains__',
  22. '__delitem__', '__delslice__', '__div__', '__divmod__', '__doc__',
  23. '__eq__', '__float__', '__floordiv__', '__ge__', '__getitem__',
  24. '__getslice__', '__gt__', '__hash__', '__hex__', '__iadd__', '__iand__',
  25. '__idiv__', '__ifloordiv__', '__ilshift__', '__imod__', '__imul__',
  26. '__index__', '__int__', '__invert__', '__ior__', '__ipow__', '__irshift__',
  27. '__isub__', '__iter__', '__itruediv__', '__ixor__', '__le__', '__len__',
  28. '__long__', '__lshift__', '__lt__', '__mod__', '__mul__', '__ne__',
  29. '__neg__', '__new__', '__nonzero__', '__oct__', '__or__', '__pos__',
  30. '__pow__', '__radd__', '__rand__', '__rdiv__', '__rdivmod__', '__repr__',
  31. '__rfloordiv__', '__rlshift__', '__rmod__', '__rmul__', '__ror__',
  32. '__rpow__', '__rrshift__', '__rshift__', '__rsub__', '__rtruediv__',
  33. '__rxor__', '__setitem__', '__setslice__', '__str__', '__sub__',
  34. '__truediv__', '__xor__', 'next', '__length_hint__', '__enter__',
  35. '__exit__', ]),
  36. exposed_prefix = "exposed_",
  37. allow_getattr = True,
  38. allow_setattr = False,
  39. allow_delattr = False,
  40. # EXCEPTIONS
  41. include_local_traceback = True,
  42. instantiate_custom_exceptions = False,
  43. import_custom_exceptions = False,
  44. instantiate_oldstyle_exceptions = False, # which don't derive from Exception
  45. propagate_SystemExit_locally = False, # whether to propagate SystemExit locally or to the other party
  46. # MISC
  47. allow_pickle = False,
  48. connid = None,
  49. credentials = None,
  50. )
  51. _connection_id_generator = itertools.count(1)
  52. class Connection(object):
  53. """The RPyC connection (also know as the RPyC protocol).
  54. * service: the service to expose
  55. * channel: the channcel over which messages are passed
  56. * config: this connection's config dict (overriding parameters from the
  57. default config dict)
  58. * _lazy: whether or not to initialize the service with the creation of the
  59. connection. default is True. if set to False, you will need to call
  60. _init_service manually later
  61. """
  62. def __init__(self, service, channel, config = {}, _lazy = False):
  63. self._closed = True
  64. self._config = DEFAULT_CONFIG.copy()
  65. self._config.update(config)
  66. if self._config["connid"] is None:
  67. self._config["connid"] = "conn%d" % (_connection_id_generator.next(),)
  68. self._channel = channel
  69. self._seqcounter = itertools.count()
  70. self._recvlock = Lock()
  71. self._sendlock = Lock()
  72. self._sync_replies = {}
  73. self._async_callbacks = {}
  74. self._local_objects = RefCountingColl()
  75. self._last_traceback = None
  76. self._proxy_cache = WeakValueDict()
  77. self._netref_classes_cache = {}
  78. self._remote_root = None
  79. self._local_root = service(weakref.proxy(self))
  80. if not _lazy:
  81. self._init_service()
  82. self._closed = False
  83. def _init_service(self):
  84. self._local_root.on_connect()
  85. def __del__(self):
  86. self.close()
  87. def __enter__(self):
  88. return self
  89. def __exit__(self, t, v, tb):
  90. self.close()
  91. def __repr__(self):
  92. a, b = object.__repr__(self).split(" object ")
  93. return "%s %r object %s" % (a, self._config["connid"], b)
  94. #
  95. # IO
  96. #
  97. def _cleanup(self, _anyway = True):
  98. if self._closed and not _anyway:
  99. return
  100. self._closed = True
  101. self._channel.close()
  102. self._local_root.on_disconnect()
  103. self._sync_replies.clear()
  104. self._async_callbacks.clear()
  105. self._local_objects.clear()
  106. self._proxy_cache.clear()
  107. self._netref_classes_cache.clear()
  108. self._last_traceback = None
  109. self._last_traceback = None
  110. self._remote_root = None
  111. self._local_root = None
  112. #self._seqcounter = None
  113. #self._config.clear()
  114. def close(self, _catchall = True):
  115. if self._closed:
  116. return
  117. self._closed = True
  118. try:
  119. try:
  120. self._async_request(consts.HANDLE_CLOSE)
  121. except EOFError:
  122. pass
  123. except Exception:
  124. if not _catchall:
  125. raise
  126. finally:
  127. self._cleanup(_anyway = True)
  128. @property
  129. def closed(self):
  130. return self._closed
  131. def fileno(self):
  132. return self._channel.fileno()
  133. def ping(self, data = "the world is a vampire!" * 20, timeout = 3):
  134. """assert that the other party is functioning properly"""
  135. res = self.async_request(consts.HANDLE_PING, data, timeout = timeout)
  136. if res.value != data:
  137. raise PingError("echo mismatches sent data")
  138. def _send(self, msg, seq, args):
  139. data = brine.dump((msg, seq, args))
  140. self._sendlock.acquire()
  141. try:
  142. self._channel.send(data)
  143. finally:
  144. self._sendlock.release()
  145. def _send_request(self, handler, args):
  146. seq = self._seqcounter.next()
  147. self._send(consts.MSG_REQUEST, seq, (handler, self._box(args)))
  148. return seq
  149. def _send_reply(self, seq, obj):
  150. self._send(consts.MSG_REPLY, seq, self._box(obj))
  151. def _send_exception(self, seq, exctype, excval, exctb):
  152. exc = vinegar.dump(exctype, excval, exctb,
  153. include_local_traceback = self._config["include_local_traceback"])
  154. self._send(consts.MSG_EXCEPTION, seq, exc)
  155. #
  156. # boxing
  157. #
  158. def _box(self, obj):
  159. """store a local object in such a way that it could be recreated on
  160. the remote party either by-value or by-reference"""
  161. if brine.dumpable(obj):
  162. return consts.LABEL_VALUE, obj
  163. if type(obj) is tuple:
  164. return consts.LABEL_TUPLE, tuple(self._box(item) for item in obj)
  165. elif isinstance(obj, netref.BaseNetref) and obj.____conn__() is self:
  166. return consts.LABEL_LOCAL_REF, obj.____oid__
  167. else:
  168. self._local_objects.add(obj)
  169. try:
  170. cls = obj.__class__
  171. except Exception:
  172. # see issue #16
  173. cls = type(obj)
  174. return consts.LABEL_REMOTE_REF, (id(obj), cls.__name__, cls.__module__)
  175. def _unbox(self, package):
  176. """recreate a local object representation of the remote object: if the
  177. object is passed by value, just return it; if the object is passed by
  178. reference, create a netref to it"""
  179. label, value = package
  180. if label == consts.LABEL_VALUE:
  181. return value
  182. if label == consts.LABEL_TUPLE:
  183. return tuple(self._unbox(item) for item in value)
  184. if label == consts.LABEL_LOCAL_REF:
  185. return self._local_objects[value]
  186. if label == consts.LABEL_REMOTE_REF:
  187. oid, clsname, modname = value
  188. if oid in self._proxy_cache:
  189. return self._proxy_cache[oid]
  190. proxy = self._netref_factory(oid, clsname, modname)
  191. self._proxy_cache[oid] = proxy
  192. return proxy
  193. raise ValueError("invalid label %r" % (label,))
  194. def _netref_factory(self, oid, clsname, modname):
  195. typeinfo = (clsname, modname)
  196. if typeinfo in self._netref_classes_cache:
  197. cls = self._netref_classes_cache[typeinfo]
  198. elif typeinfo in netref.builtin_classes_cache:
  199. cls = netref.builtin_classes_cache[typeinfo]
  200. else:
  201. info = self.sync_request(consts.HANDLE_INSPECT, oid)
  202. cls = netref.class_factory(clsname, modname, info)
  203. self._netref_classes_cache[typeinfo] = cls
  204. return cls(weakref.ref(self), oid)
  205. #
  206. # dispatching
  207. #
  208. def _dispatch_request(self, seq, raw_args):
  209. try:
  210. handler, args = raw_args
  211. args = self._unbox(args)
  212. res = self._HANDLERS[handler](self, *args)
  213. except KeyboardInterrupt:
  214. raise
  215. except:
  216. # need to catch old style exceptions too
  217. t, v, tb = sys.exc_info()
  218. self._last_traceback = tb
  219. if t is SystemExit and self._config["propagate_SystemExit_locally"]:
  220. raise
  221. self._send_exception(seq, t, v, tb)
  222. else:
  223. self._send_reply(seq, res)
  224. def _dispatch_reply(self, seq, raw):
  225. obj = self._unbox(raw)
  226. if seq in self._async_callbacks:
  227. self._async_callbacks.pop(seq)(False, obj)
  228. else:
  229. self._sync_replies[seq] = (False, obj)
  230. def _dispatch_exception(self, seq, raw):
  231. obj = vinegar.load(raw,
  232. import_custom_exceptions = self._config["import_custom_exceptions"],
  233. instantiate_custom_exceptions = self._config["instantiate_custom_exceptions"],
  234. instantiate_oldstyle_exceptions = self._config["instantiate_oldstyle_exceptions"])
  235. if seq in self._async_callbacks:
  236. self._async_callbacks.pop(seq)(True, obj)
  237. else:
  238. self._sync_replies[seq] = (True, obj)
  239. #
  240. # serving
  241. #
  242. def _recv(self, timeout, wait_for_lock):
  243. if not self._recvlock.acquire(wait_for_lock):
  244. return None
  245. try:
  246. try:
  247. if self._channel.poll(timeout):
  248. data = self._channel.recv()
  249. else:
  250. data = None
  251. except EOFError:
  252. self.close()
  253. raise
  254. finally:
  255. self._recvlock.release()
  256. return data
  257. def _dispatch(self, data):
  258. msg, seq, args = brine.load(data)
  259. if msg == consts.MSG_REQUEST:
  260. self._dispatch_request(seq, args)
  261. elif msg == consts.MSG_REPLY:
  262. self._dispatch_reply(seq, args)
  263. elif msg == consts.MSG_EXCEPTION:
  264. self._dispatch_exception(seq, args)
  265. else:
  266. raise ValueError("invalid message type: %r" % (msg,))
  267. def poll(self, timeout = 0):
  268. """serve a single transaction, should one arrives in the given
  269. interval. note that handling a request/reply may trigger nested
  270. requests, which are all part of the transaction.
  271. returns True if one was served, False otherwise"""
  272. data = self._recv(timeout, wait_for_lock = False)
  273. if not data:
  274. return False
  275. self._dispatch(data)
  276. return True
  277. def serve(self, timeout = 1):
  278. """serve a single request or reply that arrives within the given
  279. time frame (default is 1 sec). note that the dispatching of a request
  280. might trigger multiple (nested) requests, thus this function may be
  281. reentrant. returns True if a request or reply were received, False
  282. otherwise."""
  283. data = self._recv(timeout, wait_for_lock = True)
  284. if not data:
  285. return False
  286. self._dispatch(data)
  287. return True
  288. def serve_all(self):
  289. """serve all requests and replies while the connection is alive"""
  290. try:
  291. try:
  292. while True:
  293. self.serve(0.1)
  294. except select.error:
  295. if not self.closed:
  296. raise
  297. except EOFError:
  298. pass
  299. finally:
  300. self.close()
  301. def poll_all(self, timeout = 0):
  302. """serve all requests and replies that arrive within the given interval.
  303. returns True if at least one was served, False otherwise"""
  304. at_least_once = False
  305. try:
  306. while self.poll(timeout):
  307. at_least_once = True
  308. except EOFError:
  309. pass
  310. return at_least_once
  311. #
  312. # requests
  313. #
  314. def sync_request(self, handler, *args):
  315. """send a request and wait for the reply to arrive"""
  316. seq = self._send_request(handler, args)
  317. while seq not in self._sync_replies:
  318. self.serve(0.1)
  319. isexc, obj = self._sync_replies.pop(seq)
  320. if isexc:
  321. raise obj
  322. else:
  323. return obj
  324. def _async_request(self, handler, args = (), callback = (lambda a, b: None)):
  325. seq = self._send_request(handler, args)
  326. self._async_callbacks[seq] = callback
  327. def async_request(self, handler, *args, **kwargs):
  328. """send a request and return an AsyncResult object, which will
  329. eventually hold the reply"""
  330. timeout = kwargs.pop("timeout", None)
  331. if kwargs:
  332. raise TypeError("got unexpected keyword argument %r" % (kwargs.keys()[0],))
  333. res = AsyncResult(weakref.proxy(self))
  334. self._async_request(handler, args, res)
  335. if timeout is not None:
  336. res.set_expiry(timeout)
  337. return res
  338. @property
  339. def root(self):
  340. """fetch the root object of the other party"""
  341. if self._remote_root is None:
  342. self._remote_root = self.sync_request(consts.HANDLE_GETROOT)
  343. return self._remote_root
  344. #
  345. # attribute access
  346. #
  347. def _check_attr(self, obj, name):
  348. if self._config["allow_exposed_attrs"]:
  349. if name.startswith(self._config["exposed_prefix"]):
  350. name2 = name
  351. else:
  352. name2 = self._config["exposed_prefix"] + name
  353. if hasattr(obj, name2):
  354. return name2
  355. if self._config["allow_all_attrs"]:
  356. return name
  357. if self._config["allow_safe_attrs"] and name in self._config["safe_attrs"]:
  358. return name
  359. if self._config["allow_public_attrs"] and not name.startswith("_"):
  360. return name
  361. return False
  362. def _access_attr(self, oid, name, args, overrider, param, default):
  363. if type(name) is unicode:
  364. name = str(name) # IronPython issue #10
  365. elif type(name) is not str:
  366. raise TypeError("attr name must be a string")
  367. obj = self._local_objects[oid]
  368. accessor = getattr(type(obj), overrider, None)
  369. if accessor is None:
  370. name2 = self._check_attr(obj, name)
  371. if not self._config[param] or not name2:
  372. raise AttributeError("cannot access %r" % (name,))
  373. accessor = default
  374. name = name2
  375. return accessor(obj, name, *args)
  376. #
  377. # handlers
  378. #
  379. def _handle_ping(self, data):
  380. return data
  381. def _handle_close(self):
  382. self._cleanup()
  383. def _handle_getroot(self):
  384. return self._local_root
  385. def _handle_del(self, oid):
  386. self._local_objects.decref(oid)
  387. def _handle_repr(self, oid):
  388. return repr(self._local_objects[oid])
  389. def _handle_str(self, oid):
  390. return str(self._local_objects[oid])
  391. def _handle_cmp(self, oid, other):
  392. # cmp() might enter recursive resonance... yet another workaround
  393. #return cmp(self._local_objects[oid], other)
  394. obj = self._local_objects[oid]
  395. try:
  396. return type(obj).__cmp__(obj, other)
  397. except TypeError:
  398. return NotImplemented
  399. def _handle_hash(self, oid):
  400. return hash(self._local_objects[oid])
  401. def _handle_call(self, oid, args, kwargs=()):
  402. return self._local_objects[oid](*args, **dict(kwargs))
  403. def _handle_dir(self, oid):
  404. return tuple(dir(self._local_objects[oid]))
  405. def _handle_inspect(self, oid):
  406. return tuple(netref.inspect_methods(self._local_objects[oid]))
  407. def _handle_getattr(self, oid, name):
  408. return self._access_attr(oid, name, (), "_rpyc_getattr", "allow_getattr", getattr)
  409. def _handle_delattr(self, oid, name):
  410. return self._access_attr(oid, name, (), "_rpyc_delattr", "allow_delattr", delattr)
  411. def _handle_setattr(self, oid, name, value):
  412. return self._access_attr(oid, name, (value,), "_rpyc_setattr", "allow_setattr", setattr)
  413. def _handle_callattr(self, oid, name, args, kwargs):
  414. return self._handle_getattr(oid, name)(*args, **dict(kwargs))
  415. def _handle_pickle(self, oid, proto):
  416. if not self._config["allow_pickle"]:
  417. raise ValueError("pickling is disabled")
  418. return pickle.dumps(self._local_objects[oid], proto)
  419. def _handle_buffiter(self, oid, count):
  420. items = []
  421. obj = self._local_objects[oid]
  422. i = 0
  423. try:
  424. while i < count:
  425. items.append(obj.next())
  426. i += 1
  427. except StopIteration:
  428. pass
  429. return tuple(items)
  430. # collect handlers
  431. _HANDLERS = {}
  432. for name, obj in locals().items():
  433. if name.startswith("_handle_"):
  434. name2 = "HANDLE_" + name[8:].upper()
  435. if hasattr(consts, name2):
  436. _HANDLERS[getattr(consts, name2)] = obj
  437. else:
  438. raise NameError("no constant defined for %r", name)
  439. del name, name2, obj