Skip to content

Instantly share code, notes, and snippets.

@ammarfaizi2
Last active April 7, 2023 02:08
Show Gist options
  • Save ammarfaizi2/5eefcaada3e2e192d0b7bf9cba0376f8 to your computer and use it in GitHub Desktop.
Save ammarfaizi2/5eefcaada3e2e192d0b7bf9cba0376f8 to your computer and use it in GitHub Desktop.
// SPDX-License-Identifier: GPL-2.0-only
/*
* Copyright (C) 2023 Ammar Faizi <[email protected]>
*/
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include <stdio.h>
#include <fcntl.h>
#include <errno.h>
#include <string.h>
#include <assert.h>
#include <stdlib.h>
#include <sys/mman.h>
#include <liburing.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <stdatomic.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <poll.h>
#include <libgen.h>
#include <signal.h>
#define BUFFER_SIZE 8192
#define NR_CLIENTS 512
#define NR_RING_SQES 512
#define UPLOAD_DIR "uploads"
#ifndef likely
#define likely(x) __builtin_expect(!!(x), 1)
#endif
#ifndef unlikely
#define unlikely(x) __builtin_expect(!!(x), 0)
#endif
#ifndef noinline
#define noinline __attribute__((__noinline__))
#endif
#ifndef __hot
#define __hot __attribute__((__hot__))
#endif
#ifndef __cold
#define __cold __attribute__((__cold__))
#endif
#ifndef __aligned
#define __aligned(x) __attribute__((__aligned__(x)))
#endif
struct packet {
uint64_t file_size;
uint8_t fname_len;
char fname[256];
} __attribute__((__packed__));
struct client {
int sock_fd;
int file_fd;
uint32_t idx;
uint64_t file_size;
uint64_t file_pos;
struct sockaddr_storage src_addr;
socklen_t src_addr_len;
size_t len;
char str_addr[INET6_ADDRSTRLEN];
uint16_t port;
uint8_t buf_idx;
bool used;
union {
struct packet pkt;
uint8_t buf[BUFFER_SIZE];
} __aligned(4096);
};
struct server_ctx {
int sock_fd;
struct io_uring ring;
struct client *clients;
unsigned pending_sqes;
bool need_rearm_accept_on_close;
};
enum {
EV_RECV = (0x1ull << 48ull),
EV_WRITE = (0x2ull << 48ull),
};
enum {
GET_EV_MASK = (0xffffull << 48ull),
CLEAR_EV_MASK = ~GET_EV_MASK,
};
static volatile bool g_stop;
__cold
static int str_to_sockaddr(const char *addr, const char *port,
struct sockaddr_storage *ss)
{
struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)ss;
struct sockaddr_in *sin = (struct sockaddr_in *)ss;
int ret;
ret = inet_pton(AF_INET6, addr, &sin6->sin6_addr);
if (ret == 1) {
sin6->sin6_family = AF_INET6;
sin6->sin6_port = htons(atoi(port));
return 0;
}
ret = inet_pton(AF_INET, addr, &sin->sin_addr);
if (ret == 1) {
sin->sin_family = AF_INET;
sin->sin_port = htons(atoi(port));
return 0;
}
return -EINVAL;
}
__cold
static void handle_interrupt(int sig)
{
putchar('\n');
g_stop = true;
(void)sig;
}
__cold
static int init_sighandler(void)
{
struct sigaction sa = { .sa_handler = handle_interrupt };
int ret;
ret = sigaction(SIGINT, &sa, NULL);
if (ret)
goto out_err;
ret = sigaction(SIGTERM, &sa, NULL);
if (ret)
goto out_err;
ret = sigaction(SIGHUP, &sa, NULL);
if (ret)
goto out_err;
sa.sa_handler = SIG_IGN;
ret = sigaction(SIGPIPE, &sa, NULL);
if (ret)
goto out_err;
return 0;
out_err:
ret = -errno;
perror("sigaction");
return ret;
}
__cold
static int init_socket(int *sock_fd, const char *addr, const char *port)
{
struct sockaddr_storage ss;
int ret, fd, val;
memset(&ss, 0, sizeof(ss));
ret = str_to_sockaddr(addr, port, &ss);
if (ret)
return ret;
fd = socket(ss.ss_family, SOCK_STREAM, 0);
if (fd < 0) {
ret = -errno;
perror("socket");
return ret;
}
ret = bind(fd, (struct sockaddr *)&ss, sizeof(ss));
if (ret) {
ret = -errno;
perror("bind");
goto out_err;
}
ret = listen(fd, 128);
if (ret) {
ret = -errno;
perror("listen");
goto out_err;
}
printf("Listening on %s:%s\n", addr, port);
*sock_fd = fd;
return 0;
out_err:
close(fd);
return ret;
}
__cold
static int init_io_uring(struct io_uring *ring)
{
unsigned int v[2] = { 1, 1 };
unsigned int flags;
int ret;
flags = IORING_SETUP_SINGLE_ISSUER | IORING_SETUP_DEFER_TASKRUN;
retry:
ret = io_uring_queue_init(NR_RING_SQES, ring, flags);
if (ret < 0 && flags) {
flags = 0;
goto retry;
}
if (ret < 0) {
errno = -ret;
perror("io_uring_queue_init");
return ret;
}
io_uring_register_iowq_max_workers(ring, v);
return 0;
}
__cold
static int init_client(struct server_ctx *ctx)
{
struct client *clients = NULL;
size_t len;
size_t i;
int ret;
len = sizeof(*clients) * NR_CLIENTS;
ret = posix_memalign((void **)&clients, 4096, len);
if (ret < 0 || !clients)
return -ENOMEM;
for (i = 0; i < NR_CLIENTS; i++) {
clients[i].sock_fd = -1;
clients[i].file_fd = -1;
clients[i].idx = i;
clients[i].len = 0;
clients[i].used = false;
}
ctx->clients = clients;
return 0;
}
__hot
static struct io_uring_sqe *io_uring_get_sqe_nf(struct server_ctx *ctx)
{
struct io_uring *ring = &ctx->ring;
struct io_uring_sqe *sqe;
sqe = io_uring_get_sqe(ring);
if (unlikely(!sqe)) {
io_uring_submit(ring);
sqe = io_uring_get_sqe(ring);
assert(sqe);
ctx->pending_sqes = 0;
}
ctx->pending_sqes++;
return sqe;
}
__hot
static struct client *server_get_client_slot(struct server_ctx *ctx)
{
struct client *clients = ctx->clients;
size_t i;
for (i = 0; i < NR_CLIENTS; i++) {
if (!clients[i].used) {
clients[i].used = true;
return &clients[i];
}
}
return NULL;
}
__hot
static void server_close_fd(struct server_ctx *ctx, int fd)
{
struct io_uring_sqe *sqe;
sqe = io_uring_get_sqe_nf(ctx);
io_uring_prep_close(sqe, fd);
io_uring_sqe_set_data64(sqe, 0);
}
__hot
static int server_prep_accept(struct server_ctx *ctx)
{
struct io_uring_sqe *sqe;
struct client *client;
struct sockaddr *addr;
socklen_t *len;
client = server_get_client_slot(ctx);
if (unlikely(!client)) {
fprintf(stderr, "No free client slots, cannot accept more clients...\n");
ctx->need_rearm_accept_on_close = true;
return -ENFILE;
}
sqe = io_uring_get_sqe_nf(ctx);
addr = (struct sockaddr *)&client->src_addr;
len = &client->src_addr_len;
*len = sizeof(client->src_addr);
io_uring_prep_accept(sqe, ctx->sock_fd, addr, len, 0);
io_uring_sqe_set_data(sqe, client);
sqe->flags |= IOSQE_ASYNC;
return 0;
}
__hot
static void server_put_client_slot(struct server_ctx *ctx,
struct client *client)
{
if (client->sock_fd >= 0) {
printf("Closing a connection (fd=%d) from %s:%hu (file_pos=%llu)\n",
client->sock_fd, client->str_addr, client->port,
(unsigned long long)client->file_pos);
server_close_fd(ctx, client->sock_fd);
}
if (client->file_fd >= 0)
server_close_fd(ctx, client->file_fd);
client->len = 0;
client->sock_fd = -1;
client->file_fd = -1;
client->used = false;
if (unlikely(ctx->need_rearm_accept_on_close)) {
printf("Rearming accept...\n");
ctx->need_rearm_accept_on_close = false;
server_prep_accept(ctx);
}
}
__hot
static void server_prep_recv(struct server_ctx *ctx, struct client *client)
{
struct io_uring_sqe *sqe;
uint8_t *buf;
size_t len;
sqe = io_uring_get_sqe_nf(ctx);
len = sizeof(client->buf) - client->len;
buf = client->buf + client->len;
io_uring_prep_recv(sqe, client->sock_fd, buf, len, 0);
io_uring_sqe_set_data(sqe, client);
sqe->user_data |= EV_RECV;
}
__hot
static void server_prep_write(struct server_ctx *ctx, struct client *client,
unsigned flags)
{
struct io_uring_sqe *sqe;
uint8_t *buf;
size_t len;
sqe = io_uring_get_sqe_nf(ctx);
len = client->len;
buf = client->buf;
io_uring_prep_write(sqe, client->file_fd, buf, len, -1);
io_uring_sqe_set_data(sqe, client);
sqe->user_data |= EV_WRITE;
sqe->flags |= flags;
}
__hot
static void server_print_accept_info(struct client *client)
{
struct sockaddr_storage *addr = &client->src_addr;
struct sockaddr_in6 *in6 = (struct sockaddr_in6 *)addr;
struct sockaddr_in *in = (struct sockaddr_in *)addr;
char *buf = client->str_addr;
if (addr->ss_family == AF_INET) {
inet_ntop(AF_INET, &in->sin_addr, buf, sizeof(client->str_addr));
client->port = ntohs(in->sin_port);
} else if (addr->ss_family == AF_INET6) {
inet_ntop(AF_INET6, &in6->sin6_addr, buf, sizeof(client->str_addr));
client->port = ntohs(in6->sin6_port);
} else {
strcpy(buf, "unknown");
}
printf("Accepted connection (fd=%d) from %s:%hu\n", client->sock_fd,
buf, client->port);
}
__hot
static int server_handle_accept(struct server_ctx *ctx, struct client *client,
int res)
{
if (unlikely(res < 0)) {
errno = -res;
perror("CQE accept");
server_put_client_slot(ctx, client);
return res;
}
client->sock_fd = res;
server_print_accept_info(client);
server_prep_accept(ctx);
server_prep_recv(ctx, client);
return 0;
}
__hot
static int server_handle_recv_open_file(struct server_ctx *ctx,
struct client *client)
{
char path[8192];
int ret, fd;
/*
* Did we hit a short recv()? If so, we need to
* wait for more data.
*/
if (unlikely(client->len < sizeof(client->pkt)))
return 0;
/*
* Don't allow any funny business.
*/
if (unlikely(client->pkt.fname[0] == '/' ||
strstr(client->pkt.fname, ".."))) {
fprintf(stderr, "Invalid filename: %s from %s:%hu\n",
client->pkt.fname, client->str_addr, client->port);
return -EBADMSG;
}
client->file_size = be64toh(client->pkt.file_size);
printf("Receiving file %s, size %llu from %s:%hu\n",
client->pkt.fname, (unsigned long long)client->file_size,
client->str_addr, client->port);
snprintf(path, sizeof(path), UPLOAD_DIR "/%s", client->pkt.fname);
fd = open(path, O_WRONLY | O_CREAT | O_TRUNC, 0644);
if (unlikely(fd < 0)) {
ret = -errno;
perror("open");
return ret;
}
ret = ftruncate(fd, client->file_size);
if (unlikely(ret < 0)) {
ret = -errno;
perror("ftruncate");
close(fd);
return ret;
}
ret = posix_fadvise64(fd, 0, client->file_size, POSIX_FADV_SEQUENTIAL);
if (unlikely(ret < 0)) {
ret = -errno;
perror("posix_fadvise64");
close(fd);
return ret;
}
client->file_pos = 0;
client->file_fd = fd;
if (client->len > sizeof(client->pkt)) {
unsigned sq_left = NR_RING_SQES - ctx->pending_sqes;
size_t len = client->len - sizeof(client->pkt);
uint8_t *src = client->buf + sizeof(client->pkt);
uint8_t *dst = client->buf;
memmove(dst, src, len);
client->len = len;
/*
* We have some data in the buffer already. Write it
* out first.
*/
if (unlikely(sq_left < 2)) {
io_uring_submit(&ctx->ring);
ctx->pending_sqes = 0;
}
server_prep_write(ctx, client, IOSQE_IO_LINK);
}
client->len = 0;
server_prep_recv(ctx, client);
return 0;
}
__hot
static int server_handle_recv_write_file(struct server_ctx *ctx,
struct client *client)
{
unsigned sq_left = NR_RING_SQES - ctx->pending_sqes;
/*
* Must have at least two SQEs available, one for the write, and
* one for the next recv. Do not split up the write and recv
* into separate submissions. Otherwise, the recv might overwrite
* the write's buffer.
*/
if (unlikely(sq_left < 2)) {
io_uring_submit(&ctx->ring);
ctx->pending_sqes = 0;
}
server_prep_write(ctx, client, IOSQE_IO_LINK);
client->len = 0;
server_prep_recv(ctx, client);
return 0;
}
__hot
static int server_handle_recv(struct server_ctx *ctx, struct client *client,
int res)
{
int ret;
if (unlikely(res <= 0)) {
if (res < 0) {
errno = -res;
perror("CQE recv");
}
server_put_client_slot(ctx, client);
return 0;
}
client->len += (size_t)res;
if (unlikely(client->file_fd < 0))
ret = server_handle_recv_open_file(ctx, client);
else
ret = server_handle_recv_write_file(ctx, client);
if (unlikely(ret)) {
server_put_client_slot(ctx, client);
return 0;
}
return 0;
}
__hot
static int server_handle_write(struct server_ctx *ctx, struct client *client,
int res)
{
if (unlikely(res < 0)) {
errno = -res;
perror("CQE write");
shutdown(client->sock_fd, SHUT_RDWR);
server_put_client_slot(ctx, client);
return 0;
}
client->file_pos += (size_t)res;
return 0;
}
__hot
static int server_handle_events(struct server_ctx *ctx)
{
struct io_uring *ring = &ctx->ring;
struct io_uring_cqe *cqe;
struct client *client;
unsigned head, i;
uint64_t ev;
int ret = 0;
i = 0;
io_uring_for_each_cqe(ring, head, cqe) {
i++;
/*
* Skip close events.
*/
if (unlikely(!cqe->user_data))
continue;
ev = cqe->user_data & GET_EV_MASK;
cqe->user_data &= CLEAR_EV_MASK;
client = io_uring_cqe_get_data(cqe);
if (unlikely(client->sock_fd < 0)) {
ret = server_handle_accept(ctx, client, cqe->res);
} else if (likely(ev == EV_RECV)) {
ret = server_handle_recv(ctx, client, cqe->res);
} else if (likely(ev == EV_WRITE)) {
ret = server_handle_write(ctx, client, cqe->res);
} else {
fprintf(stderr, "Unknown event %llu\n",
(unsigned long long)ev);
abort();
}
if (unlikely(ret))
break;
}
io_uring_cq_advance(ring, i);
return ret;
}
__hot
static int server_close_all_clients(struct server_ctx *ctx)
{
struct client *clients = ctx->clients;
size_t i;
for (i = 0; i < NR_CLIENTS; i++) {
if (clients[i].used)
server_put_client_slot(ctx, &clients[i]);
}
return 0;
}
__hot noinline
static int server_run_event_loop(struct server_ctx *ctx)
{
int ret = 0;
server_prep_accept(ctx);
while (!g_stop) {
ret = io_uring_submit_and_wait(&ctx->ring, 1);
if (unlikely(ret < 0)) {
errno = -ret;
perror("io_uring_submit_and_wait");
break;
}
ctx->pending_sqes = 0;
ret = server_handle_events(ctx);
if (unlikely(ret))
break;
}
server_close_all_clients(ctx);
return ret;
}
static int run_server(const char *addr, const char *port)
{
struct server_ctx ctx;
int ret;
mkdir(UPLOAD_DIR, 0755);
ctx.pending_sqes = 0;
ret = init_sighandler();
if (ret)
return ret;
ret = init_socket(&ctx.sock_fd, addr, port);
if (ret)
return ret;
ret = init_io_uring(&ctx.ring);
if (ret)
goto out_sock;
ret = init_client(&ctx);
if (ret)
goto out_ring;
ret = server_run_event_loop(&ctx);
free(ctx.clients);
out_ring:
io_uring_queue_exit(&ctx.ring);
out_sock:
close(ctx.sock_fd);
return 0;
}
__cold
static int init_client_file(const char *file, uint64_t *file_size)
{
struct stat st;
int ret, fd;
fd = open(file, O_RDONLY);
if (fd < 0) {
ret = -errno;
perror("open");
return ret;
}
ret = fstat(fd, &st);
if (ret < 0) {
ret = -errno;
perror("fstat");
close(fd);
return ret;
}
*file_size = (uint64_t)st.st_size;
ret = posix_fadvise64(fd, 0, *file_size, POSIX_FADV_SEQUENTIAL);
if (ret < 0) {
ret = -errno;
perror("posix_fadvise64");
close(fd);
return ret;
}
return fd;
}
__cold
static int init_client_socket(const char *addr, const char *port)
{
struct sockaddr_storage ss;
struct pollfd pfd;
int ret, fd, val;
socklen_t slen;
ret = str_to_sockaddr(addr, port, &ss);
if (ret < 0) {
fprintf(stderr, "Invalid address %s\n", addr);
return ret;
}
fd = socket(ss.ss_family, SOCK_STREAM | SOCK_NONBLOCK, 0);
if (fd < 0) {
ret = -errno;
perror("socket");
return ret;
}
ret = connect(fd, (struct sockaddr *) &ss, sizeof(ss));
if (ret < 0 && errno != EINPROGRESS) {
ret = -errno;
perror("connect");
close(fd);
return ret;
}
printf("Connecting to %s:%s...\n", addr, port);
pfd.fd = fd;
pfd.events = POLLOUT;
ret = poll(&pfd, 1, -1);
if (ret < 0) {
ret = -errno;
perror("poll");
close(fd);
return ret;
}
if (!(pfd.revents & POLLOUT)) {
fprintf(stderr, "Connect failed\n");
close(fd);
return -1;
}
val = 0;
slen = sizeof(val);
ret = getsockopt(fd, SOL_SOCKET, SO_ERROR, &val, &slen);
if (ret < 0) {
ret = -errno;
perror("getsockopt");
close(fd);
return ret;
}
if (val) {
errno = val;
perror("connect");
close(fd);
return -val;
}
printf("Connected to %s:%s!\n", addr, port);
return fd;
}
__hot
static int prep_send_file_info(struct packet *pkt, const char *file_name,
uint64_t file_size)
{
size_t len = strlen(file_name);
if (len > (sizeof(pkt->fname) - 1u)) {
fprintf(stderr, "Filename is too long (%zu)!\n", len);
return -EINVAL;
}
pkt->file_size = htobe64(file_size);
pkt->fname_len = (uint8_t)len;
memcpy(pkt->fname, file_name, pkt->fname_len);
pkt->fname[pkt->fname_len] = '\0';
return 0;
}
__hot
static int wait_for_socket_writable(int sock_fd)
{
struct pollfd pfd;
int ret;
pfd.fd = sock_fd;
pfd.events = POLLOUT;
ret = poll(&pfd, 1, -1);
if (ret < 0) {
ret = -errno;
perror("poll");
return ret;
}
return 0;
}
__hot noinline
static int client_send_file(int sock_fd, int file_fd, const char *file_name,
uint64_t file_size)
{
union {
struct packet pkt;
char __buf[BUFFER_SIZE];
} b;
ssize_t read_ret;
ssize_t send_ret;
size_t send_len;
size_t send_pos;
char *buf;
int ret;
printf("Sending file %s (%" PRIu64 " bytes)...\n", file_name, file_size);
ret = prep_send_file_info(&b.pkt, file_name, file_size);
if (ret < 0)
return ret;
buf = b.__buf + sizeof(b.pkt);
send_len = sizeof(b.pkt);
read_ret = read(file_fd, buf, BUFFER_SIZE - sizeof(b.pkt));
if (read_ret < 0) {
ret = -errno;
perror("read");
return ret;
}
send_pos = 0;
send_len += (size_t)read_ret;
buf = b.__buf;
repeat:
send_ret = send(sock_fd, &buf[send_pos], send_len, 0);
if (unlikely(send_ret <= 0)) {
if (!send_ret) {
printf("Disconnected!\n");
return 0;
}
ret = -errno;
if (ret == -EAGAIN) {
wait_for_socket_writable(sock_fd);
goto repeat;
}
perror("send");
return ret;
}
send_len -= (size_t)send_ret;
if (unlikely(send_len)) {
/*
* Hit short send(), try again.
*/
send_pos += (size_t)send_ret;
goto repeat;
}
read_ret = read(file_fd, buf, BUFFER_SIZE);
if (unlikely(read_ret <= 0)) {
if (!read_ret) {
printf("File sent!\n");
return 0;
}
ret = -errno;
perror("read");
return ret;
}
send_len = (size_t)read_ret;
send_pos = 0;
goto repeat;
}
static int run_client(const char *addr, const char *port, const char *file)
{
uint64_t file_size = 0;
int sock_fd, file_fd;
char *orig_fname;
char *file_name;
int ret;
orig_fname = strdup(file);
if (!orig_fname)
return -ENOMEM;
file_fd = init_client_file(file, &file_size);
if (file_fd < 0) {
free(orig_fname);
return file_fd;
}
sock_fd = init_client_socket(addr, port);
if (sock_fd < 0) {
close(file_fd);
free(orig_fname);
return sock_fd;
}
file_name = basename(orig_fname);
ret = client_send_file(sock_fd, file_fd, file_name, file_size);
close(sock_fd);
close(file_fd);
free(orig_fname);
return ret;
}
__cold
static void show_help(const char *app)
{
printf("Usage:\n");
printf("\t%s client [server_addr] [server_port] [file_path]\n", app);
printf("\t%s server [bind_addr] [bind_port]\n\n", app);
}
int main(int argc, char *argv[])
{
if (argc == 4 && !strcmp(argv[1], "server"))
return -run_server(argv[2], argv[3]);
if (argc == 5 && !strcmp(argv[1], "client"))
return -run_client(argv[2], argv[3], argv[4]);
show_help(argv[0]);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment