Skip to content

Instantly share code, notes, and snippets.

@ssrlive
Last active April 4, 2024 09:24
Show Gist options
  • Save ssrlive/ea9f9b0f9c05600775658e88c3ed5043 to your computer and use it in GitHub Desktop.
Save ssrlive/ea9f9b0f9c05600775658e88c3ed5043 to your computer and use it in GitHub Desktop.
CreateIoCompletionPort
//
// [dependencies]
// windows = { version = "0.51", features = [
// "Win32_System_SystemInformation",
// "Win32_Networking_WinSock",
// "Win32_System_Memory",
// "Win32_Foundation",
// "Win32_System_IO",
// "Win32_System_Threading",
// "Win32_Security",
// ] }
//
use std::{ffi::c_void, mem::MaybeUninit, net::Ipv4Addr, ptr::write_bytes};
use windows::{
core::PSTR,
Win32::{
Foundation::{CloseHandle, HANDLE, INVALID_HANDLE_VALUE},
Networking::WinSock::{
accept, bind, closesocket, inet_ntoa, listen, ntohs, send, socket, WSACleanup, WSARecv,
WSAStartup, AF_INET, IPPROTO_TCP, SEND_RECV_FLAGS, SOCKADDR, SOCKADDR_IN, SOCKET,
SOCK_STREAM, WSABUF, WSADATA,
},
System::{
Memory::{GetProcessHeap, HeapAlloc, HeapFree, HEAP_FLAGS, HEAP_ZERO_MEMORY},
SystemInformation::{GetSystemInfo, SYSTEM_INFO},
Threading::{CreateThread, INFINITE, THREAD_CREATION_FLAGS},
IO::{
CreateIoCompletionPort, GetQueuedCompletionStatus, PostQueuedCompletionStatus,
OVERLAPPED,
},
},
},
};
#[allow(non_snake_case)]
#[inline]
pub fn MAKEWORD(a: u8, b: u8) -> u16 {
(a as u16) | ((b as u16) << 8)
}
const PORT: u16 = 5150;
const MSGSIZE: usize = 1024;
#[repr(u32)]
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
enum OperationType {
RecvPosted,
}
#[repr(C)]
struct PerIOOperationData {
overlap: OVERLAPPED,
buffer: WSABUF,
message: [u8; MSGSIZE],
number_of_bytes_received: u32,
flags: u32,
operation_type: OperationType,
}
fn main() -> Result<(), windows::core::Error> {
let mut wsa_data = WSADATA::default();
unsafe { WSAStartup(MAKEWORD(2, 2), &mut wsa_data) };
let completion_port = unsafe { CreateIoCompletionPort(INVALID_HANDLE_VALUE, None, 0, 0)? };
let mut systeminfo = SYSTEM_INFO::default();
unsafe { GetSystemInfo(&mut systeminfo) };
for _ in 0..systeminfo.dwNumberOfProcessors {
let mut thread_id = 0_u32;
unsafe {
CreateThread(
None,
0,
Some(worker_thread),
Some(&completion_port as *const HANDLE as *const c_void),
THREAD_CREATION_FLAGS(0),
Some(&mut thread_id),
)?
};
}
let s_listen = unsafe { socket(AF_INET.0.into(), SOCK_STREAM, IPPROTO_TCP.0) };
let local = SOCKADDR_IN {
sin_addr: Ipv4Addr::UNSPECIFIED.into(),
sin_family: AF_INET,
sin_port: u16::from_be_bytes(PORT.to_ne_bytes()),
..Default::default()
};
unsafe {
bind(
s_listen,
&local as *const SOCKADDR_IN as *const SOCKADDR,
std::mem::size_of::<SOCKADDR_IN>() as i32,
)
};
unsafe { listen(s_listen, 3) };
loop {
let mut client = SOCKADDR_IN::default();
let mut addrlen = std::mem::size_of::<SOCKADDR_IN>() as i32;
let s_client = unsafe {
accept(
s_listen,
Some(&mut client as *mut SOCKADDR_IN as *mut SOCKADDR),
Some(&mut addrlen),
)
};
println!(
"Accepted client {:?} ({}:{})",
s_client,
unsafe { inet_ntoa(client.sin_addr).to_string()? },
unsafe { ntohs(client.sin_port) }
);
// Associate the newly arrived client socket with completion port
unsafe {
CreateIoCompletionPort(
HANDLE(s_client.0 as isize),
completion_port,
s_client.0 as usize,
0,
)?
};
// Launch an asynchronous operation for new arrived connection
unsafe {
let lp_per_iodata = HeapAlloc(
GetProcessHeap()?,
HEAP_ZERO_MEMORY,
std::mem::size_of::<PerIOOperationData>(),
);
let lp_per_iodata = lp_per_iodata as *mut PerIOOperationData;
(*lp_per_iodata).buffer.len = MSGSIZE as u32;
(*lp_per_iodata).buffer.buf = PSTR(&mut (*lp_per_iodata).message as *mut u8);
(*lp_per_iodata).operation_type = OperationType::RecvPosted;
WSARecv(
s_client,
std::slice::from_ref(&(*lp_per_iodata).buffer),
Some(&mut (*lp_per_iodata).number_of_bytes_received as *mut u32),
&mut (*lp_per_iodata).flags,
Some(&mut (*lp_per_iodata).overlap),
None,
)
};
}
#[allow(unreachable_code)]
unsafe {
PostQueuedCompletionStatus(completion_port, 0xFFFFFFFF, 0, None)?;
CloseHandle(completion_port)?;
closesocket(s_listen);
WSACleanup();
Ok(())
}
}
unsafe extern "system" fn worker_thread(lpthreadparameter: *mut c_void) -> u32 {
let completion_port = *(lpthreadparameter as *const HANDLE);
let block = || {
loop {
let mut bytes_transferred = 0_u32;
let mut s_client = 0_usize;
let mut lp_overlapped: *mut OVERLAPPED = std::ptr::null_mut();
GetQueuedCompletionStatus(
completion_port,
&mut bytes_transferred,
&mut s_client,
&mut lp_overlapped as *mut *mut OVERLAPPED,
INFINITE,
)?;
let lp_per_iodata = lp_overlapped as *mut PerIOOperationData;
let s_client = SOCKET(s_client);
if bytes_transferred == 0xFFFFFFFF {
return Ok::<(), windows::core::Error>(());
}
if (*lp_per_iodata).operation_type == OperationType::RecvPosted {
if bytes_transferred == 0 {
println!("Client {:?} exiting", s_client);
// Connection was closed by client
closesocket(s_client);
HeapFree(
GetProcessHeap()?,
HEAP_FLAGS(0),
Some(lp_per_iodata as *const c_void),
)?;
} else {
// Echo the received data back to the client
let data = &(*lp_per_iodata).message[..bytes_transferred as usize];
send(s_client, data, SEND_RECV_FLAGS(0));
// Launch another asynchronous operation for sClient
write_bytes(lp_per_iodata as *mut MaybeUninit<PerIOOperationData>, 0, 1);
(*lp_per_iodata).buffer.len = MSGSIZE as _;
(*lp_per_iodata).buffer.buf = PSTR(&mut (*lp_per_iodata).message as *mut u8);
(*lp_per_iodata).operation_type = OperationType::RecvPosted;
WSARecv(
s_client,
std::slice::from_ref(&(*lp_per_iodata).buffer),
Some(&mut (*lp_per_iodata).number_of_bytes_received as *mut u32),
&mut (*lp_per_iodata).flags,
Some(&mut (*lp_per_iodata).overlap),
None,
);
}
}
}
};
if let Err(e) = block() {
eprintln!("Error: {:?}", e);
}
0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment