Queer European MD passionate about IT
Davte %!s(int64=5) %!d(string=hai) anos
pai
achega
3f5384f9e9
Modificáronse 3 ficheiros con 305 adicións e 109 borrados
  1. 229 74
      filebridging/client.py
  2. 35 34
      filebridging/server.py
  3. 41 1
      filebridging/utilities.py

+ 229 - 74
filebridging/client.py

@@ -9,16 +9,25 @@ import random
 import ssl
 import string
 import sys
+from typing import Union
 
 from . import utilities
 
 
 class Client:
-    def __init__(self, host='localhost', port=3001,
-                 buffer_chunk_size=10**4, buffer_length_limit=10**4,
-                 password=None, token=None):
+    def __init__(self, host='localhost', port=5000, ssl_context=None,
+                 action=None,
+                 standalone=False,
+                 buffer_chunk_size=10 ** 4,
+                 buffer_length_limit=10 ** 4,
+                 file_path=None,
+                 password=None,
+                 token=None):
         self._host = host
         self._port = port
+        self._ssl_context = ssl_context
+        self._action = action
+        self._standalone = standalone
         self._stopping = False
         self._reader = None
         self._writer = None
@@ -28,7 +37,7 @@ class Client:
         self._buffer_chunk_size = buffer_chunk_size
         # How many chunks in buffer
         self._buffer_length_limit = buffer_length_limit
-        self._file_path = None
+        self._file_path = file_path
         self._working = False
         self._token = token
         self._password = password
@@ -36,6 +45,7 @@ class Client:
         self._encryption_complete = False
         self._file_name = None
         self._file_size = None
+        self._file_size_string = None
 
     @property
     def host(self) -> str:
@@ -45,6 +55,21 @@ class Client:
     def port(self) -> int:
         return self._port
 
+    @property
+    def action(self) -> str:
+        """Client role.
+
+        Possible values:
+        - `send`
+        - `receive`
+        """
+        return self._action
+
+    @property
+    def standalone(self) -> bool:
+        """Tell whether client should run as server as well."""
+        return self._standalone
+
     @property
     def stopping(self) -> bool:
         return self._stopping
@@ -101,52 +126,128 @@ class Client:
     def file_size(self):
         return self._file_size
 
-    async def run_client(self, file_path, action):
-        self._file_path = file_path
-        if action == 'send':
-            file_name = os.path.basename(os.path.abspath(file_path))
-            file_size = os.path.getsize(os.path.abspath(file_path))
+    @property
+    def file_size_string(self):
+        return self._file_size_string
+
+    async def run_client(self) -> None:
+        if self.action == 'send':
+            file_name = os.path.basename(os.path.abspath(self.file_path))
+            file_size = os.path.getsize(os.path.abspath(self.file_path))
+            # File size increases after encryption
+            # "Salted_" (8 bytes) + salt (8 bytes)
+            # Then, 1-16 bytes are added to make file_size a multiple of 16
+            # i.e., (32 - file_size mod 16) bytes are added to original size
+            if self.password:
+                file_size += 32 - (file_size % 16)
             self.set_file_information(
                 file_name=file_name,
                 file_size=file_size
             )
-        try:
-            reader, writer = await asyncio.open_connection(
+        if self.standalone:
+            server = await asyncio.start_server(
+                ssl=self.ssl_context,
+                client_connected_cb=self._connect,
                 host=self.host,
                 port=self.port,
-                ssl=self.ssl_context
             )
-            self._reader = reader
-            self._writer = writer
-        except ConnectionRefusedError as exception:
-            logging.error(exception)
-            return
-        writer.write(
-            (
-                f"s|{self.token}|"
-                f"{self.file_name}|{self.file_size}\n".encode('utf-8')
-            ) if action == 'send'
-            else f"r|{self.token}\n".encode('utf-8')
-        )
-        await writer.drain()
+            async with server:
+                logging.info("Running at `{s.host}:{s.port}`".format(s=self))
+                await server.serve_forever()
+        else:
+            try:
+                reader, writer = await asyncio.open_connection(
+                    host=self.host,
+                    port=self.port,
+                    ssl=self.ssl_context
+                )
+            except (ConnectionRefusedError, ConnectionResetError) as exception:
+                logging.error(f"Connection error: {exception}")
+                return
+            await self.connect(reader=reader, writer=writer)
+
+    async def _connect(self, reader: asyncio.StreamReader,
+                       writer: asyncio.StreamWriter):
+        try:
+            return await self.connect(reader, writer)
+        except KeyboardInterrupt:
+            print()
+        except Exception as e:
+            logging.error(e)
+
+    async def connect(self,
+                      reader: asyncio.StreamReader,
+                      writer: asyncio.StreamWriter):
+        self._reader = reader
+        self._writer = writer
+
+        async def _write(message: Union[list, str, bytes],
+                         terminate_line=True) -> int:
+            """Framework for `asyncio.StreamWriter.write` method.
+
+            Create string from list, encode it, send and drain writer.
+            Return 0 on success, 1 on error.
+            """
+            # Adapt
+            if type(message) is list:
+                message = '|'.join(map(str, message))
+            if type(message) is str:
+                if terminate_line:
+                    message += '\n'
+                message = message.encode('utf-8')
+            if type(message) is not bytes:
+                return 1
+            try:
+                writer.write(message)
+                await writer.drain()
+            except ConnectionResetError:
+                logging.error("Client disconnected.")
+            except Exception as e:
+                logging.error(f"Unexpected exception:\n{e}", exc_info=True)
+            else:
+                return 0  # On success, return 0
+            # On exception, return 1
+            return 1
+
+        if self.action == 'send' or not self.standalone:
+            if await _write(
+                    [self.action[0], self.token,
+                     self.file_name, self.file_size]
+            ):
+                return
         # Wait for server start signal
         while 1:
-            server_hello = await reader.readline()
+            server_hello = await self.reader.readline()
             if not server_hello:
-                logging.info("Server disconnected.")
+                logging.error("Server disconnected.")
                 return
             server_hello = server_hello.decode('utf-8').strip('\n').split('|')
-            if action == 'receive' and server_hello[0] == 's':
+            if self.action == 'receive' and server_hello[0] == 's':
+
                 self.set_file_information(file_name=server_hello[2],
                                           file_size=server_hello[3])
+            elif (
+                    self.standalone
+                    and self.action == 'send'
+                    and server_hello[0] == 'r'
+            ):
+                # Check token
+                if server_hello[1] != self.token:
+                    if await _write("Invalid session token!"):
+                        return
+                    return
             elif server_hello[0] == 'start!':
                 break
             else:
                 logging.info(f"Server said: {'|'.join(server_hello)}")
-        if action == 'send':
-            await self.send(writer=writer)
+            if self.standalone:
+                if await _write("start!"):
+                    return
+                break
+        if self.action == 'send':
+            await self.send(writer=self.writer)
         else:
-            await self.receive(reader=reader)
+            await self.receive(reader=self.reader)
 
     async def encrypt_file(self, input_file, output_file):
         self._encryption_complete = False
@@ -201,28 +302,27 @@ class Client:
                     break
                 try:
                     writer.write(output_data)
-                    await writer.drain()
+                    await asyncio.wait_for(writer.drain(), timeout=3.0)
                 except ConnectionResetError:
+                    print()  # New line after progress_bar
                     logging.error('Server closed the connection.')
                     self.stop()
                     break
-                bytes_sent += self.buffer_chunk_size
+                except asyncio.exceptions.TimeoutError:
+                    print()  # New line after progress_bar
+                    logging.error('Server closed the connection.')
+                    self.stop()
+                    break
+                bytes_sent += len(output_data)
                 new_progress = min(
                     int(bytes_sent / self.file_size * 100),
                     100
                 )
-                progress_showed = (new_progress // 10) * 10
-                sys.stdout.write(
-                    f"\t\t\tSending `{self.file_name}`: "
-                    f"{'#' * (progress_showed // 10)}"
-                    f"{'.' * ((100 - progress_showed) // 10)}\t"
-                    f"{new_progress}% completed "
-                    f"({min(bytes_sent, self.file_size) // 1000} "
-                    f"of {self.file_size // 1000} KB)\r"
+                self.print_progress_bar(
+                    progress=new_progress,
+                    bytes_=bytes_sent,
                 )
-                sys.stdout.flush()
-        sys.stdout.write('\n')
-        sys.stdout.flush()
+        print()  # New line after progress_bar
         writer.close()
         return
 
@@ -242,26 +342,19 @@ class Client:
             bytes_received = 0
             while not self.stopping:
                 input_data = await reader.read(self.buffer_chunk_size)
-                bytes_received += self.buffer_chunk_size
+                bytes_received += len(input_data)
                 new_progress = min(
                     int(bytes_received / self.file_size * 100),
                     100
                 )
-                progress_showed = (new_progress // 10) * 10
-                sys.stdout.write(
-                    f"\t\t\tReceiving `{self.file_name}`: "
-                    f"{'#' * (progress_showed // 10)}"
-                    f"{'.' * ((100 - progress_showed) // 10)}\t"
-                    f"{new_progress}% completed "
-                    f"({min(bytes_received, self.file_size) // 1000} "
-                    f"of {self.file_size // 1000} KB)\r"
+                self.print_progress_bar(
+                    progress=new_progress,
+                    bytes_=bytes_received
                 )
-                sys.stdout.flush()
                 if not input_data:
                     break
                 file_to_receive.write(input_data)
-        sys.stdout.write('\n')
-        sys.stdout.flush()
+        print()  # New line after sys.stdout.write
         logging.info("File received.")
         if self.password:
             logging.info("Decrypting file...")
@@ -289,7 +382,8 @@ class Client:
         if self.working:
             logging.info("Received interruption signal, stopping...")
             self._stopping = True
-            self.writer.close()
+            if self.writer:
+                self.writer.close()
         else:
             raise KeyboardInterrupt("Not working yet...")
 
@@ -298,24 +392,62 @@ class Client:
             self._file_name = file_name
         if file_size is not None:
             self._file_size = int(file_size)
+            self._file_size_string = utilities.get_file_size_representation(
+                self.file_size
+            )
 
-    def run(self, file_path, action):
+    def run(self):
         loop = asyncio.get_event_loop()
         try:
             loop.run_until_complete(
-                self.run_client(file_path=file_path,
-                                action=action)
+                self.run_client()
             )
         except KeyboardInterrupt:
+            print()
             logging.error("Interrupted")
             for task in asyncio.all_tasks(loop):
                 task.cancel()
-            self.writer.close()
+            if self.writer:
+                self.writer.close()
             loop.run_until_complete(
-                self.writer.wait_closed()
+                self.wait_closed()
             )
         loop.close()
 
+    def print_progress_bar(self, progress: int, bytes_: int):
+        """Print client progress bar.
+
+        `progress` % = `bytes_string` transferred
+        out of `self.file_size_string`.
+        """
+        action = {
+            'send': "Sending",
+            'receive': "Receiving"
+        }[self.action]
+        bytes_string = utilities.get_file_size_representation(
+            bytes_
+        )
+        utilities.print_progress_bar(
+            prefix=f"\t\t\t{action} `{self.file_name}`: ",
+            done_symbol='#',
+            pending_symbol='.',
+            progress=progress,
+            scale=5,
+            suffix=(
+                " completed "
+                f"({bytes_string} "
+                f"of {self.file_size_string})"
+            )
+        )
+
+    @staticmethod
+    async def wait_closed() -> None:
+        """Give time to cancelled tasks to end properly.
+
+        Sleep .1 second and return.
+        """
+        await asyncio.sleep(.1)
+
 
 def get_action(action):
     """Parse abbreviations for `action`."""
@@ -380,6 +512,11 @@ def main():
                             default=None,
                             required=False,
                             help='server SSL certificate')
+    cli_parser.add_argument('--key', type=str,
+                            default=None,
+                            required=False,
+                            help='server SSL key (required only for '
+                                 'SSL-secured standalone client)')
     cli_parser.add_argument('--action', type=str,
                             default=None,
                             required=False,
@@ -397,6 +534,9 @@ def main():
                             required=False,
                             help='Session token '
                                  '(must be the same for both clients)')
+    cli_parser.add_argument('--standalone',
+                            action='store_true',
+                            help='Run both as client and server')
     cli_parser.add_argument('others',
                             metavar='R or S',
                             nargs='*',
@@ -405,10 +545,12 @@ def main():
     host = args['host']
     port = args['port']
     certificate = args['certificate']
+    key = args['key']
     action = get_action(args['action'])
     file_path = args['path']
     password = args['password']
     token = args['token']
+    standalone = args['standalone']
 
     # If host and port are not provided from command-line, try to import them
     sys.path.append(os.path.abspath('.'))
@@ -455,6 +597,11 @@ def main():
             from config import certificate
         except ImportError:
             certificate = None
+    if key is None or not os.path.isfile(key):
+        try:
+            from config import key
+        except ImportError:
+            key = None
 
     # If import fails, prompt user for host or port
     new_settings = {}  # After getting these settings, offer to store them
@@ -540,17 +687,18 @@ def main():
             logging.info("Configuration values stored.")
         else:
             logging.info("Proceeding without storing values...")
-    client = Client(
-        host=host,
-        port=port,
-        password=password,
-        token=token
-    )
+    ssl_context = None
     if certificate is not None:
-        _ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
-        _ssl_context.check_hostname = False
-        _ssl_context.load_verify_locations(certificate)
-        client.set_ssl_context(_ssl_context)
+        if key is None:  # Server-dependent client
+            ssl_context = ssl.create_default_context(
+                purpose=ssl.Purpose.SERVER_AUTH
+            )
+            ssl_context.load_verify_locations(certificate)
+        else:  # Standalone client
+            ssl_context = ssl.create_default_context(
+                purpose=ssl.Purpose.CLIENT_AUTH
+            )
+            ssl_context.load_cert_chain(certificate, key)
     else:
         logging.warning(
             "Please consider using SSL. To do so, add in `config.py` or "
@@ -559,10 +707,17 @@ def main():
             "certificate = 'path/to/certificate.crt'"
         )
     logging.info("Starting client...")
-    client.run(
+    client = Client(
+        host=host,
+        port=port,
+        ssl_context=ssl_context,
+        action=action,
+        standalone=standalone,
         file_path=file_path,
-        action=action
+        password=password,
+        token=token
     )
+    client.run()
     logging.info("Stopped client")
 
 

+ 35 - 34
filebridging/server.py

@@ -13,10 +13,11 @@ from typing import Union
 
 
 class Server:
-    def __init__(self, host='localhost', port=5000,
+    def __init__(self, host='localhost', port=5000, ssl_context=None,
                  buffer_chunk_size=10 ** 4, buffer_length_limit=10 ** 4):
         self._host = host
         self._port = port
+        self._ssl_context = ssl_context
         self.connections = collections.OrderedDict()
         # Dict of queues of bytes
         self.buffers = collections.OrderedDict()
@@ -87,27 +88,24 @@ class Server:
     async def run_writer(self, writer, connection_token):
         consecutive_interruptions = 0
         errors = 0
-        while 1:
+        while connection_token in self.buffers:
             try:
-                try:
-                    if connection_token not in self.buffers:
-                        break
-                    input_data = self.buffers[connection_token].popleft()
-                except IndexError:
-                    # Slow down if buffer is short
-                    consecutive_interruptions += 1
-                    if consecutive_interruptions > 3:
-                        break
-                    await asyncio.sleep(.5)
-                    continue
-                else:
-                    consecutive_interruptions = 0
-                if not input_data:
+                input_data = self.buffers[connection_token].popleft()
+            except IndexError:
+                # Slow down if buffer is empty; after 1.5 s of silence, break
+                consecutive_interruptions += 1
+                if consecutive_interruptions > 3:
                     break
+                await asyncio.sleep(.5)
+                continue
+            else:
+                consecutive_interruptions = 0
+            if not input_data:
+                break
+            try:
                 writer.write(input_data)
                 await writer.drain()
             except ConnectionResetError as e:
-                logging.error("Here")
                 logging.error(e)
                 break
             except Exception as e:
@@ -127,7 +125,7 @@ class Server:
         """
         client_hello = await reader.readline()
         client_hello = client_hello.decode('utf-8').strip('\n').split('|')
-        if len(client_hello) not in (2, 4,):
+        if len(client_hello) != 4:
             await self.refuse_connection(writer=writer,
                                          message="Invalid client_hello!")
             return
@@ -142,7 +140,7 @@ class Server:
                          terminate_line=True) -> int:
             # Adapt
             if type(message) is list:
-                message = '|'.join(message)
+                message = '|'.join(map(str, message))
             if type(message) is str:
                 if terminate_line:
                     message += '\n'
@@ -211,12 +209,13 @@ class Server:
                     index = 0
                 await asyncio.sleep(.5)
             # Send file information and start signal to client
-            writer.write(
-                "s|hidden_token|"
-                f"{self.connections[connection_token]['file_name']}|"
-                f"{self.connections[connection_token]['file_size']}"
-                "\n".encode('utf-8')
-            )
+            if await _write(
+                    ['s',
+                     'hidden_token',
+                     self.connections[connection_token]['file_name'],
+                     self.connections[connection_token]['file_size']]
+            ):
+                return
             if await _write("start!"):
                 return
             await self.run_writer(writer=writer,
@@ -238,6 +237,7 @@ class Server:
         try:
             loop.run_until_complete(self.run_server())
         except KeyboardInterrupt:
+            print()
             logging.info("Stopping...")
             # Cancel connection tasks (they should be done but are pending)
             for task in asyncio.all_tasks(loop):
@@ -336,10 +336,6 @@ def main():
             logging.info("Invalid port. Enter a valid port number!")
             port = None
 
-    server = Server(
-        host=host,
-        port=port,
-    )
     try:
         if certificate is None or not os.path.isfile(certificate):
             from config import certificate
@@ -350,12 +346,12 @@ def main():
         if not os.path.isfile(key):
             key = None
     except ImportError:
-        pass
+        certificate = None
+        key = None
+    ssl_context = None
     if certificate and key:
-        _ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
-        _ssl_context.check_hostname = False
-        _ssl_context.load_cert_chain(certificate, key)
-        server.set_ssl_context(_ssl_context)
+        ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
+        ssl_context.load_cert_chain(certificate, key)
     else:
         logging.warning(
             "Please consider using SSL. To do so, add in `config.py` or "
@@ -364,6 +360,11 @@ def main():
             "key = 'path/to/secret.key'\n"
             "certificate = 'path/to/certificate.crt'"
         )
+    server = Server(
+        host=host,
+        port=port,
+        ssl_context=ssl_context
+    )
     server.run()
 
 

+ 41 - 1
filebridging/utilities.py

@@ -2,11 +2,51 @@
 
 import logging
 import signal
+import sys
+
+units_of_measurements = {
+    1: 'bytes',
+    1000: 'KB',
+    1000 * 1000: 'MB',
+    1000 * 1000 * 1000: 'GB',
+    1000 * 1000 * 1000 * 1000: 'TB',
+}
+
+
+def get_file_size_representation(file_size):
+    scale, unit = get_scale_and_unit(file_size=file_size)
+    if scale < 10:
+        return f"{file_size} {unit}"
+    return f"{(file_size // (scale / 100)) / 100:.2f} {unit}"
+
+
+def get_scale_and_unit(file_size):
+    scale, unit = min(units_of_measurements.items())
+    for scale, unit in sorted(units_of_measurements.items(), reverse=True):
+        if file_size > scale:
+            break
+    return scale, unit
+
+
+def print_progress_bar(prefix='',
+                       suffix='',
+                       done_symbol="#",
+                       pending_symbol=".",
+                       progress=0,
+                       scale=10):
+    progress_showed = (progress // scale) * scale
+    sys.stdout.write(
+        f"{prefix}"
+        f"{done_symbol * (progress_showed // scale)}"
+        f"{pending_symbol * ((100 - progress_showed) // scale)}\t"
+        f"{progress}%"
+        f"{suffix}      \r"
+    )
+    sys.stdout.flush()
 
 
 def timed_input(message: str = None,
                 timeout: int = 5):
-
     class TimeoutExpired(Exception):
         pass