Queer European MD passionate about IT

server.py 13 KB


  1. """Server class.
  2. May be a local server or a publicly reachable server.
  3. Arguments
  4. - host: localhost, IPv4 address or domain (e.g. www.example.com)
  5. - port: port to reach (must be enabled)
  6. - certificate [optional]: server certificate for SSL
  7. - key [optional]: needed only for standalone clients
  8. """
  9. import argparse
  10. import asyncio
  11. import collections
  12. import logging
  13. import os
  14. import ssl
  15. from typing import Union
  16. class Server:
  17. def __init__(self, host='localhost', port=5000, ssl_context=None,
  18. buffer_chunk_size=10 ** 4, buffer_length_limit=10 ** 4):
  19. self._host = host
  20. self._port = port
  21. self._ssl_context = ssl_context
  22. self.connections = collections.OrderedDict()
  23. # Dict of queues of bytes
  24. self.buffers = collections.OrderedDict()
  25. # How many bytes per chunk
  26. self._buffer_chunk_size = buffer_chunk_size
  27. # How many chunks in buffer
  28. self._buffer_length_limit = buffer_length_limit
  29. self._working = False
  30. self._server = None
  31. self._ssl_context = None
  32. @property
  33. def host(self) -> str:
  34. return self._host
  35. @property
  36. def port(self) -> int:
  37. return self._port
  38. @property
  39. def buffer_length_limit(self) -> int:
  40. return self._buffer_length_limit
  41. @property
  42. def buffer_chunk_size(self) -> int:
  43. return self._buffer_chunk_size
  44. @property
  45. def working(self) -> bool:
  46. return self._working
  47. @property
  48. def server(self) -> asyncio.base_events.Server:
  49. return self._server
  50. @property
  51. def ssl_context(self) -> ssl.SSLContext:
  52. return self._ssl_context
  53. @property
  54. def buffer_is_full(self):
  55. return (
  56. sum(len(buffer)
  57. for buffer in self.buffers.values())
  58. >= self.buffer_length_limit
  59. )
  60. def set_ssl_context(self, ssl_context: ssl.SSLContext):
  61. self._ssl_context = ssl_context
  62. async def run_reader(self, reader, connection_token):
  63. while 1:
  64. try:
  65. # Wait one second if buffer is full
  66. while self.buffer_is_full:
  67. await asyncio.sleep(1)
  68. continue
  69. input_data = await reader.read(self.buffer_chunk_size)
  70. if connection_token not in self.buffers:
  71. break
  72. self.buffers[connection_token].append(input_data)
  73. except ConnectionResetError as e:
  74. logging.error(e)
  75. break
  76. except Exception as e:
  77. logging.error(f"Unexpected exception:\n{e}", exc_info=True)
  78. async def run_writer(self, writer, connection_token):
  79. consecutive_interruptions = 0
  80. errors = 0
  81. while connection_token in self.buffers:
  82. try:
  83. input_data = self.buffers[connection_token].popleft()
  84. except IndexError:
  85. # Slow down if buffer is empty; after 1.5 s of silence, break
  86. consecutive_interruptions += 1
  87. if consecutive_interruptions > 3:
  88. break
  89. await asyncio.sleep(.5)
  90. continue
  91. else:
  92. consecutive_interruptions = 0
  93. if not input_data:
  94. break
  95. try:
  96. writer.write(input_data)
  97. await writer.drain()
  98. except ConnectionResetError as e:
  99. logging.error(e)
  100. break
  101. except Exception as e:
  102. logging.error(e, exc_info=True)
  103. errors += 1
  104. if errors > 3:
  105. break
  106. await asyncio.sleep(0.5)
  107. writer.close()
  108. async def connect(self,
  109. reader: asyncio.StreamReader,
  110. writer: asyncio.StreamWriter):
  111. """Connect with client.
  112. Decide whether client is sender or receiver and start transmission.
  113. """
  114. client_hello = await reader.readline()
  115. client_hello = client_hello.decode('utf-8').strip('\n').split('|')
  116. if len(client_hello) != 4:
  117. await self.refuse_connection(writer=writer,
  118. message="Invalid client_hello!")
  119. return
  120. connection_token = client_hello[1]
  121. if connection_token not in self.connections:
  122. self.connections[connection_token] = dict(
  123. sender=False,
  124. receiver=False
  125. )
  126. async def _write(message: Union[list, str, bytes],
  127. terminate_line=True) -> int:
  128. # Adapt
  129. if type(message) is list:
  130. message = '|'.join(map(str, message))
  131. if type(message) is str:
  132. if terminate_line:
  133. message += '\n'
  134. message = message.encode('utf-8')
  135. if type(message) is not bytes:
  136. return 1
  137. try:
  138. writer.write(message)
  139. await writer.drain()
  140. except ConnectionResetError:
  141. logging.error("Client disconnected.")
  142. except Exception as e:
  143. logging.error(f"Unexpected exception:\n{e}", exc_info=True)
  144. else:
  145. return 0 # On success, return 0
  146. # On exception, disconnect and return 1
  147. self.disconnect(connection_token=connection_token)
  148. return 1
  149. if client_hello[0] == 's': # Sender client connection
  150. if self.connections[connection_token]['sender']:
  151. await self.refuse_connection(
  152. writer=writer,
  153. message="Invalid token! "
  154. "A sender client is already connected!\n"
  155. )
  156. return
  157. self.connections[connection_token]['sender'] = True
  158. self.connections[connection_token]['file_name'] = client_hello[2]
  159. self.connections[connection_token]['file_size'] = client_hello[3]
  160. self.buffers[connection_token] = collections.deque()
  161. logging.info("Sender is connecting...")
  162. index, step = 0, 1
  163. while not self.connections[connection_token]['receiver']:
  164. index += 1
  165. if index >= step:
  166. if await _write("Waiting for receiver..."):
  167. return
  168. step += 1
  169. index = 0
  170. await asyncio.sleep(.5)
  171. # Send start signal to client
  172. if await _write("start!"):
  173. return
  174. logging.info("Incoming transmission starting...")
  175. await self.run_reader(reader=reader,
  176. connection_token=connection_token)
  177. logging.info("Incoming transmission ended")
  178. elif client_hello[0] == 'r': # Receiver client connection
  179. if self.connections[connection_token]['receiver']:
  180. await self.refuse_connection(
  181. writer=writer,
  182. message="Invalid token! "
  183. "A receiver client is already connected!\n"
  184. )
  185. return
  186. self.connections[connection_token]['receiver'] = True
  187. logging.info("Receiver is connecting...")
  188. index, step = 0, 1
  189. while not self.connections[connection_token]['sender']:
  190. index += 1
  191. if index >= step:
  192. if await _write("Waiting for sender..."):
  193. return
  194. step += 1
  195. index = 0
  196. await asyncio.sleep(.5)
  197. # Send file information and start signal to client
  198. if await _write(
  199. ['s',
  200. 'hidden_token',
  201. self.connections[connection_token]['file_name'],
  202. self.connections[connection_token]['file_size']]
  203. ):
  204. return
  205. if await _write("start!"):
  206. return
  207. await self.run_writer(writer=writer,
  208. connection_token=connection_token)
  209. logging.info("Outgoing transmission ended")
  210. self.disconnect(connection_token=connection_token)
  211. else:
  212. await self.refuse_connection(writer=writer,
  213. message="Invalid client_hello!")
  214. return
  215. def disconnect(self, connection_token: str) -> None:
  216. del self.buffers[connection_token]
  217. del self.connections[connection_token]
  218. def run(self):
  219. loop = asyncio.get_event_loop()
  220. logging.info("Starting file bridging server...")
  221. try:
  222. loop.run_until_complete(self.run_server())
  223. except KeyboardInterrupt:
  224. print()
  225. logging.info("Stopping...")
  226. # Cancel connection tasks (they should be done but are pending)
  227. for task in asyncio.all_tasks(loop):
  228. task.cancel()
  229. loop.run_until_complete(
  230. self.server.wait_closed()
  231. )
  232. loop.close()
  233. logging.info("Stopped.")
  234. async def run_server(self):
  235. self._server = await asyncio.start_server(
  236. ssl=self.ssl_context,
  237. client_connected_cb=self.connect,
  238. host=self.host,
  239. port=self.port,
  240. )
  241. async with self.server:
  242. logging.info("Running at `{s.host}:{s.port}`".format(s=self))
  243. await self.server.serve_forever()
  244. @staticmethod
  245. async def refuse_connection(writer: asyncio.StreamWriter,
  246. message: str = None):
  247. """Send a `message` via writer and close it."""
  248. if message is None:
  249. message = "Connection refused!\n"
  250. writer.write(
  251. message.encode('utf-8')
  252. )
  253. await writer.drain()
  254. writer.close()
  255. def main():
  256. # noinspection SpellCheckingInspection
  257. log_formatter = logging.Formatter(
  258. "%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s",
  259. style='%'
  260. )
  261. root_logger = logging.getLogger()
  262. root_logger.setLevel(logging.DEBUG)
  263. # noinspection PyUnresolvedReferences
  264. asyncio.selector_events.logger.setLevel(logging.ERROR)
  265. console_handler = logging.StreamHandler()
  266. console_handler.setFormatter(log_formatter)
  267. console_handler.setLevel(logging.DEBUG)
  268. root_logger.addHandler(console_handler)
  269. # Parse command-line arguments
  270. cli_parser = argparse.ArgumentParser(description='Run server',
  271. allow_abbrev=False)
  272. cli_parser.add_argument('--host', type=str,
  273. default=None,
  274. required=False,
  275. help='server address')
  276. cli_parser.add_argument('--port', type=int,
  277. default=None,
  278. required=False,
  279. help='server port')
  280. cli_parser.add_argument('--certificate', type=str,
  281. default=None,
  282. required=False,
  283. help='server SSL certificate')
  284. cli_parser.add_argument('--key', type=str,
  285. default=None,
  286. required=False,
  287. help='server SSL key')
  288. args = vars(cli_parser.parse_args())
  289. host = args['host']
  290. port = args['port']
  291. certificate = args['certificate']
  292. key = args['key']
  293. # If host and port are not provided from command-line, try to import them
  294. if host is None:
  295. try:
  296. from config import host
  297. except ImportError:
  298. host = None
  299. if port is None:
  300. try:
  301. from config import port
  302. except ImportError:
  303. port = None
  304. # If import fails, prompt user for host or port
  305. while host is None:
  306. host = input("Enter host:\t\t\t\t\t\t")
  307. while port is None:
  308. try:
  309. port = int(input("Enter port:\t\t\t\t\t\t"))
  310. except ValueError:
  311. logging.info("Invalid port. Enter a valid port number!")
  312. port = None
  313. try:
  314. if certificate is None or not os.path.isfile(certificate):
  315. from config import certificate
  316. if key is None or not os.path.isfile(key):
  317. from config import key
  318. if not os.path.isfile(certificate):
  319. certificate = None
  320. if not os.path.isfile(key):
  321. key = None
  322. except ImportError:
  323. certificate = None
  324. key = None
  325. ssl_context = None
  326. if certificate and key:
  327. ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
  328. ssl_context.load_cert_chain(certificate, key)
  329. else:
  330. logging.warning(
  331. "Please consider using SSL. To do so, add in `config.py` or "
  332. "provide via Command Line Interface the path to a valid SSL "
  333. "key and certificate. Example:\n\n"
  334. "key = 'path/to/secret.key'\n"
  335. "certificate = 'path/to/certificate.crt'"
  336. )
  337. server = Server(
  338. host=host,
  339. port=port,
  340. ssl_context=ssl_context
  341. )
  342. server.run()
  343. if __name__ == '__main__':
  344. main()