stream.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. """
  2. abstraction layer over OS-depenedent byte streams
  3. """
  4. import sys
  5. import os
  6. import socket
  7. import time
  8. import errno
  9. from rpyc.lib import safe_import
  10. from rpyc.lib.compat import select
  11. win32file = safe_import("win32file")
  12. win32pipe = safe_import("win32pipe")
  13. msvcrt = safe_import("msvcrt")
  14. ssl = safe_import("ssl")
  15. tlsapi = safe_import("tlslite.api")
  16. retry_errnos = set([errno.EAGAIN])
  17. if hasattr(errno, "WSAEWOULDBLOCK"):
  18. retry_errnos.add(errno.WSAEWOULDBLOCK)
  19. class Stream(object):
  20. __slots__ = ()
  21. def close(self):
  22. raise NotImplementedError()
  23. @property
  24. def closed(self):
  25. raise NotImplementedError()
  26. def fileno(self):
  27. raise NotImplementedError()
  28. def poll(self, timeout):
  29. """indicate whether the stream has data to read"""
  30. rl, _, _ = select([self], [], [], timeout)
  31. return bool(rl)
  32. def read(self, count):
  33. """read exactly `count` bytes, or raise EOFError"""
  34. raise NotImplementedError()
  35. def write(self, data):
  36. """write the entire `data`, or raise EOFError"""
  37. raise NotImplementedError()
  38. class ClosedFile(object):
  39. """represents a closed file object (singleton)"""
  40. __slots__ = ()
  41. def __getattr__(self, name):
  42. raise EOFError("stream has been closed")
  43. def close(self):
  44. pass
  45. @property
  46. def closed(self):
  47. return True
  48. def fileno(self):
  49. raise EOFError("stream has been closed")
  50. ClosedFile = ClosedFile()
  51. class SocketStream(Stream):
  52. __slots__ = ("sock",)
  53. MAX_IO_CHUNK = 8000
  54. def __init__(self, sock):
  55. self.sock = sock
  56. @classmethod
  57. def _connect(cls, host, port, family = socket.AF_INET, socktype = socket.SOCK_STREAM,
  58. proto = 0, timeout = 3, nodelay = False):
  59. s = socket.socket(family, socktype, proto)
  60. s.settimeout(timeout)
  61. s.connect((host, port))
  62. if nodelay:
  63. s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  64. return s
  65. @classmethod
  66. def connect(cls, host, port, **kwargs):
  67. return cls(cls._connect(host, port, **kwargs))
  68. @classmethod
  69. def tlslite_connect(cls, host, port, username, password, **kwargs):
  70. s = cls._connect(host, port, **kwargs)
  71. s2 = tlsapi.TLSConnection(s)
  72. s2.fileno = lambda fd = s.fileno(): fd
  73. s2.handshakeClientSRP(username, password)
  74. return cls(s2)
  75. @classmethod
  76. def ssl_connect(cls, host, port, ssl_kwargs, **kwargs):
  77. s = cls._connect(host, port, **kwargs)
  78. s2 = ssl.wrap_socket(s, **ssl_kwargs)
  79. return cls(s2)
  80. @property
  81. def closed(self):
  82. return self.sock is ClosedFile
  83. def close(self):
  84. if not self.closed:
  85. try:
  86. self.sock.shutdown(socket.SHUT_RDWR)
  87. except Exception:
  88. pass
  89. self.sock.close()
  90. self.sock = ClosedFile
  91. def fileno(self):
  92. return self.sock.fileno()
  93. def read(self, count):
  94. data = []
  95. while count > 0:
  96. try:
  97. buf = self.sock.recv(min(self.MAX_IO_CHUNK, count))
  98. except socket.timeout:
  99. continue
  100. except socket.error:
  101. ex = sys.exc_info()[1]
  102. if ex[0] in retry_errnos:
  103. # windows just has to be a bitch
  104. continue
  105. self.close()
  106. raise EOFError(ex)
  107. if not buf:
  108. self.close()
  109. raise EOFError("connection closed by peer")
  110. data.append(buf)
  111. count -= len(buf)
  112. return "".join(data)
  113. def write(self, data):
  114. try:
  115. while data:
  116. count = self.sock.send(data[:self.MAX_IO_CHUNK])
  117. data = data[count:]
  118. except socket.error:
  119. ex = sys.exc_info()[1]
  120. self.close()
  121. raise EOFError(ex)
  122. class PipeStream(Stream):
  123. __slots__ = ("incoming", "outgoing")
  124. MAX_IO_CHUNK = 32000
  125. def __init__(self, incoming, outgoing):
  126. outgoing.flush()
  127. self.incoming = incoming
  128. self.outgoing = outgoing
  129. @classmethod
  130. def from_std(cls):
  131. return cls(sys.stdin, sys.stdout)
  132. @classmethod
  133. def create_pair(cls):
  134. r1, w1 = os.pipe()
  135. r2, w2 = os.pipe()
  136. side1 = cls(os.fdopen(r1, "rb"), os.fdopen(w2, "wb"))
  137. side2 = cls(os.fdopen(r2, "rb"), os.fdopen(w1, "wb"))
  138. return side1, side2
  139. @property
  140. def closed(self):
  141. return self.incoming is ClosedFile
  142. def close(self):
  143. self.incoming.close()
  144. self.outgoing.close()
  145. self.incoming = ClosedFile
  146. self.outgoing = ClosedFile
  147. def fileno(self):
  148. return self.incoming.fileno()
  149. def read(self, count):
  150. data = []
  151. try:
  152. while count > 0:
  153. buf = os.read(self.incoming.fileno(), min(self.MAX_IO_CHUNK, count))
  154. if not buf:
  155. raise EOFError("connection closed by peer")
  156. data.append(buf)
  157. count -= len(buf)
  158. except EOFError:
  159. self.close()
  160. raise
  161. except EnvironmentError:
  162. ex = sys.exc_info()[1]
  163. self.close()
  164. raise EOFError(ex)
  165. return "".join(data)
  166. def write(self, data):
  167. try:
  168. while data:
  169. chunk = data[:self.MAX_IO_CHUNK]
  170. written = os.write(self.outgoing.fileno(), chunk)
  171. data = data[written:]
  172. except EnvironmentError:
  173. ex = sys.exc_info()[1]
  174. self.close()
  175. raise EOFError(ex)
  176. class Win32PipeStream(Stream):
  177. """win32 has to suck"""
  178. __slots__ = ("incoming", "outgoing", "_fileno", "_keepalive")
  179. PIPE_BUFFER_SIZE = 130000
  180. MAX_IO_CHUNK = 32000
  181. def __init__(self, incoming, outgoing):
  182. self._keepalive = (incoming, outgoing)
  183. if hasattr(incoming, "fileno"):
  184. self._fileno = incoming.fileno()
  185. incoming = msvcrt.get_osfhandle(incoming.fileno())
  186. if hasattr(outgoing, "fileno"):
  187. outgoing = msvcrt.get_osfhandle(outgoing.fileno())
  188. self.incoming = incoming
  189. self.outgoing = outgoing
  190. @classmethod
  191. def from_std(cls):
  192. return cls(sys.stdin, sys.stdout)
  193. @classmethod
  194. def create_pair(cls):
  195. r1, w1 = win32pipe.CreatePipe(None, cls.PIPE_BUFFER_SIZE)
  196. r2, w2 = win32pipe.CreatePipe(None, cls.PIPE_BUFFER_SIZE)
  197. return cls(r1, w2), cls(r2, w1)
  198. def fileno(self):
  199. return self._fileno
  200. @property
  201. def closed(self):
  202. return self.incoming is ClosedFile
  203. def close(self):
  204. if self.closed:
  205. return
  206. try:
  207. win32file.CloseHandle(self.incoming)
  208. except Exception:
  209. pass
  210. self.incoming = ClosedFile
  211. try:
  212. win32file.CloseHandle(self.outgoing)
  213. except Exception:
  214. pass
  215. self.outgoing = ClosedFile
  216. def read(self, count):
  217. try:
  218. data = []
  219. while count > 0:
  220. dummy, buf = win32file.ReadFile(self.incoming, int(min(self.MAX_IO_CHUNK, count)))
  221. count -= len(buf)
  222. data.append(buf)
  223. except TypeError:
  224. ex = sys.exc_info()[1]
  225. if not self.closed:
  226. raise
  227. raise EOFError(ex)
  228. except win32file.error:
  229. ex = sys.exc_info()[1]
  230. self.close()
  231. raise EOFError(ex)
  232. return "".join(data)
  233. def write(self, data):
  234. try:
  235. while data:
  236. dummy, count = win32file.WriteFile(self.outgoing, data[:self.MAX_IO_CHUNK])
  237. data = data[count:]
  238. except TypeError:
  239. ex = sys.exc_info()[1]
  240. if not self.closed:
  241. raise
  242. raise EOFError(ex)
  243. except win32file.error:
  244. ex = sys.exc_info()[1]
  245. self.close()
  246. raise EOFError(ex)
  247. def poll(self, timeout, interval = 0.1):
  248. """a poor man's version of select()"""
  249. if timeout is None:
  250. timeout = sys.maxint
  251. length = 0
  252. tmax = time.time() + timeout
  253. try:
  254. while length == 0:
  255. length = win32pipe.PeekNamedPipe(self.incoming, 0)[1]
  256. if time.time() >= tmax:
  257. break
  258. time.sleep(interval)
  259. except TypeError:
  260. ex = sys.exc_info()[1]
  261. if not self.closed:
  262. raise
  263. raise EOFError(ex)
  264. return length != 0
  265. class NamedPipeStream(Win32PipeStream):
  266. NAMED_PIPE_PREFIX = r'\\.\pipe\rpyc_'
  267. PIPE_IO_TIMEOUT = 3
  268. CONNECT_TIMEOUT = 3
  269. __slots__ = ("is_server_side",)
  270. def __init__(self, handle, is_server_side):
  271. Win32PipeStream.__init__(self, handle, handle)
  272. self.is_server_side = is_server_side
  273. @classmethod
  274. def from_std(cls):
  275. raise NotImplementedError()
  276. @classmethod
  277. def create_pair(cls):
  278. raise NotImplementedError()
  279. @classmethod
  280. def create_server(cls, pipename, connect = True):
  281. if not pipename.startswith("\\\\."):
  282. pipename = cls.NAMED_PIPE_PREFIX + pipename
  283. handle = win32pipe.CreateNamedPipe(
  284. pipename,
  285. win32pipe.PIPE_ACCESS_DUPLEX,
  286. win32pipe.PIPE_TYPE_BYTE | win32pipe.PIPE_READMODE_BYTE | win32pipe.PIPE_WAIT,
  287. 1,
  288. cls.PIPE_BUFFER_SIZE,
  289. cls.PIPE_BUFFER_SIZE,
  290. cls.PIPE_IO_TIMEOUT * 1000,
  291. None
  292. )
  293. inst = cls(handle, True)
  294. if connect:
  295. inst.connect_server()
  296. return inst
  297. def connect_server(self):
  298. if not self.is_server_side:
  299. raise ValueError("this must be the server side")
  300. win32pipe.ConnectNamedPipe(self.incoming, None)
  301. @classmethod
  302. def create_client(cls, pipename, timeout = CONNECT_TIMEOUT):
  303. if not pipename.startswith("\\\\."):
  304. pipename = cls.NAMED_PIPE_PREFIX + pipename
  305. handle = win32file.CreateFile(
  306. pipename,
  307. win32file.GENERIC_READ | win32file.GENERIC_WRITE,
  308. 0,
  309. None,
  310. win32file.OPEN_EXISTING,
  311. 0,
  312. None
  313. )
  314. return cls(handle, False)
  315. def close(self):
  316. if self.closed:
  317. return
  318. if self.is_server_side:
  319. win32file.FlushFileBuffers(self.outgoing)
  320. win32pipe.DisconnectNamedPipe(self.outgoing)
  321. Win32PipeStream.close(self)
  322. if sys.platform == "win32":
  323. PipeStream = Win32PipeStream