Created
December 26, 2013 00:13
-
-
Save markrwilliams/8128178 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ctypes | |
import os | |
import socket | |
socklen_t = ctypes.c_uint | |
SCM_RIGHTS = 0x01 | |
libc = ctypes.CDLL('libc.so.6') | |
class iovec(ctypes.Structure): | |
_fields_ = [('iov_base', ctypes.c_void_p), | |
('iov_len', ctypes.c_size_t)] | |
iovec_ptr = ctypes.POINTER(iovec) | |
class msghdr(ctypes.Structure): | |
_fields_ = [('msg_name', ctypes.c_void_p), | |
('msg_namelen', socklen_t), | |
('msg_iov', iovec_ptr), | |
('msg_iovlen', ctypes.c_size_t), | |
('msg_control', ctypes.c_void_p), | |
('msg_controllen', ctypes.c_size_t), | |
('msg_flag', ctypes.c_int)] | |
@property | |
def has_control(self): | |
return self.msg_controllen >= ctypes.sizeof(cmsghdr) | |
class cmsghdr(ctypes.Structure): | |
_fields_ = [('cmsg_len', ctypes.c_size_t), | |
('cmsg_level', ctypes.c_int), | |
('cmsg_type', ctypes.c_int)] | |
# fake a flexarray | |
@classmethod | |
def with_data(cls, cmsg_len, cmsg_level, cmsg_type, cmsg_data): | |
FlexArray = ctypes.c_ubyte * ctypes.sizeof(cmsg_data) | |
class _cmsghdr(ctypes.Structure): | |
_fields_ = cls._fields_ + [('cmsg_data', FlexArray)] | |
as_bytes = FlexArray(*map(ord, buffer(cmsg_data))) | |
return _cmsghdr(cmsg_len=cmsg_len, | |
cmsg_level=cmsg_level, | |
cmsg_type=cmsg_type, | |
cmsg_data=as_bytes) | |
def CMSG_ALIGN(length): | |
sizeof_size_t = ctypes.sizeof(ctypes.c_size_t) | |
return ctypes.c_size_t((length + sizeof_size_t - 1) | |
& ~(sizeof_size_t - 1)) | |
def CMSG_SPACE(length): | |
length_align = CMSG_ALIGN(length).value | |
sizeof_cmsghdr = ctypes.sizeof(cmsghdr) | |
cmsghdr_align = CMSG_ALIGN(sizeof_cmsghdr).value | |
return ctypes.c_size_t(length_align + cmsghdr_align) | |
def CMSG_LEN(length): | |
sizeof_cmshdr = ctypes.sizeof(cmsghdr) | |
return ctypes.c_size_t(CMSG_ALIGN(sizeof_cmshdr).value + length) | |
_sendmsg = libc.sendmsg | |
_sendmsg.argtypes = [ctypes.c_int, | |
ctypes.POINTER(msghdr), | |
ctypes.c_int] | |
_sendmsg.restype = ctypes.c_int | |
def sendmsg_test(socket_path, file_path): | |
data = ctypes.c_int(12345) | |
iov = iovec(iov_base=ctypes.addressof(data), | |
iov_len=ctypes.c_size_t(ctypes.sizeof(data))) | |
fd = os.open(file_path, os.O_RDONLY) | |
cfd = ctypes.c_int(fd) | |
cmhp = cmsghdr.with_data(cmsg_len=CMSG_LEN(ctypes.sizeof(cfd)), | |
cmsg_level=socket.SOL_SOCKET, | |
cmsg_type=SCM_RIGHTS, | |
cmsg_data=cfd) | |
msgh = msghdr(msg_name=None, | |
msg_len=0, | |
msg_iov=iovec_ptr(iov), | |
msg_iovlen=1, | |
msg_control=ctypes.addressof(cmhp), | |
msg_controllen=ctypes.c_size_t(ctypes.sizeof(cmhp))) | |
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) | |
sock.connect(socket_path) | |
csfn = sock.fileno() | |
res = _sendmsg(ctypes.c_int(csfn), msgh, 0) | |
if res == -1: | |
os.close(fd) | |
raise RuntimeError | |
print 'sendmsg() returned %ld\n' % res | |
os.close(fd) | |
def _v_test_iov_prep(iovs, ints): | |
int_size = ctypes.sizeof(ints[0]) | |
tot_required = 0 | |
for i, c_i in enumerate(ints): | |
iovs[i].iov_base = ctypes.addressof(c_i) | |
iovs[i].iov_len = int_size | |
tot_required += int_size | |
return tot_required | |
def test_writev(): | |
writev = libc.writev | |
writev.argtypes = [ctypes.c_int, ctypes.POINTER(iovec), ctypes.c_int] | |
writev.restype = ctypes.c_size_t | |
iovec3 = iovec * 3 | |
iovs = iovec3() | |
ints = map(ctypes.c_int, [1, 2, 3]) | |
tot_required = _v_test_iov_prep(iovs, ints) | |
with open('/tmp/test1', 'w') as f: | |
fileno = f.fileno() | |
num_written = writev(fileno, iovs, 3) | |
assert num_written == tot_required | |
return num_written | |
def test_readv(): | |
readv = libc.readv | |
readv.argtypes = [ctypes.c_int, ctypes.POINTER(iovec), ctypes.c_int] | |
readv.restype = ctypes.c_size_t | |
iovec3 = iovec * 3 | |
iovs = iovec3() | |
ints = map(ctypes.c_int, [0, 0, 0]) | |
tot_required = _v_test_iov_prep(iovs, ints) | |
with open('/tmp/test1') as f: | |
fileno = f.fileno() | |
num_read = readv(fileno, iovs, 3) | |
assert num_read == tot_required | |
for iov in iovs: | |
print ctypes.cast(iov.iov_base, ctypes.POINTER(ctypes.c_int)).contents | |
if __name__ == '__main__': | |
import argparse | |
a = argparse.ArgumentParser() | |
a.add_argument('socket_path') | |
a.add_argument('file_path') | |
args = a.parse_args() | |
sendmsg_test(args.socket_path, args.file_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment