Queer European MD passionate about IT
Browse Source

Implemented SSL

Davte 5 years ago
parent
commit
db0da8b24b
3 changed files with 45 additions and 3 deletions
  1. 3 0
      .gitignore
  2. 22 2
      src/client.py
  3. 20 1
      src/server.py

+ 3 - 0
.gitignore

@@ -3,6 +3,9 @@
 # Configuration file
 *config.py
 
+# Data folder
+data/
+
 # Byte-compiled / optimized / DLL files
 __pycache__/
 *.py[cod]

+ 22 - 2
src/client.py

@@ -4,6 +4,7 @@ import collections
 import logging
 # import signal
 import os
+import ssl
 
 
 class Client:
@@ -17,6 +18,7 @@ class Client:
         self._buffer_length_limit = buffer_length_limit  # How many chunks in buffer
         self._file_path = None
         self._working = False
+        self._ssl_context = None
 
     @property
     def host(self) -> str:
@@ -46,10 +48,18 @@ class Client:
     def working(self) -> bool:
         return self._working
 
+    @property
+    def ssl_context(self) -> ssl.SSLContext:
+        return self._ssl_context
+
+    def set_ssl_context(self, ssl_context: ssl.SSLContext):
+        self._ssl_context = ssl_context
+
     async def run_sending_client(self, file_path='~/output.txt'):
         self._file_path = file_path
         reader, writer = await asyncio.open_connection(host=self.host,
-                                                       port=self.port)
+                                                       port=self.port,
+                                                       ssl=self.ssl_context)
         writer.write("sender\n".encode('utf-8'))
         await writer.drain()
         await reader.readline()  # Wait for server start signal
@@ -78,7 +88,8 @@ class Client:
     async def run_receiving_client(self, file_path='~/input.txt'):
         self._file_path = file_path
         reader, writer = await asyncio.open_connection(host=self.host,
-                                                       port=self.port)
+                                                       port=self.port,
+                                                       ssl=self.ssl_context)
         writer.write("receiver\n".encode('utf-8'))
         await writer.drain()
         await reader.readline()  # Wait for server start signal
@@ -227,6 +238,15 @@ if __name__ == '__main__':
         host=_host,
         port=_port,
     )
+    try:
+        from config import certificate
+        _ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
+        _ssl_context.check_hostname = False
+        _ssl_context.load_verify_locations(certificate)
+        client.set_ssl_context(_ssl_context)
+    except ImportError:
+        logging.info("Please consider using SSL.")
+        certificate, key = None, None
     logging.info("Starting client...")
     if _action == 'send':
         loop.run_until_complete(

+ 20 - 1
src/server.py

@@ -2,6 +2,7 @@ import argparse
 import asyncio
 import collections
 import logging
+import ssl
 
 
 class Server:
@@ -16,6 +17,7 @@ class Server:
         self._working = False
         self.at_eof = False
         self._server = None
+        self._ssl_context = None
 
     @property
     def host(self) -> str:
@@ -45,6 +47,13 @@ class Server:
     def server(self) -> asyncio.base_events.Server:
         return self._server
 
+    @property
+    def ssl_context(self) -> ssl.SSLContext:
+        return self._ssl_context
+
+    def set_ssl_context(self, ssl_context: ssl.SSLContext):
+        self._ssl_context = ssl_context
+
     async def run_reader(self, reader):
         while not self.stopping:
             try:
@@ -121,9 +130,10 @@ class Server:
 
     async def run_server(self):
         self._server = await asyncio.start_server(
+            ssl=self.ssl_context,
             client_connected_cb=self.connect,
             host=self.host,
-            port=self.port
+            port=self.port,
         )
         async with self.server:
             try:
@@ -197,4 +207,13 @@ if __name__ == '__main__':
         host=_host,
         port=_port,
     )
+    try:
+        from config import certificate, key
+        _ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
+        _ssl_context.check_hostname = False
+        _ssl_context.load_cert_chain(certificate, key)
+        server.set_ssl_context(_ssl_context)
+    except ImportError:
+        logging.info("Please consider using SSL.")
+        certificate, key = None, None
     server.run()