|
@@ -20,6 +20,8 @@ class Client:
|
|
|
self._host = host
|
|
|
self._port = port
|
|
|
self._stopping = False
|
|
|
+ self._reader = None
|
|
|
+ self._writer = None
|
|
|
# Shared queue of bytes
|
|
|
self.buffer = collections.deque()
|
|
|
# How many bytes per chunk
|
|
@@ -47,6 +49,14 @@ class Client:
|
|
|
def stopping(self) -> bool:
|
|
|
return self._stopping
|
|
|
|
|
|
+ @property
|
|
|
+ def reader(self) -> asyncio.StreamReader:
|
|
|
+ return self._reader
|
|
|
+
|
|
|
+ @property
|
|
|
+ def writer(self) -> asyncio.StreamWriter:
|
|
|
+ return self._writer
|
|
|
+
|
|
|
@property
|
|
|
def buffer_length_limit(self) -> int:
|
|
|
return self._buffer_length_limit
|
|
@@ -91,36 +101,52 @@ class Client:
|
|
|
def file_size(self):
|
|
|
return self._file_size
|
|
|
|
|
|
- async def run_sending_client(self, file_path='~/output.txt'):
|
|
|
+ async def run_client(self, file_path, action):
|
|
|
self._file_path = file_path
|
|
|
- file_name = os.path.basename(os.path.abspath(file_path))
|
|
|
- file_size = os.path.getsize(os.path.abspath(file_path))
|
|
|
+ if action == 'send':
|
|
|
+ file_name = os.path.basename(os.path.abspath(file_path))
|
|
|
+ file_size = os.path.getsize(os.path.abspath(file_path))
|
|
|
+ self.set_file_information(
|
|
|
+ file_name=file_name,
|
|
|
+ file_size=file_size
|
|
|
+ )
|
|
|
try:
|
|
|
reader, writer = await asyncio.open_connection(
|
|
|
host=self.host,
|
|
|
port=self.port,
|
|
|
ssl=self.ssl_context
|
|
|
)
|
|
|
+ self._reader = reader
|
|
|
+ self._writer = writer
|
|
|
except ConnectionRefusedError as exception:
|
|
|
logging.error(exception)
|
|
|
return
|
|
|
writer.write(
|
|
|
- f"s|{self.token}|{file_name}|{file_size}\n".encode('utf-8')
|
|
|
+ (
|
|
|
+ f"s|{self.token}|"
|
|
|
+ f"{self.file_name}|{self.file_size}\n".encode('utf-8')
|
|
|
+ ) if action == 'send'
|
|
|
+ else f"r|{self.token}\n".encode('utf-8')
|
|
|
)
|
|
|
- self.set_file_information(file_name=file_name,
|
|
|
- file_size=file_size)
|
|
|
await writer.drain()
|
|
|
# Wait for server start signal
|
|
|
while 1:
|
|
|
server_hello = await reader.readline()
|
|
|
if not server_hello:
|
|
|
- logging.error("Server disconnected.")
|
|
|
+ logging.info("Server disconnected.")
|
|
|
return
|
|
|
- server_hello = server_hello.decode('utf-8').strip('\n')
|
|
|
- if server_hello == 'start!':
|
|
|
+ server_hello = server_hello.decode('utf-8').strip('\n').split('|')
|
|
|
+ if action == 'receive' and server_hello[0] == 's':
|
|
|
+ self.set_file_information(file_name=server_hello[2],
|
|
|
+ file_size=server_hello[3])
|
|
|
+ elif server_hello[0] == 'start!':
|
|
|
break
|
|
|
- logging.info(f"Server said: {server_hello}")
|
|
|
- await self.send(writer=writer)
|
|
|
+ else:
|
|
|
+ logging.info(f"Server said: {'|'.join(server_hello)}")
|
|
|
+ if action == 'send':
|
|
|
+ await self.send(writer=writer)
|
|
|
+ else:
|
|
|
+ await self.receive(reader=reader)
|
|
|
|
|
|
async def encrypt_file(self, input_file, output_file):
|
|
|
self._encryption_complete = False
|
|
@@ -177,7 +203,7 @@ class Client:
|
|
|
writer.write(output_data)
|
|
|
await writer.drain()
|
|
|
except ConnectionResetError:
|
|
|
- logging.info('Server closed the connection.')
|
|
|
+ logging.error('Server closed the connection.')
|
|
|
self.stop()
|
|
|
break
|
|
|
bytes_sent += self.buffer_chunk_size
|
|
@@ -200,36 +226,6 @@ class Client:
|
|
|
writer.close()
|
|
|
return
|
|
|
|
|
|
- async def run_receiving_client(self, file_path='~/input.txt'):
|
|
|
- self._file_path = file_path
|
|
|
- try:
|
|
|
- reader, writer = await asyncio.open_connection(
|
|
|
- host=self.host,
|
|
|
- port=self.port,
|
|
|
- ssl=self.ssl_context
|
|
|
- )
|
|
|
- except ConnectionRefusedError as exception:
|
|
|
- logging.error(exception)
|
|
|
- return
|
|
|
- writer.write(f"r|{self.token}\n".encode('utf-8'))
|
|
|
- await writer.drain()
|
|
|
- # Wait for server start signal
|
|
|
- while 1:
|
|
|
- server_hello = await reader.readline()
|
|
|
- if not server_hello:
|
|
|
- logging.info("Server disconnected.")
|
|
|
- return
|
|
|
- server_hello = server_hello.decode('utf-8').strip('\n')
|
|
|
- if server_hello.startswith('info'):
|
|
|
- _, file_name, file_size = server_hello.split('|')
|
|
|
- self.set_file_information(file_name=file_name,
|
|
|
- file_size=file_size)
|
|
|
- elif server_hello == 'start!':
|
|
|
- break
|
|
|
- else:
|
|
|
- logging.info(f"Server said: {server_hello}")
|
|
|
- await self.receive(reader=reader)
|
|
|
-
|
|
|
async def receive(self, reader: asyncio.StreamReader):
|
|
|
self._working = True
|
|
|
file_path = os.path.join(
|
|
@@ -293,6 +289,7 @@ class Client:
|
|
|
if self.working:
|
|
|
logging.info("Received interruption signal, stopping...")
|
|
|
self._stopping = True
|
|
|
+ self.writer.close()
|
|
|
else:
|
|
|
raise KeyboardInterrupt("Not working yet...")
|
|
|
|
|
@@ -302,6 +299,23 @@ class Client:
|
|
|
if file_size is not None:
|
|
|
self._file_size = int(file_size)
|
|
|
|
|
|
+ def run(self, file_path, action):
|
|
|
+ loop = asyncio.get_event_loop()
|
|
|
+ try:
|
|
|
+ loop.run_until_complete(
|
|
|
+ self.run_client(file_path=file_path,
|
|
|
+ action=action)
|
|
|
+ )
|
|
|
+ except KeyboardInterrupt:
|
|
|
+ logging.error("Interrupted")
|
|
|
+ for task in asyncio.all_tasks(loop):
|
|
|
+ task.cancel()
|
|
|
+ self.writer.close()
|
|
|
+ loop.run_until_complete(
|
|
|
+ self.writer.wait_closed()
|
|
|
+ )
|
|
|
+ loop.close()
|
|
|
+
|
|
|
|
|
|
def get_action(action):
|
|
|
"""Parse abbreviations for `action`."""
|
|
@@ -362,6 +376,10 @@ def main():
|
|
|
default=None,
|
|
|
required=False,
|
|
|
help='server port')
|
|
|
+ cli_parser.add_argument('--certificate', type=str,
|
|
|
+ default=None,
|
|
|
+ required=False,
|
|
|
+ help='server SSL certificate')
|
|
|
cli_parser.add_argument('--action', type=str,
|
|
|
default=None,
|
|
|
required=False,
|
|
@@ -386,6 +404,7 @@ def main():
|
|
|
args = vars(cli_parser.parse_args())
|
|
|
host = args['host']
|
|
|
port = args['port']
|
|
|
+ certificate = args['certificate']
|
|
|
action = get_action(args['action'])
|
|
|
file_path = args['path']
|
|
|
password = args['password']
|
|
@@ -431,6 +450,11 @@ def main():
|
|
|
from config import token
|
|
|
except ImportError:
|
|
|
token = None
|
|
|
+ if certificate is None or not os.path.isfile(certificate):
|
|
|
+ try:
|
|
|
+ from config import certificate
|
|
|
+ except ImportError:
|
|
|
+ certificate = None
|
|
|
|
|
|
# If import fails, prompt user for host or port
|
|
|
new_settings = {} # After getting these settings, offer to store them
|
|
@@ -516,42 +540,29 @@ def main():
|
|
|
logging.info("Configuration values stored.")
|
|
|
else:
|
|
|
logging.info("Proceeding without storing values...")
|
|
|
- loop = asyncio.get_event_loop()
|
|
|
client = Client(
|
|
|
host=host,
|
|
|
port=port,
|
|
|
password=password,
|
|
|
token=token
|
|
|
)
|
|
|
- try:
|
|
|
- from config import certificate
|
|
|
+ if certificate is not None:
|
|
|
_ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
|
|
_ssl_context.check_hostname = False
|
|
|
_ssl_context.load_verify_locations(certificate)
|
|
|
client.set_ssl_context(_ssl_context)
|
|
|
- except ImportError:
|
|
|
+ else:
|
|
|
logging.warning(
|
|
|
"Please consider using SSL. To do so, add in `config.py` or "
|
|
|
"provide via Command Line Interface the path to a valid SSL "
|
|
|
"certificate. Example:\n\n"
|
|
|
"certificate = 'path/to/certificate.crt'"
|
|
|
)
|
|
|
- # noinspection PyUnusedLocal
|
|
|
- certificate = None
|
|
|
logging.info("Starting client...")
|
|
|
- if action == 'send':
|
|
|
- loop.run_until_complete(
|
|
|
- client.run_sending_client(
|
|
|
- file_path=file_path
|
|
|
- )
|
|
|
- )
|
|
|
- else:
|
|
|
- loop.run_until_complete(
|
|
|
- client.run_receiving_client(
|
|
|
- file_path=file_path
|
|
|
- )
|
|
|
- )
|
|
|
- loop.close()
|
|
|
+ client.run(
|
|
|
+ file_path=file_path,
|
|
|
+ action=action
|
|
|
+ )
|
|
|
logging.info("Stopped client")
|
|
|
|
|
|
|