Skip to content

Instantly share code, notes, and snippets.

@ammarfaizi2
Last active January 17, 2023 20:29
Show Gist options
  • Save ammarfaizi2/83f9f61e1c4941060a055eae71c2999b to your computer and use it in GitHub Desktop.
Save ammarfaizi2/83f9f61e1c4941060a055eae71c2999b to your computer and use it in GitHub Desktop.
// SPDX-License-Identifier: GPL-2.0-only
/*
*
* Copyright (C) 2023 Ammar Faizi <[email protected]>
*
* This is the freshtea file transfer. It's just a simple file
* transfer app using io_uring.
*
*/
#define FRESHTEA_DEBUG 0
#define NR_WORKERS 2
#define RING_ENTRIES 32
#define NR_MAX_CLIENTS 16
#define TASK_COMM_LEN 16
#define UPLOAD_DIR "uploads"
#define USE_DIRECT_MMAP_ALLOCATOR 1
#if !FRESHTEA_DEBUG
#define NDEBUG
#endif
#define _GNU_SOURCE
#include <stdio.h>
#include <assert.h>
#include <stdint.h>
#include <stdarg.h>
#include <string.h>
#include <stdlib.h>
#include <unistd.h>
#include <signal.h>
#include <endian.h>
#include <pthread.h>
#include <liburing.h>
#include <sys/mman.h>
#include <stdatomic.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#if NR_WORKERS <= 2
#define pthread_mutex_lock(LOCK) do { (void)(LOCK); } while (0)
#define pthread_mutex_unlock(LOCK) do { (void)(LOCK); } while (0)
#endif
#ifndef likely
#define likely(COND) __builtin_expect(!!(COND), 1)
#endif
#ifndef unlikely
#define unlikely(COND) __builtin_expect(!!(COND), 0)
#endif
/*
* TODO(ammarfaizi2): Add sparse context lock check.
*/
#ifndef __must_hold
#define __must_hold(LOCK)
#endif
#ifndef __always_inline
#define __always_inline __attribute__((__always_inline__)) static inline
#endif
#ifndef noinline
#define noinline __attribute__((__noinline__))
#endif
#ifndef __cold
#define __cold __attribute__((__cold__))
#endif
#ifndef __hot
#define __hot __attribute__((__hot__))
#endif
#ifndef __packed
#define __packed __attribute__((__packed__))
#endif
#define MIN_T(T, A, B) ({ \
T ___a = (T)(A); \
T ___b = (T)(B); \
\
((___a > ___b) ? ___b : ___a); \
})
#define SALIGNMENT 64
#ifndef ____cacheline_aligned_in_smp
#define ____cacheline_aligned_in_smp __attribute__((__aligned__(SALIGNMENT)))
#endif
enum {
EV_CQ_ACCEPT = (1ull << 0ull),
EV_CQ_WRITE = (1ull << 1ull),
EV_CQ_RECV = (1ull << 2ull),
EV_CQ_CLOSE = (1ull << 3ull),
/* This goes last, obviously. */
EV_CQ_LAST
};
#define EV_CQ_ALL (EV_CQ_ACCEPT | EV_CQ_WRITE | EV_CQ_RECV | EV_CQ_CLOSE)
union tsockaddr {
struct sockaddr addr;
struct sockaddr_in addr4;
struct sockaddr_in6 addr6;
};
typedef struct __packed packet_t {
uint64_t file_size;
uint8_t file_name_len;
char file_name[256];
} packet_t;
union uni_pkt {
packet_t packet;
char raw_buf[65536];
};
struct stack {
uint32_t sp;
uint32_t max_sp;
uint32_t *arr;
pthread_mutex_t lock;
};
struct client {
int tcp_fd;
int file_fd;
uint32_t idx;
socklen_t addrlen;
union tsockaddr addr;
uint64_t file_size;
uint32_t ref_count;
size_t pkt_len;
uint64_t pending_file_size;
size_t pending_write_len;
bool got_file_info;
bool write_in_flight;
bool recv_in_flight;
uint8_t recv_idx;
union uni_pkt pkt[2];
} ____cacheline_aligned_in_smp;
struct worker {
struct io_uring ring;
uint32_t nr_pending_sqe;
pthread_mutex_t lock;
struct server_ctx *sctx;
pthread_t thread;
uint32_t idx;
bool need_cleanup;
} ____cacheline_aligned_in_smp;
struct server_ctx {
int tcp_fd;
uint32_t next_worker;
struct worker *workers;
volatile bool stop;
bool accept_is_in_flight;
struct stack clients_stk;
struct client clients[] ____cacheline_aligned_in_smp;
};
struct client_ctx {
struct io_uring ring;
bool need_ring_cleanup;
int tcp_fd;
packet_t pkt_hdr;
char *map;
uint64_t file_size;
uint64_t send_pos;
};
static pthread_mutex_t g_print_lock = PTHREAD_MUTEX_INITIALIZER;
static volatile bool *g_stop;
static __thread uint32_t g_thread_idx;
noinline __attribute__((__format__(printf, 3, 4)))
static void __pr_notice(const char *file, uint32_t line, const char *fmt, ...)
{
char buf[4096];
va_list ap;
va_start(ap, fmt);
vsnprintf(buf, sizeof(buf), fmt, ap);
pthread_mutex_lock(&g_print_lock);
printf("[T%05u][ %s:%05u ]: %s\n", g_thread_idx, file, line, buf);
pthread_mutex_unlock(&g_print_lock);
va_end(ap);
}
#define pr_notice(...) \
do { \
__pr_notice(__FILE__, __LINE__, __VA_ARGS__); \
} while (0)
#if FRESHTEA_DEBUG
#define pr_debug(...) pr_notice(__VA_ARGS__)
#else
#define pr_debug(...) do {} while (0)
#endif
__cold static void signal_handler(int sig)
{
if (!g_stop) {
pr_notice("Invalid condition, g_stop is NULL! (sig = %d)", sig);
raise(SIGABRT);
__builtin_unreachable();
}
if (*g_stop)
return;
*g_stop = true;
}
static void *zmalloc_mlocked(size_t l)
{
void *ret;
#if USE_DIRECT_MMAP_ALLOCATOR
ret = mmap(NULL, l, PROT_READ | PROT_WRITE,
MAP_POPULATE | MAP_ANONYMOUS | MAP_PRIVATE | MAP_LOCKED,
-1, 0);
if (unlikely(ret == MAP_FAILED))
return NULL;
#else
if (unlikely(posix_memalign(&ret, 4096, l)))
return NULL;
#endif
mlock(ret, l);
memset(ret, 0, l);
return ret;
}
static void free_mlocked(void *ptr, size_t len)
{
if (unlikely(!ptr))
return;
#if USE_DIRECT_MMAP_ALLOCATOR
munmap(ptr, len);
#else
munlock(ptr, len);
free(ptr);
#endif
}
static int init_server_socket(struct server_ctx *sctx, char *argv[])
{
struct sockaddr_in saddr;
int ret;
int fd;
ret = atoi(argv[3]);
if (unlikely(ret < 1 || ret > 65535)) {
pr_notice("The port must be in range [1, 65535]");
return -EINVAL;
}
memset(&saddr, 0, sizeof(saddr));
saddr.sin_family = AF_INET;
saddr.sin_port = htons((uint16_t)ret);
ret = inet_pton(AF_INET, argv[2], &saddr.sin_addr);
if (unlikely(ret <= 0)) {
if (ret < 0)
return -errno;
pr_notice("Invalid IPv4 address: %s\n", argv[2]);
return -EINVAL;
}
fd = socket(AF_INET, SOCK_STREAM, 0);
if (unlikely(fd < 0)) {
ret = errno;
pr_notice("socket(): %s", strerror(ret));
return -ret;
}
ret = 1;
setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &ret, sizeof(ret));
setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &ret, sizeof(ret));
ret = bind(fd, (struct sockaddr *)&saddr, sizeof(saddr));
if (unlikely(ret < 0)) {
ret = errno;
pr_notice("bind(): %s", strerror(ret));
goto err;
}
ret = listen(fd, 1024);
if (unlikely(ret < 0)) {
ret = errno;
pr_notice("listen(): %s", strerror(ret));
goto err;
}
pr_notice("Listening on %s:%s...", argv[2], argv[3]);
sctx->tcp_fd = fd;
return 0;
err:
close(fd);
return -ret;
}
__cold static int init_signal_handlers(void)
{
struct sigaction act = { .sa_handler = signal_handler };
int ret;
if (!g_stop) {
pr_notice("Invalid condition, g_stop is NULL!");
return -EINVAL;
}
ret = sigaction(SIGINT, &act, NULL);
if (unlikely(ret))
goto err;
ret = sigaction(SIGHUP, &act, NULL);
if (unlikely(ret))
goto err;
ret = sigaction(SIGTERM, &act, NULL);
if (unlikely(ret))
goto err;
act.sa_handler = SIG_IGN;
ret = sigaction(SIGPIPE, &act, NULL);
if (unlikely(ret))
goto err;
return 0;
err:
return -errno;
}
__cold static int init_stack(struct stack *stack, size_t nr_elements)
{
uint32_t *arr, i;
size_t len;
int ret;
ret = pthread_mutex_init(&stack->lock, NULL);
if (unlikely(ret))
return -ret;
len = nr_elements * sizeof(*arr);
arr = zmalloc_mlocked(len);
if (unlikely(!arr)) {
pthread_mutex_destroy(&stack->lock);
return -ENOMEM;
}
stack->arr = arr;
stack->sp = 0;
stack->max_sp = nr_elements;
for (i = 0; i < nr_elements; i++)
arr[i] = i;
return 0;
}
__hot static int pop_stack(struct stack *stack, uint32_t *out)
{
int ret;
pthread_mutex_lock(&stack->lock);
if (unlikely(stack->sp == stack->max_sp)) {
ret = -EAGAIN;
} else {
ret = 0;
*out = stack->arr[stack->sp++];
}
pthread_mutex_unlock(&stack->lock);
return ret;
}
__hot static int push_stack(struct stack *stack, uint32_t in)
{
int ret;
pthread_mutex_lock(&stack->lock);
if (unlikely(!stack->sp)) {
ret = -EAGAIN;
} else {
ret = 0;
stack->arr[--stack->sp] = in;
}
pthread_mutex_unlock(&stack->lock);
return ret;
}
__cold static void destroy_stack(struct stack *stack)
{
munmap(stack->arr, stack->max_sp * sizeof(*stack->arr));
pthread_mutex_destroy(&stack->lock);
stack->arr = NULL;
}
__cold static int init_upload_directory(struct server_ctx *sctx)
{
struct stat st;
int ret;
ret = stat(UPLOAD_DIR, &st);
if (unlikely(ret < 0)) {
pr_notice("UPLOAD_DIR \"" UPLOAD_DIR "\" cannot be stat'ed");
pr_notice("Creating directory: \"" UPLOAD_DIR "\"");
ret = mkdir(UPLOAD_DIR, 0755);
if (unlikely(ret < 0)) {
ret = errno;
pr_notice("mkdir(): %s", strerror(ret));
return -ret;
}
ret = stat(UPLOAD_DIR, &st);
if (unlikely(ret < 0)) {
ret = errno;
pr_notice("stat(): %s", strerror(ret));
return -ret;
}
}
if (unlikely(!S_ISDIR(st.st_mode))) {
pr_notice("UPLOAD_DIR \"" UPLOAD_DIR "\" is not a directory");
return -ENOTDIR;
}
(void)sctx;
return 0;
}
__cold static int init_server_client_slots(struct server_ctx *sctx)
{
struct client *clients = sctx->clients;
uint32_t i;
int ret;
ret = init_stack(&sctx->clients_stk, NR_MAX_CLIENTS);
if (unlikely(ret))
return ret;
for (i = 0; i < NR_MAX_CLIENTS; i++) {
clients[i].idx = i;
clients[i].tcp_fd = -1;
clients[i].file_fd = -1;
}
return 0;
}
noinline static void *server_worker_entry(void *arg);
__cold static int init_server_worker(struct worker *worker)
{
pthread_t *t = &worker->thread;
char name[TASK_COMM_LEN];
int ret;
ret = pthread_mutex_init(&worker->lock, NULL);
if (unlikely(ret))
return -ret;
ret = pthread_create(t, NULL, server_worker_entry, worker);
if (unlikely(ret)) {
pthread_mutex_destroy(&worker->lock);
return -ret;
}
snprintf(name, sizeof(name), "tea-wrk-%u", worker->idx);
pthread_setname_np(*t, name);
worker->need_cleanup = true;
return ret;
}
__cold static int init_server_workers(struct server_ctx *sctx)
{
struct worker *workers;
struct worker *w;
int ret = 0;
size_t i;
workers = zmalloc_mlocked(sizeof(*workers) * NR_WORKERS);
if (unlikely(!workers))
return -ENOMEM;
for (i = 0; i < NR_WORKERS; i++) {
unsigned int max[2];
w = &workers[i];
w->idx = i;
w->sctx = sctx;
ret = io_uring_queue_init(RING_ENTRIES, &w->ring, 0);
if (unlikely(ret))
break;
max[0] = 1;
max[1] = 1;
io_uring_register_iowq_max_workers(&w->ring, max);
/*
* Let the first worker run in the main thread.
*/
if (i == 0)
continue;
ret = init_server_worker(&workers[i]);
if (unlikely(ret)) {
io_uring_queue_exit(&workers[i].ring);
break;
}
}
/*
* If @ret != 0, we hit an error, but the caller
* is still responsible for the cleanup.
*/
sctx->workers = workers;
return ret;
}
__cold static void close_all_client_slots(struct server_ctx *sctx)
{
struct client *clients = sctx->clients;
struct client *c;
size_t i;
for (i = 0; i < NR_MAX_CLIENTS; i++) {
c = &clients[i];
if (c->file_fd >= 0) {
pr_notice("Closing file fd (fd=%d)", c->file_fd);
close(c->file_fd);
}
/*
* c->tcp_fd == -1 doesn't need to be closed and pushed.
*
* c->tcp_fd == -2 doesn't need to be closed, but needs to be
* pushed back onto the stack.
*
*/
if (c->tcp_fd == -1)
continue;
if (c->tcp_fd >= 0) {
pr_notice("Closing client (fd=%d)", c->tcp_fd);
close(c->tcp_fd);
}
push_stack(&sctx->clients_stk, c->idx);
}
}
__cold static void destroy_server_workers(struct server_ctx *sctx)
{
struct worker *workers = sctx->workers;
size_t i;
for (i = 0; i < NR_WORKERS; i++) {
if (!workers[i].need_cleanup)
continue;
pthread_kill(workers[i].thread, SIGTERM);
pthread_join(workers[i].thread, NULL);
pthread_mutex_lock(&workers[i].lock);
io_uring_queue_exit(&workers[i].ring);
pthread_mutex_unlock(&workers[i].lock);
pthread_mutex_destroy(&workers[i].lock);
}
free_mlocked(workers, sizeof(*workers) * NR_WORKERS);
sctx->workers = NULL;
}
__cold static void destroy_server_context(struct server_ctx *sctx)
{
size_t len;
if (sctx->tcp_fd != -1) {
pr_notice("Closing the main TCP socket (fd=%d)", sctx->tcp_fd);
close(sctx->tcp_fd);
}
if (sctx->workers)
destroy_server_workers(sctx);
if (sctx->clients_stk.sp > 0) {
close_all_client_slots(sctx);
assert(sctx->clients_stk.sp == 0);
}
if (sctx->clients_stk.arr)
destroy_stack(&sctx->clients_stk);
len = sizeof(*sctx) + (sizeof(*sctx->clients) * NR_MAX_CLIENTS);
free_mlocked(sctx, len);
}
static int _run_server(char *argv[])
{
struct server_ctx *sctx;
size_t len;
int ret;
len = sizeof(*sctx) + (sizeof(*sctx->clients) * NR_MAX_CLIENTS);
sctx = zmalloc_mlocked(len);
if (unlikely(!sctx))
return -ENOMEM;
g_stop = &sctx->stop;
sctx->tcp_fd = -1;
ret = init_server_socket(sctx, argv);
if (unlikely(ret)) {
pr_notice("init_server_socket(): %s", strerror(-ret));
goto out;
}
ret = init_signal_handlers();
if (unlikely(ret)) {
pr_notice("init_signal_handlers(): %s", strerror(-ret));
goto out;
}
ret = init_server_client_slots(sctx);
if (unlikely(ret)) {
pr_notice("init_client_slots(): %s", strerror(-ret));
goto out;
}
ret = init_upload_directory(sctx);
if (unlikely(ret)) {
pr_notice("init_upload_directory(): %s", strerror(-ret));
goto out;
}
ret = init_server_workers(sctx);
if (unlikely(ret)) {
pr_notice("init_server_workers(): %s", strerror(-ret));
goto out;
}
server_worker_entry(&sctx->workers[0]);
out:
destroy_server_context(sctx);
return ret;
}
__hot static struct io_uring_sqe *io_uring_get_sqe_no_fail(struct worker *w)
__must_hold(&w->lock)
{
struct io_uring_sqe *sqe;
sqe = io_uring_get_sqe(&w->ring);
if (unlikely(!sqe)) {
io_uring_submit(&w->ring);
sqe = io_uring_get_sqe(&w->ring);
w->nr_pending_sqe = 1;
} else {
w->nr_pending_sqe++;
}
return sqe;
}
__hot static int server_prep_accept(struct worker *w)
{
struct server_ctx *sctx = w->sctx;
struct io_uring_sqe *sqe;
struct client *c;
uint32_t idx;
int ret;
ret = pop_stack(&sctx->clients_stk, &idx);
if (unlikely(ret))
return ret;
pthread_mutex_lock(&w->lock);
sqe = io_uring_get_sqe_no_fail(w);
c = &sctx->clients[idx];
c->tcp_fd = -2;
c->file_fd = -1;
c->got_file_info = false;
c->addrlen = sizeof(c->addr.addr4);
io_uring_prep_accept(sqe, sctx->tcp_fd,
(struct sockaddr *)&c->addr.addr4, &c->addrlen, 0);
c->ref_count = 1;
io_uring_sqe_set_data(sqe, c);
sqe->user_data |= EV_CQ_ACCEPT;
pthread_mutex_unlock(&w->lock);
sctx->accept_is_in_flight = true;
return 0;
}
__hot static int server_prep_recv_start(struct worker *w, struct client *c)
{
struct io_uring_sqe *sqe;
int fd = c->tcp_fd;
pthread_mutex_lock(&w->lock);
sqe = io_uring_get_sqe_no_fail(w);
io_uring_prep_recv(sqe, fd, &c->pkt[0], sizeof(c->pkt[0]), 0);
io_uring_sqe_set_data(sqe, c);
sqe->user_data |= EV_CQ_RECV;
c->ref_count++;
c->recv_in_flight = true;
io_uring_submit(&w->ring);
w->nr_pending_sqe = 0;
pthread_mutex_unlock(&w->lock);
return 0;
}
__hot static int server_handle_accept_event(struct worker *w,
struct io_uring_cqe *cqe,
struct client *client)
{
struct server_ctx *sctx = w->sctx;
int ret = cqe->res;
uint32_t idx;
w->sctx->accept_is_in_flight = false;
server_prep_accept(w);
if (unlikely(ret < 0)) {
pr_notice("Cannot accept a new client: %s\n", strerror(-ret));
return -ret;
}
client->tcp_fd = ret;
idx = sctx->next_worker++ % NR_WORKERS;
return server_prep_recv_start(&sctx->workers[idx], client);
}
__hot static int server_close_client(struct worker *w, struct client *c)
{
struct io_uring_sqe *sqe;
int ret = 0;
pr_debug("c->ref_count = %u", c->ref_count);
if (c->ref_count)
return 0;
pr_debug("close client: %u", c->idx);
pthread_mutex_lock(&w->lock);
sqe = io_uring_get_sqe_no_fail(w);
io_uring_prep_close(sqe, c->tcp_fd);
io_uring_sqe_set_data(sqe, NULL);
sqe->user_data |= EV_CQ_CLOSE;
if (c->file_fd >= 0) {
sqe = io_uring_get_sqe_no_fail(w);
io_uring_prep_close(sqe, c->file_fd);
io_uring_sqe_set_data(sqe, NULL);
sqe->user_data |= EV_CQ_CLOSE;
}
pthread_mutex_unlock(&w->lock);
c->tcp_fd = -1;
c->file_fd = -1;
c->pkt_len = 0;
c->recv_idx = 0;
c->pending_write_len = 0;
c->pending_file_size = 0;
c->got_file_info = false;
c->write_in_flight = false;
c->recv_in_flight = false;
ret = push_stack(&w->sctx->clients_stk, c->idx);
assert(ret == 0);
return ret;
}
static int __server_prep_recv(struct worker *w, struct client *c, void *buf,
size_t len, unsigned flags)
{
struct io_uring_sqe *sqe;
assert(c->tcp_fd >= 0);
pthread_mutex_lock(&w->lock);
sqe = io_uring_get_sqe_no_fail(w);
io_uring_prep_recv(sqe, c->tcp_fd, buf, len, flags);
io_uring_sqe_set_data(sqe, c);
sqe->user_data |= EV_CQ_RECV;
c->recv_in_flight = true;
c->ref_count++;
pthread_mutex_unlock(&w->lock);
pr_debug("prep recv: %zu", len);
return 0;
}
#define server_prep_recv(w, c, buf, len, f) \
({ \
if (buf == (void *)&c->pkt[0]) { \
pr_debug("rv idx 0"); \
} else if (buf == (void *)&c->pkt[1]) { \
pr_debug("rv idx 1"); \
} else { \
pr_debug("rv idx ???"); \
} \
__server_prep_recv(w, c, buf, len, f); \
})
static int __server_prep_write(struct worker *w, struct client *c,
const void *buf, size_t len)
{
struct io_uring_sqe *sqe;
assert(c->file_fd >= 0);
pthread_mutex_lock(&w->lock);
sqe = io_uring_get_sqe_no_fail(w);
io_uring_prep_write(sqe, c->file_fd, buf, len, -1);
io_uring_sqe_set_data(sqe, c);
sqe->user_data |= EV_CQ_WRITE;
c->write_in_flight = true;
c->pending_write_len = len;
c->ref_count++;
pthread_mutex_unlock(&w->lock);
pr_debug("prep write: %zu", len);
return 0;
}
#define server_prep_write(w, c, buf, len) \
({ \
if (buf == (void *)&c->pkt[0]) { \
pr_debug("wr idx 0"); \
} else if (buf == (void *)&c->pkt[1]) { \
pr_debug("wr idx 1"); \
} else { \
pr_debug("wr idx ???"); \
} \
__server_prep_write(w, c, buf, len); \
})
static int server_handle_short_recv_file_info(struct worker *w,
struct client *c)
{
size_t len;
char *buf;
buf = &c->pkt[0].raw_buf[c->pkt_len];
len = sizeof(c->pkt[0].raw_buf) - c->pkt_len;
return server_prep_recv(w, c, buf, len, 0);
}
__always_inline bool validate_file_name(const char *file_name)
{
/*
* Restrict file name that contains ".."
*/
return strstr(file_name, "..") == NULL;
}
static int server_handle_client_file_single_buffer(struct client *c)
{
size_t len;
char *buf;
int ret;
pr_debug("Received a small file "
"(pending_write_len = %llu; file_size = %llu",
(unsigned long long)c->pending_write_len,
(unsigned long long)c->file_size);
buf = &c->pkt[0].raw_buf[sizeof(c->pkt[0].packet)];
len = c->file_size;
/*
* Small file, not worth for io_uring request. It just
* complicates the short write handling if we do.
*/
while (1) {
ret = write(c->file_fd, buf, len);
if (unlikely(ret < 0)) {
ret = errno;
pr_notice("write(): %s", strerror(ret));
return -ret;
}
len -= (size_t)ret;
if (likely(!len))
break;
buf += (size_t)ret;
}
/*
* We are done, let the caller clean this up.
*/
return -ENETDOWN;
}
static int _server_handle_client_file_metadata(struct worker *w,
struct client *c)
{
char *buf;
size_t len;
if (c->pkt_len > sizeof(packet_t)) {
c->pending_write_len = c->pkt_len - sizeof(packet_t);
if (c->pending_write_len >= c->file_size)
/*
* Huh, a small file. We received it in a single
* recv() buffer!
*/
return server_handle_client_file_single_buffer(c);
/*
* We need to write() &c->pkt[0].raw_buf[sizeof(packet_t)]
* to the disk. So increment the recv_idx so that the next
* recv won't clobber the bytes to be written to disk.
*/
c->recv_idx++;
buf = &c->pkt[1].raw_buf[0];
assert(c->recv_idx == 1);
server_prep_write(w, c, &c->pkt[0].raw_buf[sizeof(packet_t)],
c->pkt_len - sizeof(packet_t));
} else {
buf = &c->pkt[0].raw_buf[0];
}
c->pkt_len = 0;
len = MIN_T(uint64_t, sizeof(c->pkt[0].raw_buf), c->pending_file_size);
return server_prep_recv(w, c, buf, len, MSG_WAITALL);
}
#define PATH_SIZE (sizeof(UPLOAD_DIR) + 32 + 255 + 1 + 1 + 1)
#define PATH_CP_OFFSET (sizeof(UPLOAD_DIR "/") - 1)
#define PATH_CP_SIZE (PATH_SIZE - PATH_CP_OFFSET)
static int server_handle_client_file_metadata(struct worker *w,
struct client *c)
{
char target_file[PATH_SIZE] = UPLOAD_DIR "/";
packet_t *pkt = &c->pkt[0].packet;
char *fname;
int fd;
assert(c->recv_idx == 0);
fname = pkt->file_name;
fname[sizeof(pkt->file_name) - 1] = '\0';
if (unlikely(!validate_file_name(fname))) {
pr_notice("Got an invalid file name from a client: %s", fname);
return -EPERM;
}
strncpy(&target_file[PATH_CP_OFFSET], fname, PATH_CP_SIZE);
fd = open(target_file, O_CREAT | O_TRUNC | O_WRONLY, S_IRUSR | S_IWUSR);
if (unlikely(fd < 0)) {
fd = errno;
pr_notice("open(%s): %s", target_file, strerror(fd));
return -fd;
}
c->file_fd = fd;
c->file_size = be64toh(pkt->file_size);
c->pending_file_size = c->file_size;
pr_debug("file = %zu", c->pending_file_size);
return _server_handle_client_file_metadata(w, c);
}
#undef PATH_SIZE
#undef PATH_CP_OFFSET
#undef PATH_CP_SIZE
static int server_handle_recv_need_file_info(struct worker *w, struct client *c)
{
if (unlikely(c->pkt_len < sizeof(packet_t)))
/*
* Aiee, we hit a short recv(). The file metadata
* is incomplete at this point. We can't do
* anything other than retry and wait until
* we receive at least sizeof(packet_t) bytes
* length.
*/
return server_handle_short_recv_file_info(w, c);
c->got_file_info = true;
return server_handle_client_file_metadata(w, c);
}
__hot static int server_handle_recv_file_content(struct worker *w,
struct client *c)
{
size_t len;
char *buf;
if (!c->pending_file_size)
return -ENETDOWN;
if (c->write_in_flight)
/*
* Uh oh, we can't use the next buffer slot because
* it's still used by the write operation.
*
* Do nothing. The write() event will take care
* of recv() reissue if needed.
*/
return 0;
buf = &c->pkt[c->recv_idx % 2].raw_buf[0];
len = MIN_T(uint64_t, c->pkt_len, c->pending_file_size);
server_prep_write(w, c, buf, len);
c->pkt_len = 0;
buf = &c->pkt[++c->recv_idx % 2].raw_buf[0];
len = MIN_T(uint64_t, sizeof(c->pkt[0].raw_buf), c->pending_file_size);
return server_prep_recv(w, c, buf, len, MSG_WAITALL);
}
__hot static int server_handle_recv_event(struct worker *w,
struct io_uring_cqe *cqe,
struct client *c)
{
int ret = cqe->res;
c->recv_in_flight = false;
if (unlikely(ret <= 0))
goto out_close;
c->pkt_len += (size_t)ret;
assert(c->pkt_len <= sizeof(union uni_pkt));
if (!c->got_file_info)
ret = server_handle_recv_need_file_info(w, c);
else
ret = server_handle_recv_file_content(w, c);
if (likely(!ret))
return 0;
out_close:
return server_close_client(w, c);
}
static int server_handle_short_write_event(struct worker *w, struct client *c,
size_t wr_ret)
{
char *buf = c->pkt[(c->recv_idx + 1) % 2].raw_buf;
size_t move_len;
assert(c->pending_file_size > 0);
pr_notice("Hit a short write");
/*
* Don't be a hero, just use memmove(). OK?
* Your life is already complicated enough. Idiot!
*/
c->pending_write_len -= wr_ret;
move_len = c->pending_write_len;
memmove(&buf[0], &buf[wr_ret], move_len);
return server_prep_write(w, c, buf, move_len);
}
__hot static int server_handle_write_event(struct worker *w,
struct io_uring_cqe *cqe,
struct client *c)
{
int ret = cqe->res;
size_t wr_ret;
size_t len;
char *buf;
c->write_in_flight = false;
if (unlikely(ret <= 0)) {
pr_notice("write() CQE: %s", strerror(-ret));
goto out_close;
}
assert(c->pending_write_len <= sizeof(union uni_pkt));
assert(c->pending_file_size > 0);
wr_ret = (size_t)ret;
pr_debug("c->pending_file_size = %zu", c->pending_file_size);
c->pending_file_size -= wr_ret;
pr_debug("c->pending_file_size = %zu", c->pending_file_size);
if (unlikely(c->pending_write_len > wr_ret)) {
/*
* Oh shoot, we hit a short write!
*/
return server_handle_short_write_event(w, c, wr_ret);
} else {
c->pending_write_len -= wr_ret;
assert(c->pending_write_len == 0);
}
if (c->recv_in_flight) {
/*
* Uh oh, we haven't received the buffer to be
* written into the disk.
*
* Do nothing. The recv() event will take care
* of write() reissue if needed.
*/
c->write_in_flight = false;
return 0;
}
if (!c->pending_file_size)
goto out_close;
buf = &c->pkt[c->recv_idx % 2].raw_buf[0];
len = MIN_T(uint64_t, c->pkt_len, c->pending_file_size);
server_prep_write(w, c, buf, len);
c->pkt_len = 0;
buf = &c->pkt[++c->recv_idx % 2].raw_buf[0];
len = MIN_T(uint64_t, sizeof(c->pkt[0].raw_buf), c->pending_file_size);
return server_prep_recv(w, c, buf, len, MSG_WAITALL);
out_close:
return server_close_client(w, c);
}
__hot static int server_handle_event(struct worker *w, struct io_uring_cqe *cqe)
{
struct client *c;
uint32_t event;
if (!cqe->user_data)
return 0;
event = (uint32_t)(cqe->user_data & EV_CQ_ALL);
cqe->user_data &= ~EV_CQ_ALL;
assert(cqe->user_data % SALIGNMENT == 0);
c = io_uring_cqe_get_data(cqe);
if (c)
c->ref_count--;
switch (event) {
case EV_CQ_ACCEPT:
pr_debug("Got EV_CQ_ACCEPT = %d", cqe->res);
assert(c->tcp_fd == -2);
assert(c->file_fd == -1);
assert(!c->got_file_info);
assert(c->pkt_len == 0);
return server_handle_accept_event(w, cqe, c);
case EV_CQ_RECV:
pr_debug("Got EV_CQ_RECV = %d", cqe->res);
assert(c->recv_in_flight);
return server_handle_recv_event(w, cqe, c);
case EV_CQ_WRITE:
pr_debug("Got EV_CQ_WRITE = %d", cqe->res);
assert(c->write_in_flight);
return server_handle_write_event(w, cqe, c);
case EV_CQ_CLOSE:
pr_debug("Got EV_CQ_CLOSE = %d", cqe->res);
return 0;
case EV_CQ_LAST:
default:
pr_notice("Invalid event: %u", event);
abort();
__builtin_unreachable();
}
}
__hot static int server_handle_event_loop(struct worker *w)
{
static const uint32_t max_iter = 32;
struct io_uring *ring = &w->ring;
struct io_uring_cqe *cqe;
uint32_t head;
uint32_t i;
int ret;
if (*ring->cq.khead == io_uring_smp_load_acquire((ring)->cq.ktail)) {
ret = io_uring_wait_cqe(ring, &cqe);
if (unlikely(ret < 0)) {
if (likely(ret == -EINTR))
return 0;
return ret;
}
}
i = 0;
ret = 0;
io_uring_for_each_cqe(ring, head, cqe) {
i++;
ret = server_handle_event(w, cqe);
if (unlikely(ret))
break;
if (i >= max_iter)
break;
}
io_uring_cq_advance(ring, i);
return ret;
}
__hot static int server_handle_pending_sqes(struct worker *w)
__must_hold(&w->lock)
{
int ret;
if (w->idx == 0) {
struct server_ctx *sctx = w->sctx;
if (unlikely(!sctx->accept_is_in_flight))
server_prep_accept(&sctx->workers[0]);
}
if (!w->nr_pending_sqe)
return 0;
ret = io_uring_submit(&w->ring);
if (unlikely(ret < 0)) {
pr_notice("io_uring_submit(): %s", strerror(-ret));
return ret;
}
assert(w->nr_pending_sqe == (uint32_t)ret);
w->nr_pending_sqe = 0;
return 0;
}
noinline static void *server_worker_entry(void *arg)
{
volatile bool *stop = g_stop;
struct worker *w = arg;
int ret;
g_thread_idx = w->idx;
pr_notice("Worker %u is ready!", w->idx);
/*
* The main thread is responsible to accept a new connection.
*/
if (w->idx == 0)
server_prep_accept(w);
while (!*stop) {
pthread_mutex_lock(&w->lock);
ret = server_handle_pending_sqes(w);
pthread_mutex_unlock(&w->lock);
if (unlikely(ret))
break;
ret = server_handle_event_loop(w);
if (unlikely(ret))
break;
}
if (ret)
*g_stop = true;
pr_notice("Worker %u is exiting...", w->idx);
return NULL;
}
__cold static int init_client_file_context(struct client_ctx *cctx,
char *argv[])
{
struct stat st;
void *map;
int ret;
int fd;
fd = open(argv[4], O_RDONLY);
if (unlikely(fd < 0)) {
ret = errno;
pr_notice("open(%s): %s", argv[4], strerror(ret));
return -ret;
}
ret = fstat(fd, &st);
if (unlikely(ret < 0)) {
ret = errno;
close(fd);
pr_notice("open(): %s", strerror(ret));
return -ret;
}
map = mmap(NULL, st.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
if (unlikely(map == MAP_FAILED)) {
ret = errno;
close(fd);
pr_notice("mmap(): %s", strerror(ret));
return -ret;
}
close(fd);
cctx->pkt_hdr.file_size = htobe64((uint64_t)st.st_size);
cctx->file_size = (uint64_t)st.st_size;
strncpy(cctx->pkt_hdr.file_name, basename(argv[4]),
sizeof(cctx->pkt_hdr.file_name));
cctx->pkt_hdr.file_name[sizeof(cctx->pkt_hdr.file_name) - 1] = '\0';
cctx->pkt_hdr.file_name_len = strlen(cctx->pkt_hdr.file_name);
cctx->map = map;
pr_notice("File size = %llu", (unsigned long long)cctx->file_size);
return 0;
}
__cold static int init_client_ring(struct client_ctx *cctx)
{
const uint32_t slist[] = {
16384,
8192,
4096,
2048,
1024,
512
};
size_t i = 0;
int ret;
for (i = 0; i < (sizeof(slist) / sizeof(slist[0])); i++) {
pr_notice("Trying %u ring enties", slist[i]);
ret = io_uring_queue_init(slist[i], &cctx->ring, 0);
if (ret >= 0)
break;
}
if (unlikely(ret < 0)) {
pr_notice("io_uring_queue_init(): %s", strerror(-ret));
return ret;
}
cctx->need_ring_cleanup = true;
return 0;
}
__cold static int init_client_socket(struct client_ctx *cctx, char *argv[])
{
struct sockaddr_in saddr;
int ret;
int fd;
ret = atoi(argv[3]);
if (unlikely(ret < 1 || ret > 65535)) {
pr_notice("The port must be in range [1, 65535]");
return -EINVAL;
}
memset(&saddr, 0, sizeof(saddr));
saddr.sin_family = AF_INET;
saddr.sin_port = htons((uint16_t)ret);
ret = inet_pton(AF_INET, argv[2], &saddr.sin_addr);
if (unlikely(ret <= 0)) {
if (ret < 0)
return -errno;
pr_notice("Invalid IPv4 address: %s\n", argv[2]);
return -EINVAL;
}
fd = socket(AF_INET, SOCK_STREAM, 0);
if (unlikely(fd < 0)) {
ret = errno;
pr_notice("socket(): %s", strerror(ret));
return -ret;
}
ret = 1024 * 1024 * 1024;
ret = setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &ret, sizeof(ret));
pr_notice("ret = %d", ret);
pr_notice("Connecting to %s:%s...", argv[2], argv[3]);
ret = connect(fd, (struct sockaddr *)&saddr, sizeof(saddr));
if (unlikely(ret < 0)) {
ret = errno;
pr_notice("connect(): %s", strerror(ret));
goto err;
}
pr_notice("Connected to %s:%s!", argv[2], argv[3]);
cctx->tcp_fd = fd;
return 0;
err:
close(fd);
return -ret;
}
#define CLIENT_BUFFER_SIZE 65536
noinline __hot static int client_send_file_to_server(struct client_ctx *cctx)
{
uint64_t file_size = cctx->file_size;
uint64_t file_size_bias = file_size;
uint64_t last_madvise_pos = 0;
uint64_t map_pos = 0;
uint64_t map_pos_bias = map_pos;
struct io_uring_sqe *sqe;
struct io_uring_cqe *cqe;
int fd = cctx->tcp_fd;
size_t send_len;
uint32_t head;
uint32_t i, j = 0;
pr_notice("Sending file...");
sqe = io_uring_get_sqe(&cctx->ring);
io_uring_prep_madvise(sqe, cctx->map, file_size, MADV_SEQUENTIAL);
sqe->flags |= IOSQE_IO_LINK;
sqe->user_data = 0;
pr_notice("MADV_SEQUENTIAL &map[0] %#llx",
(unsigned long long)file_size);
sqe = io_uring_get_sqe(&cctx->ring);
io_uring_prep_send(sqe, fd, &cctx->pkt_hdr, sizeof(cctx->pkt_hdr), 0);
sqe->flags |= IOSQE_IO_LINK;
sqe->user_data = 0;
repeat:
i = 0;
while (1) {
if (!file_size_bias)
break;
sqe = io_uring_get_sqe(&cctx->ring);
if (!sqe)
break;
send_len = MIN_T(uint64_t, CLIENT_BUFFER_SIZE, file_size_bias);
file_size_bias -= send_len;
io_uring_prep_send(sqe, fd, &cctx->map[map_pos_bias], send_len,
0);
sqe->flags |= IOSQE_IO_LINK;
sqe->user_data = send_len;
map_pos_bias += send_len;
i++;
}
io_uring_submit(&cctx->ring);
if (map_pos > 0 && (++j % 8 == 0)) {
char *buf = &cctx->map[last_madvise_pos];
size_t len = map_pos - last_madvise_pos;
sqe = io_uring_get_sqe(&cctx->ring);
io_uring_prep_madvise(sqe, buf, len, MADV_DONTNEED);
sqe->flags |= IOSQE_ASYNC;
sqe->user_data = 0;
pr_notice("MADV_DONTNEED &map[%#llx] len=%#llx",
(unsigned long long)last_madvise_pos,
(unsigned long long)len);
last_madvise_pos = map_pos;
}
io_uring_wait_cqe_nr(&cctx->ring, &cqe, i);
i = 0;
io_uring_for_each_cqe(&cctx->ring, head, cqe) {
int ret = cqe->res;
size_t send_ret = (size_t)ret;
if (unlikely(ret < 0)) {
pr_notice("CQE send(): %s", strerror(-ret));
return -ret;
}
if (cqe->user_data && (int)cqe->user_data != ret) {
pr_notice("cqe->user_data != ret -- %d %d",
(int)cqe->user_data, ret);
return -ECANCELED;
}
pr_debug("cqe = %d; udata = %llu", ret, cqe->user_data);
i++;
if (likely(cqe->user_data)) {
file_size -= send_ret;
map_pos += send_ret;
}
}
io_uring_cq_advance(&cctx->ring, i);
file_size_bias = file_size;
map_pos_bias = map_pos;
pr_debug("file size = %lu", file_size);
if (likely(file_size > 0))
goto repeat;
return 0;
}
__cold static void destroy_client_context(struct client_ctx *cctx)
{
if (cctx->tcp_fd != -1)
close(cctx->tcp_fd);
if (cctx->map)
munmap(cctx->map, cctx->file_size);
if (cctx->need_ring_cleanup)
io_uring_queue_exit(&cctx->ring);
}
static int _run_client(char *argv[])
{
struct client_ctx cctx;
int ret;
memset(&cctx, 0, sizeof(cctx));
cctx.tcp_fd = -1;
ret = init_client_file_context(&cctx, argv);
if (unlikely(ret < 0)) {
pr_notice("init_client_file_context(): %s", strerror(-ret));
goto out;
}
ret = init_client_ring(&cctx);
if (unlikely(ret < 0)) {
pr_notice("init_client_ring(): %s", strerror(-ret));
goto out;
}
ret = init_client_socket(&cctx, argv);
if (unlikely(ret < 0)) {
pr_notice("init_client_socket(): %s", strerror(-ret));
goto out;
}
ret = client_send_file_to_server(&cctx);
out:
destroy_client_context(&cctx);
return ret;
}
noinline static int run_server(char *argv[])
{
int ret;
ret = _run_server(argv);
if (ret < 0)
ret = -ret;
return ret;
}
noinline static int run_client(char *argv[])
{
int ret;
ret = _run_client(argv);
if (ret < 0)
ret = -ret;
return ret;
}
__cold static void show_usage(const char *app)
{
printf("Usage:\n");
printf(" %s server [bind_addr] [bind_port]\n", app);
printf(" %s client [server_addr] [server_port] [file_name]\n", app);
}
int main(int argc, char *argv[])
{
assert(SALIGNMENT >= EV_CQ_LAST);
setvbuf(stdout, NULL, _IOLBF, 1024);
setvbuf(stderr, NULL, _IOLBF, 1024);
if (argc != 4 && argc != 5)
goto out;
if (argc == 4 && !strcmp(argv[1], "server"))
return run_server(argv);
if (argc == 5 && !strcmp(argv[1], "client"))
return run_client(argv);
out:
show_usage(argv[0]);
return EINVAL;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment