Created
July 12, 2024 10:01
-
-
Save tolgahanakgun/cedc81c5b98f385e68d872798ffc3ed5 to your computer and use it in GitHub Desktop.
Single Python file TFTP server implementation with >32MB file support
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This code is based on https://github.com/m4tx/pyTFTP project | |
import os | |
import errno | |
import logging | |
import socket | |
import argparse | |
from pathlib import Path, PurePosixPath | |
from threading import Thread | |
from typing import List, NewType, Tuple, Union, Dict | |
logger = logging.getLogger('tftpd') | |
BLOCK_SIZE = 512 | |
BUF_SIZE = 65536 | |
TIMEOUT = 0.5 | |
MAX_RETRIES = 10 | |
class TFTPOpcodes: | |
"""Class containing all the opcodes used in TFTP.""" | |
RRQ = b'\x00\x01' | |
WRQ = b'\x00\x02' | |
DATA = b'\x00\x03' | |
ACK = b'\x00\x04' | |
ERROR = b'\x00\x05' | |
OACK = b'\x00\x06' | |
class TFTPErrorCodes: | |
"""Class containing all the error codes and their messages used in TFTP.""" | |
UNKNOWN = 0 | |
FILE_NOT_FOUND = 1 | |
ACCESS_VIOLATION = 2 | |
DISK_FULL = 3 | |
ILLEGAL_OPERATION = 4 | |
UNKNOWN_TRANSFER_ID = 5 | |
FILE_EXISTS = 6 | |
NO_SUCH_USER = 7 | |
INVALID_OPTIONS = 8 | |
__MESSAGES = { | |
UNKNOWN: '', | |
FILE_NOT_FOUND: 'File not found', | |
ACCESS_VIOLATION: 'Access violation', | |
DISK_FULL: 'Disk full or allocation exceeded', | |
ILLEGAL_OPERATION: 'Illegal TFTP operation', | |
UNKNOWN_TRANSFER_ID: 'Unknown transfer ID', | |
FILE_EXISTS: 'File already exists', | |
NO_SUCH_USER: 'No such user', | |
INVALID_OPTIONS: 'Invalid options specified', | |
} | |
@classmethod | |
def get_message(cls, error_code: int) -> str: | |
"""Return an error message for given error code. | |
:param error_code: error code to get the message for | |
:return: error message | |
""" | |
return cls.__MESSAGES[error_code] | |
class TFTPOptions: | |
# RFC 2348 | |
BLKSIZE = b'blksize' | |
# RFC 7440 | |
WINDOWSIZE = b'windowsize' | |
Address = NewType('Address', tuple) | |
Packet = NewType('Packet', Tuple[bytes, Address]) | |
class TFTPException(Exception): | |
"""Generic TFTP exception.""" | |
pass | |
class TFTPError(TFTPException): | |
"""Exception meaning that a TFTP ERROR packet received.""" | |
def __init__(self, error_id: int, message: str) -> None: | |
super(TFTPError, self).__init__( | |
'Error {}: {}'.format(error_id, message)) | |
self.error_id = error_id | |
self.message = message | |
class TFTPTerminatedError(TFTPException): | |
"""Exception meaning that the TFTP connection was terminated for the | |
reason passed in `error_id` and `message` arguments.""" | |
def __init__(self, error_id: int, error_message: str, | |
message: str) -> None: | |
super(TFTPTerminatedError, self).__init__( | |
'Terminated with error {}: {}; cause: {}'.format( | |
error_id, error_message, message)) | |
self.error_id = error_id | |
self.error_message = message | |
self.message = message | |
class TFTP: | |
""" | |
Base class for writing TFTP clients and servers. Handles all the basic | |
communication: generic method for sending and receiving packets, methods | |
for transmitting specific packets and whole files, as well as error | |
and timeout handling. | |
""" | |
def __init__(self, sock: socket.socket, addr: Address, | |
block_size: int = BLOCK_SIZE, window_size: int = 1) -> None: | |
""" | |
:param sock: socket to use to communicate | |
:param addr: address (host + port) of the connected host | |
""" | |
self._sock = sock | |
self._sock.settimeout(TIMEOUT) | |
self._addr = addr | |
self._block_size = block_size # RFC 2348 | |
self._window_size = window_size # RFC 7440 | |
# Whether to check the TID of incoming packets. If set to False, the | |
# next packet received will be used to set the new TID (and this will | |
# set _check_addr back to True) | |
self._check_addr = True | |
self.__last_packet: Packet = None | |
self.__packet_buffer: Packet = None | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
self._sock.close() | |
########################################################################### | |
# Error handling | |
########################################################################### | |
def _check_error(self, data: bytes, expected_opcodes: List[bytes]) -> None: | |
"""Check if the packet received has valid opcode and terminate the | |
connection if not or an ERROR packet was received. | |
:param data: the packet received | |
:param expected_opcodes: list of valid opcodes | |
:raise: TFTPTerminatedError if the opcode was not valid | |
:raise: TFTPError if an ERROR packet was received | |
""" | |
opcode = data[0:2] | |
if opcode == TFTPOpcodes.ERROR: | |
raise TFTPError( | |
int.from_bytes(data[2:4], byteorder='big'), | |
data[4:-1].decode('utf-8')) | |
elif opcode not in expected_opcodes: | |
self._terminate(TFTPErrorCodes.ILLEGAL_OPERATION, | |
'Invalid packet: {}'.format(data)) | |
def _terminate(self, error_code: int, message: str, | |
error_message: str = None) -> None: | |
"""Send an ERROR packet, terminate the connection, and raise | |
a TFTPTerminatedError | |
:param error_code: error code to send | |
:param message: message to use for the exception | |
:param error_message: message to send with the ERROR packet. If None, | |
a default message for the given error code is used. | |
:raise: TFTPTerminatedError | |
""" | |
error_message = self._error_occurred(error_code, error_message) | |
self._sock.close() | |
raise TFTPTerminatedError(error_code, error_message, message) | |
def _error_occurred(self, error_code: int, error_message: str = None, | |
addr: Address = None) -> str: | |
"""Send an ERROR packet, auto-generating the message if necessary. | |
:param error_code: error code to send | |
:param error_message: message to send with the ERROR packet. If None, | |
a default message for the given error code is used. | |
:param addr: the address to send the packet to | |
:return: the error message that was sent | |
""" | |
if error_message is None: | |
error_message = TFTPErrorCodes.get_message(error_code) | |
self._send_err(error_code, error_message, addr) | |
return error_message | |
########################################################################### | |
# Receiving | |
########################################################################### | |
def _set_packet_buffer(self, data: bytes, addr: Address) -> None: | |
"""Set given packet as the "packet buffer". Packets in the buffer have | |
priority when trying to retrieve data using _recv(), giving a way to | |
use data from a different source (e.g. recvfrom() executed in another | |
function) when receiving a packets using a unified function. | |
:param data: data to be set in the buffer | |
:param addr: address to be set in the buffer | |
""" | |
self.__packet_buffer = Packet((data, addr)) | |
def _recv(self, handle_timeout: bool = True) -> Packet: | |
"""Receive a packet, taking into account packets in the packet buffer, | |
and retrying (by resending the last sent packet) if needed. | |
:return: packet received | |
:raise: TFTPException on timeout | |
""" | |
if self.__packet_buffer is not None: | |
rv = self.__packet_buffer | |
self.__packet_buffer = None | |
return rv | |
if not handle_timeout: | |
r = self._sock.recvfrom(BUF_SIZE) | |
return r | |
retries = 0 | |
while retries <= MAX_RETRIES: | |
try: | |
r = self._sock.recvfrom(BUF_SIZE) | |
return r | |
except socket.timeout: | |
retries += 1 | |
if retries <= MAX_RETRIES: | |
self.__resend_last_packet() | |
raise TFTPException('Timed out') | |
def _recv_packet_mul( | |
self, opcodes: List[bytes], | |
min_data_length: int, handle_timeout: bool = True) -> Tuple[ | |
Address, bytes, bytes]: | |
"""Receive a packet and check if its opcode, length, and TID are valid. | |
:param opcodes: list of valid opcodes | |
:param min_data_length: minimum valid length of the data | |
:param check_addr: True if TID validity should be checked; False | |
otherwise | |
:return: a 3-tuple containing: source packet address, opcode received | |
and the data | |
""" | |
while True: | |
data, addr = self._recv(handle_timeout) | |
if not self._check_addr or addr == self._addr: | |
break | |
logger.warning('Invalid TID: %s (expected: %s)', addr, self._addr) | |
self._error_occurred(TFTPErrorCodes.UNKNOWN_TRANSFER_ID, addr=addr) | |
if not self._check_addr: | |
self._addr = addr | |
self._check_addr = True | |
self._check_error(data, opcodes) | |
if len(data) < min_data_length + 2: | |
self._terminate(TFTPErrorCodes.ILLEGAL_OPERATION, | |
'Packet too short: {}'.format(data)) | |
return addr, data[0:2], data[2:] | |
def _recv_packet(self, opcode: bytes, min_data_length: int, | |
handle_timeout: bool = True) -> Tuple[Address, bytes]: | |
"""Receive a packet and check if its opcode, length, and TID are valid. | |
:param opcode: valid opcode | |
:param min_data_length: minimum valid length of the data | |
:return: a pair containing: source packet address and the data received | |
""" | |
addr, _, data = self._recv_packet_mul([opcode], min_data_length, | |
handle_timeout) | |
return addr, data | |
def _recv_data( | |
self, handle_timeout: bool = True) -> Tuple[Address, bytes, bytes]: | |
"""Receive a DATA packet and return the block ID and the data. | |
:return: 3-tuple containing the source address, block ID, and the data | |
""" | |
addr, data = self._recv_packet(TFTPOpcodes.DATA, 2, handle_timeout) | |
return addr, data[0:2], data[2:] | |
def _recv_ack(self, handle_timeout: bool = True) -> Tuple[Address, int]: | |
"""Receive an ACK packet and return the block ID. | |
:return: pair containing the source address and the block ID | |
""" | |
addr, data = self._recv_packet(TFTPOpcodes.ACK, 2, handle_timeout) | |
return addr, int.from_bytes(data, byteorder='big') | |
########################################################################### | |
# Sending | |
########################################################################### | |
def _send(self, data: bytes, addr: Address = None) -> None: | |
"""Send a packet and store it as the last packet sent. | |
:param data: data to be sent | |
:param addr: the destionation address to send the packet to. If None, | |
self._addr is used. | |
""" | |
if addr is None: | |
addr = self._addr | |
self.__last_packet = Packet((data, addr)) | |
self._sock.sendto(data, addr) | |
def __resend_last_packet(self) -> None: | |
"""Resend the last packet received (used for retries in _recv()).""" | |
self._sock.sendto(*self.__last_packet) | |
def _send_ack(self, block_id: Union[bytes, int]) -> None: | |
"""Send an ACK packet. | |
:param block_id: block ID to send | |
""" | |
if isinstance(block_id, int): | |
block_id = block_id.to_bytes(2, byteorder='big') | |
self._send(TFTPOpcodes.ACK + block_id) | |
def _send_data(self, block_id: int, data: bytes) -> None: | |
"""Send a DATA packet. | |
:param block_id: block ID of the data | |
:param data: the data to send | |
""" | |
self._send( | |
TFTPOpcodes.DATA + block_id.to_bytes(2, byteorder='big') + data) | |
def _send_err(self, error_code: int, error_message: str = None, | |
addr: Address = None) -> None: | |
"""Send an ERROR packet. | |
:param error_code: error code to send | |
:param error_message: error message to send | |
:param addr: the desitination address to send the packet to | |
""" | |
error_code_bytes = error_code.to_bytes(2, byteorder='big') | |
error_message_bytes = error_message.encode('utf-8') | |
self._send(TFTPOpcodes.ERROR + error_code_bytes + error_message_bytes + | |
b'\x00', addr) | |
########################################################################### | |
# Options (RFC 2347) | |
########################################################################### | |
def _process_options(self, options: List[bytes]) -> Dict[bytes, bytes]: | |
"""Process the options received in RRQ/WRQ packet. | |
This is an implementation of the RFC 2347 Options Extension. | |
:param options: list of the option strings (null-separated in | |
the original packet) | |
:return: dictionary of the processed and accepted options | |
""" | |
if options[-1] == b'': | |
options.pop() | |
if len(options) % 2 == 1: | |
raise ValueError | |
ret_val = {} | |
vals = zip(options[::2], options[1::2]) | |
d = {k.lower(): (k, v) for k, v in vals} | |
# Block size (RFC 2348) | |
if TFTPOptions.BLKSIZE in d: | |
orig_key, orig_val = d[TFTPOptions.BLKSIZE] | |
blk_size = int(orig_val) | |
if blk_size < 8 or blk_size > 65464: | |
# Invalid according to RFC 2348 | |
raise ValueError | |
self._block_size = blk_size | |
ret_val[orig_key] = orig_val | |
# Window size (RFC 7440) | |
if TFTPOptions.WINDOWSIZE in d: | |
orig_key, orig_val = d[TFTPOptions.WINDOWSIZE] | |
window_size = int(orig_val) | |
if window_size < 1 or window_size > 65535: | |
# Invalid according to RFC 7440 | |
raise ValueError | |
self._window_size = window_size | |
ret_val[orig_key] = orig_val | |
return ret_val | |
def _format_options(self, options: Dict[bytes, bytes]): | |
"""Create single options bytes object out of the provided dictionary. | |
:param options: dictionary to convert to bytes object | |
:return: generated bytes object | |
""" | |
return b''.join(b'%s\x00%s\x00' % option for option in options.items()) | |
########################################################################### | |
# Files | |
########################################################################### | |
def _recv_file(self) -> bytes: | |
"""Receive a file by listening for DATA packets and responding | |
with ACKs. | |
:return: received file | |
""" | |
last_id = 0 | |
parts = [] | |
retries = 0 | |
while retries <= MAX_RETRIES: | |
start_last_id = last_id | |
for _ in range(self._window_size): | |
try: | |
addr, block_id, data = self._recv_data( | |
handle_timeout=False) | |
id_int = int.from_bytes(block_id, byteorder='big') | |
if id_int == last_id + 1: | |
parts.append(data) | |
last_id = id_int | |
if block_id == b'\xff\xff': | |
last_id = -1 | |
if len(data) < self._block_size: | |
self._send_ack(last_id) | |
return b''.join(parts) | |
except socket.timeout: | |
if last_id == start_last_id: | |
retries += 1 | |
break | |
else: | |
retries = 0 | |
if retries <= MAX_RETRIES: | |
self._send_ack((65535 if last_id == -1 else last_id)) | |
raise TFTPException('Timed out') | |
def _send_file(self, data: bytes) -> None: | |
"""Send a file by sending DATA packets and listening for ACKs. | |
:param data: data to be sent | |
""" | |
outer_block_id = 0 | |
block_id = 0 | |
while True: | |
retries = 0 | |
while retries <= MAX_RETRIES: | |
try: | |
if not self.__send_blocks(data, outer_block_id, block_id): | |
return | |
_, ack_block_id = self._recv_ack(handle_timeout=False) | |
last_block_id = block_id + self._window_size | |
if ((last_block_id >= ack_block_id >= block_id) or | |
(ack_block_id <= last_block_id % 65536 and | |
ack_block_id < block_id)): | |
# If received ACK is a reply to one of the blocks sent | |
# sent the next batch of blocks, else re-send | |
if ack_block_id < block_id: | |
outer_block_id += 1 | |
block_id = ack_block_id | |
break | |
except socket.timeout: | |
retries += 1 | |
else: | |
raise TFTPException('Timed out') | |
def __send_blocks( | |
self, data: bytes, outer_block_id: int, inner_block_id: int): | |
"""Send a single window of data. | |
:param data: data to be sent | |
:param outer_block_id: starting "outer" block ID (incremented by 1 | |
each time inner block ID overflows) | |
:param inner_block_id: starting "inner" block ID in the range [0, 65535] | |
:return: False if there is no data to be sent; True otherwise | |
""" | |
blk_size = self._block_size | |
for i in range(self._window_size): | |
local_blkid = outer_block_id * 65536 + inner_block_id + i | |
if local_blkid * self._block_size > len(data): | |
if i == 0: | |
return False | |
else: | |
break | |
to_send = data[local_blkid * blk_size: | |
(local_blkid + 1) * blk_size] | |
self._send_data((local_blkid + 1) % 65536, to_send) | |
return True | |
class TFTPClientHandler(TFTP): | |
""" | |
Class that handles the communication with a single TFTP client on the | |
server side. | |
""" | |
def __init__(self, host: str, addr: Address, root_dir: Path, | |
allow_upload: bool, initial_buffer: bytes = None) -> None: | |
""" | |
:param host: host of the server to bind to | |
:param addr: address of the client to connect with | |
:param root_dir: root directory of the files to serve | |
:param allow_upload: whether or not allow to upload files | |
:param initial_buffer: initial packet buffer; usually a `bytes` object | |
containing the first (RRQ/WRQ) packet, or None, if there is no | |
external server that catches the first packet. | |
""" | |
super().__init__( | |
socket.socket(socket.AF_INET, socket.SOCK_DGRAM), addr) | |
if initial_buffer is not None: | |
self._set_packet_buffer(initial_buffer, self._addr) | |
self._sock.bind((host, 0)) | |
logger.info('Incoming connection from %s, binding at: %s', | |
self._addr, self._sock.getsockname()) | |
self.__root_dir = root_dir | |
self.__allow_upload = allow_upload | |
def handle_client(self) -> None: | |
"""Handle the request sent by the connected client.""" | |
opcode, file_name, mode = self.__recv_rq() | |
try: | |
path = self.__get_file_path(file_name) | |
if opcode == TFTPOpcodes.RRQ: | |
self.__handle_rrq(path) | |
else: | |
self.__handle_wrq(path) | |
except OSError as e: | |
self.__handle_file_error(e) | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
logger.info('Closing connection to %s, bound at: %s', | |
self._addr, self._sock.getsockname()) | |
super(TFTPClientHandler, self).__exit__(exc_type, exc_val, exc_tb) | |
def __recv_rq(self) -> Tuple[bytes, str, str]: | |
"""Receive an RRQ/WRQ packet and return received data. | |
:return: 3-tuple containing: received opcode, file name and file | |
transfer mode | |
""" | |
_, opcode, data = self._recv_packet_mul( | |
[TFTPOpcodes.RRQ, TFTPOpcodes.WRQ], 2) | |
try: | |
file_name_bytes, mode_bytes, *options = data.split(b'\0') | |
try: | |
new_options = self._process_options(options) | |
if len(new_options): | |
self.__send_oack(new_options) | |
if opcode == TFTPOpcodes.RRQ: | |
self._recv_ack() | |
except ValueError: | |
self._terminate(TFTPErrorCodes.INVALID_OPTIONS, | |
'Invalid options received') | |
file_name = file_name_bytes.decode('utf-8') | |
mode = mode_bytes.decode('utf-8') | |
except ValueError as e: | |
self._terminate(TFTPErrorCodes.ILLEGAL_OPERATION, str(e)) | |
if mode != 'octet': | |
self._terminate(TFTPErrorCodes.ILLEGAL_OPERATION, | |
'Mode is not "octet": {}'.format(mode)) | |
return opcode, file_name, mode | |
def __send_oack(self, new_options: Dict[bytes, bytes]): | |
"""Send an OACK packet. | |
:param new_options: dictionary of options to be included in | |
the OACK packet. | |
""" | |
msg = TFTPOpcodes.OACK + self._format_options(new_options) | |
self._send(msg) | |
def __get_file_path(self, file_name: str) -> Path: | |
"""Return file path inside server root directory, ignoring "evil" | |
paths, like "../../secret_file", "/etc/fstab", etc. | |
:param file_name: file name to get the path to | |
:return: absolute path inside the server root directory | |
""" | |
while PurePosixPath(file_name).is_absolute(): | |
file_name = file_name[1:] | |
path = self.__root_dir.joinpath(file_name) | |
try: | |
path.relative_to(self.__root_dir) | |
except ValueError: | |
self._terminate(TFTPErrorCodes.ACCESS_VIOLATION, | |
'Invalid path: {}'.format(file_name)) | |
return path | |
def __handle_rrq(self, path: Path) -> None: | |
"""Handle RRQ request: read and send the requested file. | |
:param path: path to the requested file | |
""" | |
self._send_file(path.read_bytes()) | |
def __handle_wrq(self, path: Path) -> None: | |
"""Handle WRQ request: download and save the file from the client, | |
taking into account the `__allow_upload` setting. | |
:param path: path to save the file as | |
""" | |
if not self.__allow_upload: | |
self._terminate(TFTPErrorCodes.ACCESS_VIOLATION, | |
'Upload not allowed') | |
if path.exists(): | |
self._terminate(TFTPErrorCodes.FILE_EXISTS, | |
'File exists: {}'.format(path)) | |
self._send_ack(b'\x00\x00') | |
path.write_bytes(self._recv_file()) | |
def __handle_file_error(self, e: OSError) -> None: | |
"""Handle given IO error, sending an appropriate ERROR packet and | |
terminating the transmission. | |
:param e: error raised when trying to open the file | |
""" | |
error_message = None | |
if e.errno == errno.ENOENT: | |
error_code = TFTPErrorCodes.FILE_NOT_FOUND | |
elif e.errno == errno.EPERM or e.errno == errno.EACCES: | |
error_code = TFTPErrorCodes.ACCESS_VIOLATION | |
elif e.errno == errno.EFBIG or e.errno == errno.ENOSPC: | |
error_code = TFTPErrorCodes.DISK_FULL | |
else: | |
error_code = TFTPErrorCodes.UNKNOWN | |
error_message = e.strerror | |
self._terminate(error_code, e.strerror, error_message) | |
class TFTPServer: | |
""" | |
Class that handles communication with multiple TFTP clients. Uses | |
TFTPClientHandler for the communication with each single client, running | |
one instance of this class in a separate thread for each client. | |
""" | |
def __init__(self, host: str, port: int, root_dir: Union[str, Path], | |
allow_upload: bool) -> None: | |
""" | |
:param host: host of the server to bind to | |
:param port: port to bind to | |
:param root_dir: the directory where the files should be served from | |
:param allow_upload: whether or not allow uploading new files | |
""" | |
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
addr = (host, port) | |
logger.info('Starting TFTP server, listening on %s', addr) | |
self.sock.bind(addr) | |
self.host = host | |
self.root_dir = Path(root_dir) | |
self.allow_upload = allow_upload | |
def __enter__(self): | |
return self | |
def serve(self) -> None: | |
"""Run the main server loop: wait for new connections and run | |
TFTPClientHandler for each. | |
""" | |
while True: | |
data, addr = self.sock.recvfrom(BUF_SIZE) | |
def handle_client() -> None: | |
TFTPClientHandler( | |
self.host, addr, self.root_dir, self.allow_upload, | |
data).handle_client() | |
Thread(target=handle_client).start() | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
logger.info('Stopping TFTP server') | |
self.sock.close() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Run a simple TFTP server that supports upload and download functionality with >32MB file size") | |
parser.add_argument("IP", | |
help="IP address to listen") | |
parser.add_argument("-p", "--port", | |
help="Port number to listen (default: %(default)s)", | |
default=69, | |
type=int) | |
args = parser.parse_args() | |
logging.basicConfig(level=logging.INFO) | |
with TFTPServer(args.IP, args.port, os.getcwd(), True) as server: | |
server.serve() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment