Queer European MD passionate about IT

client.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. import argparse
  2. import asyncio
  3. import collections
  4. import logging
  5. # import signal
  6. import os
  7. import ssl
  8. class Client:
  9. def __init__(self, host='localhost', port=3001,
  10. buffer_chunk_size=10**4, buffer_length_limit=10**4,
  11. password=None):
  12. self._password = password
  13. self._host = host
  14. self._port = port
  15. self._stopping = False
  16. # Shared queue of bytes
  17. self.buffer = collections.deque()
  18. # How many bytes per chunk
  19. self._buffer_chunk_size = buffer_chunk_size
  20. # How many chunks in buffer
  21. self._buffer_length_limit = buffer_length_limit
  22. self._file_path = None
  23. self._working = False
  24. self._ssl_context = None
  25. self._encryption_complete = False
  26. @property
  27. def host(self) -> str:
  28. return self._host
  29. @property
  30. def port(self) -> int:
  31. return self._port
  32. @property
  33. def stopping(self) -> bool:
  34. return self._stopping
  35. @property
  36. def buffer_length_limit(self) -> int:
  37. return self._buffer_length_limit
  38. @property
  39. def buffer_chunk_size(self) -> int:
  40. return self._buffer_chunk_size
  41. @property
  42. def file_path(self) -> str:
  43. return self._file_path
  44. @property
  45. def working(self) -> bool:
  46. return self._working
  47. @property
  48. def ssl_context(self) -> ssl.SSLContext:
  49. return self._ssl_context
  50. def set_ssl_context(self, ssl_context: ssl.SSLContext):
  51. self._ssl_context = ssl_context
  52. @property
  53. def password(self):
  54. """Password for file encryption or decryption."""
  55. return self._password
  56. @property
  57. def encryption_complete(self):
  58. return self._encryption_complete
  59. async def run_sending_client(self, file_path='~/output.txt'):
  60. self._file_path = file_path
  61. reader, writer = await asyncio.open_connection(host=self.host,
  62. port=self.port,
  63. ssl=self.ssl_context)
  64. writer.write("sender\n".encode('utf-8'))
  65. await writer.drain()
  66. await reader.readline() # Wait for server start signal
  67. await self.send(writer=writer)
  68. async def encrypt_file(self, input_file, output_file):
  69. self._encryption_complete = False
  70. logging.info("Encrypting file...")
  71. stdout, stderr = ''.encode(), ''.encode()
  72. try:
  73. _subprocess = await asyncio.create_subprocess_shell(
  74. "openssl enc -aes-256-cbc "
  75. "-md sha512 -pbkdf2 -iter 100000 -salt "
  76. f"-in \"{input_file}\" -out \"{output_file}\" "
  77. f"-pass pass:{self.password}"
  78. )
  79. stdout, stderr = await _subprocess.communicate()
  80. except Exception as e:
  81. logging.error(
  82. "Exception {e}:\n{o}\n{er}".format(
  83. e=e,
  84. o=stdout.decode().strip(),
  85. er=stderr.decode().strip()
  86. )
  87. )
  88. logging.info("Encryption completed.")
  89. self._encryption_complete = True
  90. async def send(self, writer: asyncio.StreamWriter):
  91. self._working = True
  92. file_path = self.file_path
  93. if self.password:
  94. file_path = self.file_path + '.enc'
  95. # Remove already-encrypted file if present (salt would differ)
  96. if os.path.isfile(file_path):
  97. os.remove(file_path)
  98. asyncio.ensure_future(
  99. self.encrypt_file(
  100. input_file=self.file_path,
  101. output_file=file_path
  102. )
  103. )
  104. # Give encryption an edge
  105. while not os.path.isfile(file_path):
  106. await asyncio.sleep(.5)
  107. logging.info("Sending file...")
  108. with open(file_path, 'rb') as file_to_send:
  109. while not self.stopping:
  110. output_data = file_to_send.read(self.buffer_chunk_size)
  111. if not output_data:
  112. # If encryption is in progress, wait and read again later
  113. if self.password and not self.encryption_complete:
  114. await asyncio.sleep(1)
  115. continue
  116. break
  117. try:
  118. writer.write(output_data)
  119. await writer.drain()
  120. except ConnectionResetError:
  121. logging.info('Server closed the connection.')
  122. self.stop()
  123. break
  124. writer.close()
  125. return
  126. async def run_receiving_client(self, file_path='~/input.txt'):
  127. self._file_path = file_path
  128. reader, writer = await asyncio.open_connection(host=self.host,
  129. port=self.port,
  130. ssl=self.ssl_context)
  131. writer.write("receiver\n".encode('utf-8'))
  132. await writer.drain()
  133. await reader.readline() # Wait for server start signal
  134. await self.receive(reader=reader)
  135. async def receive(self, reader: asyncio.StreamReader):
  136. self._working = True
  137. file_path = self.file_path
  138. logging.info("Receiving file...")
  139. if self.password:
  140. file_path += '.enc'
  141. with open(file_path, 'wb') as file_to_receive:
  142. while not self.stopping:
  143. input_data = await reader.read(self.buffer_chunk_size)
  144. if not input_data:
  145. break
  146. file_to_receive.write(input_data)
  147. logging.info("File received.")
  148. if self.password:
  149. logging.info("Decrypting file...")
  150. stdout, stderr = ''.encode(), ''.encode()
  151. try:
  152. _subprocess = await asyncio.create_subprocess_shell(
  153. "openssl enc -aes-256-cbc "
  154. "-md sha512 -pbkdf2 -iter 100000 -salt -d "
  155. f"-in \"{file_path}\" -out \"{self.file_path}\" "
  156. f"-pass pass:{self.password}"
  157. )
  158. stdout, stderr = await _subprocess.communicate()
  159. logging.info("Decryption completed.")
  160. except Exception as e:
  161. logging.error(
  162. "Exception {e}:\n{o}\n{er}".format(
  163. e=e,
  164. o=stdout.decode().strip(),
  165. er=stderr.decode().strip()
  166. )
  167. )
  168. logging.info("Decryption failed", exc_info=True)
  169. def stop(self, *_):
  170. if self.working:
  171. logging.info("Received interruption signal, stopping...")
  172. self._stopping = True
  173. else:
  174. raise KeyboardInterrupt("Not working yet...")
  175. def get_action(action):
  176. """Parse abbreviations for `action`."""
  177. if not isinstance(action, str):
  178. return
  179. elif action.lower().startswith('r'):
  180. return 'receive'
  181. elif action.lower().startswith('s'):
  182. return 'send'
  183. def get_file_path(path, action='receive'):
  184. """Check that file `path` is correct and return it."""
  185. if (
  186. isinstance(path, str)
  187. and action == 'send'
  188. and os.path.isfile(path)
  189. ):
  190. return path
  191. elif (
  192. isinstance(path, str)
  193. and action == 'receive'
  194. and os.access(os.path.dirname(os.path.abspath(path)), os.W_OK)
  195. ):
  196. return path
  197. elif path is not None:
  198. logging.error(f"Invalid file: `{path}`")
  199. if __name__ == '__main__':
  200. # noinspection SpellCheckingInspection
  201. log_formatter = logging.Formatter(
  202. "%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s",
  203. style='%'
  204. )
  205. root_logger = logging.getLogger()
  206. root_logger.setLevel(logging.DEBUG)
  207. console_handler = logging.StreamHandler()
  208. console_handler.setFormatter(log_formatter)
  209. console_handler.setLevel(logging.DEBUG)
  210. root_logger.addHandler(console_handler)
  211. # Parse command-line arguments
  212. cli_parser = argparse.ArgumentParser(description='Run client',
  213. allow_abbrev=False)
  214. cli_parser.add_argument('--_host', type=str,
  215. default=None,
  216. required=False,
  217. help='server address')
  218. cli_parser.add_argument('--_port', type=int,
  219. default=None,
  220. required=False,
  221. help='server _port')
  222. cli_parser.add_argument('--action', type=str,
  223. default=None,
  224. required=False,
  225. help='[S]end or [R]eceive')
  226. cli_parser.add_argument('--path', type=str,
  227. default=None,
  228. required=False,
  229. help='File path')
  230. cli_parser.add_argument('--password', '--p', '--pass', type=str,
  231. default=None,
  232. required=False,
  233. help='Password for file encryption or decryption')
  234. cli_parser.add_argument('others',
  235. metavar='R or S',
  236. nargs='*',
  237. help='[S]end or [R]eceive (see `action`)')
  238. args = vars(cli_parser.parse_args())
  239. _host = args['_host']
  240. _port = args['_port']
  241. _action = get_action(args['action'])
  242. _file_path = args['path']
  243. _password = args['password']
  244. # If _host and _port are not provided from command-line, try to import them
  245. if _host is None:
  246. try:
  247. from config import host as _host
  248. except ImportError:
  249. _host = None
  250. if _port is None:
  251. try:
  252. from config import port as _port
  253. except ImportError:
  254. _port = None
  255. # Take `s`, `r` etc. from command line as `_action`
  256. if _action is None:
  257. for arg in args['others']:
  258. _action = get_action(arg)
  259. if _action:
  260. break
  261. if _action is None:
  262. try:
  263. from config import action as _action
  264. _action = get_action(_action)
  265. except ImportError:
  266. _action = None
  267. if _file_path is None:
  268. try:
  269. from config import file_path as _file_path
  270. _file_path = get_action(_file_path)
  271. except ImportError:
  272. _file_path = None
  273. if _password is None:
  274. try:
  275. from config import password as _password
  276. except ImportError:
  277. _password = None
  278. # If import fails, prompt user for _host or _port
  279. while _host is None:
  280. _host = input("Enter _host:\t\t\t\t\t\t")
  281. while _port is None:
  282. try:
  283. _port = int(input("Enter _port:\t\t\t\t\t\t"))
  284. except ValueError:
  285. logging.info("Invalid _port. Enter a valid _port number!")
  286. _port = None
  287. while _action is None:
  288. _action = get_action(
  289. input("Do you want to (R)eceive or (S)end a file?\t\t")
  290. )
  291. while _file_path is None:
  292. _file_path = get_file_path(
  293. path=input(f"Enter file to {_action}:\t\t\t\t\t\t"),
  294. action=_action
  295. )
  296. if _password is None:
  297. logging.warning(
  298. "You have provided no password for file encryption.\n"
  299. "Your file will be unencoded unless you provide a password in "
  300. "config file."
  301. )
  302. loop = asyncio.get_event_loop()
  303. client = Client(
  304. host=_host,
  305. port=_port,
  306. password=_password
  307. )
  308. try:
  309. from config import certificate
  310. _ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
  311. _ssl_context.check_hostname = False
  312. _ssl_context.load_verify_locations(certificate)
  313. client.set_ssl_context(_ssl_context)
  314. except ImportError:
  315. logging.warning("Please consider using SSL.")
  316. certificate, key = None, None
  317. logging.info("Starting client...")
  318. if _action == 'send':
  319. loop.run_until_complete(
  320. client.run_sending_client(
  321. file_path=_file_path
  322. )
  323. )
  324. else:
  325. loop.run_until_complete(
  326. client.run_receiving_client(
  327. file_path=_file_path
  328. )
  329. )
  330. loop.close()
  331. logging.info("Stopped client")