| 
          """The actual WSGI app.""" | 
        
        
           | 
          from __future__ import annotations | 
        
        
           | 
          
 | 
        
        
           | 
          import base64 | 
        
        
           | 
          import binascii | 
        
        
           | 
          import datetime | 
        
        
           | 
          import hashlib | 
        
        
           | 
          import hmac | 
        
        
           | 
          import json | 
        
        
           | 
          import os | 
        
        
           | 
          import re | 
        
        
           | 
          import typing | 
        
        
           | 
          
 | 
        
        
           | 
          from Crypto.Cipher import AES, PKCS1_OAEP | 
        
        
           | 
          from Crypto.PublicKey import RSA | 
        
        
           | 
          
 | 
        
        
           | 
          import flask | 
        
        
           | 
          
 | 
        
        
           | 
          import peewee | 
        
        
           | 
          
 | 
        
        
           | 
          
 | 
        
        
           | 
          SECRETS = { | 
        
        
           | 
              'foobar': 'barbat', | 
        
        
           | 
              'batfoo': 'foobar', | 
        
        
           | 
              'barbat': 'batfoo' | 
        
        
           | 
          } | 
        
        
           | 
          
 | 
        
        
           | 
          
 | 
        
        
           | 
          class BadRequest(Exception): | 
        
        
           | 
              """A class for 400 HTTP response codes, though it works for any. | 
        
        
           | 
          
 | 
        
        
           | 
              This allows using a single Flask handler to respond to them all. | 
        
        
           | 
              """ | 
        
        
           | 
          
 | 
        
        
           | 
              def __init__(self, message: str, code: int = 400): | 
        
        
           | 
                  """Store the code and message to be handled.""" | 
        
        
           | 
                  self.code = code | 
        
        
           | 
                  self.message = message | 
        
        
           | 
                  super().__init__(message) | 
        
        
           | 
          
 | 
        
        
           | 
          
 | 
        
        
           | 
          class Server: | 
        
        
           | 
              """A class to manage encryption.""" | 
        
        
           | 
          
 | 
        
        
           | 
              def __init__(self): | 
        
        
           | 
                  """Load the asymmetric key.""" | 
        
        
           | 
                  self.key = None | 
        
        
           | 
                  self.public_key = None | 
        
        
           | 
                  self.get_asymmetric_key() | 
        
        
           | 
          
 | 
        
        
           | 
              def get_asymmetric_key(self): | 
        
        
           | 
                  """Get the private key, generate if not found.""" | 
        
        
           | 
                  try: | 
        
        
           | 
                      with open('key.pem', 'rb') as f: | 
        
        
           | 
                          raw_key = f.read() | 
        
        
           | 
                  except FileNotFoundError: | 
        
        
           | 
                      raw_key = None | 
        
        
           | 
                  if raw_key: | 
        
        
           | 
                      self.key = RSA.importKey(raw_key) | 
        
        
           | 
                  else: | 
        
        
           | 
                      self.key = RSA.generate(4096) | 
        
        
           | 
                      with open('key.pem', 'wb') as f: | 
        
        
           | 
                          f.write(self.key.exportKey()) | 
        
        
           | 
                  self.public_key = self.key.publickey().exportKey().decode() | 
        
        
           | 
          
 | 
        
        
           | 
              def decrypt_asymmetric(self, data: bytes) -> dict: | 
        
        
           | 
                  """Decrypt a JSON object encrypted with out public key.""" | 
        
        
           | 
                  cipher = PKCS1_OAEP.new(self.key) | 
        
        
           | 
                  try: | 
        
        
           | 
                      plain_text = cipher.decrypt(data).decode() | 
        
        
           | 
                  except ValueError: | 
        
        
           | 
                      raise BadRequest('Bad cipher.', 400) | 
        
        
           | 
                  try: | 
        
        
           | 
                      return json.loads(plain_text) | 
        
        
           | 
                  except json.decoder.JSONDecodeError: | 
        
        
           | 
                      raise BadRequest('Invalid JSON.', 400) | 
        
        
           | 
          
 | 
        
        
           | 
              def create_account(self, data: bytes): | 
        
        
           | 
                  """Create an account, including decrypting the request.""" | 
        
        
           | 
                  data = self.decrypt_asymmetric(flask.request.get_data()) | 
        
        
           | 
                  try: | 
        
        
           | 
                      User.create(username=data['username'], password=data['password']) | 
        
        
           | 
                  except KeyError: | 
        
        
           | 
                      raise BadRequest('Username or password missing.', 400) | 
        
        
           | 
                  except peewee.IntegrityError: | 
        
        
           | 
                      raise BadRequest('Username already taken.', 409) | 
        
        
           | 
          
 | 
        
        
           | 
              def start_session(self, token: bytes) -> int: | 
        
        
           | 
                  """Recieve, validate and store a recieved session token.""" | 
        
        
           | 
                  data = self.decrypt_asymmetric(token) | 
        
        
           | 
                  try: | 
        
        
           | 
                      data['key'] = base64.b64decode(data['key']) | 
        
        
           | 
                  except KeyError: | 
        
        
           | 
                      raise BadRequest('Key missing.', 400) | 
        
        
           | 
                  except binascii.Error: | 
        
        
           | 
                      raise BadRequest('Bad base 64 encoding of key.', 400) | 
        
        
           | 
                  if not (data.get('username') and data.get('password')): | 
        
        
           | 
                      raise BadRequest('Username or password missing.', 400) | 
        
        
           | 
                  user = User.get(User.username == data['username']) | 
        
        
           | 
                  if not user.password == data['password']: | 
        
        
           | 
                      raise BadRequest('Incorrect password.', 401) | 
        
        
           | 
                  session = UserSession.create(key=data['key'], user=user) | 
        
        
           | 
                  return session.id | 
        
        
           | 
          
 | 
        
        
           | 
              def encrypt_message( | 
        
        
           | 
                      self, session: UserSession, plain_text: bytes) -> bytes: | 
        
        
           | 
                  """Encrypt a message with a symmetric key.""" | 
        
        
           | 
                  iv = os.urandom(AES.block_size) | 
        
        
           | 
                  cipher = AES.new(session.key, AES.MODE_CFB, iv) | 
        
        
           | 
                  return iv + cipher.encrypt(plain_text) | 
        
        
           | 
          
 | 
        
        
           | 
              def decrypt_message( | 
        
        
           | 
                      self, session: UserSession, cipher_text: str) -> bytes: | 
        
        
           | 
                  """Decrypt a message with a symmetric key.""" | 
        
        
           | 
                  cipher = AES.new(session.key, AES.MODE_CFB, os.urandom(AES.block_size)) | 
        
        
           | 
                  try: | 
        
        
           | 
                      cipher_text = base64.b64decode(cipher_text) | 
        
        
           | 
                  except binascii.Error: | 
        
        
           | 
                      raise BadRequest('Bad base 64 encoding of key.', 400) | 
        
        
           | 
                  return cipher.decrypt(cipher_text)[AES.block_size:] | 
        
        
           | 
          
 | 
        
        
           | 
              def decrypt_request( | 
        
        
           | 
                      self, params: typing.Dict[str, typing.Any]) -> typing.Tuple[ | 
        
        
           | 
                          UserSession, bytes]: | 
        
        
           | 
                  """Get the session of and decrypt a symmetrically encrypted request.""" | 
        
        
           | 
                  session_id = params.get('session_id') | 
        
        
           | 
                  if not session_id: | 
        
        
           | 
                      raise BadRequest('Session ID missing.', 400) | 
        
        
           | 
                  session = UserSession.get(UserSession.id == session_id) | 
        
        
           | 
                  request = params.get('request') | 
        
        
           | 
                  if not request: | 
        
        
           | 
                      raise BadRequest('Request missing.', 400) | 
        
        
           | 
                  return session, self.decrypt_message(session, request) | 
        
        
           | 
          
 | 
        
        
           | 
          
 | 
        
        
           | 
          class HashedPassword: | 
        
        
           | 
              """A class to hash and check passwords.""" | 
        
        
           | 
          
 | 
        
        
           | 
              @classmethod | 
        
        
           | 
              def hash_password(cls, password: str) -> HashedPassword: | 
        
        
           | 
                  """Hash a password.""" | 
        
        
           | 
                  salt = os.urandom(32) | 
        
        
           | 
                  key = hashlib.pbkdf2_hmac('sha3-256', password.encode(), salt, 100_000) | 
        
        
           | 
                  return cls(salt + key) | 
        
        
           | 
          
 | 
        
        
           | 
              def __init__(self, hashed_password: bytes): | 
        
        
           | 
                  """Store the hashed password.""" | 
        
        
           | 
                  self.hashed_password = hashed_password | 
        
        
           | 
          
 | 
        
        
           | 
              def __eq__(self, password: str) -> bool: | 
        
        
           | 
                  """Check for equality against an unhashed password.""" | 
        
        
           | 
                  salt = self.hashed_password[:32] | 
        
        
           | 
                  key = self.hashed_password[32:] | 
        
        
           | 
                  attempt_key = hashlib.pbkdf2_hmac( | 
        
        
           | 
                      'sha3-256', password.encode(), salt, 100_000 | 
        
        
           | 
                  ) | 
        
        
           | 
                  return hmac.compare_digest(key, attempt_key) | 
        
        
           | 
          
 | 
        
        
           | 
              def __bytes__(self) -> str: | 
        
        
           | 
                  """Expose the hashed password.""" | 
        
        
           | 
                  return self.hashed_password | 
        
        
           | 
          
 | 
        
        
           | 
          
 | 
        
        
           | 
          db = peewee.SqliteDatabase('db.sqlite3') | 
        
        
           | 
          app = flask.Flask(__name__) | 
        
        
           | 
          server = Server() | 
        
        
           | 
          
 | 
        
        
           | 
          
 | 
        
        
           | 
          class BaseModel(peewee.Model): | 
        
        
           | 
              """A base model, that sets the DB.""" | 
        
        
           | 
          
 | 
        
        
           | 
              class Meta: | 
        
        
           | 
                  """Set the DB and use new table names.""" | 
        
        
           | 
          
 | 
        
        
           | 
                  database = db | 
        
        
           | 
                  use_legacy_table_names = False | 
        
        
           | 
          
 | 
        
        
           | 
          
 | 
        
        
           | 
          class User(BaseModel): | 
        
        
           | 
              """A model for a user.""" | 
        
        
           | 
          
 | 
        
        
           | 
              username = peewee.CharField(max_length=32, unique=True) | 
        
        
           | 
              password_hash = peewee.BlobField() | 
        
        
           | 
              secret = peewee.CharField(null=True) | 
        
        
           | 
          
 | 
        
        
           | 
              @property | 
        
        
           | 
              def password(self) -> HashedPassword: | 
        
        
           | 
                  """Return an object that will use hashing in it's equality check.""" | 
        
        
           | 
                  return HashedPassword(self.password_hash) | 
        
        
           | 
          
 | 
        
        
           | 
              @password.setter | 
        
        
           | 
              def password(self, password: str): | 
        
        
           | 
                  """Set the password to a hash of the provided password.""" | 
        
        
           | 
                  self.password_hash = bytes(HashedPassword.hash_password(password)) | 
        
        
           | 
          
 | 
        
        
           | 
          
 | 
        
        
           | 
          class UserSession(BaseModel): | 
        
        
           | 
              """A model for a session.""" | 
        
        
           | 
          
 | 
        
        
           | 
              key = peewee.BlobField() | 
        
        
           | 
              key_updated = peewee.DateTimeField(default=datetime.datetime.now) | 
        
        
           | 
              user = peewee.ForeignKeyField(model=User) | 
        
        
           | 
          
 | 
        
        
           | 
          
 | 
        
        
           | 
          @app.route('/public_key') | 
        
        
           | 
          def get_public_key() -> str: | 
        
        
           | 
              """Get the app's public key.""" | 
        
        
           | 
              return flask.Response(server.public_key, mimetype='text/plain') | 
        
        
           | 
          
 | 
        
        
           | 
          
 | 
        
        
           | 
          @app.route('/create_account', methods=['POST']) | 
        
        
           | 
          def create_account() -> flask.Response: | 
        
        
           | 
              """Create an account.""" | 
        
        
           | 
              server.create_account(flask.request.get_data()) | 
        
        
           | 
              return flask.Response('OK', mimetype='text/plain') | 
        
        
           | 
          
 | 
        
        
           | 
          
 | 
        
        
           | 
          @app.route('/start_session', methods=['POST']) | 
        
        
           | 
          def start_session() -> flask.Response: | 
        
        
           | 
              """Start an encrypted session.""" | 
        
        
           | 
              session_id = server.start_session(flask.request.get_data()) | 
        
        
           | 
              return flask.Response(str(session_id), mimetype='text/plain') | 
        
        
           | 
          
 | 
        
        
           | 
          
 | 
        
        
           | 
          @app.route('/get_secret') | 
        
        
           | 
          def get_secret() -> flask.Response: | 
        
        
           | 
              """Process an example encrypted request and response.""" | 
        
        
           | 
              session, request = server.decrypt_request(flask.request.args) | 
        
        
           | 
              try: | 
        
        
           | 
                  secret = SECRETS[request.decode()] | 
        
        
           | 
              except KeyError: | 
        
        
           | 
                  raise BadRequest('Secret not found.', 404) | 
        
        
           | 
              response = server.encrypt_message(session, secret) | 
        
        
           | 
              return flask.Response(response, mimetype='text/plain') | 
        
        
           | 
          
 | 
        
        
           | 
          
 | 
        
        
           | 
          @app.errorhandler(peewee.DoesNotExist) | 
        
        
           | 
          def model_not_found(error: peewee.DoesNotExist) -> flask.Response: | 
        
        
           | 
              """Give an error when a model was not found.""" | 
        
        
           | 
              m = re.search('<Model: [A-Za-z]+>', str(error)) | 
        
        
           | 
              model = m.group(0)[8:-1] | 
        
        
           | 
              return flask.Response(f'{model} not found.', mimetype='text/plain'), 404 | 
        
        
           | 
          
 | 
        
        
           | 
          
 | 
        
        
           | 
          @app.errorhandler(BadRequest) | 
        
        
           | 
          def other_error(details: BadRequest) -> flask.Response: | 
        
        
           | 
              """Process any other error raised.""" | 
        
        
           | 
              return flask.Response(details.message, mimetype='text/plain'), details.code | 
        
        
           | 
          
 | 
        
        
           | 
          
 | 
        
        
           | 
          db.create_tables([User, UserSession]) | 
        
        
           | 
          app.run() |