Last active
August 12, 2021 19:15
-
-
Save evandiewald/11e6e91d4107df3d5502ad12be05d07a to your computer and use it in GitHub Desktop.
VsockListener class from vsock-parent.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class VsockListener: | |
"""Server""" | |
def __init__(self, conn_backlog=128): | |
self.conn_backlog = conn_backlog | |
self.files_received = [0, 0, 0] # --> [sym key, inference, pub key] | |
def bind(self, port): | |
"""Bind and listen for connections on the specified port""" | |
self.sock = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM) | |
self.sock.bind((socket.VMADDR_CID_ANY, port)) | |
self.sock.listen(self.conn_backlog) | |
def recv_data_parent(self): | |
full_msg = '' | |
while sum(self.files_received) < 3: | |
(from_client, (remote_cid, remote_port)) = self.sock.accept() | |
msg = from_client.recv(8) | |
if len(msg) == 8: | |
(length,) = unpack('>Q', msg) | |
data = b'' | |
while len(data) < length: | |
to_read = length - len(data) | |
data += from_client.recv(1024 if to_read > 1024 else to_read) | |
if length > 120 and length < 257: # this must be our encrypted symmetric key (usually 256 bytes) | |
with open('inference_key_received', 'wb') as f: | |
f.write(data) | |
print('Encryption key received.') | |
self.files_received[0] = 1 | |
if self.files_received[0] and self.files_received[1]: | |
break | |
elif length < 120: # assume anything smaller is our (encrypted) inference | |
with open('inference_received.txt.encrypted', 'wb') as f: | |
f.write(data) | |
print('Encrypted inference received.') | |
self.files_received[1] = 1 | |
if self.files_received[0] and self.files_received[1]: | |
break | |
else: # enclave's public key | |
with open('enclave_public_key_received.pem', 'wb') as f: | |
f.write(data) | |
print('Enclave\'s public key received.') | |
self.files_received[2] = 1 | |
print('All files received, shutting down...') | |
from_client.close() | |
# in reality, this decryption would actually happen on the client's machine | |
print('Attempting to decrypt inference...') | |
with open('inference_key_received', 'rb') as f: | |
encrypted_key = f.read() | |
decrypted_contents = decrypt('inference_received.txt.encrypted', encrypted_key, 'my_private_key.pem') | |
with open('inference_received_decrypted.txt', 'wb') as f: | |
f.write(decrypted_contents) | |
print('Decryption successful!') | |
LABELS = [ | |
'Actinic Keratoses and Intraepithelial Carcinoma', | |
'Basal Cell Carcinoma', | |
'Benign Keratosis', | |
'Dermatofibroma', | |
'Melanoma', | |
'Melanocytic Nevi', | |
'Vascular Lesions' | |
] | |
print('Classification received: ', LABELS[int.from_bytes(decrypted_contents, 'big')]) | |
def server_handler(args): | |
print('Server ready!') | |
server = VsockListener() | |
server.bind(args.port) | |
server.recv_data_parent() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment