Skip to content

Instantly share code, notes, and snippets.

@ammarfaizi2
Last active January 16, 2023 06:58
Show Gist options
  • Save ammarfaizi2/1ed10c9ed6423c1606a2c385fa912870 to your computer and use it in GitHub Desktop.
Save ammarfaizi2/1ed10c9ed6423c1606a2c385fa912870 to your computer and use it in GitHub Desktop.
/**
* Fresh tea file transfer.
*/
#define _GNU_SOURCE
#include <stdio.h>
#include <assert.h>
#include <stdint.h>
#include <stdarg.h>
#include <string.h>
#include <stdlib.h>
#include <unistd.h>
#include <signal.h>
#include <pthread.h>
#include <liburing.h>
#include <sys/mman.h>
#include <stdatomic.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#define NR_WORKERS 4
#define RING_ENTRIES 32
#define NR_MAX_CLIENTS 10240
#define BUFFER_SIZE 8192
#define TASK_COMM_LEN 16
#ifndef likely
#define likely(COND) __builtin_expect(!!(COND), 1)
#endif
#ifndef unlikely
#define unlikely(COND) __builtin_expect(!!(COND), 0)
#endif
#ifndef __always_inline
#define __always_inline __attribute__((__always_inline__)) static inline
#endif
#ifndef noinline
#define noinline __attribute__((__noinline__))
#endif
#ifndef __cold
#define __cold __attribute__((__cold__))
#endif
#ifndef __hot
#define __hot __attribute__((__hot__))
#endif
union tsockaddr {
struct sockaddr addr;
struct sockaddr_in addr4;
struct sockaddr_in6 addr6;
};
typedef struct __attribute__((__packed__)) packet_t {
uint64_t file_size;
uint8_t file_name_len;
char file_name[0xffu];
} packet_t;
union uni_pkt {
packet_t packet;
char raw_buf[BUFFER_SIZE];
};
struct client {
int fd;
uint32_t idx;
socklen_t addrlen;
union tsockaddr addr;
size_t pkt_len;
union uni_pkt pkt;
};
struct worker {
struct io_uring ring;
uint32_t pending_sqe;
pthread_mutex_t lock;
struct __kernel_timespec timeout;
struct server_ctx *sctx;
pthread_t thread;
uint32_t idx;
bool need_cleanup;
};
struct stack {
uint32_t sp;
uint32_t max_sp;
uint32_t *arr;
pthread_mutex_t lock;
};
struct server_ctx {
int tcp_fd;
_Atomic(uint32_t) next_worker;
struct worker *workers;
volatile bool stop;
bool accept_in_flight;
uint16_t bind_port;
const char *bind_addr;
struct stack clients_stk;
struct client clients[];
};
static volatile bool *g_stop;
noinline __attribute__((__format__(printf, 3, 4)))
static void __pr_notice(const char *file, uint32_t line, const char *fmt, ...)
{
char buf[4096];
va_list ap;
va_start(ap, fmt);
vsnprintf(buf, sizeof(buf), fmt, ap);
printf("%s:%u %s\n", file, line, buf);
va_end(ap);
}
#define pr_notice(...) \
do { \
__pr_notice(__FILE__, __LINE__, __VA_ARGS__); \
} while (0)
static void *zmalloc_mlocked(size_t l)
{
void *ret;
ret = mmap(NULL, l, PROT_READ | PROT_WRITE, MAP_ANONYMOUS | MAP_PRIVATE,
-1, 0);
if (unlikely(ret == MAP_FAILED))
return NULL;
mlock(ret, l);
memset(ret, 0, l);
return ret;
}
static void free_mlocked(void *ptr, size_t len)
{
if (unlikely(!ptr))
return;
munmap(ptr, len);
}
__cold static int init_stack(struct stack *stack, size_t nr_elements)
{
uint32_t *arr, i;
size_t len;
int ret;
ret = pthread_mutex_init(&stack->lock, NULL);
if (unlikely(ret))
return -ret;
len = nr_elements * sizeof(*arr);
arr = zmalloc_mlocked(len);
if (unlikely(!arr)) {
pthread_mutex_destroy(&stack->lock);
return -ENOMEM;
}
stack->arr = arr;
stack->sp = 0;
stack->max_sp = nr_elements;
for (i = 0; i < nr_elements; i++)
arr[i] = i;
return 0;
}
__hot static int pop_stack(struct stack *stack, uint32_t *out)
{
int ret;
pthread_mutex_lock(&stack->lock);
if (unlikely(stack->sp == stack->max_sp)) {
ret = -EAGAIN;
} else {
ret = 0;
*out = stack->arr[stack->sp++];
}
pthread_mutex_unlock(&stack->lock);
return ret;
}
__hot static int push_stack(struct stack *stack, uint32_t in)
{
int ret;
pthread_mutex_lock(&stack->lock);
if (unlikely(!stack->sp)) {
ret = -EAGAIN;
} else {
ret = 0;
stack->arr[--stack->sp] = in;
}
pthread_mutex_unlock(&stack->lock);
return ret;
}
__cold static void destroy_stack(struct stack *stack)
{
munmap(stack->arr, stack->max_sp * sizeof(*stack->arr));
pthread_mutex_destroy(&stack->lock);
stack->arr = NULL;
}
__cold static void signal_handler(int sig)
{
putchar('\n');
pr_notice("Got signal %d", sig);
if (!g_stop) {
pr_notice("Invalid condition, g_stop is NULL!");
raise(SIGABRT);
__builtin_unreachable();
}
*g_stop = true;
}
__cold static int init_signal_handlers(void)
{
struct sigaction act = { .sa_handler = signal_handler };
int ret;
if (!g_stop) {
pr_notice("Invalid condition, g_stop is NULL!");
return -EINVAL;
}
ret = sigaction(SIGINT, &act, NULL);
if (unlikely(ret))
goto err;
ret = sigaction(SIGHUP, &act, NULL);
if (unlikely(ret))
goto err;
ret = sigaction(SIGTERM, &act, NULL);
if (unlikely(ret))
goto err;
act.sa_handler = SIG_IGN;
ret = sigaction(SIGPIPE, &act, NULL);
if (unlikely(ret))
goto err;
return 0;
err:
return -errno;
}
__cold static int init_clients(struct server_ctx *sctx)
{
struct client *clients = sctx->clients;
uint32_t i;
int ret;
ret = init_stack(&sctx->clients_stk, NR_MAX_CLIENTS);
if (unlikely(ret))
return ret;
for (i = 0; i < NR_MAX_CLIENTS; i++)
clients[i].idx = i;
return 0;
}
static void *server_worker_entry(void *arg);
__cold static int init_server_worker(struct worker *worker)
{
pthread_t *t = &worker->thread;
char name[TASK_COMM_LEN];
int ret;
ret = pthread_mutex_init(&worker->lock, NULL);
if (unlikely(ret))
return -ret;
ret = pthread_create(t, NULL, server_worker_entry, worker);
if (unlikely(ret)) {
pthread_mutex_destroy(&worker->lock);
return -ret;
}
snprintf(name, sizeof(name), "tea-wrk-%u", worker->idx);
pthread_setname_np(*t, name);
worker->need_cleanup = true;
return ret;
}
__cold static void destroy_server_workers(struct server_ctx *sctx)
{
struct worker *workers = sctx->workers;
size_t i;
for (i = 0; i < NR_WORKERS; i++) {
if (!workers[i].need_cleanup)
continue;
pthread_join(workers[i].thread, NULL);
pthread_mutex_lock(&workers[i].lock);
io_uring_queue_exit(&workers[i].ring);
pthread_mutex_unlock(&workers[i].lock);
pthread_mutex_destroy(&workers[i].lock);
}
free_mlocked(workers, sizeof(*workers) * NR_WORKERS);
sctx->workers = NULL;
}
__cold static int init_server_workers(struct server_ctx *sctx)
{
struct worker *workers;
int ret = 0;
size_t i;
workers = zmalloc_mlocked(sizeof(*workers) * NR_WORKERS);
if (unlikely(!workers))
return -ENOMEM;
for (i = 0; i < NR_WORKERS; i++) {
workers[i].idx = i;
workers[i].sctx = sctx;
ret = io_uring_queue_init(RING_ENTRIES, &workers[i].ring, 0);
if (unlikely(ret))
break;
/*
* Let the first worker run in the main thread.
*/
if (i == 0)
continue;
ret = init_server_worker(&workers[i]);
if (unlikely(ret)) {
io_uring_queue_exit(&workers[i].ring);
break;
}
}
sctx->workers = workers;
if (unlikely(ret))
destroy_server_workers(sctx);
return ret;
}
__cold static void destroy_server_ctx(struct server_ctx *sctx)
{
size_t len;
if (!sctx)
return;
sctx->stop = true;
if (sctx->clients_stk.arr)
destroy_stack(&sctx->clients_stk);
if (sctx->workers)
destroy_server_workers(sctx);
len = sizeof(*sctx) + (sizeof(*sctx->clients) * NR_MAX_CLIENTS);
free_mlocked(sctx, len);
}
__cold static int init_server_socket(struct server_ctx *sctx)
{
struct sockaddr_in saddr;
int ret;
int fd;
memset(&saddr, 0, sizeof(saddr));
saddr.sin_family = AF_INET;
saddr.sin_port = htons(sctx->bind_port);
inet_pton(AF_INET, sctx->bind_addr, &saddr.sin_addr);
fd = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0);
if (unlikely(fd < 0)) {
ret = -errno;
pr_notice("socket(): %s", strerror(-ret));
return ret;
}
ret = 1;
setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &ret, sizeof(ret));
setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &ret, sizeof(ret));
ret = bind(fd, (struct sockaddr *)&saddr, sizeof(saddr));
if (unlikely(ret < 0)) {
ret = -errno;
pr_notice("bind(): %s", strerror(-ret));
close(fd);
return ret;
}
ret = listen(fd, 1024);
if (unlikely(ret < 0)) {
ret = -errno;
pr_notice("listen(): %s", strerror(-ret));
close(fd);
return ret;
}
printf("Listening %s:%hu...\n", sctx->bind_addr, sctx->bind_port);
sctx->tcp_fd = fd;
return 0;
}
__cold static int parse_arg(struct server_ctx *sctx, int argc, char *argv[])
{
int port;
if (argc >= 4) {
port = atoi(argv[3]);
if (port < 1 || port > 65535) {
printf("Port must be in range [1, 65535]\n");
return -EINVAL;
}
}
if (argc == 4 && !strcmp(argv[1], "server")) {
sctx->bind_addr = argv[2];
sctx->bind_port = (uint16_t)port;
return 0;
}
if (argc == 5 && !strcmp(argv[1], "client")) {
sctx->bind_addr = argv[2];
sctx->bind_port = (uint16_t)port;
return 0;
}
printf("Usage:\n");
printf(" ./ftransfer server [bind_addr] [bind_port]\n");
printf(" ./ftransfer client [server_addr] [server_port] [filename]\n");
return -EINVAL;
}
__cold static int init_server_ctx(struct server_ctx **sctx_p, int argc,
char *argv[])
{
struct server_ctx *sctx;
size_t len;
int ret;
len = sizeof(*sctx) + (sizeof(*sctx->clients) * NR_MAX_CLIENTS);
sctx = zmalloc_mlocked(len);
if (unlikely(!sctx))
return -ENOMEM;
ret = parse_arg(sctx, argc, argv);
if (unlikely(ret))
return ret;
g_stop = &sctx->stop;
ret = init_signal_handlers();
if (unlikely(ret)) {
pr_notice("init_signal_handlers(): %s", strerror(-ret));
goto err;
}
ret = init_server_socket(sctx);
if (unlikely(ret)) {
pr_notice("init_signal_handlers(): %s", strerror(-ret));
goto err;
}
ret = init_clients(sctx);
if (unlikely(ret)) {
pr_notice("init_clients(): %s", strerror(-ret));
goto err;
}
ret = init_server_workers(sctx);
if (unlikely(ret)) {
pr_notice("init_server_workers(): %s", strerror(-ret));
goto err;
}
*sctx_p = sctx;
return 0;
err:
destroy_server_ctx(sctx);
return ret;
}
__hot static struct io_uring_sqe *io_uring_get_sqe_no_fail(struct worker *w)
{
struct io_uring_sqe *sqe;
sqe = io_uring_get_sqe(&w->ring);
if (unlikely(!sqe)) {
io_uring_submit(&w->ring);
sqe = io_uring_get_sqe(&w->ring);
w->pending_sqe = 1;
} else {
w->pending_sqe++;
}
return sqe;
}
/*
* On close, do prep_accept.
*/
__hot static int server_prep_accept(struct worker *w)
{
struct server_ctx *sctx = w->sctx;
struct io_uring_sqe *sqe;
struct client *client;
uint32_t idx;
int ret;
if (unlikely(sctx->accept_in_flight))
return 0;
ret = pop_stack(&sctx->clients_stk, &idx);
if (unlikely(ret))
return ret;
pthread_mutex_lock(&w->lock);
sqe = io_uring_get_sqe_no_fail(w);
pthread_mutex_unlock(&w->lock);
client = &sctx->clients[idx];
client->fd = -1;
client->addrlen = sizeof(client->addr.addr4);
io_uring_prep_accept(sqe, sctx->tcp_fd,
(struct sockaddr *)&client->addr.addr4,
&client->addrlen, 0);
io_uring_sqe_set_data(sqe, client);
sctx->accept_in_flight = true;
return 0;
}
__hot static int server_prep_recv_start(struct worker *w, struct client *client)
{
struct io_uring_sqe *sqe;
int fd = client->fd;
pthread_mutex_lock(&w->lock);
sqe = io_uring_get_sqe_no_fail(w);
io_uring_prep_recv(sqe, fd, &client->pkt, sizeof(client->pkt), 0);
io_uring_sqe_set_data(sqe, client);
pthread_mutex_unlock(&w->lock);
return 0;
}
__hot static int handle_accept_event(struct worker *w, struct io_uring_cqe *cqe,
struct client *client)
{
struct server_ctx *sctx = w->sctx;
int ret = cqe->res;
uint32_t idx;
sctx->accept_in_flight = false;
server_prep_accept(w);
if (unlikely(ret < 0)) {
pr_notice("Cannot accept a new client: %s\n", strerror(-ret));
return -ret;
}
client->fd = ret;
idx = atomic_fetch_add(&sctx->next_worker, 1) % NR_WORKERS;
return server_prep_recv_start(&sctx->workers[idx], client);
}
__hot static int handle_client_event(struct worker *w, struct io_uring_cqe *cqe,
struct client *client)
{
struct server_ctx *sctx;
int ret = cqe->res;
if (unlikely(ret <= 0))
goto out_close;
return 0;
out_close:
sctx = w->sctx;
if (unlikely(ret < 0))
pr_notice("Error: %s", strerror(-ret));
close(client->fd);
ret = push_stack(&sctx->clients_stk, client->idx);
assert(ret == 0);
return 0;
}
__hot static int handle_event(struct worker *w, struct io_uring_cqe *cqe)
{
struct client *client;
client = io_uring_cqe_get_data(cqe);
if (client->fd == -1)
return handle_accept_event(w, cqe, client);
else
return handle_client_event(w, cqe, client);
}
__hot static int server_handle_event_loop(struct worker *w)
{
struct io_uring *ring = &w->ring;
struct io_uring_cqe *cqe;
uint32_t head;
uint32_t i;
int ret;
ret = io_uring_wait_cqe_timeout(ring, &cqe, &w->timeout);
if (unlikely(ret)) {
if (likely(ret == -ETIME || ret == -EINTR))
return 0;
return ret;
}
i = 0;
io_uring_for_each_cqe(ring, head, cqe) {
i++;
ret = handle_event(w, cqe);
if (unlikely(ret))
break;
}
io_uring_cq_advance(ring, i);
return ret;
}
__hot noinline static void *server_worker_entry(void *arg)
{
struct worker *w = arg;
int ret = 0;
w->timeout.tv_sec = 1;
w->timeout.tv_nsec = 0;
if (w->idx == 0)
server_prep_accept(w);
while (!*g_stop) {
if (w->pending_sqe) {
ret = io_uring_submit(&w->ring);
assert(w->pending_sqe == (uint32_t)ret);
w->pending_sqe = 0;
}
ret = server_handle_event_loop(w);
if (unlikely(ret)) {
pr_notice("server_handle_event_loop(): %s\n",
strerror(-ret));
break;
}
}
if (ret)
*g_stop = true;
pr_notice("Thread %u is exiting...", w->idx);
return NULL;
}
int main(int argc, char *argv[])
{
struct server_ctx *sctx = NULL;
int ret;
setvbuf(stdout, NULL, _IOLBF, 1024);
setvbuf(stderr, NULL, _IOLBF, 1024);
ret = init_server_ctx(&sctx, argc, argv);
if (unlikely(ret))
return -ret;
server_worker_entry(&sctx->workers[0]);
destroy_server_ctx(sctx);
if (ret < 0)
ret = -ret;
return ret;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment