123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200 |
- import argparse
- import asyncio
- import collections
- import logging
- class Server:
- def __init__(self, host='localhost', port=5000,
- buffer_chunk_size=10**4, buffer_length_limit=10**4):
- self._host = host
- self._port = port
- self._stopping = False
- self.buffer = collections.deque() # Shared queue of bytes
- self._buffer_chunk_size = buffer_chunk_size # How many bytes per chunk
- self._buffer_length_limit = buffer_length_limit # How many chunks in buffer
- self._working = False
- self.at_eof = False
- self._server = None
- @property
- def host(self) -> str:
- return self._host
- @property
- def port(self) -> int:
- return self._port
- @property
- def stopping(self) -> bool:
- return self._stopping
- @property
- def buffer_length_limit(self) -> int:
- return self._buffer_length_limit
- @property
- def buffer_chunk_size(self) -> int:
- return self._buffer_chunk_size
- @property
- def working(self) -> bool:
- return self._working
- @property
- def server(self) -> asyncio.base_events.Server:
- return self._server
- async def run_reader(self, reader):
- while not self.stopping:
- try:
- # Stop if buffer is full
- while len(self.buffer) >= self.buffer_length_limit:
- await asyncio.sleep(1)
- continue
- input_data = await reader.read(self.buffer_chunk_size)
- if reader.at_eof():
- self.at_eof = True
- self.buffer.append(input_data)
- except Exception as e:
- logging.error(e)
- async def run_writer(self, writer):
- while not self.stopping:
- try:
- # Slow down if buffer is short
- if len(self.buffer) < 3:
- await asyncio.sleep(.1)
- try:
- input_data = self.buffer.popleft()
- except IndexError:
- if not self.at_eof:
- continue
- else:
- writer.write_eof()
- await writer.drain()
- self.at_eof = False
- break
- writer.write(input_data)
- await writer.drain()
- except Exception as e:
- logging.error(e)
- async def connect(self,
- reader: asyncio.StreamReader,
- writer: asyncio.StreamWriter):
- """Connect with client.
- Decide whether client is sender or receiver and start transmission.
- """
- client_hello = await reader.readline()
- peer_is_sender = client_hello.decode('utf-8') == 'sender\n'
- writer.write("Start!\n".encode('utf-8')) # Send start signal to client
- await writer.drain()
- self._working = True
- if peer_is_sender:
- logging.info("Sender is connecting...")
- await self.run_reader(reader=reader)
- logging.info("Incoming transmission ended")
- else:
- logging.info("Receiver is connecting...")
- await self.run_writer(writer=writer)
- logging.info("Outgoing transmission ended")
- self._working = False # Reset peer_is_sender
- return
- def run(self):
- loop = asyncio.get_event_loop()
- logging.info("Starting file bridging server...")
- try:
- loop.run_until_complete(self.run_server())
- except KeyboardInterrupt:
- logging.info("Stopping...")
- # Cancel connection tasks (they should be done but are pending)
- for task in asyncio.all_tasks(loop):
- task.cancel()
- loop.run_until_complete(
- self.server.wait_closed()
- )
- loop.close()
- logging.info("Stopped.")
- async def run_server(self):
- self._server = await asyncio.start_server(
- client_connected_cb=self.connect,
- host=self.host,
- port=self.port
- )
- async with self.server:
- try:
- await self.server.serve_forever()
- except KeyboardInterrupt:
- logging.info("Stopping...")
- self.server.close()
- await self.server.wait_closed()
- return
- def stop(self, *_):
- if self.working and not self.stopping:
- logging.info("Received interruption signal, stopping...")
- self._stopping = True
- else:
- raise KeyboardInterrupt("Not working yet...")
- if __name__ == '__main__':
- # noinspection SpellCheckingInspection
- log_formatter = logging.Formatter(
- "%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s",
- style='%'
- )
- root_logger = logging.getLogger()
- root_logger.setLevel(logging.DEBUG)
- console_handler = logging.StreamHandler()
- console_handler.setFormatter(log_formatter)
- console_handler.setLevel(logging.DEBUG)
- root_logger.addHandler(console_handler)
- # Parse command-line arguments
- parser = argparse.ArgumentParser(description='Run server',
- allow_abbrev=False)
- parser.add_argument('--host', type=str,
- default=None,
- required=False,
- help='server address')
- parser.add_argument('--port', type=int,
- default=None,
- required=False,
- help='server port')
- args = vars(parser.parse_args())
- _host = args['host']
- _port = args['port']
- # If _host and _port are not provided from command-line, try to import them
- if _host is None:
- try:
- from config import host as _host
- except ImportError:
- _host = None
- if _port is None:
- try:
- from config import port as _port
- except ImportError:
- _port = None
- # If import fails, prompt user for _host or _port
- while _host is None:
- _host = input("Enter host:\t\t\t\t\t\t")
- while _port is None:
- try:
- _port = int(input("Enter port:\t\t\t\t\t\t"))
- except ValueError:
- logging.info("Invalid port. Enter a valid port number!")
- _port = None
- server = Server(
- host=_host,
- port=_port,
- )
- server.run()
|