Queer European MD passionate about IT
Browse Source

Allow multiple client connections

Davte 5 years ago
parent
commit
e68ab4282c
2 changed files with 169 additions and 54 deletions
  1. 83 13
      src/client.py
  2. 86 41
      src/server.py

+ 83 - 13
src/client.py

@@ -4,14 +4,15 @@ import collections
 import logging
 import logging
 # import signal
 # import signal
 import os
 import os
+import random
 import ssl
 import ssl
+import string
 
 
 
 
 class Client:
 class Client:
     def __init__(self, host='localhost', port=3001,
     def __init__(self, host='localhost', port=3001,
                  buffer_chunk_size=10**4, buffer_length_limit=10**4,
                  buffer_chunk_size=10**4, buffer_length_limit=10**4,
-                 password=None):
-        self._password = password
+                 password=None, token=None):
         self._host = host
         self._host = host
         self._port = port
         self._port = port
         self._stopping = False
         self._stopping = False
@@ -23,6 +24,8 @@ class Client:
         self._buffer_length_limit = buffer_length_limit
         self._buffer_length_limit = buffer_length_limit
         self._file_path = None
         self._file_path = None
         self._working = False
         self._working = False
+        self._token = token
+        self._password = password
         self._ssl_context = None
         self._ssl_context = None
         self._encryption_complete = False
         self._encryption_complete = False
 
 
@@ -61,6 +64,10 @@ class Client:
     def set_ssl_context(self, ssl_context: ssl.SSLContext):
     def set_ssl_context(self, ssl_context: ssl.SSLContext):
         self._ssl_context = ssl_context
         self._ssl_context = ssl_context
 
 
+    @property
+    def token(self):
+        return self._token
+
     @property
     @property
     def password(self):
     def password(self):
         """Password for file encryption or decryption."""
         """Password for file encryption or decryption."""
@@ -72,12 +79,31 @@ class Client:
 
 
     async def run_sending_client(self, file_path='~/output.txt'):
     async def run_sending_client(self, file_path='~/output.txt'):
         self._file_path = file_path
         self._file_path = file_path
-        reader, writer = await asyncio.open_connection(host=self.host,
-                                                       port=self.port,
-                                                       ssl=self.ssl_context)
-        writer.write("sender\n".encode('utf-8'))
+        file_name = os.path.basename(os.path.abspath(file_path))
+        file_size = os.path.getsize(os.path.abspath(file_path))
+        try:
+            reader, writer = await asyncio.open_connection(
+                host=self.host,
+                port=self.port,
+                ssl=self.ssl_context
+            )
+        except ConnectionRefusedError as exception:
+            logging.error(exception)
+            return
+        writer.write(
+            f"s|{self.token}|{file_name}|{file_size}\n".encode('utf-8')
+        )
         await writer.drain()
         await writer.drain()
-        await reader.readline()  # Wait for server start signal
+        # Wait for server start signal
+        while 1:
+            server_hello = await reader.readline()
+            if not server_hello:
+                logging.info("Server disconnected.")
+                return
+            server_hello = server_hello.decode('utf-8').strip('\n')
+            if server_hello == 'start!':
+                break
+            logging.info(f"Server said: {server_hello}")
         await self.send(writer=writer)
         await self.send(writer=writer)
 
 
     async def encrypt_file(self, input_file, output_file):
     async def encrypt_file(self, input_file, output_file):
@@ -142,12 +168,27 @@ class Client:
 
 
     async def run_receiving_client(self, file_path='~/input.txt'):
     async def run_receiving_client(self, file_path='~/input.txt'):
         self._file_path = file_path
         self._file_path = file_path
-        reader, writer = await asyncio.open_connection(host=self.host,
-                                                       port=self.port,
-                                                       ssl=self.ssl_context)
-        writer.write("receiver\n".encode('utf-8'))
+        try:
+            reader, writer = await asyncio.open_connection(
+                host=self.host,
+                port=self.port,
+                ssl=self.ssl_context
+            )
+        except ConnectionRefusedError as exception:
+            logging.error(exception)
+            return
+        writer.write(f"r|{self.token}\n".encode('utf-8'))
         await writer.drain()
         await writer.drain()
-        await reader.readline()  # Wait for server start signal
+        # Wait for server start signal
+        while 1:
+            server_hello = await reader.readline()
+            if not server_hello:
+                logging.info("Server disconnected.")
+                return
+            server_hello = server_hello.decode('utf-8').strip('\n')
+            if server_hello == 'start!':
+                break
+            logging.info(f"Server said: {server_hello}")
         await self.receive(reader=reader)
         await self.receive(reader=reader)
 
 
     async def receive(self, reader: asyncio.StreamReader):
     async def receive(self, reader: asyncio.StreamReader):
@@ -258,6 +299,11 @@ if __name__ == '__main__':
                             default=None,
                             default=None,
                             required=False,
                             required=False,
                             help='Password for file encryption or decryption')
                             help='Password for file encryption or decryption')
+    cli_parser.add_argument('--token', '--t', '--session_token', type=str,
+                            default=None,
+                            required=False,
+                            help='Session token '
+                                 '(must be the same for both clients)')
     cli_parser.add_argument('others',
     cli_parser.add_argument('others',
                             metavar='R or S',
                             metavar='R or S',
                             nargs='*',
                             nargs='*',
@@ -268,6 +314,7 @@ if __name__ == '__main__':
     _action = get_action(args['action'])
     _action = get_action(args['action'])
     _file_path = args['path']
     _file_path = args['path']
     _password = args['password']
     _password = args['password']
+    _token = args['token']
 
 
     # If _host and _port are not provided from command-line, try to import them
     # If _host and _port are not provided from command-line, try to import them
     if _host is None:
     if _host is None:
@@ -303,6 +350,11 @@ if __name__ == '__main__':
             from config import password as _password
             from config import password as _password
         except ImportError:
         except ImportError:
             _password = None
             _password = None
+    if _token is None:
+        try:
+            from config import token as _token
+        except ImportError:
+            _token = None
 
 
     # If import fails, prompt user for _host or _port
     # If import fails, prompt user for _host or _port
     while _host is None:
     while _host is None:
@@ -328,11 +380,29 @@ if __name__ == '__main__':
             "Your file will be unencoded unless you provide a password in "
             "Your file will be unencoded unless you provide a password in "
             "config file."
             "config file."
         )
         )
+    if _token is None and _action == 'send':
+        # Generate a random [6-10] chars-long alphanumerical token
+        _token = ''.join(
+            random.SystemRandom().choice(
+                string.ascii_uppercase + string.digits
+            )
+            for _ in range(random.SystemRandom().randint(6, 10))
+        )
+        logging.info(
+            "You have not provided a token for this connection.\n"
+            f"A token has been generated for you:\t\t{_token}\n"
+            "Your peer must be informed of this token.\n"
+            "For future connections, you may provide a custom token writing "
+            "it in config file."
+        )
+    while _token is None or not (6 <= len(_token) <= 10):
+        _token = input("Please enter a 6-10 chars token.\t\t\t\t")
     loop = asyncio.get_event_loop()
     loop = asyncio.get_event_loop()
     client = Client(
     client = Client(
         host=_host,
         host=_host,
         port=_port,
         port=_port,
-        password=_password
+        password=_password,
+        token=_token
     )
     )
     try:
     try:
         from config import certificate
         from config import certificate

+ 86 - 41
src/server.py

@@ -10,9 +10,9 @@ class Server:
                  buffer_chunk_size=10**4, buffer_length_limit=10**4):
                  buffer_chunk_size=10**4, buffer_length_limit=10**4):
         self._host = host
         self._host = host
         self._port = port
         self._port = port
-        self._stopping = False
-        # Shared queue of bytes
-        self.buffer = collections.deque()
+        self.connections = collections.OrderedDict()
+        # Dict of queues of bytes
+        self.buffers = collections.OrderedDict()
         # How many bytes per chunk
         # How many bytes per chunk
         self._buffer_chunk_size = buffer_chunk_size
         self._buffer_chunk_size = buffer_chunk_size
         # How many chunks in buffer
         # How many chunks in buffer
@@ -29,10 +29,6 @@ class Server:
     def port(self) -> int:
     def port(self) -> int:
         return self._port
         return self._port
 
 
-    @property
-    def stopping(self) -> bool:
-        return self._stopping
-
     @property
     @property
     def buffer_length_limit(self) -> int:
     def buffer_length_limit(self) -> int:
         return self._buffer_length_limit
         return self._buffer_length_limit
@@ -53,28 +49,40 @@ class Server:
     def ssl_context(self) -> ssl.SSLContext:
     def ssl_context(self) -> ssl.SSLContext:
         return self._ssl_context
         return self._ssl_context
 
 
+    @property
+    def buffer_is_full(self):
+        return (
+            sum(len(buffer)
+                for buffer in self.buffers.values())
+            >= self.buffer_length_limit
+        )
+
     def set_ssl_context(self, ssl_context: ssl.SSLContext):
     def set_ssl_context(self, ssl_context: ssl.SSLContext):
         self._ssl_context = ssl_context
         self._ssl_context = ssl_context
 
 
-    async def run_reader(self, reader):
-        while not self.stopping:
+    async def run_reader(self, reader, connection_token):
+        while 1:
             try:
             try:
-                # Stop if buffer is full
-                while len(self.buffer) >= self.buffer_length_limit:
+                # Wait one second if buffer is full
+                while self.buffer_is_full:
                     await asyncio.sleep(1)
                     await asyncio.sleep(1)
                     continue
                     continue
                 input_data = await reader.read(self.buffer_chunk_size)
                 input_data = await reader.read(self.buffer_chunk_size)
-                self.buffer.append(input_data)
+                if connection_token not in self.buffers:
+                    break
+                self.buffers[connection_token].append(input_data)
             except Exception as e:
             except Exception as e:
-                logging.error(e)
+                logging.error(e, exc_info=True)
 
 
-    async def run_writer(self, writer):
+    async def run_writer(self, writer, connection_token):
         consecutive_interruptions = 0
         consecutive_interruptions = 0
         errors = 0
         errors = 0
-        while not self.stopping:
+        while 1:
             try:
             try:
                 try:
                 try:
-                    input_data = self.buffer.popleft()
+                    if connection_token not in self.buffers:
+                        break
+                    input_data = self.buffers[connection_token].popleft()
                 except IndexError:
                 except IndexError:
                     # Slow down if buffer is short
                     # Slow down if buffer is short
                     consecutive_interruptions += 1
                     consecutive_interruptions += 1
@@ -89,7 +97,7 @@ class Server:
                 writer.write(input_data)
                 writer.write(input_data)
                 await writer.drain()
                 await writer.drain()
             except Exception as e:
             except Exception as e:
-                logging.error(e)
+                logging.error(e, exc_info=True)
                 errors += 1
                 errors += 1
                 if errors > 3:
                 if errors > 3:
                     break
                     break
@@ -104,25 +112,70 @@ class Server:
         Decide whether client is sender or receiver and start transmission.
         Decide whether client is sender or receiver and start transmission.
         """
         """
         client_hello = await reader.readline()
         client_hello = await reader.readline()
-        peer_is_sender = client_hello.decode('utf-8') == 'sender\n'
+        client_hello = client_hello.decode('utf-8').strip('\n').split('|')
+        peer_is_sender = client_hello[0] == 's'
+        connection_token = client_hello[1]
+        if connection_token not in self.connections:
+            self.connections[connection_token] = dict(
+                sender=False,
+                receiver=False
+            )
         if peer_is_sender:
         if peer_is_sender:
-            self._working = True
+            if self.connections[connection_token]['sender']:
+                writer.write(
+                    "Invalid token! "
+                    "A sender client is already connected!\n".encode('utf-8')
+                )
+                await writer.drain()
+                writer.close()
+                return
+            self.connections[connection_token]['sender'] = True
+            self.buffers[connection_token] = collections.deque()
             logging.info("Sender is connecting...")
             logging.info("Sender is connecting...")
+            index, step = 0, 1
+            while not self.connections[connection_token]['receiver']:
+                index += 1
+                if index >= step:
+                    writer.write("Waiting for receiver...\n".encode('utf-8'))
+                    await writer.drain()
+                    step += 1
+                    index = 0
+                await asyncio.sleep(.5)
             # Send start signal to client
             # Send start signal to client
-            writer.write("Start!\n".encode('utf-8'))
+            writer.write("start!\n".encode('utf-8'))
             await writer.drain()
             await writer.drain()
-            await self.run_reader(reader=reader)
+            logging.info("Incoming transmission starting...")
+            await self.run_reader(reader=reader,
+                                  connection_token=connection_token)
             logging.info("Incoming transmission ended")
             logging.info("Incoming transmission ended")
         else:
         else:
+            if self.connections[connection_token]['receiver']:
+                writer.write(
+                    "Invalid token! "
+                    "A receiver client is already connected!\n".encode('utf-8')
+                )
+                await writer.drain()
+                writer.close()
+                return
+            self.connections[connection_token]['receiver'] = True
             logging.info("Receiver is connecting...")
             logging.info("Receiver is connecting...")
-            while len(self.buffer) == 0:
+            index, step = 0, 1
+            while not self.connections[connection_token]['sender']:
+                index += 1
+                if index >= step:
+                    writer.write("Waiting for sender...\n".encode('utf-8'))
+                    await writer.drain()
+                    step += 1
+                    index = 0
                 await asyncio.sleep(.5)
                 await asyncio.sleep(.5)
             # Send start signal to client
             # Send start signal to client
-            writer.write("Start!\n".encode('utf-8'))
+            writer.write("start!\n".encode('utf-8'))
             await writer.drain()
             await writer.drain()
-            await self.run_writer(writer=writer)
+            await self.run_writer(writer=writer,
+                                  connection_token=connection_token)
             logging.info("Outgoing transmission ended")
             logging.info("Outgoing transmission ended")
-            self._working = False
+            del self.buffers[connection_token]
+            del self.connections[connection_token]
         return
         return
 
 
     def run(self):
     def run(self):
@@ -149,23 +202,11 @@ class Server:
             port=self.port,
             port=self.port,
         )
         )
         async with self.server:
         async with self.server:
-            try:
-                await self.server.serve_forever()
-            except KeyboardInterrupt:
-                logging.info("Stopping...")
-                self.server.close()
-                await self.server.wait_closed()
+            await self.server.serve_forever()
         return
         return
 
 
-    def stop(self, *_):
-        if self.working and not self.stopping:
-            logging.info("Received interruption signal, stopping...")
-            self._stopping = True
-        else:
-            raise KeyboardInterrupt("Not working yet...")
-
 
 
-if __name__ == '__main__':
+def main():
     # noinspection SpellCheckingInspection
     # noinspection SpellCheckingInspection
     log_formatter = logging.Formatter(
     log_formatter = logging.Formatter(
         "%(asctime)s [%(module)-15s %(levelname)-8s]     %(message)s",
         "%(asctime)s [%(module)-15s %(levelname)-8s]     %(message)s",
@@ -221,12 +262,16 @@ if __name__ == '__main__':
         port=_port,
         port=_port,
     )
     )
     try:
     try:
+        # noinspection PyUnresolvedReferences
         from config import certificate, key
         from config import certificate, key
         _ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
         _ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
         _ssl_context.check_hostname = False
         _ssl_context.check_hostname = False
         _ssl_context.load_cert_chain(certificate, key)
         _ssl_context.load_cert_chain(certificate, key)
         server.set_ssl_context(_ssl_context)
         server.set_ssl_context(_ssl_context)
     except ImportError:
     except ImportError:
-        logging.info("Please consider using SSL.")
-        certificate, key = None, None
+        logging.warning("Please consider using SSL.")
     server.run()
     server.run()
+
+
+if __name__ == '__main__':
+    main()