Queer European MD passionate about IT

server.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import argparse
  2. import asyncio
  3. import collections
  4. import logging
  5. class Server:
  6. def __init__(self, host='localhost', port=5000,
  7. buffer_chunk_size=10**4, buffer_length_limit=10**4):
  8. self._host = host
  9. self._port = port
  10. self._stopping = False
  11. self.buffer = collections.deque() # Shared queue of bytes
  12. self._buffer_chunk_size = buffer_chunk_size # How many bytes per chunk
  13. self._buffer_length_limit = buffer_length_limit # How many chunks in buffer
  14. self._working = False
  15. self.at_eof = False
  16. self._server = None
  17. @property
  18. def host(self) -> str:
  19. return self._host
  20. @property
  21. def port(self) -> int:
  22. return self._port
  23. @property
  24. def stopping(self) -> bool:
  25. return self._stopping
  26. @property
  27. def buffer_length_limit(self) -> int:
  28. return self._buffer_length_limit
  29. @property
  30. def buffer_chunk_size(self) -> int:
  31. return self._buffer_chunk_size
  32. @property
  33. def working(self) -> bool:
  34. return self._working
  35. @property
  36. def server(self) -> asyncio.base_events.Server:
  37. return self._server
  38. async def run_reader(self, reader):
  39. while not self.stopping:
  40. try:
  41. # Stop if buffer is full
  42. while len(self.buffer) >= self.buffer_length_limit:
  43. await asyncio.sleep(1)
  44. continue
  45. input_data = await reader.read(self.buffer_chunk_size)
  46. if reader.at_eof():
  47. self.at_eof = True
  48. self.buffer.append(input_data)
  49. except Exception as e:
  50. logging.error(e)
  51. async def run_writer(self, writer):
  52. while not self.stopping:
  53. try:
  54. # Slow down if buffer is short
  55. if len(self.buffer) < 3:
  56. await asyncio.sleep(.1)
  57. try:
  58. input_data = self.buffer.popleft()
  59. except IndexError:
  60. if not self.at_eof:
  61. continue
  62. else:
  63. writer.write_eof()
  64. await writer.drain()
  65. self.at_eof = False
  66. break
  67. writer.write(input_data)
  68. await writer.drain()
  69. except Exception as e:
  70. logging.error(e)
  71. async def connect(self,
  72. reader: asyncio.StreamReader,
  73. writer: asyncio.StreamWriter):
  74. """Connect with client.
  75. Decide whether client is sender or receiver and start transmission.
  76. """
  77. client_hello = await reader.readline()
  78. peer_is_sender = client_hello.decode('utf-8') == 'sender\n'
  79. writer.write("Start!\n".encode('utf-8')) # Send start signal to client
  80. await writer.drain()
  81. self._working = True
  82. if peer_is_sender:
  83. logging.info("Sender is connecting...")
  84. await self.run_reader(reader=reader)
  85. logging.info("Incoming transmission ended")
  86. else:
  87. logging.info("Receiver is connecting...")
  88. await self.run_writer(writer=writer)
  89. logging.info("Outgoing transmission ended")
  90. self._working = False # Reset peer_is_sender
  91. return
  92. def run(self):
  93. loop = asyncio.get_event_loop()
  94. logging.info("Starting file bridging server...")
  95. try:
  96. loop.run_until_complete(self.run_server())
  97. except KeyboardInterrupt:
  98. logging.info("Stopping...")
  99. # Cancel connection tasks (they should be done but are pending)
  100. for task in asyncio.all_tasks(loop):
  101. task.cancel()
  102. loop.run_until_complete(
  103. self.server.wait_closed()
  104. )
  105. loop.close()
  106. logging.info("Stopped.")
  107. async def run_server(self):
  108. self._server = await asyncio.start_server(
  109. client_connected_cb=self.connect,
  110. host=self.host,
  111. port=self.port
  112. )
  113. async with self.server:
  114. try:
  115. await self.server.serve_forever()
  116. except KeyboardInterrupt:
  117. logging.info("Stopping...")
  118. self.server.close()
  119. await self.server.wait_closed()
  120. return
  121. def stop(self, *_):
  122. if self.working and not self.stopping:
  123. logging.info("Received interruption signal, stopping...")
  124. self._stopping = True
  125. else:
  126. raise KeyboardInterrupt("Not working yet...")
  127. if __name__ == '__main__':
  128. # noinspection SpellCheckingInspection
  129. log_formatter = logging.Formatter(
  130. "%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s",
  131. style='%'
  132. )
  133. root_logger = logging.getLogger()
  134. root_logger.setLevel(logging.DEBUG)
  135. console_handler = logging.StreamHandler()
  136. console_handler.setFormatter(log_formatter)
  137. console_handler.setLevel(logging.DEBUG)
  138. root_logger.addHandler(console_handler)
  139. # Parse command-line arguments
  140. parser = argparse.ArgumentParser(description='Run server',
  141. allow_abbrev=False)
  142. parser.add_argument('--host', type=str,
  143. default=None,
  144. required=False,
  145. help='server address')
  146. parser.add_argument('--port', type=int,
  147. default=None,
  148. required=False,
  149. help='server port')
  150. args = vars(parser.parse_args())
  151. _host = args['host']
  152. _port = args['port']
  153. # If _host and _port are not provided from command-line, try to import them
  154. if _host is None:
  155. try:
  156. from config import host as _host
  157. except ImportError:
  158. _host = None
  159. if _port is None:
  160. try:
  161. from config import port as _port
  162. except ImportError:
  163. _port = None
  164. # If import fails, prompt user for _host or _port
  165. while _host is None:
  166. _host = input("Enter host:\t\t\t\t\t\t")
  167. while _port is None:
  168. try:
  169. _port = int(input("Enter port:\t\t\t\t\t\t"))
  170. except ValueError:
  171. logging.info("Invalid port. Enter a valid port number!")
  172. _port = None
  173. server = Server(
  174. host=_host,
  175. port=_port,
  176. )
  177. server.run()