Queer European MD passionate about IT

server.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import argparse
  2. import asyncio
  3. import collections
  4. import logging
  5. import ssl
  6. class Server:
  7. def __init__(self, host='localhost', port=5000,
  8. buffer_chunk_size=10**4, buffer_length_limit=10**4):
  9. self._host = host
  10. self._port = port
  11. self._stopping = False
  12. self.buffer = collections.deque() # Shared queue of bytes
  13. self._buffer_chunk_size = buffer_chunk_size # How many bytes per chunk
  14. self._buffer_length_limit = buffer_length_limit # How many chunks in buffer
  15. self._working = False
  16. self.at_eof = False
  17. self._server = None
  18. self._ssl_context = None
  19. @property
  20. def host(self) -> str:
  21. return self._host
  22. @property
  23. def port(self) -> int:
  24. return self._port
  25. @property
  26. def stopping(self) -> bool:
  27. return self._stopping
  28. @property
  29. def buffer_length_limit(self) -> int:
  30. return self._buffer_length_limit
  31. @property
  32. def buffer_chunk_size(self) -> int:
  33. return self._buffer_chunk_size
  34. @property
  35. def working(self) -> bool:
  36. return self._working
  37. @property
  38. def server(self) -> asyncio.base_events.Server:
  39. return self._server
  40. @property
  41. def ssl_context(self) -> ssl.SSLContext:
  42. return self._ssl_context
  43. def set_ssl_context(self, ssl_context: ssl.SSLContext):
  44. self._ssl_context = ssl_context
  45. async def run_reader(self, reader):
  46. while not self.stopping:
  47. try:
  48. # Stop if buffer is full
  49. while len(self.buffer) >= self.buffer_length_limit:
  50. await asyncio.sleep(1)
  51. continue
  52. input_data = await reader.read(self.buffer_chunk_size)
  53. if reader.at_eof():
  54. self.at_eof = True
  55. self.buffer.append(input_data)
  56. except Exception as e:
  57. logging.error(e)
  58. async def run_writer(self, writer):
  59. while not self.stopping:
  60. try:
  61. # Slow down if buffer is short
  62. if len(self.buffer) < 3:
  63. await asyncio.sleep(.1)
  64. try:
  65. input_data = self.buffer.popleft()
  66. except IndexError:
  67. if not self.at_eof:
  68. continue
  69. else:
  70. writer.write_eof()
  71. await writer.drain()
  72. self.at_eof = False
  73. break
  74. writer.write(input_data)
  75. await writer.drain()
  76. except Exception as e:
  77. logging.error(e)
  78. async def connect(self,
  79. reader: asyncio.StreamReader,
  80. writer: asyncio.StreamWriter):
  81. """Connect with client.
  82. Decide whether client is sender or receiver and start transmission.
  83. """
  84. client_hello = await reader.readline()
  85. peer_is_sender = client_hello.decode('utf-8') == 'sender\n'
  86. writer.write("Start!\n".encode('utf-8')) # Send start signal to client
  87. await writer.drain()
  88. if peer_is_sender:
  89. self._working = True
  90. logging.info("Sender is connecting...")
  91. await self.run_reader(reader=reader)
  92. logging.info("Incoming transmission ended")
  93. else:
  94. logging.info("Receiver is connecting...")
  95. await self.run_writer(writer=writer)
  96. logging.info("Outgoing transmission ended")
  97. self._working = False
  98. return
  99. def run(self):
  100. loop = asyncio.get_event_loop()
  101. logging.info("Starting file bridging server...")
  102. try:
  103. loop.run_until_complete(self.run_server())
  104. except KeyboardInterrupt:
  105. logging.info("Stopping...")
  106. # Cancel connection tasks (they should be done but are pending)
  107. for task in asyncio.all_tasks(loop):
  108. task.cancel()
  109. loop.run_until_complete(
  110. self.server.wait_closed()
  111. )
  112. loop.close()
  113. logging.info("Stopped.")
  114. async def run_server(self):
  115. self._server = await asyncio.start_server(
  116. ssl=self.ssl_context,
  117. client_connected_cb=self.connect,
  118. host=self.host,
  119. port=self.port,
  120. )
  121. async with self.server:
  122. try:
  123. await self.server.serve_forever()
  124. except KeyboardInterrupt:
  125. logging.info("Stopping...")
  126. self.server.close()
  127. await self.server.wait_closed()
  128. return
  129. def stop(self, *_):
  130. if self.working and not self.stopping:
  131. logging.info("Received interruption signal, stopping...")
  132. self._stopping = True
  133. else:
  134. raise KeyboardInterrupt("Not working yet...")
  135. if __name__ == '__main__':
  136. # noinspection SpellCheckingInspection
  137. log_formatter = logging.Formatter(
  138. "%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s",
  139. style='%'
  140. )
  141. root_logger = logging.getLogger()
  142. root_logger.setLevel(logging.DEBUG)
  143. console_handler = logging.StreamHandler()
  144. console_handler.setFormatter(log_formatter)
  145. console_handler.setLevel(logging.DEBUG)
  146. root_logger.addHandler(console_handler)
  147. # Parse command-line arguments
  148. parser = argparse.ArgumentParser(description='Run server',
  149. allow_abbrev=False)
  150. parser.add_argument('--host', type=str,
  151. default=None,
  152. required=False,
  153. help='server address')
  154. parser.add_argument('--port', type=int,
  155. default=None,
  156. required=False,
  157. help='server port')
  158. args = vars(parser.parse_args())
  159. _host = args['host']
  160. _port = args['port']
  161. # If _host and _port are not provided from command-line, try to import them
  162. if _host is None:
  163. try:
  164. from config import host as _host
  165. except ImportError:
  166. _host = None
  167. if _port is None:
  168. try:
  169. from config import port as _port
  170. except ImportError:
  171. _port = None
  172. # If import fails, prompt user for _host or _port
  173. while _host is None:
  174. _host = input("Enter host:\t\t\t\t\t\t")
  175. while _port is None:
  176. try:
  177. _port = int(input("Enter port:\t\t\t\t\t\t"))
  178. except ValueError:
  179. logging.info("Invalid port. Enter a valid port number!")
  180. _port = None
  181. server = Server(
  182. host=_host,
  183. port=_port,
  184. )
  185. try:
  186. from config import certificate, key
  187. _ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
  188. _ssl_context.check_hostname = False
  189. _ssl_context.load_cert_chain(certificate, key)
  190. server.set_ssl_context(_ssl_context)
  191. except ImportError:
  192. logging.info("Please consider using SSL.")
  193. certificate, key = None, None
  194. server.run()