|
@@ -9,16 +9,25 @@ import random
|
|
|
import ssl
|
|
|
import string
|
|
|
import sys
|
|
|
+from typing import Union
|
|
|
|
|
|
from . import utilities
|
|
|
|
|
|
|
|
|
class Client:
|
|
|
- def __init__(self, host='localhost', port=3001,
|
|
|
- buffer_chunk_size=10**4, buffer_length_limit=10**4,
|
|
|
- password=None, token=None):
|
|
|
+ def __init__(self, host='localhost', port=5000, ssl_context=None,
|
|
|
+ action=None,
|
|
|
+ standalone=False,
|
|
|
+ buffer_chunk_size=10 ** 4,
|
|
|
+ buffer_length_limit=10 ** 4,
|
|
|
+ file_path=None,
|
|
|
+ password=None,
|
|
|
+ token=None):
|
|
|
self._host = host
|
|
|
self._port = port
|
|
|
+ self._ssl_context = ssl_context
|
|
|
+ self._action = action
|
|
|
+ self._standalone = standalone
|
|
|
self._stopping = False
|
|
|
self._reader = None
|
|
|
self._writer = None
|
|
@@ -28,7 +37,7 @@ class Client:
|
|
|
self._buffer_chunk_size = buffer_chunk_size
|
|
|
# How many chunks in buffer
|
|
|
self._buffer_length_limit = buffer_length_limit
|
|
|
- self._file_path = None
|
|
|
+ self._file_path = file_path
|
|
|
self._working = False
|
|
|
self._token = token
|
|
|
self._password = password
|
|
@@ -36,6 +45,7 @@ class Client:
|
|
|
self._encryption_complete = False
|
|
|
self._file_name = None
|
|
|
self._file_size = None
|
|
|
+ self._file_size_string = None
|
|
|
|
|
|
@property
|
|
|
def host(self) -> str:
|
|
@@ -45,6 +55,21 @@ class Client:
|
|
|
def port(self) -> int:
|
|
|
return self._port
|
|
|
|
|
|
+ @property
|
|
|
+ def action(self) -> str:
|
|
|
+ """Client role.
|
|
|
+
|
|
|
+ Possible values:
|
|
|
+ - `send`
|
|
|
+ - `receive`
|
|
|
+ """
|
|
|
+ return self._action
|
|
|
+
|
|
|
+ @property
|
|
|
+ def standalone(self) -> bool:
|
|
|
+ """Tell whether client should run as server as well."""
|
|
|
+ return self._standalone
|
|
|
+
|
|
|
@property
|
|
|
def stopping(self) -> bool:
|
|
|
return self._stopping
|
|
@@ -101,52 +126,128 @@ class Client:
|
|
|
def file_size(self):
|
|
|
return self._file_size
|
|
|
|
|
|
- async def run_client(self, file_path, action):
|
|
|
- self._file_path = file_path
|
|
|
- if action == 'send':
|
|
|
- file_name = os.path.basename(os.path.abspath(file_path))
|
|
|
- file_size = os.path.getsize(os.path.abspath(file_path))
|
|
|
+ @property
|
|
|
+ def file_size_string(self):
|
|
|
+ return self._file_size_string
|
|
|
+
|
|
|
+ async def run_client(self) -> None:
|
|
|
+ if self.action == 'send':
|
|
|
+ file_name = os.path.basename(os.path.abspath(self.file_path))
|
|
|
+ file_size = os.path.getsize(os.path.abspath(self.file_path))
|
|
|
+ # File size increases after encryption
|
|
|
+ # "Salted_" (8 bytes) + salt (8 bytes)
|
|
|
+ # Then, 1-16 bytes are added to make file_size a multiple of 16
|
|
|
+ # i.e., (32 - file_size mod 16) bytes are added to original size
|
|
|
+ if self.password:
|
|
|
+ file_size += 32 - (file_size % 16)
|
|
|
self.set_file_information(
|
|
|
file_name=file_name,
|
|
|
file_size=file_size
|
|
|
)
|
|
|
- try:
|
|
|
- reader, writer = await asyncio.open_connection(
|
|
|
+ if self.standalone:
|
|
|
+ server = await asyncio.start_server(
|
|
|
+ ssl=self.ssl_context,
|
|
|
+ client_connected_cb=self._connect,
|
|
|
host=self.host,
|
|
|
port=self.port,
|
|
|
- ssl=self.ssl_context
|
|
|
)
|
|
|
- self._reader = reader
|
|
|
- self._writer = writer
|
|
|
- except ConnectionRefusedError as exception:
|
|
|
- logging.error(exception)
|
|
|
- return
|
|
|
- writer.write(
|
|
|
- (
|
|
|
- f"s|{self.token}|"
|
|
|
- f"{self.file_name}|{self.file_size}\n".encode('utf-8')
|
|
|
- ) if action == 'send'
|
|
|
- else f"r|{self.token}\n".encode('utf-8')
|
|
|
- )
|
|
|
- await writer.drain()
|
|
|
+ async with server:
|
|
|
+ logging.info("Running at `{s.host}:{s.port}`".format(s=self))
|
|
|
+ await server.serve_forever()
|
|
|
+ else:
|
|
|
+ try:
|
|
|
+ reader, writer = await asyncio.open_connection(
|
|
|
+ host=self.host,
|
|
|
+ port=self.port,
|
|
|
+ ssl=self.ssl_context
|
|
|
+ )
|
|
|
+ except (ConnectionRefusedError, ConnectionResetError) as exception:
|
|
|
+ logging.error(f"Connection error: {exception}")
|
|
|
+ return
|
|
|
+ await self.connect(reader=reader, writer=writer)
|
|
|
+
|
|
|
+ async def _connect(self, reader: asyncio.StreamReader,
|
|
|
+ writer: asyncio.StreamWriter):
|
|
|
+ try:
|
|
|
+ return await self.connect(reader, writer)
|
|
|
+ except KeyboardInterrupt:
|
|
|
+ print()
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(e)
|
|
|
+
|
|
|
+ async def connect(self,
|
|
|
+ reader: asyncio.StreamReader,
|
|
|
+ writer: asyncio.StreamWriter):
|
|
|
+ self._reader = reader
|
|
|
+ self._writer = writer
|
|
|
+
|
|
|
+ async def _write(message: Union[list, str, bytes],
|
|
|
+ terminate_line=True) -> int:
|
|
|
+ """Framework for `asyncio.StreamWriter.write` method.
|
|
|
+
|
|
|
+ Create string from list, encode it, send and drain writer.
|
|
|
+ Return 0 on success, 1 on error.
|
|
|
+ """
|
|
|
+ # Adapt
|
|
|
+ if type(message) is list:
|
|
|
+ message = '|'.join(map(str, message))
|
|
|
+ if type(message) is str:
|
|
|
+ if terminate_line:
|
|
|
+ message += '\n'
|
|
|
+ message = message.encode('utf-8')
|
|
|
+ if type(message) is not bytes:
|
|
|
+ return 1
|
|
|
+ try:
|
|
|
+ writer.write(message)
|
|
|
+ await writer.drain()
|
|
|
+ except ConnectionResetError:
|
|
|
+ logging.error("Client disconnected.")
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(f"Unexpected exception:\n{e}", exc_info=True)
|
|
|
+ else:
|
|
|
+ return 0 # On success, return 0
|
|
|
+ # On exception, return 1
|
|
|
+ return 1
|
|
|
+
|
|
|
+ if self.action == 'send' or not self.standalone:
|
|
|
+ if await _write(
|
|
|
+ [self.action[0], self.token,
|
|
|
+ self.file_name, self.file_size]
|
|
|
+ ):
|
|
|
+ return
|
|
|
# Wait for server start signal
|
|
|
while 1:
|
|
|
- server_hello = await reader.readline()
|
|
|
+ server_hello = await self.reader.readline()
|
|
|
if not server_hello:
|
|
|
- logging.info("Server disconnected.")
|
|
|
+ logging.error("Server disconnected.")
|
|
|
return
|
|
|
server_hello = server_hello.decode('utf-8').strip('\n').split('|')
|
|
|
- if action == 'receive' and server_hello[0] == 's':
|
|
|
+ if self.action == 'receive' and server_hello[0] == 's':
|
|
|
+
|
|
|
self.set_file_information(file_name=server_hello[2],
|
|
|
file_size=server_hello[3])
|
|
|
+ elif (
|
|
|
+ self.standalone
|
|
|
+ and self.action == 'send'
|
|
|
+ and server_hello[0] == 'r'
|
|
|
+ ):
|
|
|
+ # Check token
|
|
|
+ if server_hello[1] != self.token:
|
|
|
+ if await _write("Invalid session token!"):
|
|
|
+ return
|
|
|
+ return
|
|
|
elif server_hello[0] == 'start!':
|
|
|
break
|
|
|
else:
|
|
|
logging.info(f"Server said: {'|'.join(server_hello)}")
|
|
|
- if action == 'send':
|
|
|
- await self.send(writer=writer)
|
|
|
+ if self.standalone:
|
|
|
+ if await _write("start!"):
|
|
|
+ return
|
|
|
+ break
|
|
|
+ if self.action == 'send':
|
|
|
+ await self.send(writer=self.writer)
|
|
|
else:
|
|
|
- await self.receive(reader=reader)
|
|
|
+ await self.receive(reader=self.reader)
|
|
|
|
|
|
async def encrypt_file(self, input_file, output_file):
|
|
|
self._encryption_complete = False
|
|
@@ -201,28 +302,27 @@ class Client:
|
|
|
break
|
|
|
try:
|
|
|
writer.write(output_data)
|
|
|
- await writer.drain()
|
|
|
+ await asyncio.wait_for(writer.drain(), timeout=3.0)
|
|
|
except ConnectionResetError:
|
|
|
+ print() # New line after progress_bar
|
|
|
logging.error('Server closed the connection.')
|
|
|
self.stop()
|
|
|
break
|
|
|
- bytes_sent += self.buffer_chunk_size
|
|
|
+ except asyncio.exceptions.TimeoutError:
|
|
|
+ print() # New line after progress_bar
|
|
|
+ logging.error('Server closed the connection.')
|
|
|
+ self.stop()
|
|
|
+ break
|
|
|
+ bytes_sent += len(output_data)
|
|
|
new_progress = min(
|
|
|
int(bytes_sent / self.file_size * 100),
|
|
|
100
|
|
|
)
|
|
|
- progress_showed = (new_progress // 10) * 10
|
|
|
- sys.stdout.write(
|
|
|
- f"\t\t\tSending `{self.file_name}`: "
|
|
|
- f"{'#' * (progress_showed // 10)}"
|
|
|
- f"{'.' * ((100 - progress_showed) // 10)}\t"
|
|
|
- f"{new_progress}% completed "
|
|
|
- f"({min(bytes_sent, self.file_size) // 1000} "
|
|
|
- f"of {self.file_size // 1000} KB)\r"
|
|
|
+ self.print_progress_bar(
|
|
|
+ progress=new_progress,
|
|
|
+ bytes_=bytes_sent,
|
|
|
)
|
|
|
- sys.stdout.flush()
|
|
|
- sys.stdout.write('\n')
|
|
|
- sys.stdout.flush()
|
|
|
+ print() # New line after progress_bar
|
|
|
writer.close()
|
|
|
return
|
|
|
|
|
@@ -242,26 +342,19 @@ class Client:
|
|
|
bytes_received = 0
|
|
|
while not self.stopping:
|
|
|
input_data = await reader.read(self.buffer_chunk_size)
|
|
|
- bytes_received += self.buffer_chunk_size
|
|
|
+ bytes_received += len(input_data)
|
|
|
new_progress = min(
|
|
|
int(bytes_received / self.file_size * 100),
|
|
|
100
|
|
|
)
|
|
|
- progress_showed = (new_progress // 10) * 10
|
|
|
- sys.stdout.write(
|
|
|
- f"\t\t\tReceiving `{self.file_name}`: "
|
|
|
- f"{'#' * (progress_showed // 10)}"
|
|
|
- f"{'.' * ((100 - progress_showed) // 10)}\t"
|
|
|
- f"{new_progress}% completed "
|
|
|
- f"({min(bytes_received, self.file_size) // 1000} "
|
|
|
- f"of {self.file_size // 1000} KB)\r"
|
|
|
+ self.print_progress_bar(
|
|
|
+ progress=new_progress,
|
|
|
+ bytes_=bytes_received
|
|
|
)
|
|
|
- sys.stdout.flush()
|
|
|
if not input_data:
|
|
|
break
|
|
|
file_to_receive.write(input_data)
|
|
|
- sys.stdout.write('\n')
|
|
|
- sys.stdout.flush()
|
|
|
+ print() # New line after sys.stdout.write
|
|
|
logging.info("File received.")
|
|
|
if self.password:
|
|
|
logging.info("Decrypting file...")
|
|
@@ -289,7 +382,8 @@ class Client:
|
|
|
if self.working:
|
|
|
logging.info("Received interruption signal, stopping...")
|
|
|
self._stopping = True
|
|
|
- self.writer.close()
|
|
|
+ if self.writer:
|
|
|
+ self.writer.close()
|
|
|
else:
|
|
|
raise KeyboardInterrupt("Not working yet...")
|
|
|
|
|
@@ -298,24 +392,62 @@ class Client:
|
|
|
self._file_name = file_name
|
|
|
if file_size is not None:
|
|
|
self._file_size = int(file_size)
|
|
|
+ self._file_size_string = utilities.get_file_size_representation(
|
|
|
+ self.file_size
|
|
|
+ )
|
|
|
|
|
|
- def run(self, file_path, action):
|
|
|
+ def run(self):
|
|
|
loop = asyncio.get_event_loop()
|
|
|
try:
|
|
|
loop.run_until_complete(
|
|
|
- self.run_client(file_path=file_path,
|
|
|
- action=action)
|
|
|
+ self.run_client()
|
|
|
)
|
|
|
except KeyboardInterrupt:
|
|
|
+ print()
|
|
|
logging.error("Interrupted")
|
|
|
for task in asyncio.all_tasks(loop):
|
|
|
task.cancel()
|
|
|
- self.writer.close()
|
|
|
+ if self.writer:
|
|
|
+ self.writer.close()
|
|
|
loop.run_until_complete(
|
|
|
- self.writer.wait_closed()
|
|
|
+ self.wait_closed()
|
|
|
)
|
|
|
loop.close()
|
|
|
|
|
|
+ def print_progress_bar(self, progress: int, bytes_: int):
|
|
|
+ """Print client progress bar.
|
|
|
+
|
|
|
+ `progress` % = `bytes_string` transferred
|
|
|
+ out of `self.file_size_string`.
|
|
|
+ """
|
|
|
+ action = {
|
|
|
+ 'send': "Sending",
|
|
|
+ 'receive': "Receiving"
|
|
|
+ }[self.action]
|
|
|
+ bytes_string = utilities.get_file_size_representation(
|
|
|
+ bytes_
|
|
|
+ )
|
|
|
+ utilities.print_progress_bar(
|
|
|
+ prefix=f"\t\t\t{action} `{self.file_name}`: ",
|
|
|
+ done_symbol='#',
|
|
|
+ pending_symbol='.',
|
|
|
+ progress=progress,
|
|
|
+ scale=5,
|
|
|
+ suffix=(
|
|
|
+ " completed "
|
|
|
+ f"({bytes_string} "
|
|
|
+ f"of {self.file_size_string})"
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ async def wait_closed() -> None:
|
|
|
+ """Give time to cancelled tasks to end properly.
|
|
|
+
|
|
|
+ Sleep .1 second and return.
|
|
|
+ """
|
|
|
+ await asyncio.sleep(.1)
|
|
|
+
|
|
|
|
|
|
def get_action(action):
|
|
|
"""Parse abbreviations for `action`."""
|
|
@@ -380,6 +512,11 @@ def main():
|
|
|
default=None,
|
|
|
required=False,
|
|
|
help='server SSL certificate')
|
|
|
+ cli_parser.add_argument('--key', type=str,
|
|
|
+ default=None,
|
|
|
+ required=False,
|
|
|
+ help='server SSL key (required only for '
|
|
|
+ 'SSL-secured standalone client)')
|
|
|
cli_parser.add_argument('--action', type=str,
|
|
|
default=None,
|
|
|
required=False,
|
|
@@ -397,6 +534,9 @@ def main():
|
|
|
required=False,
|
|
|
help='Session token '
|
|
|
'(must be the same for both clients)')
|
|
|
+ cli_parser.add_argument('--standalone',
|
|
|
+ action='store_true',
|
|
|
+ help='Run both as client and server')
|
|
|
cli_parser.add_argument('others',
|
|
|
metavar='R or S',
|
|
|
nargs='*',
|
|
@@ -405,10 +545,12 @@ def main():
|
|
|
host = args['host']
|
|
|
port = args['port']
|
|
|
certificate = args['certificate']
|
|
|
+ key = args['key']
|
|
|
action = get_action(args['action'])
|
|
|
file_path = args['path']
|
|
|
password = args['password']
|
|
|
token = args['token']
|
|
|
+ standalone = args['standalone']
|
|
|
|
|
|
# If host and port are not provided from command-line, try to import them
|
|
|
sys.path.append(os.path.abspath('.'))
|
|
@@ -455,6 +597,11 @@ def main():
|
|
|
from config import certificate
|
|
|
except ImportError:
|
|
|
certificate = None
|
|
|
+ if key is None or not os.path.isfile(key):
|
|
|
+ try:
|
|
|
+ from config import key
|
|
|
+ except ImportError:
|
|
|
+ key = None
|
|
|
|
|
|
# If import fails, prompt user for host or port
|
|
|
new_settings = {} # After getting these settings, offer to store them
|
|
@@ -540,17 +687,18 @@ def main():
|
|
|
logging.info("Configuration values stored.")
|
|
|
else:
|
|
|
logging.info("Proceeding without storing values...")
|
|
|
- client = Client(
|
|
|
- host=host,
|
|
|
- port=port,
|
|
|
- password=password,
|
|
|
- token=token
|
|
|
- )
|
|
|
+ ssl_context = None
|
|
|
if certificate is not None:
|
|
|
- _ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
|
|
- _ssl_context.check_hostname = False
|
|
|
- _ssl_context.load_verify_locations(certificate)
|
|
|
- client.set_ssl_context(_ssl_context)
|
|
|
+ if key is None: # Server-dependent client
|
|
|
+ ssl_context = ssl.create_default_context(
|
|
|
+ purpose=ssl.Purpose.SERVER_AUTH
|
|
|
+ )
|
|
|
+ ssl_context.load_verify_locations(certificate)
|
|
|
+ else: # Standalone client
|
|
|
+ ssl_context = ssl.create_default_context(
|
|
|
+ purpose=ssl.Purpose.CLIENT_AUTH
|
|
|
+ )
|
|
|
+ ssl_context.load_cert_chain(certificate, key)
|
|
|
else:
|
|
|
logging.warning(
|
|
|
"Please consider using SSL. To do so, add in `config.py` or "
|
|
@@ -559,10 +707,17 @@ def main():
|
|
|
"certificate = 'path/to/certificate.crt'"
|
|
|
)
|
|
|
logging.info("Starting client...")
|
|
|
- client.run(
|
|
|
+ client = Client(
|
|
|
+ host=host,
|
|
|
+ port=port,
|
|
|
+ ssl_context=ssl_context,
|
|
|
+ action=action,
|
|
|
+ standalone=standalone,
|
|
|
file_path=file_path,
|
|
|
- action=action
|
|
|
+ password=password,
|
|
|
+ token=token
|
|
|
)
|
|
|
+ client.run()
|
|
|
logging.info("Stopped client")
|
|
|
|
|
|
|