Skip to content

Instantly share code, notes, and snippets.

@wallabra
Last active October 8, 2018 01:54
Show Gist options
  • Save wallabra/ed712ce344180d0e0461ce1a62a65370 to your computer and use it in GitHub Desktop.
Save wallabra/ed712ce344180d0e0461ce1a62a65370 to your computer and use it in GitHub Desktop.
The internal NDRL structure.
import time
import threading
import base64
import io
import random
import socket
import struct
import queue
from typing import Any, Callable, Dict, Generic, Tuple, Type, TypeVar, Union
from ndrl import ndrlmeta
class NDRLException(BaseException):
def __init__(self, address : str = "none", description : str = ""):
self.address = address
self.description = description
def __str__(self) -> str:
return "{} (reference address: {})".format(self.description, self.address)
class NDRLVersionMismatch(NDRLException):
pass
class NDRLSocket(object):
def __init__(self, agent, target_addr : str, target_port : int):
self.agent = agent
self.target_addr = target_addr
self.target_port = target_port
self.incoming = queue.Queue()
self.remainder_buffer = b''
self.closed = False
def close(self):
if not self.closed:
self.agent.sockets.remove(self)
self.closed = True
return True
else:
return False
def packet(self, payload : bytes = b""):
return self.agent.build_packet(payload, self.target_addr, self.target_port)
def _send(self, packet):
pass
def _put_recv(self, packet):
self.incoming.put(packet)
def recv(self, amount = 1024):
res = self.remainder_buffer
while True:
if self.incoming.empty():
break
packet = self.incoming.get()
res += packet.payload
if len(res) > amount:
break
self.remainder_buffer = res[amount:]
return res[:amount]
def send(self, data):
packet = self.agent.encode_packet(self.packet(data))
self._send(packet)
def receive_bytes(self, data):
self.agent.put(data)
def receive_packet(self, packet):
self.agent.put(self.agent.encode_packet(packet))
class NDRLAgent(object):
def __init__(self, address : Union[str, None] = None, socket_type : Type[NDRLSocket] = NDRLSocket):
while True:
self.address = (
address
or ":".join([hex(random.randint(0, 0xFFFF))[2 :]])
)
if not self._exists(self.address):
break
address = None
self.ports = {}
self._data_buf = bytearray()
self._socket_type = socket_type
self.agent_init(address)
self.sockets = []
threading.Thread(target=self._read_loop).start()
def agent_init(self, addr):
pass
def connect(self, target_addr : str, target_port : int):
connection = self._socket_type(self, target_addr, target_port)
self.sockets.append(connection)
return connection
def put(self, data : bytes = b''):
self._data_buf.extend(data)
async def _read_loop(self):
while True:
self.read_packets()
time.sleep(0.001)
def peek_packets(self, num = 0xFFFF):
read_packets = []
head = 0
for _ in range(num):
_cur_data_buf = self._data_buf[head :]
packet_data = _cur_data_buf[: 46]
if len(packet_data) < 46:
break
_tmp_packet = self.decode_packet(packet_data)
packet_data.extend(_cur_data_buf[46 : 46 + _tmp_packet['Payload Length']])
if len(packet_data) < 46 + _tmp_packet['Payload Length']:
break
res_packet = self.decode_packet(packet_data)
read_packets.append(res_packet)
self.on_packet(res_packet)
head += 46 + _tmp_packet['Payload Length']
return read_packets
def read_packets(self, num = 0xFFFF):
read_packets = []
for _ in range(num):
packet_data = self._data_buf[: 46]
if len(packet_data) < 46:
break
_tmp_packet = self.decode_packet(packet_data)
packet_data.extend(self._data_buf[46 : 46 + _tmp_packet['Payload Length']])
if len(packet_data) < 46 + _tmp_packet['Payload Length']:
break
res_packet = self.decode_packet(packet_data)
read_packets.append(res_packet)
self.on_packet(res_packet)
self._data_buf = self._data_buf[46 + _tmp_packet['Payload Length'] :]
return read_packets
def on_packet(self, packet):
for s in self.sockets:
if s.target_addr == packet.source_addr and s.target_port == packet.source_port:
s._put_recv(packet)
def random_unused_port(self):
while True:
res = random.randint(0, 0xFFFFFFFF)
if res not in self.ports:
return res
def _exists(self, addr : str) -> bool:
return (addr == self.address) or self.exists(addr)
def exists(self, addr : str) -> bool:
return False
def build_packet(
self,
payload : Union[bytes, bytearray],
target_addr : str,
target_port : int,
source_addr : Union[str, None] = None,
source_port : Union[int, None] = None,
version : Union[int, None] = None
):
return {
"Version": version or getattr(ndrlmeta, 'PROTOCOL_VERSION', 0xFF),
"Source Address": source_addr or self.address,
"Target Address": target_addr,
"Source Port": source_port or self.random_unused_port(),
"Target Port": target_port,
"Payload Length": len(payload),
"Payload": payload,
}
def encode_packet(self, packet : Dict[str, Any], b64 : bool = True):
data = (
struct.pack('H', packet.get('Version', getattr(ndrlmeta, 'PROTOCOL_VERSION', 0xFF)))
+ struct.pack('L', packet.get('Source Port', self.random_unused_port()) or self.random_unused_port())
+ struct.pack('L', packet['Target Port'])
+ socket.inet_pton(socket.AF_INET6, packet.get('Source Address', self.address))
+ socket.inet_pton(socket.AF_INET6, packet['Target Address'])
+ struct.pack('Q', packet.get('Payload Length', len(packet['Payload'])))
+ packet['Payload']
)
if b64:
data = base64.b85encode(data)
return data
def decode_packet(self, data : bytes, b64 : bool = True):
version = struct.unpack('H', data[:2])[0]
if b64:
data = base64.b85decode(data)
if version != getattr(ndrlmeta, 'PROTOCOL_VERSION', 0xFF):
raise NDRLVersionMismatch(
self.address,
"Packet of version {} does not match our current version ({})!"
.format(version, getattr(ndrlmeta, 'PROTOCOL_VERSION', 0xFF))
)
names = (
'Version',
'Source Port',
'Target Port',
'Source Address',
'Target Address',
'Payload Length',
'Payload'
)
data = [
version,
*struct.unpack('=2L', data[2 : 10]),
*(socket.inet_ntop(socket.AF_INET6, data[a, b]) for a, b in ((10, 26), (26, 46))),
struct.unpack('=Q', data[46 : 54])[0],
data[54: struct.unpack('=Q', data[46 : 54])[0]]
]
return dict(zip(names, data))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment