Queer European MD passionate about IT
瀏覽代碼

IT WORKS!

Davte 5 年之前
父節點
當前提交
248d4ccb88
共有 2 個文件被更改,包括 44 次插入20 次删除
  1. 7 2
      src/client.py
  2. 37 18
      src/server.py

+ 7 - 2
src/client.py

@@ -55,8 +55,12 @@ class Client:
         with open(self.file_path, 'rb') as file_to_send:
             while not self.stopping:
                 output_data = file_to_send.read(self.buffer_chunk_size)
+                if not output_data:
+                    break
                 writer.write(output_data)
                 await writer.drain()
+        writer.write_eof()
+        await writer.drain()
 
     async def run_receiving_client(self, file_path='~/input.txt'):
         self._file_path = file_path
@@ -68,9 +72,10 @@ class Client:
         with open(self.file_path, 'wb') as file_to_receive:
             while not self.stopping:
                 input_data = await reader.read(self.buffer_chunk_size)
+                if reader.at_eof():
+                    break
                 if not input_data:
                     continue
-                print(input_data)
                 file_to_receive.write(input_data)
 
     def stop(self, *_):
@@ -112,7 +117,7 @@ if __name__ == '__main__':
     loop = asyncio.get_event_loop()
     client = Client(
         host='127.0.0.1',
-        port=5000,
+        port=(5000 if action == 'send' else 5001),
     )
     # loop.add_signal_handler(signal.SIGINT, client.stop, loop)
     logging.info("Starting client...")

+ 37 - 18
src/server.py

@@ -1,27 +1,32 @@
 import asyncio
 import collections
 import logging
-import signal
 
 
 class Server:
-    def __init__(self, host='localhost', port=3001,
+    def __init__(self, host='localhost', input_port=5000, output_port=5001,
                  buffer_chunk_size=10**4, buffer_length_limit=10**4):
         self._host = host
-        self._port = port
+        self._input_port = input_port
+        self._output_port = output_port
         self._stopping = False
         self.buffer = collections.deque()  # Shared queue of bytes
         self._buffer_chunk_size = buffer_chunk_size   # How many bytes per chunk
         self._buffer_length_limit = buffer_length_limit  # How many chunks in buffer
         self._working = False
+        self.at_eof = False
 
     @property
     def host(self) -> str:
         return self._host
 
     @property
-    def port(self) -> int:
-        return self._port
+    def input_port(self) -> int:
+        return self._input_port
+
+    @property
+    def output_port(self) -> int:
+        return self._output_port
 
     @property
     def stopping(self) -> bool:
@@ -46,10 +51,9 @@ class Server:
                 while len(self.buffer) >= self.buffer_length_limit:
                     await asyncio.sleep(1)
                     continue
-                try:
-                    input_data = await reader.readexactly(self.buffer_chunk_size)
-                except asyncio.IncompleteReadError as e:
-                    input_data = e.partial
+                input_data = await reader.read(self.buffer_chunk_size)
+                if reader.at_eof():
+                    self.at_eof = True
                 self.buffer.append(input_data)
             except Exception as e:
                 logging.error(e)
@@ -63,26 +67,39 @@ class Server:
                 try:
                     input_data = self.buffer.popleft()
                 except IndexError:
-                    continue
+                    if not self.at_eof:
+                        continue
+                    else:
+                        writer.write_eof()
+                        await writer.drain()
+                        self.at_eof = False
+                        break
                 writer.write(input_data)
                 await writer.drain()
             except Exception as e:
                 logging.error(e)
 
-    async def forward_bytes(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
+    # noinspection PyUnusedLocal
+    async def handle_incoming_data(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
         self._working = True
         asyncio.ensure_future(self.run_reader(reader=reader))
+
+    # noinspection PyUnusedLocal
+    async def handle_outgoing_data(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
+        self._working = True
         asyncio.ensure_future(self.run_writer(writer=writer))
 
     async def run_server(self):
-        reader_server = await asyncio.start_server(client_connected_cb=self.forward_bytes,
-                                                   host=self.host, port=self.port)
+        reader_server = await asyncio.start_server(client_connected_cb=self.handle_incoming_data,
+                                                   host=self.host, port=self.input_port)
+        await asyncio.start_server(client_connected_cb=self.handle_outgoing_data,
+                                   host=self.host, port=self.output_port)
         async with reader_server:
             await reader_server.serve_forever()
         return
 
     def stop(self, *_):
-        if self.working:
+        if self.working and not self.stopping:
             logging.info("Received interruption signal, stopping...")
             self._stopping = True
         else:
@@ -105,10 +122,12 @@ if __name__ == '__main__':
     loop = asyncio.get_event_loop()
     server = Server(
         host='127.0.0.1',
-        port=5000,
+        input_port=5000,
+        output_port=5001
     )
-    # loop.add_signal_handler(signal.SIGINT, server.stop, loop)
     logging.info("Starting file bridging server...")
-    loop.run_until_complete(server.run_server())
+    try:
+        loop.run_until_complete(server.run_server())
+    except KeyboardInterrupt:
+        logging.info("Stopping...")
     loop.close()
-    logging.info("Stopped server")