Last active
January 16, 2023 06:58
-
-
Save ammarfaizi2/1ed10c9ed6423c1606a2c385fa912870 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Fresh tea file transfer. | |
*/ | |
#define _GNU_SOURCE | |
#include <stdio.h> | |
#include <assert.h> | |
#include <stdint.h> | |
#include <stdarg.h> | |
#include <string.h> | |
#include <stdlib.h> | |
#include <unistd.h> | |
#include <signal.h> | |
#include <pthread.h> | |
#include <liburing.h> | |
#include <sys/mman.h> | |
#include <stdatomic.h> | |
#include <arpa/inet.h> | |
#include <sys/socket.h> | |
#define NR_WORKERS 4 | |
#define RING_ENTRIES 32 | |
#define NR_MAX_CLIENTS 10240 | |
#define BUFFER_SIZE 8192 | |
#define TASK_COMM_LEN 16 | |
#ifndef likely | |
#define likely(COND) __builtin_expect(!!(COND), 1) | |
#endif | |
#ifndef unlikely | |
#define unlikely(COND) __builtin_expect(!!(COND), 0) | |
#endif | |
#ifndef __always_inline | |
#define __always_inline __attribute__((__always_inline__)) static inline | |
#endif | |
#ifndef noinline | |
#define noinline __attribute__((__noinline__)) | |
#endif | |
#ifndef __cold | |
#define __cold __attribute__((__cold__)) | |
#endif | |
#ifndef __hot | |
#define __hot __attribute__((__hot__)) | |
#endif | |
union tsockaddr { | |
struct sockaddr addr; | |
struct sockaddr_in addr4; | |
struct sockaddr_in6 addr6; | |
}; | |
typedef struct __attribute__((__packed__)) packet_t { | |
uint64_t file_size; | |
uint8_t file_name_len; | |
char file_name[0xffu]; | |
} packet_t; | |
union uni_pkt { | |
packet_t packet; | |
char raw_buf[BUFFER_SIZE]; | |
}; | |
struct client { | |
int fd; | |
uint32_t idx; | |
socklen_t addrlen; | |
union tsockaddr addr; | |
size_t pkt_len; | |
union uni_pkt pkt; | |
}; | |
struct worker { | |
struct io_uring ring; | |
uint32_t pending_sqe; | |
pthread_mutex_t lock; | |
struct __kernel_timespec timeout; | |
struct server_ctx *sctx; | |
pthread_t thread; | |
uint32_t idx; | |
bool need_cleanup; | |
}; | |
struct stack { | |
uint32_t sp; | |
uint32_t max_sp; | |
uint32_t *arr; | |
pthread_mutex_t lock; | |
}; | |
struct server_ctx { | |
int tcp_fd; | |
_Atomic(uint32_t) next_worker; | |
struct worker *workers; | |
volatile bool stop; | |
bool accept_in_flight; | |
uint16_t bind_port; | |
const char *bind_addr; | |
struct stack clients_stk; | |
struct client clients[]; | |
}; | |
static volatile bool *g_stop; | |
noinline __attribute__((__format__(printf, 3, 4))) | |
static void __pr_notice(const char *file, uint32_t line, const char *fmt, ...) | |
{ | |
char buf[4096]; | |
va_list ap; | |
va_start(ap, fmt); | |
vsnprintf(buf, sizeof(buf), fmt, ap); | |
printf("%s:%u %s\n", file, line, buf); | |
va_end(ap); | |
} | |
#define pr_notice(...) \ | |
do { \ | |
__pr_notice(__FILE__, __LINE__, __VA_ARGS__); \ | |
} while (0) | |
static void *zmalloc_mlocked(size_t l) | |
{ | |
void *ret; | |
ret = mmap(NULL, l, PROT_READ | PROT_WRITE, MAP_ANONYMOUS | MAP_PRIVATE, | |
-1, 0); | |
if (unlikely(ret == MAP_FAILED)) | |
return NULL; | |
mlock(ret, l); | |
memset(ret, 0, l); | |
return ret; | |
} | |
static void free_mlocked(void *ptr, size_t len) | |
{ | |
if (unlikely(!ptr)) | |
return; | |
munmap(ptr, len); | |
} | |
__cold static int init_stack(struct stack *stack, size_t nr_elements) | |
{ | |
uint32_t *arr, i; | |
size_t len; | |
int ret; | |
ret = pthread_mutex_init(&stack->lock, NULL); | |
if (unlikely(ret)) | |
return -ret; | |
len = nr_elements * sizeof(*arr); | |
arr = zmalloc_mlocked(len); | |
if (unlikely(!arr)) { | |
pthread_mutex_destroy(&stack->lock); | |
return -ENOMEM; | |
} | |
stack->arr = arr; | |
stack->sp = 0; | |
stack->max_sp = nr_elements; | |
for (i = 0; i < nr_elements; i++) | |
arr[i] = i; | |
return 0; | |
} | |
__hot static int pop_stack(struct stack *stack, uint32_t *out) | |
{ | |
int ret; | |
pthread_mutex_lock(&stack->lock); | |
if (unlikely(stack->sp == stack->max_sp)) { | |
ret = -EAGAIN; | |
} else { | |
ret = 0; | |
*out = stack->arr[stack->sp++]; | |
} | |
pthread_mutex_unlock(&stack->lock); | |
return ret; | |
} | |
__hot static int push_stack(struct stack *stack, uint32_t in) | |
{ | |
int ret; | |
pthread_mutex_lock(&stack->lock); | |
if (unlikely(!stack->sp)) { | |
ret = -EAGAIN; | |
} else { | |
ret = 0; | |
stack->arr[--stack->sp] = in; | |
} | |
pthread_mutex_unlock(&stack->lock); | |
return ret; | |
} | |
__cold static void destroy_stack(struct stack *stack) | |
{ | |
munmap(stack->arr, stack->max_sp * sizeof(*stack->arr)); | |
pthread_mutex_destroy(&stack->lock); | |
stack->arr = NULL; | |
} | |
__cold static void signal_handler(int sig) | |
{ | |
putchar('\n'); | |
pr_notice("Got signal %d", sig); | |
if (!g_stop) { | |
pr_notice("Invalid condition, g_stop is NULL!"); | |
raise(SIGABRT); | |
__builtin_unreachable(); | |
} | |
*g_stop = true; | |
} | |
__cold static int init_signal_handlers(void) | |
{ | |
struct sigaction act = { .sa_handler = signal_handler }; | |
int ret; | |
if (!g_stop) { | |
pr_notice("Invalid condition, g_stop is NULL!"); | |
return -EINVAL; | |
} | |
ret = sigaction(SIGINT, &act, NULL); | |
if (unlikely(ret)) | |
goto err; | |
ret = sigaction(SIGHUP, &act, NULL); | |
if (unlikely(ret)) | |
goto err; | |
ret = sigaction(SIGTERM, &act, NULL); | |
if (unlikely(ret)) | |
goto err; | |
act.sa_handler = SIG_IGN; | |
ret = sigaction(SIGPIPE, &act, NULL); | |
if (unlikely(ret)) | |
goto err; | |
return 0; | |
err: | |
return -errno; | |
} | |
__cold static int init_clients(struct server_ctx *sctx) | |
{ | |
struct client *clients = sctx->clients; | |
uint32_t i; | |
int ret; | |
ret = init_stack(&sctx->clients_stk, NR_MAX_CLIENTS); | |
if (unlikely(ret)) | |
return ret; | |
for (i = 0; i < NR_MAX_CLIENTS; i++) | |
clients[i].idx = i; | |
return 0; | |
} | |
static void *server_worker_entry(void *arg); | |
__cold static int init_server_worker(struct worker *worker) | |
{ | |
pthread_t *t = &worker->thread; | |
char name[TASK_COMM_LEN]; | |
int ret; | |
ret = pthread_mutex_init(&worker->lock, NULL); | |
if (unlikely(ret)) | |
return -ret; | |
ret = pthread_create(t, NULL, server_worker_entry, worker); | |
if (unlikely(ret)) { | |
pthread_mutex_destroy(&worker->lock); | |
return -ret; | |
} | |
snprintf(name, sizeof(name), "tea-wrk-%u", worker->idx); | |
pthread_setname_np(*t, name); | |
worker->need_cleanup = true; | |
return ret; | |
} | |
__cold static void destroy_server_workers(struct server_ctx *sctx) | |
{ | |
struct worker *workers = sctx->workers; | |
size_t i; | |
for (i = 0; i < NR_WORKERS; i++) { | |
if (!workers[i].need_cleanup) | |
continue; | |
pthread_join(workers[i].thread, NULL); | |
pthread_mutex_lock(&workers[i].lock); | |
io_uring_queue_exit(&workers[i].ring); | |
pthread_mutex_unlock(&workers[i].lock); | |
pthread_mutex_destroy(&workers[i].lock); | |
} | |
free_mlocked(workers, sizeof(*workers) * NR_WORKERS); | |
sctx->workers = NULL; | |
} | |
__cold static int init_server_workers(struct server_ctx *sctx) | |
{ | |
struct worker *workers; | |
int ret = 0; | |
size_t i; | |
workers = zmalloc_mlocked(sizeof(*workers) * NR_WORKERS); | |
if (unlikely(!workers)) | |
return -ENOMEM; | |
for (i = 0; i < NR_WORKERS; i++) { | |
workers[i].idx = i; | |
workers[i].sctx = sctx; | |
ret = io_uring_queue_init(RING_ENTRIES, &workers[i].ring, 0); | |
if (unlikely(ret)) | |
break; | |
/* | |
* Let the first worker run in the main thread. | |
*/ | |
if (i == 0) | |
continue; | |
ret = init_server_worker(&workers[i]); | |
if (unlikely(ret)) { | |
io_uring_queue_exit(&workers[i].ring); | |
break; | |
} | |
} | |
sctx->workers = workers; | |
if (unlikely(ret)) | |
destroy_server_workers(sctx); | |
return ret; | |
} | |
__cold static void destroy_server_ctx(struct server_ctx *sctx) | |
{ | |
size_t len; | |
if (!sctx) | |
return; | |
sctx->stop = true; | |
if (sctx->clients_stk.arr) | |
destroy_stack(&sctx->clients_stk); | |
if (sctx->workers) | |
destroy_server_workers(sctx); | |
len = sizeof(*sctx) + (sizeof(*sctx->clients) * NR_MAX_CLIENTS); | |
free_mlocked(sctx, len); | |
} | |
__cold static int init_server_socket(struct server_ctx *sctx) | |
{ | |
struct sockaddr_in saddr; | |
int ret; | |
int fd; | |
memset(&saddr, 0, sizeof(saddr)); | |
saddr.sin_family = AF_INET; | |
saddr.sin_port = htons(sctx->bind_port); | |
inet_pton(AF_INET, sctx->bind_addr, &saddr.sin_addr); | |
fd = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0); | |
if (unlikely(fd < 0)) { | |
ret = -errno; | |
pr_notice("socket(): %s", strerror(-ret)); | |
return ret; | |
} | |
ret = 1; | |
setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &ret, sizeof(ret)); | |
setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &ret, sizeof(ret)); | |
ret = bind(fd, (struct sockaddr *)&saddr, sizeof(saddr)); | |
if (unlikely(ret < 0)) { | |
ret = -errno; | |
pr_notice("bind(): %s", strerror(-ret)); | |
close(fd); | |
return ret; | |
} | |
ret = listen(fd, 1024); | |
if (unlikely(ret < 0)) { | |
ret = -errno; | |
pr_notice("listen(): %s", strerror(-ret)); | |
close(fd); | |
return ret; | |
} | |
printf("Listening %s:%hu...\n", sctx->bind_addr, sctx->bind_port); | |
sctx->tcp_fd = fd; | |
return 0; | |
} | |
__cold static int parse_arg(struct server_ctx *sctx, int argc, char *argv[]) | |
{ | |
int port; | |
if (argc >= 4) { | |
port = atoi(argv[3]); | |
if (port < 1 || port > 65535) { | |
printf("Port must be in range [1, 65535]\n"); | |
return -EINVAL; | |
} | |
} | |
if (argc == 4 && !strcmp(argv[1], "server")) { | |
sctx->bind_addr = argv[2]; | |
sctx->bind_port = (uint16_t)port; | |
return 0; | |
} | |
if (argc == 5 && !strcmp(argv[1], "client")) { | |
sctx->bind_addr = argv[2]; | |
sctx->bind_port = (uint16_t)port; | |
return 0; | |
} | |
printf("Usage:\n"); | |
printf(" ./ftransfer server [bind_addr] [bind_port]\n"); | |
printf(" ./ftransfer client [server_addr] [server_port] [filename]\n"); | |
return -EINVAL; | |
} | |
__cold static int init_server_ctx(struct server_ctx **sctx_p, int argc, | |
char *argv[]) | |
{ | |
struct server_ctx *sctx; | |
size_t len; | |
int ret; | |
len = sizeof(*sctx) + (sizeof(*sctx->clients) * NR_MAX_CLIENTS); | |
sctx = zmalloc_mlocked(len); | |
if (unlikely(!sctx)) | |
return -ENOMEM; | |
ret = parse_arg(sctx, argc, argv); | |
if (unlikely(ret)) | |
return ret; | |
g_stop = &sctx->stop; | |
ret = init_signal_handlers(); | |
if (unlikely(ret)) { | |
pr_notice("init_signal_handlers(): %s", strerror(-ret)); | |
goto err; | |
} | |
ret = init_server_socket(sctx); | |
if (unlikely(ret)) { | |
pr_notice("init_signal_handlers(): %s", strerror(-ret)); | |
goto err; | |
} | |
ret = init_clients(sctx); | |
if (unlikely(ret)) { | |
pr_notice("init_clients(): %s", strerror(-ret)); | |
goto err; | |
} | |
ret = init_server_workers(sctx); | |
if (unlikely(ret)) { | |
pr_notice("init_server_workers(): %s", strerror(-ret)); | |
goto err; | |
} | |
*sctx_p = sctx; | |
return 0; | |
err: | |
destroy_server_ctx(sctx); | |
return ret; | |
} | |
__hot static struct io_uring_sqe *io_uring_get_sqe_no_fail(struct worker *w) | |
{ | |
struct io_uring_sqe *sqe; | |
sqe = io_uring_get_sqe(&w->ring); | |
if (unlikely(!sqe)) { | |
io_uring_submit(&w->ring); | |
sqe = io_uring_get_sqe(&w->ring); | |
w->pending_sqe = 1; | |
} else { | |
w->pending_sqe++; | |
} | |
return sqe; | |
} | |
/* | |
* On close, do prep_accept. | |
*/ | |
__hot static int server_prep_accept(struct worker *w) | |
{ | |
struct server_ctx *sctx = w->sctx; | |
struct io_uring_sqe *sqe; | |
struct client *client; | |
uint32_t idx; | |
int ret; | |
if (unlikely(sctx->accept_in_flight)) | |
return 0; | |
ret = pop_stack(&sctx->clients_stk, &idx); | |
if (unlikely(ret)) | |
return ret; | |
pthread_mutex_lock(&w->lock); | |
sqe = io_uring_get_sqe_no_fail(w); | |
pthread_mutex_unlock(&w->lock); | |
client = &sctx->clients[idx]; | |
client->fd = -1; | |
client->addrlen = sizeof(client->addr.addr4); | |
io_uring_prep_accept(sqe, sctx->tcp_fd, | |
(struct sockaddr *)&client->addr.addr4, | |
&client->addrlen, 0); | |
io_uring_sqe_set_data(sqe, client); | |
sctx->accept_in_flight = true; | |
return 0; | |
} | |
__hot static int server_prep_recv_start(struct worker *w, struct client *client) | |
{ | |
struct io_uring_sqe *sqe; | |
int fd = client->fd; | |
pthread_mutex_lock(&w->lock); | |
sqe = io_uring_get_sqe_no_fail(w); | |
io_uring_prep_recv(sqe, fd, &client->pkt, sizeof(client->pkt), 0); | |
io_uring_sqe_set_data(sqe, client); | |
pthread_mutex_unlock(&w->lock); | |
return 0; | |
} | |
__hot static int handle_accept_event(struct worker *w, struct io_uring_cqe *cqe, | |
struct client *client) | |
{ | |
struct server_ctx *sctx = w->sctx; | |
int ret = cqe->res; | |
uint32_t idx; | |
sctx->accept_in_flight = false; | |
server_prep_accept(w); | |
if (unlikely(ret < 0)) { | |
pr_notice("Cannot accept a new client: %s\n", strerror(-ret)); | |
return -ret; | |
} | |
client->fd = ret; | |
idx = atomic_fetch_add(&sctx->next_worker, 1) % NR_WORKERS; | |
return server_prep_recv_start(&sctx->workers[idx], client); | |
} | |
__hot static int handle_client_event(struct worker *w, struct io_uring_cqe *cqe, | |
struct client *client) | |
{ | |
struct server_ctx *sctx; | |
int ret = cqe->res; | |
if (unlikely(ret <= 0)) | |
goto out_close; | |
return 0; | |
out_close: | |
sctx = w->sctx; | |
if (unlikely(ret < 0)) | |
pr_notice("Error: %s", strerror(-ret)); | |
close(client->fd); | |
ret = push_stack(&sctx->clients_stk, client->idx); | |
assert(ret == 0); | |
return 0; | |
} | |
__hot static int handle_event(struct worker *w, struct io_uring_cqe *cqe) | |
{ | |
struct client *client; | |
client = io_uring_cqe_get_data(cqe); | |
if (client->fd == -1) | |
return handle_accept_event(w, cqe, client); | |
else | |
return handle_client_event(w, cqe, client); | |
} | |
__hot static int server_handle_event_loop(struct worker *w) | |
{ | |
struct io_uring *ring = &w->ring; | |
struct io_uring_cqe *cqe; | |
uint32_t head; | |
uint32_t i; | |
int ret; | |
ret = io_uring_wait_cqe_timeout(ring, &cqe, &w->timeout); | |
if (unlikely(ret)) { | |
if (likely(ret == -ETIME || ret == -EINTR)) | |
return 0; | |
return ret; | |
} | |
i = 0; | |
io_uring_for_each_cqe(ring, head, cqe) { | |
i++; | |
ret = handle_event(w, cqe); | |
if (unlikely(ret)) | |
break; | |
} | |
io_uring_cq_advance(ring, i); | |
return ret; | |
} | |
__hot noinline static void *server_worker_entry(void *arg) | |
{ | |
struct worker *w = arg; | |
int ret = 0; | |
w->timeout.tv_sec = 1; | |
w->timeout.tv_nsec = 0; | |
if (w->idx == 0) | |
server_prep_accept(w); | |
while (!*g_stop) { | |
if (w->pending_sqe) { | |
ret = io_uring_submit(&w->ring); | |
assert(w->pending_sqe == (uint32_t)ret); | |
w->pending_sqe = 0; | |
} | |
ret = server_handle_event_loop(w); | |
if (unlikely(ret)) { | |
pr_notice("server_handle_event_loop(): %s\n", | |
strerror(-ret)); | |
break; | |
} | |
} | |
if (ret) | |
*g_stop = true; | |
pr_notice("Thread %u is exiting...", w->idx); | |
return NULL; | |
} | |
int main(int argc, char *argv[]) | |
{ | |
struct server_ctx *sctx = NULL; | |
int ret; | |
setvbuf(stdout, NULL, _IOLBF, 1024); | |
setvbuf(stderr, NULL, _IOLBF, 1024); | |
ret = init_server_ctx(&sctx, argc, argv); | |
if (unlikely(ret)) | |
return -ret; | |
server_worker_entry(&sctx->workers[0]); | |
destroy_server_ctx(sctx); | |
if (ret < 0) | |
ret = -ret; | |
return ret; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment