|
@@ -10,9 +10,9 @@ class Server:
|
|
buffer_chunk_size=10**4, buffer_length_limit=10**4):
|
|
buffer_chunk_size=10**4, buffer_length_limit=10**4):
|
|
self._host = host
|
|
self._host = host
|
|
self._port = port
|
|
self._port = port
|
|
- self._stopping = False
|
|
|
|
- # Shared queue of bytes
|
|
|
|
- self.buffer = collections.deque()
|
|
|
|
|
|
+ self.connections = collections.OrderedDict()
|
|
|
|
+ # Dict of queues of bytes
|
|
|
|
+ self.buffers = collections.OrderedDict()
|
|
# How many bytes per chunk
|
|
# How many bytes per chunk
|
|
self._buffer_chunk_size = buffer_chunk_size
|
|
self._buffer_chunk_size = buffer_chunk_size
|
|
# How many chunks in buffer
|
|
# How many chunks in buffer
|
|
@@ -29,10 +29,6 @@ class Server:
|
|
def port(self) -> int:
|
|
def port(self) -> int:
|
|
return self._port
|
|
return self._port
|
|
|
|
|
|
- @property
|
|
|
|
- def stopping(self) -> bool:
|
|
|
|
- return self._stopping
|
|
|
|
-
|
|
|
|
@property
|
|
@property
|
|
def buffer_length_limit(self) -> int:
|
|
def buffer_length_limit(self) -> int:
|
|
return self._buffer_length_limit
|
|
return self._buffer_length_limit
|
|
@@ -53,28 +49,40 @@ class Server:
|
|
def ssl_context(self) -> ssl.SSLContext:
|
|
def ssl_context(self) -> ssl.SSLContext:
|
|
return self._ssl_context
|
|
return self._ssl_context
|
|
|
|
|
|
|
|
+ @property
|
|
|
|
+ def buffer_is_full(self):
|
|
|
|
+ return (
|
|
|
|
+ sum(len(buffer)
|
|
|
|
+ for buffer in self.buffers.values())
|
|
|
|
+ >= self.buffer_length_limit
|
|
|
|
+ )
|
|
|
|
+
|
|
def set_ssl_context(self, ssl_context: ssl.SSLContext):
|
|
def set_ssl_context(self, ssl_context: ssl.SSLContext):
|
|
self._ssl_context = ssl_context
|
|
self._ssl_context = ssl_context
|
|
|
|
|
|
- async def run_reader(self, reader):
|
|
|
|
- while not self.stopping:
|
|
|
|
|
|
+ async def run_reader(self, reader, connection_token):
|
|
|
|
+ while 1:
|
|
try:
|
|
try:
|
|
- # Stop if buffer is full
|
|
|
|
- while len(self.buffer) >= self.buffer_length_limit:
|
|
|
|
|
|
+ # Wait one second if buffer is full
|
|
|
|
+ while self.buffer_is_full:
|
|
await asyncio.sleep(1)
|
|
await asyncio.sleep(1)
|
|
continue
|
|
continue
|
|
input_data = await reader.read(self.buffer_chunk_size)
|
|
input_data = await reader.read(self.buffer_chunk_size)
|
|
- self.buffer.append(input_data)
|
|
|
|
|
|
+ if connection_token not in self.buffers:
|
|
|
|
+ break
|
|
|
|
+ self.buffers[connection_token].append(input_data)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
- logging.error(e)
|
|
|
|
|
|
+ logging.error(e, exc_info=True)
|
|
|
|
|
|
- async def run_writer(self, writer):
|
|
|
|
|
|
+ async def run_writer(self, writer, connection_token):
|
|
consecutive_interruptions = 0
|
|
consecutive_interruptions = 0
|
|
errors = 0
|
|
errors = 0
|
|
- while not self.stopping:
|
|
|
|
|
|
+ while 1:
|
|
try:
|
|
try:
|
|
try:
|
|
try:
|
|
- input_data = self.buffer.popleft()
|
|
|
|
|
|
+ if connection_token not in self.buffers:
|
|
|
|
+ break
|
|
|
|
+ input_data = self.buffers[connection_token].popleft()
|
|
except IndexError:
|
|
except IndexError:
|
|
# Slow down if buffer is short
|
|
# Slow down if buffer is short
|
|
consecutive_interruptions += 1
|
|
consecutive_interruptions += 1
|
|
@@ -89,7 +97,7 @@ class Server:
|
|
writer.write(input_data)
|
|
writer.write(input_data)
|
|
await writer.drain()
|
|
await writer.drain()
|
|
except Exception as e:
|
|
except Exception as e:
|
|
- logging.error(e)
|
|
|
|
|
|
+ logging.error(e, exc_info=True)
|
|
errors += 1
|
|
errors += 1
|
|
if errors > 3:
|
|
if errors > 3:
|
|
break
|
|
break
|
|
@@ -104,25 +112,70 @@ class Server:
|
|
Decide whether client is sender or receiver and start transmission.
|
|
Decide whether client is sender or receiver and start transmission.
|
|
"""
|
|
"""
|
|
client_hello = await reader.readline()
|
|
client_hello = await reader.readline()
|
|
- peer_is_sender = client_hello.decode('utf-8') == 'sender\n'
|
|
|
|
|
|
+ client_hello = client_hello.decode('utf-8').strip('\n').split('|')
|
|
|
|
+ peer_is_sender = client_hello[0] == 's'
|
|
|
|
+ connection_token = client_hello[1]
|
|
|
|
+ if connection_token not in self.connections:
|
|
|
|
+ self.connections[connection_token] = dict(
|
|
|
|
+ sender=False,
|
|
|
|
+ receiver=False
|
|
|
|
+ )
|
|
if peer_is_sender:
|
|
if peer_is_sender:
|
|
- self._working = True
|
|
|
|
|
|
+ if self.connections[connection_token]['sender']:
|
|
|
|
+ writer.write(
|
|
|
|
+ "Invalid token! "
|
|
|
|
+ "A sender client is already connected!\n".encode('utf-8')
|
|
|
|
+ )
|
|
|
|
+ await writer.drain()
|
|
|
|
+ writer.close()
|
|
|
|
+ return
|
|
|
|
+ self.connections[connection_token]['sender'] = True
|
|
|
|
+ self.buffers[connection_token] = collections.deque()
|
|
logging.info("Sender is connecting...")
|
|
logging.info("Sender is connecting...")
|
|
|
|
+ index, step = 0, 1
|
|
|
|
+ while not self.connections[connection_token]['receiver']:
|
|
|
|
+ index += 1
|
|
|
|
+ if index >= step:
|
|
|
|
+ writer.write("Waiting for receiver...\n".encode('utf-8'))
|
|
|
|
+ await writer.drain()
|
|
|
|
+ step += 1
|
|
|
|
+ index = 0
|
|
|
|
+ await asyncio.sleep(.5)
|
|
# Send start signal to client
|
|
# Send start signal to client
|
|
- writer.write("Start!\n".encode('utf-8'))
|
|
|
|
|
|
+ writer.write("start!\n".encode('utf-8'))
|
|
await writer.drain()
|
|
await writer.drain()
|
|
- await self.run_reader(reader=reader)
|
|
|
|
|
|
+ logging.info("Incoming transmission starting...")
|
|
|
|
+ await self.run_reader(reader=reader,
|
|
|
|
+ connection_token=connection_token)
|
|
logging.info("Incoming transmission ended")
|
|
logging.info("Incoming transmission ended")
|
|
else:
|
|
else:
|
|
|
|
+ if self.connections[connection_token]['receiver']:
|
|
|
|
+ writer.write(
|
|
|
|
+ "Invalid token! "
|
|
|
|
+ "A receiver client is already connected!\n".encode('utf-8')
|
|
|
|
+ )
|
|
|
|
+ await writer.drain()
|
|
|
|
+ writer.close()
|
|
|
|
+ return
|
|
|
|
+ self.connections[connection_token]['receiver'] = True
|
|
logging.info("Receiver is connecting...")
|
|
logging.info("Receiver is connecting...")
|
|
- while len(self.buffer) == 0:
|
|
|
|
|
|
+ index, step = 0, 1
|
|
|
|
+ while not self.connections[connection_token]['sender']:
|
|
|
|
+ index += 1
|
|
|
|
+ if index >= step:
|
|
|
|
+ writer.write("Waiting for sender...\n".encode('utf-8'))
|
|
|
|
+ await writer.drain()
|
|
|
|
+ step += 1
|
|
|
|
+ index = 0
|
|
await asyncio.sleep(.5)
|
|
await asyncio.sleep(.5)
|
|
# Send start signal to client
|
|
# Send start signal to client
|
|
- writer.write("Start!\n".encode('utf-8'))
|
|
|
|
|
|
+ writer.write("start!\n".encode('utf-8'))
|
|
await writer.drain()
|
|
await writer.drain()
|
|
- await self.run_writer(writer=writer)
|
|
|
|
|
|
+ await self.run_writer(writer=writer,
|
|
|
|
+ connection_token=connection_token)
|
|
logging.info("Outgoing transmission ended")
|
|
logging.info("Outgoing transmission ended")
|
|
- self._working = False
|
|
|
|
|
|
+ del self.buffers[connection_token]
|
|
|
|
+ del self.connections[connection_token]
|
|
return
|
|
return
|
|
|
|
|
|
def run(self):
|
|
def run(self):
|
|
@@ -149,23 +202,11 @@ class Server:
|
|
port=self.port,
|
|
port=self.port,
|
|
)
|
|
)
|
|
async with self.server:
|
|
async with self.server:
|
|
- try:
|
|
|
|
- await self.server.serve_forever()
|
|
|
|
- except KeyboardInterrupt:
|
|
|
|
- logging.info("Stopping...")
|
|
|
|
- self.server.close()
|
|
|
|
- await self.server.wait_closed()
|
|
|
|
|
|
+ await self.server.serve_forever()
|
|
return
|
|
return
|
|
|
|
|
|
- def stop(self, *_):
|
|
|
|
- if self.working and not self.stopping:
|
|
|
|
- logging.info("Received interruption signal, stopping...")
|
|
|
|
- self._stopping = True
|
|
|
|
- else:
|
|
|
|
- raise KeyboardInterrupt("Not working yet...")
|
|
|
|
-
|
|
|
|
|
|
|
|
-if __name__ == '__main__':
|
|
|
|
|
|
+def main():
|
|
# noinspection SpellCheckingInspection
|
|
# noinspection SpellCheckingInspection
|
|
log_formatter = logging.Formatter(
|
|
log_formatter = logging.Formatter(
|
|
"%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s",
|
|
"%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s",
|
|
@@ -221,12 +262,16 @@ if __name__ == '__main__':
|
|
port=_port,
|
|
port=_port,
|
|
)
|
|
)
|
|
try:
|
|
try:
|
|
|
|
+ # noinspection PyUnresolvedReferences
|
|
from config import certificate, key
|
|
from config import certificate, key
|
|
_ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
|
_ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
|
_ssl_context.check_hostname = False
|
|
_ssl_context.check_hostname = False
|
|
_ssl_context.load_cert_chain(certificate, key)
|
|
_ssl_context.load_cert_chain(certificate, key)
|
|
server.set_ssl_context(_ssl_context)
|
|
server.set_ssl_context(_ssl_context)
|
|
except ImportError:
|
|
except ImportError:
|
|
- logging.info("Please consider using SSL.")
|
|
|
|
- certificate, key = None, None
|
|
|
|
|
|
+ logging.warning("Please consider using SSL.")
|
|
server.run()
|
|
server.run()
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+if __name__ == '__main__':
|
|
|
|
+ main()
|