Queer European MD passionate about IT
Quellcode durchsuchen

Pass file information to receiver client

Davte vor 5 Jahren
Ursprung
Commit
4f01831169
2 geänderte Dateien mit 56 neuen und 18 gelöschten Zeilen
  1. 21 2
      src/client.py
  2. 35 16
      src/server.py

+ 21 - 2
src/client.py

@@ -2,7 +2,6 @@ import argparse
 import asyncio
 import collections
 import logging
-# import signal
 import os
 import random
 import ssl
@@ -28,6 +27,8 @@ class Client:
         self._password = password
         self._ssl_context = None
         self._encryption_complete = False
+        self._file_name = None
+        self._file_size = None
 
     @property
     def host(self) -> str:
@@ -77,6 +78,14 @@ class Client:
     def encryption_complete(self):
         return self._encryption_complete
 
+    @property
+    def file_name(self):
+        return self._file_name
+
+    @property
+    def file_size(self):
+        return self._file_size
+
     async def run_sending_client(self, file_path='~/output.txt'):
         self._file_path = file_path
         file_name = os.path.basename(os.path.abspath(file_path))
@@ -186,7 +195,11 @@ class Client:
                 logging.info("Server disconnected.")
                 return
             server_hello = server_hello.decode('utf-8').strip('\n')
-            if server_hello == 'start!':
+            if server_hello.startswith('info'):
+                _, file_name, file_size = server_hello.split('|')
+                self.set_file_information(file_name=file_name,
+                                          file_size=file_size)
+            elif server_hello == 'start!':
                 break
             logging.info(f"Server said: {server_hello}")
         await self.receive(reader=reader)
@@ -233,6 +246,12 @@ class Client:
         else:
             raise KeyboardInterrupt("Not working yet...")
 
+    def set_file_information(self, file_name=None, file_size=None):
+        if file_name is not None:
+            self._file_name = file_name
+        if file_size is not None:
+            self._file_size = file_size
+
 
 def get_action(action):
     """Parse abbreviations for `action`."""

+ 35 - 16
src/server.py

@@ -113,23 +113,27 @@ class Server:
         """
         client_hello = await reader.readline()
         client_hello = client_hello.decode('utf-8').strip('\n').split('|')
-        peer_is_sender = client_hello[0] == 's'
+        if len(client_hello) not in (2, 4,):
+            await self.refuse_connection(writer=writer,
+                                         message="Invalid client_hello!")
+            return
         connection_token = client_hello[1]
         if connection_token not in self.connections:
             self.connections[connection_token] = dict(
                 sender=False,
                 receiver=False
             )
-        if peer_is_sender:
+        if client_hello[0] == 's':
             if self.connections[connection_token]['sender']:
-                writer.write(
-                    "Invalid token! "
-                    "A sender client is already connected!\n".encode('utf-8')
+                await self.refuse_connection(
+                    writer=writer,
+                    message="Invalid token! "
+                            "A sender client is already connected!\n"
                 )
-                await writer.drain()
-                writer.close()
                 return
             self.connections[connection_token]['sender'] = True
+            self.connections[connection_token]['file_name'] = client_hello[2]
+            self.connections[connection_token]['file_size'] = client_hello[3]
             self.buffers[connection_token] = collections.deque()
             logging.info("Sender is connecting...")
             index, step = 0, 1
@@ -148,14 +152,13 @@ class Server:
             await self.run_reader(reader=reader,
                                   connection_token=connection_token)
             logging.info("Incoming transmission ended")
-        else:
+        else:  # Receiver client connection
             if self.connections[connection_token]['receiver']:
-                writer.write(
-                    "Invalid token! "
-                    "A receiver client is already connected!\n".encode('utf-8')
+                await self.refuse_connection(
+                    writer=writer,
+                    message="Invalid token! "
+                            "A receiver client is already connected!\n"
                 )
-                await writer.drain()
-                writer.close()
                 return
             self.connections[connection_token]['receiver'] = True
             logging.info("Receiver is connecting...")
@@ -168,7 +171,13 @@ class Server:
                     step += 1
                     index = 0
                 await asyncio.sleep(.5)
-            # Send start signal to client
+            # Send file information and start signal to client
+            writer.write(
+                "info|"
+                f"{self.connections[connection_token]['file_name']}|"
+                f"{self.connections[connection_token]['file_size']}"
+                "\n".encode('utf-8')
+            )
             writer.write("start!\n".encode('utf-8'))
             await writer.drain()
             await self.run_writer(writer=writer,
@@ -176,7 +185,6 @@ class Server:
             logging.info("Outgoing transmission ended")
             del self.buffers[connection_token]
             del self.connections[connection_token]
-        return
 
     def run(self):
         loop = asyncio.get_event_loop()
@@ -203,7 +211,18 @@ class Server:
         )
         async with self.server:
             await self.server.serve_forever()
-        return
+
+    @staticmethod
+    async def refuse_connection(writer: asyncio.StreamWriter,
+                                message: str = None):
+        """Send a `message` via writer and close it."""
+        if message is None:
+            message = "Connection refused!\n"
+        writer.write(
+            message.encode('utf-8')
+        )
+        await writer.drain()
+        writer.close()
 
 
 def main():