123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- import argparse
- import asyncio
- import collections
- import logging
- # import signal
- import os
- class Client:
- def __init__(self, host='localhost', port=3001,
- buffer_chunk_size=10**4, buffer_length_limit=10**4):
- self._host = host
- self._port = port
- self._stopping = False
- self.buffer = collections.deque() # Shared queue of bytes
- self._buffer_chunk_size = buffer_chunk_size # How many bytes per chunk
- self._buffer_length_limit = buffer_length_limit # How many chunks in buffer
- self._file_path = None
- self._working = False
- @property
- def host(self) -> str:
- return self._host
- @property
- def port(self) -> int:
- return self._port
- @property
- def stopping(self) -> bool:
- return self._stopping
- @property
- def buffer_length_limit(self) -> int:
- return self._buffer_length_limit
- @property
- def buffer_chunk_size(self) -> int:
- return self._buffer_chunk_size
- @property
- def file_path(self) -> str:
- return self._file_path
- @property
- def working(self) -> bool:
- return self._working
- async def run_sending_client(self, file_path='~/output.txt'):
- self._file_path = file_path
- _, writer = await asyncio.open_connection(host=self.host, port=self.port)
- await self.send(writer=writer)
- async def send(self, writer: asyncio.StreamWriter):
- self._working = True
- with open(self.file_path, 'rb') as file_to_send:
- while not self.stopping:
- output_data = file_to_send.read(self.buffer_chunk_size)
- if not output_data:
- break
- writer.write(output_data)
- try:
- await writer.drain()
- except ConnectionResetError:
- logging.info('Server closed the connection.')
- self.stop()
- break
- else:
- # If transmission has succeeded, write end of file
- writer.write_eof()
- await writer.drain()
- return
- async def run_receiving_client(self, file_path='~/input.txt'):
- self._file_path = file_path
- reader, _ = await asyncio.open_connection(host=self.host, port=self.port)
- await self.receive(reader=reader)
- async def receive(self, reader: asyncio.StreamReader):
- self._working = True
- with open(self.file_path, 'wb') as file_to_receive:
- while not self.stopping:
- input_data = await reader.read(self.buffer_chunk_size)
- if reader.at_eof():
- break
- if not input_data:
- continue
- file_to_receive.write(input_data)
- def stop(self, *_):
- if self.working:
- logging.info("Received interruption signal, stopping...")
- self._stopping = True
- else:
- raise KeyboardInterrupt("Not working yet...")
- def get_action(action):
- """Parse abbreviations for `action`."""
- if not isinstance(action, str):
- return
- elif action.lower().startswith('r'):
- return 'receive'
- elif action.lower().startswith('s'):
- return 'send'
- def get_file_path(path, action='receive'):
- """Check that file `path` is correct and return it."""
- if (
- isinstance(path, str)
- and action == 'send'
- and os.path.isfile(path)
- ):
- return path
- elif (
- isinstance(path, str)
- and action == 'receive'
- and os.access(os.path.dirname(os.path.abspath(path)), os.W_OK)
- ):
- return path
- elif path is not None:
- logging.error(f"Invalid file: `{path}`")
- if __name__ == '__main__':
- log_formatter = logging.Formatter(
- "%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s",
- style='%'
- )
- root_logger = logging.getLogger()
- root_logger.setLevel(logging.DEBUG)
- console_handler = logging.StreamHandler()
- console_handler.setFormatter(log_formatter)
- console_handler.setLevel(logging.DEBUG)
- root_logger.addHandler(console_handler)
- # Parse command-line arguments
- cli_parser = argparse.ArgumentParser(description='Run client',
- allow_abbrev=False)
- cli_parser.add_argument('--host', type=str,
- default=None,
- required=False,
- help='server address')
- cli_parser.add_argument('--port', type=int,
- default=None,
- required=False,
- help='server _port')
- cli_parser.add_argument('--action', type=str,
- default=None,
- required=False,
- help='[S]end or [R]eceive')
- cli_parser.add_argument('--path', type=str,
- default=None,
- required=False,
- help='File path')
- cli_parser.add_argument('others',
- metavar='R or S',
- nargs='*',
- help='[S]end or [R]eceive (see `action`)')
- args = vars(cli_parser.parse_args())
- _host = args['host']
- _port = args['port']
- _action = get_action(args['action'])
- _file_path = args['path']
- # If _host and _port are not provided from command-line, try to import them
- if _host is None:
- try:
- from config import host as _host
- except ImportError:
- _host = None
- if _port is None:
- try:
- from config import port as _port
- except ImportError:
- _port = None
- # Take `s`, `r` etc. from command line as `_action`
- if _action is None:
- for arg in args['others']:
- _action = get_action(arg)
- if _action:
- break
- if _action is None:
- try:
- from config import action as _action
- _action = get_action(_action)
- except ImportError:
- _action = None
- if _file_path is None:
- try:
- from config import file_path as _file_path
- _file_path = get_action(_file_path)
- except ImportError:
- _file_path = None
- # If import fails, prompt user for _host or _port
- while _host is None:
- _host = input("Enter _host:\t\t\t\t\t\t")
- while _port is None:
- try:
- _port = int(input("Enter _port:\t\t\t\t\t\t"))
- except ValueError:
- logging.info("Invalid _port. Enter a valid _port number!")
- _port = None
- while _action is None:
- _action = get_action(
- input("Do you want to (R)eceive or (S)end a file?\t\t")
- )
- while _file_path is None:
- _file_path = get_file_path(
- path=input(f"Enter file to {_action}:\t\t\t\t\t\t"),
- action=_action
- )
- loop = asyncio.get_event_loop()
- client = Client(
- host=_host,
- port=_port,
- )
- logging.info("Starting client...")
- if _action == 'send':
- loop.run_until_complete(
- client.run_sending_client(
- file_path=_file_path
- )
- )
- else:
- loop.run_until_complete(
- client.run_receiving_client(
- file_path=_file_path
- )
- )
- loop.close()
- logging.info("Stopped client")
|