Queer European MD passionate about IT
瀏覽代碼

Refactoring

Davte 5 年之前
父節點
當前提交
3b7aa265ab
共有 3 個文件被更改,包括 190 次插入93 次删除
  1. 32 0
      README.md
  2. 72 61
      filebridging/client.py
  3. 86 32
      filebridging/server.py

+ 32 - 0
README.md

@@ -1,3 +1,35 @@
 # filebridging
 
 Share files via a bridge server using TCP over SSL and aes-256-cbc encryption.
+
+## Requirements
+Python3.8+ is needed for this package.
+
+## Usage
+If you need a virtual environment, create it.
+```bash
+python3.8 -m venv env;
+alias pip="env/bin/pip";
+alias python="env/bin/python";
+```
+
+Install filebridging and read the help.
+```bash
+pip install filebridging
+python -m filebridging.server --help
+python -m filebridging.client --help
+```
+
+## Examples
+Client-server example
+```bash
+# 3 distinct tabs
+python -m filebridging.server --host localhost --port 5000 --certificate ~/.ssh/server.crt --key ~/.ssh/server.key
+python -m filebridging.client s --host localhost --port 5000 --certificate ~/.ssh/server.crt --token 12345678 --password supersecretpasswordhere --path ~/file_to_send 
+python -m filebridging.client r --host localhost --port 5000 --certificate ~/.ssh/server.crt --token 12345678 --password supersecretpasswordhere --path ~/Downloads 
+```
+
+Client-client example
+```bash
+
+```

+ 72 - 61
filebridging/client.py

@@ -20,6 +20,8 @@ class Client:
         self._host = host
         self._port = port
         self._stopping = False
+        self._reader = None
+        self._writer = None
         # Shared queue of bytes
         self.buffer = collections.deque()
         # How many bytes per chunk
@@ -47,6 +49,14 @@ class Client:
     def stopping(self) -> bool:
         return self._stopping
 
+    @property
+    def reader(self) -> asyncio.StreamReader:
+        return self._reader
+
+    @property
+    def writer(self) -> asyncio.StreamWriter:
+        return self._writer
+
     @property
     def buffer_length_limit(self) -> int:
         return self._buffer_length_limit
@@ -91,36 +101,52 @@ class Client:
     def file_size(self):
         return self._file_size
 
-    async def run_sending_client(self, file_path='~/output.txt'):
+    async def run_client(self, file_path, action):
         self._file_path = file_path
-        file_name = os.path.basename(os.path.abspath(file_path))
-        file_size = os.path.getsize(os.path.abspath(file_path))
+        if action == 'send':
+            file_name = os.path.basename(os.path.abspath(file_path))
+            file_size = os.path.getsize(os.path.abspath(file_path))
+            self.set_file_information(
+                file_name=file_name,
+                file_size=file_size
+            )
         try:
             reader, writer = await asyncio.open_connection(
                 host=self.host,
                 port=self.port,
                 ssl=self.ssl_context
             )
+            self._reader = reader
+            self._writer = writer
         except ConnectionRefusedError as exception:
             logging.error(exception)
             return
         writer.write(
-            f"s|{self.token}|{file_name}|{file_size}\n".encode('utf-8')
+            (
+                f"s|{self.token}|"
+                f"{self.file_name}|{self.file_size}\n".encode('utf-8')
+            ) if action == 'send'
+            else f"r|{self.token}\n".encode('utf-8')
         )
-        self.set_file_information(file_name=file_name,
-                                  file_size=file_size)
         await writer.drain()
         # Wait for server start signal
         while 1:
             server_hello = await reader.readline()
             if not server_hello:
-                logging.error("Server disconnected.")
+                logging.info("Server disconnected.")
                 return
-            server_hello = server_hello.decode('utf-8').strip('\n')
-            if server_hello == 'start!':
+            server_hello = server_hello.decode('utf-8').strip('\n').split('|')
+            if action == 'receive' and server_hello[0] == 's':
+                self.set_file_information(file_name=server_hello[2],
+                                          file_size=server_hello[3])
+            elif server_hello[0] == 'start!':
                 break
-            logging.info(f"Server said: {server_hello}")
-        await self.send(writer=writer)
+            else:
+                logging.info(f"Server said: {'|'.join(server_hello)}")
+        if action == 'send':
+            await self.send(writer=writer)
+        else:
+            await self.receive(reader=reader)
 
     async def encrypt_file(self, input_file, output_file):
         self._encryption_complete = False
@@ -177,7 +203,7 @@ class Client:
                     writer.write(output_data)
                     await writer.drain()
                 except ConnectionResetError:
-                    logging.info('Server closed the connection.')
+                    logging.error('Server closed the connection.')
                     self.stop()
                     break
                 bytes_sent += self.buffer_chunk_size
@@ -200,36 +226,6 @@ class Client:
         writer.close()
         return
 
-    async def run_receiving_client(self, file_path='~/input.txt'):
-        self._file_path = file_path
-        try:
-            reader, writer = await asyncio.open_connection(
-                host=self.host,
-                port=self.port,
-                ssl=self.ssl_context
-            )
-        except ConnectionRefusedError as exception:
-            logging.error(exception)
-            return
-        writer.write(f"r|{self.token}\n".encode('utf-8'))
-        await writer.drain()
-        # Wait for server start signal
-        while 1:
-            server_hello = await reader.readline()
-            if not server_hello:
-                logging.info("Server disconnected.")
-                return
-            server_hello = server_hello.decode('utf-8').strip('\n')
-            if server_hello.startswith('info'):
-                _, file_name, file_size = server_hello.split('|')
-                self.set_file_information(file_name=file_name,
-                                          file_size=file_size)
-            elif server_hello == 'start!':
-                break
-            else:
-                logging.info(f"Server said: {server_hello}")
-        await self.receive(reader=reader)
-
     async def receive(self, reader: asyncio.StreamReader):
         self._working = True
         file_path = os.path.join(
@@ -293,6 +289,7 @@ class Client:
         if self.working:
             logging.info("Received interruption signal, stopping...")
             self._stopping = True
+            self.writer.close()
         else:
             raise KeyboardInterrupt("Not working yet...")
 
@@ -302,6 +299,23 @@ class Client:
         if file_size is not None:
             self._file_size = int(file_size)
 
+    def run(self, file_path, action):
+        loop = asyncio.get_event_loop()
+        try:
+            loop.run_until_complete(
+                self.run_client(file_path=file_path,
+                                action=action)
+            )
+        except KeyboardInterrupt:
+            logging.error("Interrupted")
+            for task in asyncio.all_tasks(loop):
+                task.cancel()
+            self.writer.close()
+            loop.run_until_complete(
+                self.writer.wait_closed()
+            )
+        loop.close()
+
 
 def get_action(action):
     """Parse abbreviations for `action`."""
@@ -362,6 +376,10 @@ def main():
                             default=None,
                             required=False,
                             help='server port')
+    cli_parser.add_argument('--certificate', type=str,
+                            default=None,
+                            required=False,
+                            help='server SSL certificate')
     cli_parser.add_argument('--action', type=str,
                             default=None,
                             required=False,
@@ -386,6 +404,7 @@ def main():
     args = vars(cli_parser.parse_args())
     host = args['host']
     port = args['port']
+    certificate = args['certificate']
     action = get_action(args['action'])
     file_path = args['path']
     password = args['password']
@@ -431,6 +450,11 @@ def main():
             from config import token
         except ImportError:
             token = None
+    if certificate is None or not os.path.isfile(certificate):
+        try:
+            from config import certificate
+        except ImportError:
+            certificate = None
 
     # If import fails, prompt user for host or port
     new_settings = {}  # After getting these settings, offer to store them
@@ -516,42 +540,29 @@ def main():
             logging.info("Configuration values stored.")
         else:
             logging.info("Proceeding without storing values...")
-    loop = asyncio.get_event_loop()
     client = Client(
         host=host,
         port=port,
         password=password,
         token=token
     )
-    try:
-        from config import certificate
+    if certificate is not None:
         _ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
         _ssl_context.check_hostname = False
         _ssl_context.load_verify_locations(certificate)
         client.set_ssl_context(_ssl_context)
-    except ImportError:
+    else:
         logging.warning(
             "Please consider using SSL. To do so, add in `config.py` or "
             "provide via Command Line Interface the path to a valid SSL "
             "certificate. Example:\n\n"
             "certificate = 'path/to/certificate.crt'"
         )
-        # noinspection PyUnusedLocal
-        certificate = None
     logging.info("Starting client...")
-    if action == 'send':
-        loop.run_until_complete(
-            client.run_sending_client(
-                file_path=file_path
-            )
-        )
-    else:
-        loop.run_until_complete(
-            client.run_receiving_client(
-                file_path=file_path
-            )
-        )
-    loop.close()
+    client.run(
+        file_path=file_path,
+        action=action
+    )
     logging.info("Stopped client")
 
 

+ 86 - 32
filebridging/server.py

@@ -7,12 +7,14 @@ import argparse
 import asyncio
 import collections
 import logging
+import os
 import ssl
+from typing import Union
 
 
 class Server:
     def __init__(self, host='localhost', port=5000,
-                 buffer_chunk_size=10**4, buffer_length_limit=10**4):
+                 buffer_chunk_size=10 ** 4, buffer_length_limit=10 ** 4):
         self._host = host
         self._port = port
         self.connections = collections.OrderedDict()
@@ -57,9 +59,9 @@ class Server:
     @property
     def buffer_is_full(self):
         return (
-            sum(len(buffer)
-                for buffer in self.buffers.values())
-            >= self.buffer_length_limit
+                sum(len(buffer)
+                    for buffer in self.buffers.values())
+                >= self.buffer_length_limit
         )
 
     def set_ssl_context(self, ssl_context: ssl.SSLContext):
@@ -80,7 +82,7 @@ class Server:
                 logging.error(e)
                 break
             except Exception as e:
-                logging.error(e, exc_info=True)
+                logging.error(f"Unexpected exception:\n{e}", exc_info=True)
 
     async def run_writer(self, writer, connection_token):
         consecutive_interruptions = 0
@@ -105,6 +107,7 @@ class Server:
                 writer.write(input_data)
                 await writer.drain()
             except ConnectionResetError as e:
+                logging.error("Here")
                 logging.error(e)
                 break
             except Exception as e:
@@ -134,7 +137,32 @@ class Server:
                 sender=False,
                 receiver=False
             )
-        if client_hello[0] == 's':
+
+        async def _write(message: Union[list, str, bytes],
+                         terminate_line=True) -> int:
+            # Adapt
+            if type(message) is list:
+                message = '|'.join(message)
+            if type(message) is str:
+                if terminate_line:
+                    message += '\n'
+                message = message.encode('utf-8')
+            if type(message) is not bytes:
+                return 1
+            try:
+                writer.write(message)
+                await writer.drain()
+            except ConnectionResetError:
+                logging.error("Client disconnected.")
+            except Exception as e:
+                logging.error(f"Unexpected exception:\n{e}", exc_info=True)
+            else:
+                return 0  # On success, return 0
+            # On exception, disconnect and return 1
+            self.disconnect(connection_token=connection_token)
+            return 1
+
+        if client_hello[0] == 's':  # Sender client connection
             if self.connections[connection_token]['sender']:
                 await self.refuse_connection(
                     writer=writer,
@@ -151,19 +179,19 @@ class Server:
             while not self.connections[connection_token]['receiver']:
                 index += 1
                 if index >= step:
-                    writer.write("Waiting for receiver...\n".encode('utf-8'))
-                    await writer.drain()
+                    if await _write("Waiting for receiver..."):
+                        return
                     step += 1
                     index = 0
                 await asyncio.sleep(.5)
             # Send start signal to client
-            writer.write("start!\n".encode('utf-8'))
-            await writer.drain()
+            if await _write("start!"):
+                return
             logging.info("Incoming transmission starting...")
             await self.run_reader(reader=reader,
                                   connection_token=connection_token)
             logging.info("Incoming transmission ended")
-        else:  # Receiver client connection
+        elif client_hello[0] == 'r':  # Receiver client connection
             if self.connections[connection_token]['receiver']:
                 await self.refuse_connection(
                     writer=writer,
@@ -177,25 +205,32 @@ class Server:
             while not self.connections[connection_token]['sender']:
                 index += 1
                 if index >= step:
-                    writer.write("Waiting for sender...\n".encode('utf-8'))
-                    await writer.drain()
+                    if await _write("Waiting for sender..."):
+                        return
                     step += 1
                     index = 0
                 await asyncio.sleep(.5)
             # Send file information and start signal to client
             writer.write(
-                "info|"
+                "s|hidden_token|"
                 f"{self.connections[connection_token]['file_name']}|"
                 f"{self.connections[connection_token]['file_size']}"
                 "\n".encode('utf-8')
             )
-            writer.write("start!\n".encode('utf-8'))
-            await writer.drain()
+            if await _write("start!"):
+                return
             await self.run_writer(writer=writer,
                                   connection_token=connection_token)
             logging.info("Outgoing transmission ended")
-            del self.buffers[connection_token]
-            del self.connections[connection_token]
+            self.disconnect(connection_token=connection_token)
+        else:
+            await self.refuse_connection(writer=writer,
+                                         message="Invalid client_hello!")
+            return
+
+    def disconnect(self, connection_token: str) -> None:
+        del self.buffers[connection_token]
+        del self.connections[connection_token]
 
     def run(self):
         loop = asyncio.get_event_loop()
@@ -255,19 +290,29 @@ def main():
     root_logger.addHandler(console_handler)
 
     # Parse command-line arguments
-    parser = argparse.ArgumentParser(description='Run server',
-                                     allow_abbrev=False)
-    parser.add_argument('--host', type=str,
-                        default=None,
-                        required=False,
-                        help='server address')
-    parser.add_argument('--port', type=int,
-                        default=None,
-                        required=False,
-                        help='server port')
-    args = vars(parser.parse_args())
+    cli_parser = argparse.ArgumentParser(description='Run server',
+                                         allow_abbrev=False)
+    cli_parser.add_argument('--host', type=str,
+                            default=None,
+                            required=False,
+                            help='server address')
+    cli_parser.add_argument('--port', type=int,
+                            default=None,
+                            required=False,
+                            help='server port')
+    cli_parser.add_argument('--certificate', type=str,
+                            default=None,
+                            required=False,
+                            help='server SSL certificate')
+    cli_parser.add_argument('--key', type=str,
+                            default=None,
+                            required=False,
+                            help='server SSL key')
+    args = vars(cli_parser.parse_args())
     host = args['host']
     port = args['port']
+    certificate = args['certificate']
+    key = args['key']
 
     # If host and port are not provided from command-line, try to import them
     if host is None:
@@ -296,13 +341,22 @@ def main():
         port=port,
     )
     try:
-        # noinspection PyUnresolvedReferences
-        from config import certificate, key
+        if certificate is None or not os.path.isfile(certificate):
+            from config import certificate
+        if key is None or not os.path.isfile(key):
+            from config import key
+        if not os.path.isfile(certificate):
+            certificate = None
+        if not os.path.isfile(key):
+            key = None
+    except ImportError:
+        pass
+    if certificate and key:
         _ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
         _ssl_context.check_hostname = False
         _ssl_context.load_cert_chain(certificate, key)
         server.set_ssl_context(_ssl_context)
-    except ImportError:
+    else:
         logging.warning(
             "Please consider using SSL. To do so, add in `config.py` or "
             "provide via Command Line Interface the path to a valid SSL "