Created
September 23, 2023 17:21
-
-
Save zzeleznick/a39f03464df16f14c01cf6c0ea7fcbe4 to your computer and use it in GitHub Desktop.
Socket Server and CLI
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import socket | |
import subprocess | |
import time | |
from socket_lib import attach_cleanup_handler | |
from socket_lib import cleanup_server | |
host = 'localhost' | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Simple socket client with server management.') | |
parser.add_argument('--port', type=int, default=8092, help='Port on which the server runs (default: 8092)') | |
parser.add_argument('--cleanup', action='store_true', help='Cleanup any existing server processes before starting and at exit') | |
parser.add_argument('--kill', action='store_true', help='Terminate the existing server process') | |
parser.add_argument('--message', type=str, help='Send one command and then exit') | |
return parser.parse_args() | |
def is_port_open(host, port): | |
print("Checking the server connection...") | |
try: | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
s.settimeout(1) # Set a short timeout for the connection attempt | |
s.connect((host, port)) | |
return True | |
except (ConnectionRefusedError, socket.timeout): | |
return False | |
def start_server(port): | |
print("Starting the server...") | |
local_dir = os.path.dirname(os.path.abspath(__file__)) | |
server_path = os.path.join(local_dir, 'socket_server.py') | |
subprocess.Popen(["python", server_path, "--port", str(port)], start_new_session=True) | |
time.sleep(0.2) # Wait 200ms for the server to start | |
def send_request(request, port): | |
if not is_port_open(host, port): | |
start_server(port) | |
print("Preparing to send our message...") | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
s.settimeout(1) # Set a short timeout for the connection attempt | |
s.connect((host, port)) | |
s.send(request.encode('utf-8')) | |
response = s.recv(1024) | |
print(f"resp: {response.decode('utf-8')}") | |
def main(): | |
args = parse_args() | |
port = args.port | |
message = args.message | |
if args.cleanup: | |
was_running = cleanup_server(port) # Cleanup in case of a previous crash | |
attach_cleanup_handler(port) # Cleanup on interrupt, termination, or exit | |
if was_running: | |
time.sleep(0.2) # Wait 200ms for the server to be cleaned up if it was running | |
elif args.kill: | |
cleanup_server(port) # Cleanup and exit manually | |
return | |
if message: | |
send_request(message, port) | |
return | |
while True: | |
request = input("Enter a request (e.g., 'GET /echo' or 'GET /ping'): ") | |
send_request(request, port) | |
if __name__ == '__main__': | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import atexit | |
import os | |
import subprocess | |
import signal | |
import sys | |
def cleanup_server(port): | |
try: | |
server_pid = subprocess.check_output(["lsof", "-i", f"TCP:{port}", "-t"]).decode('utf-8').strip() | |
print(f"Terminating the server process (PID {server_pid})...") | |
os.kill(int(server_pid), signal.SIGTERM) | |
return True | |
except (subprocess.CalledProcessError, ValueError): | |
print("Server process not found or unable to terminate.") | |
def attach_cleanup_handler(port): | |
def signal_handler(sig, frame): | |
print(f"Calling cleanup after signal {sig}...") | |
cleanup_server(port) | |
sys.exit(0) | |
def atexit_handler(): | |
print(f"Calling cleanup after exit...") | |
cleanup_server(port) | |
signal.signal(signal.SIGINT, signal_handler) | |
signal.signal(signal.SIGTERM, signal_handler) | |
atexit.register(atexit_handler) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import logging | |
import os | |
import socket | |
import time | |
import tempfile | |
import threading | |
from contextvars import ContextVar | |
from dataclasses import dataclass | |
from datetime import datetime | |
from logging.handlers import RotatingFileHandler | |
from socket_lib import cleanup_server | |
# Create our context variables. These will be filled at the start of request | |
# processing, and used in the logging that happens during that processing | |
ctx_request = ContextVar('request') | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Simple socket server with logging.') | |
parser.add_argument('--port', type=int, default=8092, help='Port to run the server on (default: 8092)') | |
return parser.parse_args() | |
from dataclasses import dataclass | |
@dataclass | |
class Request: | |
""" | |
Class to store request method, path, and identifier. | |
""" | |
method: str = "main" | |
path: str = "__init__" | |
id: str = "dummy" | |
class InjectingFilter(logging.Filter): | |
""" | |
A filter which injects context-specific information into logs and by | |
default does not filter anything. | |
""" | |
def filter(self, record): | |
request = ctx_request.get() | |
record.method = request.method | |
record.path = request.path | |
record.id = request.id | |
return True | |
def setup_logging(log_dir): | |
log_file = os.path.join(log_dir, 'server.log') | |
# Configure the logging handlers | |
logging.basicConfig( | |
level=logging.INFO, | |
format="[%(asctime)s] [%(levelname)s] %(threadName)-11s %(name)-9s %(id)-6s %(method)-4s %(path)-11s %(message)s", | |
handlers=[ | |
RotatingFileHandler(log_file, maxBytes=1024 * 1024, backupCount=5), | |
logging.StreamHandler() | |
] | |
) | |
for handler in logging.getLogger().handlers: | |
handler.addFilter(InjectingFilter()) | |
logging.info(f"Logging to {log_file}") | |
def handle_client(client_socket, client_id): | |
message = client_socket.recv(1024).decode('utf-8').strip() | |
components = message.split(' ', 1) if message else ('', '') | |
method, path = components if len(components) > 1 else ('GET', components[0]) | |
request = Request(method=method, path=path, id=client_id) | |
ctx_request.set(request) | |
if not message: | |
logging.info("Empty request received, closing the connection", extra={'client_id': client_id}) | |
client_socket.close() | |
return | |
# Log the received message with the client's ID | |
logging.info(f"Received message: '{message}'") | |
if method != 'GET': | |
response = b'Invalid method\n' | |
elif path == '/echo': | |
response = b'Echo: Hello, client!\n' # Echo the received data back to the client | |
elif path == '/ping': | |
response = b'Pong!\n' # Respond to a ping request | |
else: | |
response = b'Invalid route\n' # Respond to unknown routes | |
client_socket.send(response) | |
client_socket.close() | |
def bind_socket(host, port, retries=5): | |
""" | |
Attempt to bind a socket to the given host and port, retrying up to | |
`retries` times. | |
""" | |
for _ in range(retries): | |
try: | |
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # Allow the socket to be reused | |
server_socket.bind((host, port)) | |
return server_socket | |
except OSError as e: | |
if e.errno == 48: # Address already in use | |
logging.warning(f"Address {host}:{port} already in use, will try to clean up resource and retry...") | |
cleanup_server(port) | |
time.sleep(0.1) # Wait 100ms before retrying | |
else: | |
raise | |
raise Exception(f"Unable to bind to address {host}:{port} after {retries} attempts.") | |
def main(): | |
args = parse_args() | |
port = args.port | |
# Set a default request object in the context | |
ctx_request.set(Request()) | |
# Get a temporary directory with a known path prefix | |
log_dir = tempfile.mkdtemp(prefix='server_logs_') | |
setup_logging(log_dir) | |
host = 'localhost' | |
server_socket = bind_socket(host, port) | |
server_socket.listen(5) # Queue up to 5 simultaneous cli requests | |
logging.info(f"Server listening on {host}:{port}") | |
client_id_counter = 0 | |
while True: | |
client_socket, addr = server_socket.accept() # Accept incoming connections | |
# Generate a timestamp-based client ID | |
client_id = f"{datetime.now().strftime('%Y%m%d%H%M%S')}-{client_id_counter}" | |
client_id_counter += 1 | |
logging.info(f"Accepted connection from {addr}", extra={'client_id': client_id} ) | |
# Start a new thread to handle the client | |
client_thread = threading.Thread(target=handle_client, args=(client_socket, client_id)) | |
client_thread.start() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment