Skip to content

Instantly share code, notes, and snippets.

@ammarfaizi2
Created May 25, 2025 09:46
Show Gist options
  • Save ammarfaizi2/e91bd6709f559244b26dd935d46a5463 to your computer and use it in GitHub Desktop.
Save ammarfaizi2/e91bd6709f559244b26dd935d46a5463 to your computer and use it in GitHub Desktop.
// SPDX-License-Identifier: GPL-2.0-only
/*
* Author: Ammar Faizi <[email protected]>
*/
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#ifndef PM_USE_TCP
#define PM_USE_TCP 1
#endif
#ifndef PM_USE_HTTP
#define PM_USE_HTTP 1
#endif
#ifndef PM_USE_SSL
#define PM_USE_SSL 0
#endif
#ifndef __maybe_unused
#define __maybe_unused __attribute__((__unused__))
#endif
#include <ctype.h>
#if PM_USE_TCP
#include <netinet/in.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <sys/eventfd.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/epoll.h>
#include <stdatomic.h>
#include <pthread.h>
#include <assert.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <stdarg.h>
#include <errno.h>
#include <stdio.h>
struct sockaddr_in46 {
union {
struct sockaddr sa;
struct sockaddr_in v4;
struct sockaddr_in6 v6;
};
};
struct pm_buf {
size_t len;
size_t cap;
char *buf;
};
struct pm_net_tcp_ctx;
typedef struct pm_net_tcp_ctx pm_net_tcp_ctx_t;
typedef struct pm_net_tcp_client pm_net_tcp_client_t;
typedef int (*pm_net_tcp_recv_cb_t)(pm_net_tcp_client_t *c);
typedef int (*pm_net_tcp_send_cb_t)(pm_net_tcp_client_t *c);
typedef int (*pm_net_tcp_close_cb_t)(pm_net_tcp_client_t *c);
typedef int (*pm_net_tcp_accept_cb_t)(pm_net_tcp_ctx_t *ctx, pm_net_tcp_client_t *c);
struct pm_net_tcp_arg {
uint16_t nr_workers;
uint32_t client_init_cap;
int sock_backlog;
struct sockaddr_in46 bind_addr;
};
int pm_buf_init(struct pm_buf *b, size_t cap);
int pm_buf_append(struct pm_buf *b, const void *data, size_t len);
int pm_buf_append_fmt(struct pm_buf *b, const void *fmt, ...);
int pm_buf_resize(struct pm_buf *b, size_t new_cap);
void pm_buf_destroy(struct pm_buf *b);
int pm_net_tcp_ctx_init(pm_net_tcp_ctx_t **ctx_p, const struct pm_net_tcp_arg *arg);
void pm_net_tcp_ctx_run(pm_net_tcp_ctx_t *ctx);
void pm_net_tcp_ctx_wait(pm_net_tcp_ctx_t *ctx);
void pm_net_tcp_ctx_stop(pm_net_tcp_ctx_t *ctx);
void pm_net_tcp_ctx_destroy(pm_net_tcp_ctx_t *ctx);
void pm_net_tcp_ctx_set_udata(pm_net_tcp_ctx_t *ctx, void *udata);
void *pm_net_tcp_ctx_get_udata(pm_net_tcp_ctx_t *ctx);
void pm_net_tcp_ctx_set_accept_cb(pm_net_tcp_ctx_t *ctx, pm_net_tcp_accept_cb_t accept_cb);
void pm_net_tcp_client_set_udata(pm_net_tcp_client_t *c, void *udata);
void *pm_net_tcp_client_get_udata(pm_net_tcp_client_t *c);
void pm_net_tcp_client_set_recv_cb(pm_net_tcp_client_t *c, pm_net_tcp_recv_cb_t recv_cb);
void pm_net_tcp_client_set_send_cb(pm_net_tcp_client_t *c, pm_net_tcp_send_cb_t send_cb);
void pm_net_tcp_client_set_close_cb(pm_net_tcp_client_t *c, pm_net_tcp_close_cb_t close_cb);
struct pm_buf *pm_net_tcp_client_get_recv_buf(pm_net_tcp_client_t *c);
struct pm_buf *pm_net_tcp_client_get_send_buf(pm_net_tcp_client_t *c);
const struct sockaddr_in46 *pm_net_tcp_client_get_src_addr(pm_net_tcp_client_t *c);
void pm_net_tcp_client_user_close(pm_net_tcp_client_t *c);
struct pm_stack_u32 {
size_t bp;
size_t sp;
uint32_t *arr;
pthread_mutex_t lock;
};
struct pm_net_tcp_client {
int fd;
uint32_t idx;
uint32_t ep_mask;
struct pm_buf recv_buf;
struct pm_buf send_buf;
struct sockaddr_in46 src_addr;
void *udata;
pm_net_tcp_recv_cb_t recv_cb;
pm_net_tcp_send_cb_t send_cb;
pm_net_tcp_close_cb_t close_cb;
bool user_close;
bool is_used;
};
struct pm_net_tcp_wrk {
int ep_fd;
int ev_fd;
uint32_t idx;
uint32_t nr_events;
_Atomic(uint32_t) nr_online_conn;
struct epoll_event *events;
struct pm_net_tcp_ctx *ctx;
struct pm_net_tcp_client **clients;
struct pm_stack_u32 stack;
size_t client_cap;
pthread_t thread;
volatile bool need_join_thread;
volatile bool handle_event_should_break;
};
struct pm_net_tcp_ctx {
volatile bool should_stop;
volatile bool started;
volatile bool accept_stopped;
int tcp_fd;
pm_net_tcp_accept_cb_t accept_cb;
void *ctx_udata;
struct pm_net_tcp_wrk *workers;
struct pm_net_tcp_arg arg;
pthread_mutex_t accept_lock;
pthread_mutex_t start_lock;
pthread_cond_t start_cond;
};
enum {
EPL_EVT_EVENTFD = (1ull << 48ull),
EPL_EVT_CLIENT = (2ull << 48ull),
EPL_EVT_ACCEPT = (3ull << 48ull),
};
#define EPL_EV_MASK (0xffffull << 48ull)
#define GET_EPL_EV(data) ((data) & EPL_EV_MASK)
#define GET_EPL_DT(data) ((void *)((data) & ~EPL_EV_MASK))
#define INIT_BUF_SIZE 2048
static int pm_stack_u32_init(struct pm_stack_u32 *s, size_t cap)
{
int ret;
s->arr = malloc(cap * sizeof(uint32_t));
if (!s->arr)
return -1;
ret = pthread_mutex_init(&s->lock, NULL);
if (ret) {
free(s->arr);
return -1;
}
s->sp = 0;
s->bp = cap;
return 0;
}
static void pm_stack_u32_destroy(struct pm_stack_u32 *s)
{
pthread_mutex_destroy(&s->lock);
free(s->arr);
memset(s, 0, sizeof(*s));
}
static int __pm_stack_u32_push(struct pm_stack_u32 *s, uint32_t v)
{
if (s->sp == s->bp)
return -EAGAIN;
s->arr[s->sp++] = v;
return 0;
}
static int pm_stack_u32_push(struct pm_stack_u32 *s, uint32_t v)
{
int ret;
pthread_mutex_lock(&s->lock);
ret = __pm_stack_u32_push(s, v);
pthread_mutex_unlock(&s->lock);
return ret;
}
static int __pm_stack_u32_pop(struct pm_stack_u32 *s, uint32_t *v)
{
uint32_t isp;
if (s->sp == 0)
return -EAGAIN;
isp = --s->sp;
*v = s->arr[isp];
s->arr[isp] = -1;
return 0;
}
__maybe_unused
static int pm_stack_u32_pop(struct pm_stack_u32 *s, uint32_t *v)
{
int ret;
pthread_mutex_lock(&s->lock);
ret = __pm_stack_u32_pop(s, v);
pthread_mutex_unlock(&s->lock);
return ret;
}
static int sock_init(struct pm_net_tcp_ctx *ctx)
{
int family, fd, err;
socklen_t len;
int tmp;
ctx->tcp_fd = -1;
family = ctx->arg.bind_addr.sa.sa_family;
if (family != AF_INET && family != AF_INET6)
return -EINVAL;
if (family == AF_INET)
len = sizeof(struct sockaddr_in);
else
len = sizeof(struct sockaddr_in6);
fd = socket(family, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0);
if (fd < 0)
return -errno;
tmp = 1;
setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &tmp, sizeof(tmp));
if (bind(fd, &ctx->arg.bind_addr.sa, len) < 0) {
err = -errno;
close(fd);
return err;
}
if (listen(fd, ctx->arg.sock_backlog) < 0) {
err = -errno;
close(fd);
return err;
}
ctx->tcp_fd = fd;
return 0;
}
static void sock_destroy(struct pm_net_tcp_ctx *ctx)
{
if (ctx->tcp_fd >= 0) {
close(ctx->tcp_fd);
ctx->tcp_fd = -1;
}
}
int pm_buf_init(struct pm_buf *b, size_t cap)
{
b->buf = malloc(cap);
if (!b->buf)
return -ENOMEM;
b->len = 0;
b->cap = cap;
return 0;
}
void pm_buf_destroy(struct pm_buf *b)
{
free(b->buf);
memset(b, 0, sizeof(*b));
}
int pm_buf_append_fmt(struct pm_buf *b, const void *fmt, ...)
{
va_list ap, ap2;
int len;
va_start(ap, fmt);
va_copy(ap2, ap);
len = vsnprintf(NULL, 0, fmt, ap);
if (b->cap - b->len < ((size_t)len + 1u)) {
if (pm_buf_resize(b, (b->len + len + 1) * 2)) {
va_end(ap2);
va_end(ap);
return -ENOMEM;
}
}
vsnprintf(b->buf + b->len, len + 1, fmt, ap2);
b->len += (size_t)len;
va_end(ap2);
va_end(ap);
return 0;
}
int pm_buf_append(struct pm_buf *b, const void *data, size_t len)
{
size_t new_len;
if (!len)
return 0;
new_len = b->len + len;
if (new_len > b->cap) {
if (pm_buf_resize(b, (new_len + 1) * 2))
return -ENOMEM;
}
memcpy(b->buf + b->len, data, len);
b->len = new_len;
return 0;
}
int pm_buf_resize(struct pm_buf *b, size_t new_cap)
{
char *new_buf;
new_buf = realloc(b->buf, new_cap);
if (!new_buf)
return -ENOMEM;
b->buf = new_buf;
b->cap = new_cap;
if (b->len > new_cap)
b->len = new_cap;
return 0;
}
static int client_init(struct pm_net_tcp_client *c)
{
c->fd = -1;
memset(&c->src_addr, 0, sizeof(c->src_addr));
return 0;
}
static struct pm_net_tcp_client *client_alloc(void)
{
struct pm_net_tcp_client *c;
c = calloc(1, sizeof(*c));
if (!c)
return NULL;
if (client_init(c)) {
free(c);
return NULL;
}
return c;
}
static void client_destroy(struct pm_net_tcp_client *c)
{
if (!c)
return;
if (c->close_cb)
c->close_cb(c);
if (c->fd >= 0)
close(c->fd);
pm_buf_destroy(&c->recv_buf);
pm_buf_destroy(&c->send_buf);
memset(c, 0, sizeof(*c));
}
static void clients_destroy(struct pm_net_tcp_wrk *w)
{
uint32_t i;
if (!w->clients)
return;
for (i = 0; i < w->client_cap; i++)
client_destroy(w->clients[i]);
free(w->clients);
w->clients = NULL;
}
static int clients_init(struct pm_net_tcp_wrk *w)
{
struct pm_net_tcp_client **clients, *c;
uint32_t i;
int ret;
clients = calloc(w->ctx->arg.client_init_cap, sizeof(*clients));
if (!clients)
return -ENOMEM;
ret = pm_stack_u32_init(&w->stack, w->ctx->arg.client_init_cap);
if (ret) {
free(clients);
return ret;
}
w->client_cap = w->ctx->arg.client_init_cap;
w->clients = clients;
for (i = 0; i < w->client_cap; i++) {
c = client_alloc();
if (!c) {
clients_destroy(w);
pm_stack_u32_destroy(&w->stack);
return -ENOMEM;
}
c->fd = -1;
c->idx = i;
clients[i] = c;
__pm_stack_u32_push(&w->stack, i);
}
return 0;
}
static int epoll_add(int ep_fd, int fd, uint32_t events, union epoll_data data)
{
struct epoll_event ev = { .events = events, .data = data };
return epoll_ctl(ep_fd, EPOLL_CTL_ADD, fd, &ev);
}
static int epoll_del(int ep_fd, int fd)
{
return epoll_ctl(ep_fd, EPOLL_CTL_DEL, fd, NULL);
}
static int epoll_mod(int ep_fd, int fd, uint32_t events, union epoll_data data)
{
struct epoll_event ev = { .events = events, .data = data };
return epoll_ctl(ep_fd, EPOLL_CTL_MOD, fd, &ev);
}
static int send_event_fd(struct pm_net_tcp_wrk *w)
{
uint64_t val = 1;
ssize_t ret;
ret = write(w->ev_fd, &val, sizeof(val));
if (ret < 0)
return -errno;
return 0;
}
static int recv_event_fd(struct pm_net_tcp_wrk *w)
{
uint64_t val;
ssize_t ret;
ret = read(w->ev_fd, &val, sizeof(val));
if (ret < 0)
return -errno;
return 0;
}
static int epoll_init(struct pm_net_tcp_wrk *w)
{
static const uint32_t nr_events = 128;
struct epoll_event *events;
union epoll_data data;
int ep_fd, ev_fd, err;
w->ep_fd = w->ev_fd = -1;
ep_fd = epoll_create(128);
if (ep_fd < 0)
return -errno;
ev_fd = eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC);
if (ev_fd < 0) {
err = -errno;
close(ep_fd);
return err;
}
events = calloc(nr_events, sizeof(*events));
if (!events) {
close(ev_fd);
close(ep_fd);
return -ENOMEM;
}
data.u64 = EPL_EVT_EVENTFD;
err = epoll_add(ep_fd, ev_fd, EPOLLIN, data);
if (err) {
free(events);
close(ev_fd);
close(ep_fd);
return err;
}
w->ep_fd = ep_fd;
w->ev_fd = ev_fd;
w->nr_events = nr_events;
w->events = events;
return 0;
}
static void epoll_destroy(struct pm_net_tcp_wrk *w)
{
if (w->ep_fd >= 0) {
close(w->ep_fd);
w->ep_fd = -1;
}
if (w->ev_fd >= 0) {
close(w->ev_fd);
w->ev_fd = -1;
}
if (w->events) {
free(w->events);
w->events = NULL;
}
}
static void *worker_entry(void *arg);
static int worker_init(struct pm_net_tcp_wrk *w)
{
int ret;
ret = clients_init(w);
if (ret)
return ret;
ret = epoll_init(w);
if (ret) {
clients_destroy(w);
return ret;
}
if (w->idx == 0) {
union epoll_data data;
data.u64 = EPL_EVT_ACCEPT;
ret = epoll_add(w->ep_fd, w->ctx->tcp_fd, EPOLLIN, data);
if (ret) {
epoll_destroy(w);
clients_destroy(w);
return ret;
}
}
ret = pthread_create(&w->thread, NULL, &worker_entry, w);
if (ret) {
epoll_destroy(w);
clients_destroy(w);
return -ret;
}
w->need_join_thread = true;
return ret;
}
static void worker_destroy(struct pm_net_tcp_wrk *w)
{
if (!w)
return;
if (w->need_join_thread) {
pthread_join(w->thread, NULL);
w->need_join_thread = false;
}
epoll_destroy(w);
clients_destroy(w);
}
static void workers_destroy(struct pm_net_tcp_ctx *ctx)
{
uint32_t i;
if (!ctx->workers)
return;
pm_net_tcp_ctx_stop(ctx);
for (i = 0; i < ctx->arg.nr_workers; i++)
worker_destroy(&ctx->workers[i]);
free(ctx->workers);
ctx->workers = NULL;
}
static int workers_init(struct pm_net_tcp_ctx *ctx)
{
struct pm_net_tcp_wrk *workers;
uint32_t i;
int ret;
if (ctx->arg.nr_workers == 0)
return -EINVAL;
workers = calloc(ctx->arg.nr_workers, sizeof(*ctx->workers));
if (!workers)
return -ENOMEM;
for (i = 0; i < ctx->arg.nr_workers; i++) {
struct pm_net_tcp_wrk *w = &workers[i];
w->idx = i;
w->ctx = ctx;
ret = worker_init(w);
if (ret) {
while (i--)
worker_destroy(&workers[i]);
free(workers);
return ret;
}
}
ctx->workers = workers;
return 0;
}
static int get_client_slot(struct pm_net_tcp_wrk *w, struct pm_net_tcp_client **cp)
{
struct pm_net_tcp_client *c;
uint32_t idx;
int ret;
pthread_mutex_lock(&w->stack.lock);
ret = __pm_stack_u32_pop(&w->stack, &idx);
if (ret) {
pthread_mutex_unlock(&w->stack.lock);
return -EAGAIN;
}
c = w->clients[idx];
pthread_mutex_unlock(&w->stack.lock);
assert(c);
assert(c->fd < 0);
assert(c->idx == idx);
assert(!c->recv_buf.len);
assert(!c->send_buf.len);
assert(!c->recv_cb);
assert(!c->send_cb);
assert(!c->close_cb);
assert(!c->user_close);
assert(!c->is_used);
if (pm_buf_init(&c->recv_buf, INIT_BUF_SIZE)) {
pm_stack_u32_push(&w->stack, idx);
return -ENOMEM;
}
if (pm_buf_init(&c->send_buf, INIT_BUF_SIZE)) {
pm_buf_destroy(&c->recv_buf);
pm_stack_u32_push(&w->stack, idx);
return -ENOMEM;
}
c->is_used = true;
*cp = c;
atomic_fetch_add(&w->nr_online_conn, 1u);
return ret;
}
static int __put_client_slot(struct pm_net_tcp_wrk *w, struct pm_net_tcp_client *c, bool del_epoll)
{
bool close_fd = false;
int ret;
pthread_mutex_lock(&w->stack.lock);
assert(c->is_used);
if (c->close_cb)
c->close_cb(c);
if (c->fd >= 0) {
if (del_epoll) {
ret = epoll_del(w->ep_fd, c->fd);
assert(!ret);
}
close_fd = true;
close(c->fd);
c->fd = -1;
}
if (c->recv_buf.cap)
pm_buf_destroy(&c->recv_buf);
if (c->send_buf.cap)
pm_buf_destroy(&c->send_buf);
c->recv_buf.len = 0;
c->send_buf.len = 0;
c->recv_cb = NULL;
c->send_cb = NULL;
c->close_cb = NULL;
c->user_close = false;
c->is_used = false;
ret = __pm_stack_u32_push(&w->stack, c->idx);
assert(!ret);
pthread_mutex_unlock(&w->stack.lock);
atomic_fetch_sub(&w->nr_online_conn, 1u);
if (close_fd) {
struct pm_net_tcp_wrk *mw = &w->ctx->workers[0];
pm_net_tcp_ctx_t *ctx = w->ctx;
pthread_mutex_lock(&ctx->accept_lock);
if (ctx->accept_stopped) {
union epoll_data data;
data.u64 = EPL_EVT_ACCEPT;
ret = epoll_add(mw->ep_fd, ctx->tcp_fd, EPOLLIN, data);
assert(!ret);
ctx->accept_stopped = false;
send_event_fd(mw);
}
pthread_mutex_unlock(&ctx->accept_lock);
}
return ret;
}
static int put_client_slot(struct pm_net_tcp_wrk *w, struct pm_net_tcp_client *c)
{
return __put_client_slot(w, c, true);
}
static int put_client_slot_no_epoll(struct pm_net_tcp_wrk *w, struct pm_net_tcp_client *c)
{
return __put_client_slot(w, c, false);
}
static struct pm_net_tcp_wrk *pick_worker_for_new_conn(struct pm_net_tcp_ctx *ctx)
{
struct pm_net_tcp_wrk *w = &ctx->workers[0];
uint32_t i, min, min_idx = 0, tmp;
if (ctx->arg.nr_workers == 1)
return w;
min = atomic_load_explicit(&w->nr_online_conn, memory_order_relaxed);
for (i = 1; i < ctx->arg.nr_workers; i++) {
w = &ctx->workers[i];
tmp = atomic_load_explicit(&w->nr_online_conn, memory_order_relaxed);
if (tmp < min) {
min = tmp;
min_idx = i;
}
}
return &ctx->workers[min_idx];
}
static int handle_accept_error(int err, struct pm_net_tcp_wrk *w)
{
if (err == EAGAIN || err == EINTR)
return 0;
if (err == EMFILE || err == ENFILE) {
pthread_mutex_lock(&w->ctx->accept_lock);
w->ctx->accept_stopped = true;
pthread_mutex_unlock(&w->ctx->accept_lock);
return epoll_del(w->ep_fd, w->ctx->tcp_fd);
}
return -err;
}
/*
* @fd: The ownership is taken by give_client_fd_to_a_worker().
*/
static int give_client_fd_to_a_worker(struct pm_net_tcp_ctx *ctx, int fd,
const struct sockaddr_in46 *addr)
{
struct pm_net_tcp_client *c;
struct pm_net_tcp_wrk *w;
union epoll_data data;
int r;
w = pick_worker_for_new_conn(ctx);
r = get_client_slot(w, &c);
if (r) {
close(fd);
return -ENOMEM;
}
c->fd = fd;
c->src_addr = *addr;
c->ep_mask = EPOLLIN;
if (ctx->accept_cb)
ctx->accept_cb(ctx, c);
data.u64 = 0;
data.ptr = c;
data.u64 |= EPL_EVT_CLIENT;
r = epoll_add(w->ep_fd, fd, c->ep_mask, data);
if (r) {
put_client_slot_no_epoll(w, c);
return r;
}
send_event_fd(w);
return 0;
}
static int handle_event_accept(struct pm_net_tcp_wrk *w)
{
static const uint32_t NR_MAX_ACCEPT_CYCLE = 32;
struct sockaddr_in46 addr;
uint32_t counter = 0;
socklen_t len;
int ret;
do_accept:
memset(&addr, 0, sizeof(addr));
if (w->ctx->arg.bind_addr.sa.sa_family == AF_INET)
len = sizeof(struct sockaddr_in);
else
len = sizeof(struct sockaddr_in6);
ret = accept4(w->ctx->tcp_fd, &addr.sa, &len, SOCK_NONBLOCK | SOCK_CLOEXEC);
if (ret < 0)
return handle_accept_error(errno, w);
if (len > sizeof(addr)) {
close(ret);
return -EINVAL;
}
ret = give_client_fd_to_a_worker(w->ctx, ret, &addr);
if (ret)
return 0;
if (++counter < NR_MAX_ACCEPT_CYCLE)
goto do_accept;
return 0;
}
static int apply_ep_mask(struct pm_net_tcp_wrk *w, struct pm_net_tcp_client *c)
{
union epoll_data data;
data.u64 = 0;
data.ptr = c;
data.u64 |= EPL_EVT_CLIENT;
return epoll_mod(w->ep_fd, c->fd, c->ep_mask, data);
}
static int handle_event_client_send(struct pm_net_tcp_client *c)
{
int err = 0;
ssize_t ret;
size_t len;
char *buf;
buf = c->send_buf.buf;
len = c->send_buf.len;
if (!len)
return 0;
ret = send(c->fd, buf, len, MSG_DONTWAIT);
if (ret < 0) {
err = -errno;
if (err == -EAGAIN || err == -EINTR)
return 0;
return err;
}
if (!ret)
return -ECONNRESET;
if ((size_t)ret < len) {
memmove(buf, buf + ret, len - (size_t)ret);
c->send_buf.len -= (size_t)ret;
} else {
c->send_buf.len = 0;
}
if (c->send_cb) {
err = c->send_cb(c);
if (err == -EAGAIN)
err = 0;
}
return err;
}
static int handle_event_client_recv(struct pm_net_tcp_client *c)
{
int err = 0;
ssize_t ret;
size_t len;
char *buf;
buf = c->recv_buf.buf + c->recv_buf.len;
len = c->recv_buf.cap - c->recv_buf.len;
if (!len) {
if (pm_buf_resize(&c->recv_buf, (c->recv_buf.cap + 1) * 2))
return -ENOMEM;
buf = c->recv_buf.buf + c->recv_buf.len;
len = c->recv_buf.cap - c->recv_buf.len;
}
ret = recv(c->fd, buf, len, MSG_DONTWAIT);
if (ret < 0) {
err = -errno;
if (err == -EAGAIN || err == -EINTR)
return 0;
return err;
}
if (!ret)
return -ECONNRESET;
c->recv_buf.len += (size_t)ret;
if (c->recv_cb) {
err = c->recv_cb(c);
if (err == -EAGAIN)
err = 0;
}
if (c->send_buf.len)
err = handle_event_client_send(c);
return err;
}
static int handle_event_client(struct pm_net_tcp_wrk *w, struct epoll_event *ev)
{
struct pm_net_tcp_client *c = GET_EPL_DT(ev->data.u64);
uint32_t events = ev->events;
int ret = 0;
if (events & EPOLLIN) {
ret = handle_event_client_recv(c);
if (ret)
return ret;
}
if (events & EPOLLOUT) {
ret = handle_event_client_send(c);
if (ret)
return ret;
}
if (c->user_close || (events & (EPOLLERR | EPOLLHUP)))
return -ECONNRESET;
if ((c->ep_mask & EPOLLOUT) && !(c->send_buf.len)) {
c->ep_mask &= ~EPOLLOUT;
ret = apply_ep_mask(w, c);
} else if (!(c->ep_mask & EPOLLOUT) && c->send_buf.len) {
c->ep_mask |= EPOLLOUT;
ret = apply_ep_mask(w, c);
}
return ret;
}
/*
* Accept and event FD are low priority events.
* Handle them after all other events.
*/
struct epl_handle_ev {
bool has_event_accept;
bool has_event_evfd;
};
static int handle_event(struct pm_net_tcp_wrk *w, struct epoll_event *ev,
struct epl_handle_ev *he)
{
uint64_t ev_type = GET_EPL_EV(ev->data.u64);
int ret = 0;
switch (ev_type) {
case EPL_EVT_CLIENT:
ret = handle_event_client(w, ev);
if (ret) {
struct pm_net_tcp_client *c = GET_EPL_DT(ev->data.u64);
put_client_slot(w, c);
ret = 0;
}
break;
case EPL_EVT_ACCEPT:
he->has_event_accept = true;
break;
case EPL_EVT_EVENTFD:
he->has_event_evfd = true;
break;
default:
break;
}
return ret;
}
static int handle_low_priority_events(struct pm_net_tcp_wrk *w,
struct epl_handle_ev *he)
{
struct pm_net_tcp_ctx *ctx = w->ctx;
int ret = 0;
if (ctx->should_stop)
return ret;
if (he->has_event_evfd) {
ret = recv_event_fd(w);
if (ret)
return ret;
}
if (he->has_event_accept) {
ret = handle_event_accept(w);
if (ret)
return ret;
}
return ret;
}
static int handle_events(struct pm_net_tcp_wrk *w, int nr_events)
{
struct pm_net_tcp_ctx *ctx = w->ctx;
struct epl_handle_ev he;
int ret = 0, i;
if (!nr_events)
return 0;
memset(&he, 0, sizeof(he));
for (i = 0; i < nr_events; i++) {
struct epoll_event *ev = &w->events[i];
ret = handle_event(w, ev, &he);
if (ret < 0)
break;
if (ctx->should_stop)
break;
}
if (!ret)
ret = handle_low_priority_events(w, &he);
return ret;
}
static int poll_events(struct pm_net_tcp_wrk *w)
{
struct epoll_event *events = w->events;
uint32_t nr_events = w->nr_events;
int ret;
ret = epoll_wait(w->ep_fd, events, nr_events, -1);
if (ret < 0) {
ret = -errno;
if (ret == -EINTR)
return 0;
}
return ret;
}
enum {
WORKER_WAIT_RUN = 0,
WORKER_WAIT_STOP = 1,
};
static int worker_wait_for_start_signal(struct pm_net_tcp_wrk *w)
{
struct pm_net_tcp_ctx *ctx = w->ctx;
uint16_t port;
char buf[64];
int ret;
if (w->ctx->arg.bind_addr.sa.sa_family == AF_INET)
port = w->ctx->arg.bind_addr.v4.sin_port;
else
port = w->ctx->arg.bind_addr.v6.sin6_port;
snprintf(buf, sizeof(buf), "tcp%hu-%u", ntohs(port), w->idx);
pthread_setname_np(w->thread, buf);
pthread_mutex_lock(&ctx->start_lock);
while (1) {
if (ctx->should_stop) {
ret = WORKER_WAIT_STOP;
break;
}
if (ctx->started) {
ret = WORKER_WAIT_RUN;
break;
}
pthread_cond_wait(&ctx->start_cond, &ctx->start_lock);
}
pthread_mutex_unlock(&ctx->start_lock);
return ret;
}
static void *worker_entry(void *arg)
{
struct pm_net_tcp_wrk *w = arg;
struct pm_net_tcp_ctx *ctx = w->ctx;
int ret;
ret = worker_wait_for_start_signal(w);
if (ret == WORKER_WAIT_STOP)
goto out;
while (!ctx->should_stop) {
ret = poll_events(w);
if (ret < 0)
break;
ret = handle_events(w, ret);
if (ret < 0)
break;
}
out:
pm_net_tcp_ctx_stop(ctx);
return NULL;
}
int pm_net_tcp_ctx_init(pm_net_tcp_ctx_t **ctx_p, const struct pm_net_tcp_arg *arg)
{
pm_net_tcp_ctx_t *ctx;
int ret;
ctx = calloc(1, sizeof(*ctx));
if (!ctx)
return -ENOMEM;
memset(ctx, 0, sizeof(*ctx));
ctx->arg = *arg;
ret = pthread_mutex_init(&ctx->start_lock, NULL);
if (ret)
goto out_ctx;
ret = pthread_mutex_init(&ctx->accept_lock, NULL);
if (ret)
goto out_start_lock;
ret = pthread_cond_init(&ctx->start_cond, NULL);
if (ret)
goto out_accept_lock;
ret = sock_init(ctx);
if (ret)
goto out_start_cond;
ret = workers_init(ctx);
if (ret)
goto out_sock;
*ctx_p = ctx;
return 0;
out_sock:
sock_destroy(ctx);
out_start_cond:
pthread_cond_destroy(&ctx->start_cond);
out_accept_lock:
pthread_mutex_destroy(&ctx->accept_lock);
out_start_lock:
pthread_mutex_destroy(&ctx->start_lock);
out_ctx:
free(ctx);
return ret;
}
void pm_net_tcp_ctx_run(struct pm_net_tcp_ctx *ctx)
{
pthread_mutex_lock(&ctx->start_lock);
ctx->started = true;
pthread_cond_broadcast(&ctx->start_cond);
pthread_mutex_unlock(&ctx->start_lock);
}
void pm_net_tcp_ctx_wait(struct pm_net_tcp_ctx *ctx)
{
pthread_mutex_lock(&ctx->start_lock);
while (!ctx->should_stop)
pthread_cond_wait(&ctx->start_cond, &ctx->start_lock);
pthread_mutex_unlock(&ctx->start_lock);
}
void pm_net_tcp_ctx_stop(struct pm_net_tcp_ctx *ctx)
{
uint32_t i;
pthread_mutex_lock(&ctx->start_lock);
ctx->should_stop = true;
pthread_cond_broadcast(&ctx->start_cond);
for (i = 0; i < ctx->arg.nr_workers; i++) {
struct pm_net_tcp_wrk *w = &ctx->workers[i];
if (w->need_join_thread)
send_event_fd(w);
}
pthread_mutex_unlock(&ctx->start_lock);
}
void pm_net_tcp_ctx_destroy(struct pm_net_tcp_ctx *ctx)
{
pm_net_tcp_ctx_stop(ctx);
workers_destroy(ctx);
sock_destroy(ctx);
pthread_cond_destroy(&ctx->start_cond);
pthread_mutex_destroy(&ctx->start_lock);
pthread_mutex_destroy(&ctx->accept_lock);
memset(ctx, 0, sizeof(*ctx));
}
void pm_net_tcp_ctx_set_udata(pm_net_tcp_ctx_t *ctx, void *udata)
{
ctx->ctx_udata = udata;
}
void *pm_net_tcp_ctx_get_udata(pm_net_tcp_ctx_t *ctx)
{
return ctx->ctx_udata;
}
void pm_net_tcp_ctx_set_accept_cb(pm_net_tcp_ctx_t *ctx, pm_net_tcp_accept_cb_t accept_cb)
{
ctx->accept_cb = accept_cb;
}
void pm_net_tcp_client_set_udata(pm_net_tcp_client_t *c, void *udata)
{
c->udata = udata;
}
void *pm_net_tcp_client_get_udata(pm_net_tcp_client_t *c)
{
return c->udata;
}
void pm_net_tcp_client_set_recv_cb(pm_net_tcp_client_t *c, pm_net_tcp_recv_cb_t recv_cb)
{
c->recv_cb = recv_cb;
}
void pm_net_tcp_client_set_send_cb(pm_net_tcp_client_t *c, pm_net_tcp_send_cb_t send_cb)
{
c->send_cb = send_cb;
}
void pm_net_tcp_client_set_close_cb(pm_net_tcp_client_t *c, pm_net_tcp_close_cb_t close_cb)
{
c->close_cb = close_cb;
}
struct pm_buf *pm_net_tcp_client_get_recv_buf(pm_net_tcp_client_t *c)
{
return &c->recv_buf;
}
struct pm_buf *pm_net_tcp_client_get_send_buf(pm_net_tcp_client_t *c)
{
return &c->send_buf;
}
const struct sockaddr_in46 *pm_net_tcp_client_get_src_addr(pm_net_tcp_client_t *c)
{
return &c->src_addr;
}
void pm_net_tcp_client_user_close(pm_net_tcp_client_t *c)
{
c->user_close = true;
}
#endif /* #if PM_USE_TCP */
#if PM_USE_SSL
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <openssl/bio.h>
#ifdef __cplusplus
extern "C" {
#endif
struct pm_net_tcp_ssl_arg {
char cert_file[512];
char key_file[512];
struct pm_net_tcp_arg net_arg;
};
struct pm_net_tcp_ssl_ctx;
struct pm_net_tcp_ssl_client;
typedef struct pm_net_tcp_ssl_ctx pm_net_tcp_ssl_ctx_t;
typedef struct pm_net_tcp_ssl_client pm_net_tcp_ssl_client_t;
typedef int (*pm_net_tcp_ssl_recv_cb_t)(pm_net_tcp_ssl_client_t *c);
typedef int (*pm_net_tcp_ssl_send_cb_t)(pm_net_tcp_ssl_client_t *c);
typedef int (*pm_net_tcp_ssl_close_cb_t)(pm_net_tcp_ssl_client_t *c);
typedef int (*pm_net_tcp_ssl_accept_cb_t)(pm_net_tcp_ssl_ctx_t *ctx, pm_net_tcp_ssl_client_t *c);
int pm_net_tcp_ssl_ctx_init(pm_net_tcp_ssl_ctx_t **ctx_p, const struct pm_net_tcp_ssl_arg *arg);
void pm_net_tcp_ssl_ctx_run(pm_net_tcp_ssl_ctx_t *ctx);
void pm_net_tcp_ssl_ctx_wait(pm_net_tcp_ssl_ctx_t *ctx);
void pm_net_tcp_ssl_ctx_stop(pm_net_tcp_ssl_ctx_t *ctx);
void pm_net_tcp_ssl_ctx_destroy(pm_net_tcp_ssl_ctx_t *ctx_p);
void pm_net_tcp_ssl_ctx_set_udata(pm_net_tcp_ssl_ctx_t *ctx, void *udata);
void *pm_net_tcp_ssl_ctx_get_udata(pm_net_tcp_ssl_ctx_t *ctx);
void pm_net_tcp_ssl_ctx_set_accept_cb(pm_net_tcp_ssl_ctx_t *ctx, pm_net_tcp_ssl_accept_cb_t accept_cb);
void pm_net_tcp_ssl_client_set_udata(pm_net_tcp_ssl_client_t *c, void *udata);
void *pm_net_tcp_ssl_client_get_udata(pm_net_tcp_ssl_client_t *c);
void pm_net_tcp_ssl_client_set_recv_cb(pm_net_tcp_ssl_client_t *c, pm_net_tcp_ssl_recv_cb_t recv_cb);
void pm_net_tcp_ssl_client_set_send_cb(pm_net_tcp_ssl_client_t *c, pm_net_tcp_ssl_send_cb_t send_cb);
void pm_net_tcp_ssl_client_set_close_cb(pm_net_tcp_ssl_client_t *c, pm_net_tcp_ssl_close_cb_t close_cb);
struct pm_buf *pm_net_tcp_ssl_client_get_recv_buf(pm_net_tcp_ssl_client_t *c);
struct pm_buf *pm_net_tcp_ssl_client_get_send_buf(pm_net_tcp_ssl_client_t *c);
const struct sockaddr_in46 *pm_net_tcp_ssl_client_get_src_addr(pm_net_tcp_ssl_client_t *c);
void pm_net_tcp_ssl_client_user_close(pm_net_tcp_ssl_client_t *c);
#include <stdio.h>
#include <assert.h>
#include <string.h>
struct pm_net_tcp_ssl_ctx {
SSL_CTX *ssl_ctx;
struct pm_net_tcp_ctx *net_ctx;
struct pm_net_tcp_ssl_arg arg;
pm_net_tcp_ssl_accept_cb_t accept_cb;
void *udata;
};
struct pm_net_tcp_ssl_client {
SSL *ssl;
BIO *rbio;
BIO *wbio;
struct pm_net_tcp_ssl_ctx *ssl_ctx;
pm_net_tcp_client_t *net_client;
void *udata;
struct pm_buf recv_buf;
struct pm_buf send_buf;
pm_net_tcp_ssl_recv_cb_t recv_cb;
pm_net_tcp_ssl_send_cb_t send_cb;
pm_net_tcp_ssl_close_cb_t close_cb;
bool has_accepted;
};
static int do_bio_write(struct pm_net_tcp_ssl_client *ssl_c, struct pm_buf *rbuf)
{
size_t uret;
int ret;
if (rbuf->len == 0)
return -EAGAIN;
ret = BIO_write(ssl_c->rbio, rbuf->buf, rbuf->len);
if (ret <= 0)
return -EIO;
uret = (size_t)ret;
if (uret > rbuf->len)
return -EINVAL;
if (uret < rbuf->len) {
rbuf->len -= uret;
memmove(rbuf->buf, rbuf->buf + uret, rbuf->len);
ret = -EAGAIN;
} else {
rbuf->len = 0;
ret = 0;
}
return ret;
}
static int do_bio_read(struct pm_net_tcp_ssl_client *ssl_c, struct pm_buf *sbuf)
{
size_t len;
char *buf;
int ret;
if (BIO_ctrl_pending(ssl_c->wbio) == 0)
return -EAGAIN;
len = sbuf->cap - sbuf->len;
if (!len) {
if (pm_buf_resize(sbuf, (sbuf->cap + 1) * 2))
return -ENOMEM;
len = sbuf->cap - sbuf->len;
}
buf = sbuf->buf + sbuf->len;
ret = BIO_read(ssl_c->wbio, buf, len);
if (ret <= 0)
return -EIO;
sbuf->len += (size_t)ret;
return 0;
}
static int handle_ssl_err(struct pm_net_tcp_ssl_client *ssl_c, int ret,
struct pm_buf *tcp_rbuf, struct pm_buf *tcp_sbuf)
{
switch (SSL_get_error(ssl_c->ssl, ret)) {
case SSL_ERROR_WANT_READ:
return do_bio_read(ssl_c, tcp_sbuf);
case SSL_ERROR_WANT_WRITE:
return do_bio_write(ssl_c, tcp_rbuf);
default:
return -EIO;
}
}
static int do_ssl_read(struct pm_net_tcp_ssl_client *ssl_c,
struct pm_buf *tcp_rbuf, struct pm_buf *tcp_sbuf)
{
struct pm_buf *ssl_rbuf = &ssl_c->recv_buf;
size_t uret, len;
char *buf;
int ret;
len = ssl_rbuf->cap - ssl_rbuf->len;
if (len < 2) {
if (pm_buf_resize(ssl_rbuf, (ssl_rbuf->cap + 1) * 2))
return -ENOMEM;
len = ssl_rbuf->cap - ssl_rbuf->len;
}
buf = ssl_rbuf->buf + ssl_rbuf->len;
ret = SSL_read(ssl_c->ssl, buf, len - 1);
if (ret <= 0)
return handle_ssl_err(ssl_c, ret, tcp_rbuf, tcp_sbuf);
uret = (size_t)ret;
if (uret > len - 1)
return -EINVAL;
ssl_rbuf->len += uret;
ssl_rbuf->buf[ssl_rbuf->len] = '\0';
ret = 0;
if (ssl_c->recv_cb)
ret = ssl_c->recv_cb(ssl_c);
if (ret == -EAGAIN)
ret = 0;
return ret;
}
static int do_ssl_write(struct pm_net_tcp_ssl_client *ssl_c,
struct pm_buf *tcp_rbuf, struct pm_buf *tcp_sbuf)
{
struct pm_buf *ssl_sbuf = &ssl_c->send_buf;
size_t uret, len;
char *buf;
int ret;
len = ssl_sbuf->len;
if (!len)
return -EAGAIN;
buf = ssl_sbuf->buf;
ret = SSL_write(ssl_c->ssl, buf, len);
if (ret <= 0)
return handle_ssl_err(ssl_c, ret, tcp_rbuf, tcp_sbuf);
uret = (size_t)ret;
if (uret > len)
return -EINVAL;
if (uret < len) {
memmove(ssl_sbuf->buf, ssl_sbuf->buf + uret, len - uret);
ssl_sbuf->len -= uret;
} else {
ssl_sbuf->len = 0;
}
if (ssl_c->send_cb) {
ret = ssl_c->send_cb(ssl_c);
if (ret < 0 && ret != -EAGAIN)
return ret;
}
return do_bio_read(ssl_c, tcp_sbuf);
}
static int __pm_net_tcp_ssl_client_recv_cb(pm_net_tcp_ssl_client_t *ssl_c,
struct pm_buf *tcp_rbuf,
struct pm_buf *tcp_sbuf)
{
int ret;
ret = do_bio_write(ssl_c, tcp_rbuf);
if (ret)
return ret;
if (!ssl_c->has_accepted) {
struct pm_net_tcp_ssl_ctx *ssl_ctx = ssl_c->ssl_ctx;
ret = SSL_accept(ssl_c->ssl);
if (ret != 1)
return handle_ssl_err(ssl_c, ret, tcp_rbuf, tcp_sbuf);
ssl_c->has_accepted = true;
if (ssl_ctx->accept_cb)
ret = ssl_ctx->accept_cb(ssl_ctx, ssl_c);
}
ret = do_ssl_read(ssl_c, tcp_rbuf, tcp_sbuf);
if (ret)
return ret;
ret = do_ssl_write(ssl_c, tcp_rbuf, tcp_sbuf);
if (ret)
return ret;
return 0;
}
static int pm_net_tcp_ssl_client_recv_cb(pm_net_tcp_client_t *c)
{
pm_net_tcp_ssl_client_t *ssl_c = pm_net_tcp_client_get_udata(c);
struct pm_buf *tcp_rbuf = pm_net_tcp_client_get_recv_buf(c);
struct pm_buf *tcp_sbuf = pm_net_tcp_client_get_send_buf(c);
int ret;
do {
ret = __pm_net_tcp_ssl_client_recv_cb(ssl_c, tcp_rbuf, tcp_sbuf);
} while (tcp_rbuf->len);
return ret;
}
static int pm_net_tcp_ssl_client_send_cb(pm_net_tcp_client_t *c)
{
pm_net_tcp_ssl_client_t *ssl_c = pm_net_tcp_client_get_udata(c);
struct pm_buf *tcp_rbuf = pm_net_tcp_client_get_recv_buf(c);
struct pm_buf *tcp_sbuf = pm_net_tcp_client_get_send_buf(c);
int ret;
ret = do_ssl_write(ssl_c, tcp_rbuf, tcp_sbuf);
if (ret != -EAGAIN)
return ret;
ret = do_bio_read(ssl_c, tcp_sbuf);
if (ret)
return ret;
return 0;
}
static int pm_net_tcp_ssl_client_close_cb(pm_net_tcp_client_t *c)
{
pm_net_tcp_ssl_client_t *ssl_c = pm_net_tcp_client_get_udata(c);
if (ssl_c->close_cb)
ssl_c->close_cb(ssl_c);
pm_buf_destroy(&ssl_c->recv_buf);
pm_buf_destroy(&ssl_c->send_buf);
SSL_free(ssl_c->ssl);
free(ssl_c);
return 0;
}
static int pm_net_ssl_accept_cb(pm_net_tcp_ctx_t *ctx, pm_net_tcp_client_t *c)
{
pm_net_tcp_ssl_ctx_t *ssl_ctx = pm_net_tcp_ctx_get_udata(ctx);
pm_net_tcp_ssl_client_t *ssl_c;
ssl_c = calloc(1, sizeof(*ssl_c));
if (!ssl_c)
return -ENOMEM;
ssl_c->ssl = SSL_new(ssl_ctx->ssl_ctx);
if (!ssl_c->ssl)
goto out_ssl_c;
ssl_c->rbio = BIO_new(BIO_s_mem());
if (!ssl_c->rbio)
goto out_ssl;
ssl_c->wbio = BIO_new(BIO_s_mem());
if (!ssl_c->wbio)
goto out_rbio;
if (pm_buf_init(&ssl_c->recv_buf, 2048))
goto out_wbio;
if (pm_buf_init(&ssl_c->send_buf, 2048))
goto out_recv_buf;
ssl_c->ssl_ctx = ssl_ctx;
SSL_set_bio(ssl_c->ssl, ssl_c->rbio, ssl_c->wbio);
SSL_set_accept_state(ssl_c->ssl);
ssl_c->net_client = c;
pm_net_tcp_client_set_udata(c, ssl_c);
pm_net_tcp_client_set_recv_cb(c, &pm_net_tcp_ssl_client_recv_cb);
pm_net_tcp_client_set_send_cb(c, &pm_net_tcp_ssl_client_send_cb);
pm_net_tcp_client_set_close_cb(c, &pm_net_tcp_ssl_client_close_cb);
return 0;
out_recv_buf:
pm_buf_destroy(&ssl_c->recv_buf);
out_wbio:
BIO_free(ssl_c->wbio);
out_rbio:
BIO_free(ssl_c->rbio);
out_ssl:
SSL_free(ssl_c->ssl);
out_ssl_c:
free(ssl_c);
return -ENOMEM;
}
static void set_default_arg(struct pm_net_tcp_ssl_arg *arg)
{
memset(arg, 0, sizeof(*arg));
arg->net_arg.nr_workers = 1;
arg->net_arg.client_init_cap = 1024;
arg->net_arg.sock_backlog = 1024;
}
int pm_net_tcp_ssl_ctx_init(pm_net_tcp_ssl_ctx_t **ctx_p, const struct pm_net_tcp_ssl_arg *arg)
{
pm_net_tcp_ssl_ctx_t *ctx;
int ret;
ctx = calloc(1, sizeof(*ctx));
if (!ctx)
return -ENOMEM;
if (arg)
ctx->arg = *arg;
else
set_default_arg(&ctx->arg);
SSL_library_init();
SSL_load_error_strings();
ctx->ssl_ctx = SSL_CTX_new(TLS_server_method());
if (!ctx->ssl_ctx) {
ret = -ENOMEM;
goto out_ctx;
}
ret = SSL_CTX_use_certificate_file(ctx->ssl_ctx, ctx->arg.cert_file, SSL_FILETYPE_PEM);
if (ret <= 0) {
ret = -EINVAL;
goto out_ssl_ctx;
}
ret = SSL_CTX_use_PrivateKey_file(ctx->ssl_ctx, ctx->arg.key_file, SSL_FILETYPE_PEM);
if (ret <= 0) {
ret = -EINVAL;
goto out_ssl_ctx;
}
ret = pm_net_tcp_ctx_init(&ctx->net_ctx, &ctx->arg.net_arg);
if (ret)
goto out_ssl_ctx;
pm_net_tcp_ctx_set_udata(ctx->net_ctx, ctx);
pm_net_tcp_ctx_set_accept_cb(ctx->net_ctx, &pm_net_ssl_accept_cb);
*ctx_p = ctx;
return 0;
out_ssl_ctx:
SSL_CTX_free(ctx->ssl_ctx);
out_ctx:
free(ctx);
return ret;
}
void pm_net_tcp_ssl_ctx_run(pm_net_tcp_ssl_ctx_t *ctx)
{
pm_net_tcp_ctx_run(ctx->net_ctx);
}
void pm_net_tcp_ssl_ctx_wait(pm_net_tcp_ssl_ctx_t *ctx)
{
pm_net_tcp_ctx_wait(ctx->net_ctx);
}
void pm_net_tcp_ssl_ctx_stop(pm_net_tcp_ssl_ctx_t *ctx)
{
pm_net_tcp_ctx_stop(ctx->net_ctx);
}
void pm_net_tcp_ssl_ctx_destroy(pm_net_tcp_ssl_ctx_t *ctx)
{
pm_net_tcp_ctx_destroy(ctx->net_ctx);
SSL_CTX_free(ctx->ssl_ctx);
free(ctx);
}
void pm_net_tcp_ssl_ctx_set_udata(pm_net_tcp_ssl_ctx_t *ctx, void *udata)
{
ctx->udata = udata;
}
void *pm_net_tcp_ssl_ctx_get_udata(pm_net_tcp_ssl_ctx_t *ctx)
{
return ctx->udata;
}
void pm_net_tcp_ssl_ctx_set_accept_cb(pm_net_tcp_ssl_ctx_t *ctx, pm_net_tcp_ssl_accept_cb_t accept_cb)
{
ctx->accept_cb = accept_cb;
}
void pm_net_tcp_ssl_client_set_udata(pm_net_tcp_ssl_client_t *c, void *udata)
{
c->udata = udata;
}
void *pm_net_tcp_ssl_client_get_udata(pm_net_tcp_ssl_client_t *c)
{
return c->udata;
}
void pm_net_tcp_ssl_client_set_recv_cb(pm_net_tcp_ssl_client_t *c, pm_net_tcp_ssl_recv_cb_t recv_cb)
{
c->recv_cb = recv_cb;
}
void pm_net_tcp_ssl_client_set_send_cb(pm_net_tcp_ssl_client_t *c, pm_net_tcp_ssl_send_cb_t send_cb)
{
c->send_cb = send_cb;
}
void pm_net_tcp_ssl_client_set_close_cb(pm_net_tcp_ssl_client_t *c, pm_net_tcp_ssl_close_cb_t close_cb)
{
c->close_cb = close_cb;
}
struct pm_buf *pm_net_tcp_ssl_client_get_recv_buf(pm_net_tcp_ssl_client_t *c)
{
return &c->recv_buf;
}
struct pm_buf *pm_net_tcp_ssl_client_get_send_buf(pm_net_tcp_ssl_client_t *c)
{
return &c->send_buf;
}
const struct sockaddr_in46 *pm_net_tcp_ssl_client_get_src_addr(pm_net_tcp_ssl_client_t *c)
{
return pm_net_tcp_client_get_src_addr(c->net_client);
}
void pm_net_tcp_ssl_client_user_close(pm_net_tcp_ssl_client_t *c)
{
pm_net_tcp_client_user_close(c->net_client);
}
#endif /* #if PM_USE_SSL */
#if PM_USE_HTTP
enum {
PM_HTTP_NET_CTX_PLAIN = 0,
PM_HTTP_NET_CTX_SSL = 1
};
enum {
PM_HTTP_METHOD_GET = 1,
PM_HTTP_METHOD_POST = 2,
PM_HTTP_METHOD_PUT = 3,
PM_HTTP_METHOD_DELETE = 4,
PM_HTTP_METHOD_HEAD = 5,
PM_HTTP_METHOD_OPTIONS = 6,
PM_HTTP_METHOD_TRACE = 7,
PM_HTTP_METHOD_CONNECT = 8,
PM_HTTP_METHOD_PATCH = 9
};
enum {
PM_HTTP_VER_09 = 0,
PM_HTTP_VER_10 = 1,
PM_HTTP_VER_11 = 2,
PM_HTTP_VER_20 = 3,
PM_HTTP_VER_30 = 4,
PM_HTTP_VER_31 = 5,
};
struct pm_http_net_ctx;
struct pm_http_ctx;
typedef struct pm_http_net_ctx pm_http_net_ctx_t;
typedef struct pm_http_ctx pm_http_ctx_t;
struct pm_http_net_ctx {
union {
pm_net_tcp_ctx_t *plain;
#if PM_USE_SSL
pm_net_tcp_ssl_ctx_t *ssl;
#endif
};
};
struct pm_http_easy_arg {
bool use_plain;
bool use_ssl;
uint16_t plain_port;
uint16_t ssl_port;
uint16_t nr_workers;
const char *cert_file;
const char *key_file;
};
struct pm_http_hdr_pair {
char *key;
char *val;
uint16_t key_len;
uint16_t val_len;
};
struct pm_http_hdr {
size_t nr_pairs;
struct pm_http_hdr_pair *pairs;
};
struct pm_http_str {
char *str;
size_t len;
};
struct pm_http_req {
uint8_t method;
uint8_t ver;
uint64_t content_length;
uint64_t cl_remain;
struct pm_http_str uri;
struct pm_http_str qs;
struct pm_http_hdr hdr;
struct pm_buf body;
};
struct pm_http_res {
uint8_t ver;
uint16_t status_code;
struct pm_http_hdr hdr;
struct pm_buf body;
};
typedef void (*pm_http_req_cb_t)(struct pm_http_req *req, struct pm_http_res *res, void *arg);
int pm_http_hdr_add(struct pm_http_hdr *hdr, const char *key, const char *val);
int pm_http_hdr_get(struct pm_http_hdr *hdr, const char *key, char **val);
int pm_http_ctx_init(pm_http_ctx_t **ctx_p);
int pm_http_ctx_add_net_ctx(pm_http_ctx_t *ctx, pm_http_net_ctx_t *net_ctx, uint8_t type);
int pm_http_ctx_easy_init(pm_http_ctx_t **ctx_p, const struct pm_http_easy_arg *arg);
void pm_http_ctx_set_req_cb(pm_http_ctx_t *ctx, pm_http_req_cb_t cb, void *arg);
const char *pm_http_method(uint8_t method);
void pm_http_ctx_run(pm_http_ctx_t *ctx);
void pm_http_ctx_wait(pm_http_ctx_t *ctx);
void pm_http_ctx_stop(pm_http_ctx_t *ctx);
void pm_http_ctx_destroy(pm_http_ctx_t *ctx_p);
struct pm_http_net_ctx_arr {
size_t nr_ctx;
uint8_t *type_arr;
struct pm_http_net_ctx *ctx_arr;
};
struct pm_http_ctx {
struct pm_http_net_ctx_arr net_ctx_arr;
pm_http_req_cb_t req_cb;
void *req_cb_arg;
};
struct pm_http_client {
bool use_ssl;
bool keep_alive;
struct pm_http_ctx *ctx;
struct pm_buf *recv_buf;
struct pm_buf *send_buf;
void *nclient;
struct pm_http_req *req;
struct pm_http_res *res;
uint32_t nr_reqs;
};
const char *pm_http_method(uint8_t method)
{
switch (method) {
case PM_HTTP_METHOD_GET:
return "GET";
case PM_HTTP_METHOD_POST:
return "POST";
case PM_HTTP_METHOD_PUT:
return "PUT";
case PM_HTTP_METHOD_DELETE:
return "DELETE";
case PM_HTTP_METHOD_HEAD:
return "HEAD";
case PM_HTTP_METHOD_OPTIONS:
return "OPTIONS";
case PM_HTTP_METHOD_TRACE:
return "TRACE";
case PM_HTTP_METHOD_CONNECT:
return "CONNECT";
case PM_HTTP_METHOD_PATCH:
return "PATCH";
default:
return "UNKNOWN";
}
}
void pm_http_ctx_set_req_cb(pm_http_ctx_t *ctx, pm_http_req_cb_t cb, void *arg)
{
ctx->req_cb = cb;
ctx->req_cb_arg = arg;
}
static void pm_http_str_free(struct pm_http_str *str)
{
if (!str->str)
return;
free(str->str);
str->str = NULL;
str->len = 0;
}
static int pm_http_strdup(struct pm_http_str *str, const char *s, size_t len)
{
char *p;
pm_http_str_free(str);
p = malloc(len + 1);
if (!p)
return -ENOMEM;
memcpy(p, s, len);
p[len] = '\0';
str->str = p;
str->len = len;
return 0;
}
static char *strtolower(char *str)
{
char *p;
for (p = str; *p; p++)
*p = tolower(*p);
return str;
}
static int pm_http_client_add_req(struct pm_http_client *hc,
struct pm_http_req *req)
{
struct pm_http_req *new_reqs;
struct pm_http_res *new_res;
size_t new_nr_reqs;
new_nr_reqs = hc->nr_reqs + 1;
new_reqs = realloc(hc->req, new_nr_reqs * sizeof(*new_reqs));
if (!new_reqs)
return -ENOMEM;
hc->req = new_reqs;
new_res = realloc(hc->res, new_nr_reqs * sizeof(*new_res));
if (!new_res)
return -ENOMEM;
hc->res = new_res;
memset(&hc->res[hc->nr_reqs], 0, sizeof(hc->res[hc->nr_reqs]));
hc->req[hc->nr_reqs] = *req;
hc->nr_reqs = new_nr_reqs;
return 0;
}
static void pm_http_hdr_destroy(struct pm_http_hdr *hdr)
{
size_t i;
if (!hdr->nr_pairs)
return;
for (i = 0; i < hdr->nr_pairs; i++) {
struct pm_http_hdr_pair *pair = &hdr->pairs[i];
free(pair->key);
free(pair->val);
}
free(hdr->pairs);
memset(hdr, 0, sizeof(*hdr));
}
static void pm_http_client_free_reqs(struct pm_http_client *hc)
{
size_t i;
if (!hc->nr_reqs)
return;
for (i = 0; i < hc->nr_reqs; i++) {
struct pm_http_req *req = &hc->req[i];
struct pm_http_res *res = &hc->res[i];
pm_http_str_free(&req->uri);
pm_http_str_free(&req->qs);
pm_buf_destroy(&req->body);
pm_http_hdr_destroy(&req->hdr);
pm_buf_destroy(&res->body);
pm_http_hdr_destroy(&res->hdr);
}
free(hc->req);
free(hc->res);
hc->req = NULL;
hc->res = NULL;
hc->nr_reqs = 0;
}
int pm_http_hdr_add(struct pm_http_hdr *hdr, const char *key, const char *val)
{
struct pm_http_hdr_pair *pair;
size_t new_nr_pairs;
new_nr_pairs = hdr->nr_pairs + 1;
pair = realloc(hdr->pairs, new_nr_pairs * sizeof(*pair));
if (!pair)
return -ENOMEM;
hdr->pairs = pair;
pair = &hdr->pairs[hdr->nr_pairs];
pair->key = strdup(key);
if (!pair->key)
return -ENOMEM;
pair->val = strdup(val);
if (!pair->val) {
free(pair->key);
return -ENOMEM;
}
pair->key = strtolower(pair->key);
pair->key_len = strlen(key);
pair->val_len = strlen(val);
hdr->nr_pairs = new_nr_pairs;
return 0;
}
int pm_http_hdr_get(struct pm_http_hdr *hdr, const char *key, char **val)
{
size_t i;
for (i = 0; i < hdr->nr_pairs; i++) {
struct pm_http_hdr_pair *pair = &hdr->pairs[i];
if (!strcasecmp(pair->key, key)) {
*val = pair->val;
return 0;
}
}
return -ENOENT;
}
int pm_http_ctx_init(pm_http_ctx_t **ctx_p)
{
pm_http_ctx_t *ctx;
ctx = calloc(1, sizeof(*ctx));
if (!ctx)
return -1;
*ctx_p = ctx;
return 0;
}
int pm_http_ctx_easy_init(pm_http_ctx_t **ctx_p, const struct pm_http_easy_arg *arg)
{
struct pm_http_net_ctx net_ctx;
pm_http_ctx_t *ctx;
int ret;
if (!arg->use_plain && !arg->use_ssl)
return -EINVAL;
ret = pm_http_ctx_init(&ctx);
if (ret)
return ret;
if (arg->use_plain) {
struct pm_net_tcp_arg parg;
memset(&parg, 0, sizeof(parg));
parg.bind_addr.v6.sin6_addr = in6addr_any;
parg.bind_addr.v6.sin6_port = htons(arg->plain_port);
parg.bind_addr.v6.sin6_family = AF_INET6;
parg.client_init_cap = 8192;
parg.nr_workers = arg->nr_workers;
parg.sock_backlog = 2048;
ret = pm_net_tcp_ctx_init(&net_ctx.plain, &parg);
if (ret) {
pm_http_ctx_destroy(ctx);
return ret;
}
ret = pm_http_ctx_add_net_ctx(ctx, &net_ctx, PM_HTTP_NET_CTX_PLAIN);
if (ret) {
pm_http_ctx_destroy(ctx);
pm_net_tcp_ctx_destroy(net_ctx.plain);
return ret;
}
}
#if PM_USE_SSL
if (arg->use_ssl) {
struct pm_net_tcp_ssl_arg sarg;
struct pm_net_tcp_arg *parg = &sarg.net_arg;
memset(&sarg, 0, sizeof(sarg));
parg->bind_addr.v6.sin6_addr = in6addr_any;
parg->bind_addr.v6.sin6_port = htons(arg->ssl_port);
parg->bind_addr.v6.sin6_family = AF_INET6;
parg->client_init_cap = 8192;
parg->nr_workers = arg->nr_workers;
parg->sock_backlog = 2048;
strncpy(sarg.cert_file, arg->cert_file, sizeof(sarg.cert_file) - 1);
strncpy(sarg.key_file, arg->key_file, sizeof(sarg.key_file) - 1);
ret = pm_net_tcp_ssl_ctx_init(&net_ctx.ssl, &sarg);
if (ret) {
pm_http_ctx_destroy(ctx);
return ret;
}
ret = pm_http_ctx_add_net_ctx(ctx, &net_ctx, PM_HTTP_NET_CTX_SSL);
if (ret) {
pm_http_ctx_destroy(ctx);
pm_net_tcp_ssl_ctx_destroy(net_ctx.ssl);
return ret;
}
}
#endif /* #if PM_USE_SSL */
*ctx_p = ctx;
return 0;
}
int pm_http_ctx_add_net_ctx(pm_http_ctx_t *ctx, pm_http_net_ctx_t *net_ctx, uint8_t type)
{
struct pm_http_net_ctx_arr *net_ctx_arr;
struct pm_http_net_ctx *new_ctx_arr;
uint8_t *new_type_arr;
size_t new_nr_ctx;
net_ctx_arr = &ctx->net_ctx_arr;
new_nr_ctx = net_ctx_arr->nr_ctx + 1;
new_ctx_arr = realloc(net_ctx_arr->ctx_arr, new_nr_ctx * sizeof(*new_ctx_arr));
if (!new_ctx_arr)
return -ENOMEM;
new_type_arr = realloc(net_ctx_arr->type_arr, new_nr_ctx * sizeof(*new_type_arr));
if (!new_type_arr) {
net_ctx_arr->ctx_arr = new_ctx_arr;
return -ENOMEM;
}
net_ctx_arr->ctx_arr = new_ctx_arr;
net_ctx_arr->type_arr = new_type_arr;
net_ctx_arr->ctx_arr[net_ctx_arr->nr_ctx] = *net_ctx;
net_ctx_arr->type_arr[net_ctx_arr->nr_ctx] = type;
net_ctx_arr->nr_ctx = new_nr_ctx;
return 0;
}
static void __pm_http_ctx_run(pm_http_ctx_t *ctx, pm_net_tcp_ctx_t *nctx);
#if PM_USE_SSL
static void __pm_http_ctx_run_ssl(pm_http_ctx_t *ctx, pm_net_tcp_ssl_ctx_t *nctx);
#endif
void pm_http_ctx_run(pm_http_ctx_t *ctx)
{
struct pm_http_net_ctx_arr *net_ctx_arr;
size_t i;
net_ctx_arr = &ctx->net_ctx_arr;
for (i = 0; i < net_ctx_arr->nr_ctx; i++) {
struct pm_http_net_ctx *net_ctx = &net_ctx_arr->ctx_arr[i];
uint8_t type = net_ctx_arr->type_arr[i];
if (type == PM_HTTP_NET_CTX_PLAIN)
__pm_http_ctx_run(ctx, net_ctx->plain);
#if PM_USE_SSL
else
__pm_http_ctx_run_ssl(ctx, net_ctx->ssl);
#endif /* #if PM_USE_SSL */
}
}
void pm_http_ctx_wait(pm_http_ctx_t *ctx)
{
struct pm_http_net_ctx_arr *net_ctx_arr;
size_t i;
net_ctx_arr = &ctx->net_ctx_arr;
for (i = 0; i < net_ctx_arr->nr_ctx; i++) {
struct pm_http_net_ctx *net_ctx = &net_ctx_arr->ctx_arr[i];
uint8_t type = net_ctx_arr->type_arr[i];
if (type == PM_HTTP_NET_CTX_PLAIN)
pm_net_tcp_ctx_wait(net_ctx->plain);
#if PM_USE_SSL
else
pm_net_tcp_ssl_ctx_wait(net_ctx->ssl);
#endif /* #if PM_USE_SSL */
}
}
void pm_http_ctx_stop(pm_http_ctx_t *ctx)
{
struct pm_http_net_ctx_arr *net_ctx_arr;
size_t i;
net_ctx_arr = &ctx->net_ctx_arr;
for (i = 0; i < net_ctx_arr->nr_ctx; i++) {
struct pm_http_net_ctx *net_ctx = &net_ctx_arr->ctx_arr[i];
uint8_t type = net_ctx_arr->type_arr[i];
if (type == PM_HTTP_NET_CTX_PLAIN)
pm_net_tcp_ctx_stop(net_ctx->plain);
#if PM_USE_SSL
else
pm_net_tcp_ssl_ctx_stop(net_ctx->ssl);
#endif /* #if PM_USE_SSL */
}
}
void pm_http_ctx_destroy(pm_http_ctx_t *ctx_p)
{
struct pm_http_net_ctx_arr *net_ctx_arr;
size_t i;
net_ctx_arr = &ctx_p->net_ctx_arr;
for (i = 0; i < net_ctx_arr->nr_ctx; i++) {
struct pm_http_net_ctx *net_ctx = &net_ctx_arr->ctx_arr[i];
uint8_t type = net_ctx_arr->type_arr[i];
if (type == PM_HTTP_NET_CTX_PLAIN)
pm_net_tcp_ctx_destroy(net_ctx->plain);
#if PM_USE_SSL
else
pm_net_tcp_ssl_ctx_destroy(net_ctx->ssl);
#endif /* #if PM_USE_SSL */
}
free(net_ctx_arr->ctx_arr);
free(net_ctx_arr->type_arr);
free(ctx_p);
}
static int parse_http_hdr(struct pm_http_req *req, struct pm_buf *rbuf)
{
char *method, *uri, *qs, *ver, *p, *q, *dcrlf, c;
size_t len = rbuf->len;
size_t eaten_len;
char *to_null[2];
if (len < 6)
return -EAGAIN;
p = rbuf->buf;
method = p;
while (*p != ' ') {
c = *p;
if (c < 'A' || c > 'Z')
return -EINVAL;
if (p - method >= 8)
return -EINVAL;
p++;
len--;
if (!len)
return -EAGAIN;
}
to_null[0] = p;
p++;
len--;
if (!len)
return -EAGAIN;
dcrlf = memmem(p, len, "\r\n\r\n", 4);
if (!dcrlf)
return -EAGAIN;
if (!strncmp(method, "GET", 3))
req->method = PM_HTTP_METHOD_GET;
else if (!strncmp(method, "POST", 4))
req->method = PM_HTTP_METHOD_POST;
else if (!strncmp(method, "PUT", 3))
req->method = PM_HTTP_METHOD_PUT;
else if (!strncmp(method, "DELETE", 6))
req->method = PM_HTTP_METHOD_DELETE;
else if (!strncmp(method, "HEAD", 4))
req->method = PM_HTTP_METHOD_HEAD;
else if (!strncmp(method, "OPTIONS", 7))
req->method = PM_HTTP_METHOD_OPTIONS;
else if (!strncmp(method, "TRACE", 5))
req->method = PM_HTTP_METHOD_TRACE;
else if (!strncmp(method, "CONNECT", 7))
req->method = PM_HTTP_METHOD_CONNECT;
else if (!strncmp(method, "PATCH", 5))
req->method = PM_HTTP_METHOD_PATCH;
else
return -EINVAL;
uri = p;
if (*p != '/')
return -EINVAL;
while (*p != ' ') {
c = *p;
if (c < 32 || c > 126)
return -EINVAL;
p++;
len--;
if (!len)
return -EAGAIN;
}
to_null[1] = p;
p++;
len--;
if (!len)
return -EAGAIN;
qs = strchr(uri, '?');
if (qs) {
*qs = '\0';
qs++;
}
ver = p;
if (!strncmp(ver, "HTTP/0.9", 8))
return -EINVAL;
else if (!strncmp(ver, "HTTP/1.0", 8))
req->ver = PM_HTTP_VER_10;
else if (!strncmp(ver, "HTTP/1.1", 8))
req->ver = PM_HTTP_VER_11;
else if (!strncmp(ver, "HTTP/2.0", 8))
return -EINVAL;
else if (!strncmp(ver, "HTTP/3.0", 8))
return -EINVAL;
else if (!strncmp(ver, "HTTP/3.1", 8))
return -EINVAL;
else
return -EINVAL;
p += 8;
len -= 8;
if (!len) {
/*
* Must not run out of buffer here, as we have already
* checked for "\r\n\r\n" above.
*/
return -EINVAL;
}
if (memcmp(p, "\r\n", 2))
return -EINVAL;
*p = '\0';
p += 2;
len -= 2;
if (!len) {
/*
* Must not run out of buffer here, as we have already
* checked for "\r\n\r\n" above.
*/
return -EINVAL;
}
to_null[0][0] = '\0';
to_null[1][0] = '\0';
if (pm_http_strdup(&req->uri, uri, strlen(uri)))
return -ENOMEM;
if (qs && pm_http_strdup(&req->qs, qs, strlen(qs)))
return -ENOMEM;
do {
char *key, *val;
q = strstr(p, ": ");
if (!q)
return -EINVAL;
key = p;
*q = '\0';
p = q + 2;
len -= (size_t) (p - key);
q = strstr(p, "\r\n");
if (!q)
return -EINVAL;
val = p;
*q = '\0';
if (!strcasecmp(key, "Content-Length")) {
req->content_length = strtoull(val, NULL, 10);
req->cl_remain = req->content_length;
}
if (pm_http_hdr_add(&req->hdr, key, val))
return -ENOMEM;
p = q + 2;
len -= 2;
if (len < 2)
return -EINVAL;
if (!memcmp(p, "\r\n", 2)) {
p += 2;
len -= 2;
break;
}
} while (q < dcrlf);
eaten_len = (size_t) (p - rbuf->buf);
assert(eaten_len <= rbuf->len);
if (eaten_len < rbuf->len) {
memmove(rbuf->buf, p, rbuf->len - eaten_len);
rbuf->len -= eaten_len;
} else {
rbuf->len = 0;
}
return 0;
}
static int parse_http_body(struct pm_http_req *req, struct pm_buf *rbuf)
{
int ret;
if (!req->cl_remain)
return 0;
if ((rbuf->len == req->cl_remain) && (req->cl_remain == req->content_length)) {
/*
* Don't bother copying the buffer if it's already the
* right size.
*/
req->body = *rbuf;
rbuf->len = 0;
rbuf->cap = 0;
req->cl_remain = 0;
return pm_buf_init(rbuf, 4096);
}
if (req->cl_remain > rbuf->len) {
ret = pm_buf_append(&req->body, rbuf->buf, rbuf->len);
if (ret)
return ret;
req->cl_remain -= rbuf->len;
rbuf->len = 0;
return -EAGAIN;
} else {
ret = pm_buf_append(&req->body, rbuf->buf, req->cl_remain);
if (ret)
return ret;
rbuf->len -= req->cl_remain;
memmove(rbuf->buf, rbuf->buf + req->cl_remain, rbuf->len);
req->cl_remain = 0;
return 0;
}
}
static void pm_http_client_close(struct pm_http_client *hc)
{
#if PM_USE_SSL
if (hc->use_ssl)
pm_net_tcp_ssl_client_user_close(hc->nclient);
else
#endif /* #if PM_USE_SSL */
pm_net_tcp_client_user_close(hc->nclient);
}
static int collect_requests(struct pm_http_client *hc)
{
struct pm_buf *rbuf = hc->recv_buf;
struct pm_http_req req;
int ret;
memset(&req, 0, sizeof(req));
/*
* Ensure the buffer is null-terminated.
*/
if (rbuf->len == rbuf->cap) {
if (pm_buf_resize(rbuf, rbuf->cap + 1))
return -ENOMEM;
}
rbuf->buf[rbuf->len] = '\0';
ret = parse_http_hdr(&req, rbuf);
if (ret < 0) {
if (ret != -EAGAIN || hc->recv_buf->len >= 8192)
pm_http_client_close(hc);
return ret;
}
ret = parse_http_body(&req, rbuf);
if (ret < 0) {
if (ret != -EAGAIN)
pm_http_client_close(hc);
return ret;
}
ret = pm_http_client_add_req(hc, &req);
if (ret) {
pm_http_client_close(hc);
return ret;
}
return 0;
}
static int handle_requests(struct pm_http_client *hc)
{
pm_http_ctx_t *ctx = hc->ctx;
size_t i;
for (i = 0; i < hc->nr_reqs; i++) {
struct pm_http_req *req = &hc->req[i];
struct pm_http_res *res = &hc->res[i];
struct pm_buf *req_body = &req->body;
if (!req_body->buf) {
if (pm_buf_init(req_body, 2))
return -ENOMEM;
}
if (req_body->buf[req_body->len] != '\0') {
if (pm_buf_append(req_body, "\0", 1))
return -ENOMEM;
}
res->status_code = 200;
if (ctx->req_cb)
ctx->req_cb(req, res, ctx->req_cb_arg);
}
return 0;
}
static const char *translate_http_code(uint16_t code)
{
switch (code) {
case 200:
return "OK";
case 204:
return "No Content";
case 400:
return "Bad Request";
case 404:
return "Not Found";
case 405:
return "Method Not Allowed";
case 500:
return "Internal Server Error";
case 501:
return "Not Implemented";
case 503:
return "Service Unavailable";
case 505:
return "HTTP Version Not Supported";
default:
return "Unknown";
}
}
static void gen_date(char *buf)
{
time_t t;
struct tm tm;
t = time(NULL);
gmtime_r(&t, &tm);
strftime(buf, 30, "%a, %d %b %Y %H:%M:%S GMT", &tm);
}
static bool should_keep_alive(struct pm_http_req *req)
{
char *val;
int ret;
ret = pm_http_hdr_get(&req->hdr, "connection", &val);
if (!ret) {
if (!strcasecmp(val, "close"))
return false;
if (!strcasecmp(val, "keep-alive"))
return true;
}
return (req->ver == PM_HTTP_VER_11);
}
static int construct_response(struct pm_http_req *req, struct pm_http_res *res,
struct pm_buf *sbuf, bool *keep_alive)
{
char date[32];
int ret = 0;
size_t i;
*keep_alive = should_keep_alive(req);
if (res->ver == PM_HTTP_VER_10)
ret = pm_buf_append(sbuf, "HTTP/1.0 ", 9);
else
ret = pm_buf_append(sbuf, "HTTP/1.1 ", 9);
ret |= pm_buf_append_fmt(sbuf, "%u %s\r\n", res->status_code, translate_http_code(res->status_code));
ret |= pm_buf_append_fmt(sbuf, "Server: Proxmasterd\r\n");
gen_date(date);
ret |= pm_buf_append_fmt(sbuf, "Date: %s\r\n", date);
ret |= pm_buf_append_fmt(sbuf, "Connection: %s\r\n", *keep_alive ? "keep-alive" : "close");
ret |= pm_buf_append_fmt(sbuf, "Content-Length: %zu\r\n", res->body.len);
if (ret)
return -ENOMEM;
for (i = 0; i < res->hdr.nr_pairs; i++) {
struct pm_http_hdr_pair *pair = &res->hdr.pairs[i];
ret |= pm_buf_append_fmt(sbuf, "%s: %s\r\n", pair->key, pair->val);
}
ret |= pm_buf_append(sbuf, "\r\n", 2);
ret |= pm_buf_append(sbuf, res->body.buf, res->body.len);
if (ret)
return -ENOMEM;
return 0;
}
static int send_responses(struct pm_http_client *hc)
{
bool keep_alive, tmp;
size_t i;
int ret;
keep_alive = true;
for (i = 0; i < hc->nr_reqs; i++) {
ret = construct_response(&hc->req[i], &hc->res[i], hc->send_buf,
&tmp);
if (ret)
return ret;
keep_alive = (keep_alive && tmp);
}
pm_http_client_free_reqs(hc);
hc->keep_alive = keep_alive;
return 0;
}
static struct pm_http_client *pm_http_alloc_client(void)
{
return calloc(1, sizeof(struct pm_http_client));
}
static int pm_http_handle_recv(struct pm_http_client *hc)
{
struct pm_buf *rbuf = hc->recv_buf;
int ret;
while (1) {
ret = collect_requests(hc);
if (ret)
return ret;
if (!rbuf->len)
break;
}
ret = handle_requests(hc);
if (ret)
return ret;
ret = send_responses(hc);
if (ret)
return ret;
if (!hc->keep_alive)
pm_http_client_close(hc);
return 0;
}
static int pm_http_handle_close(struct pm_http_client *hc)
{
pm_http_client_free_reqs(hc);
free(hc);
return 0;
}
static int pm_http_close_cb(pm_net_tcp_client_t *c)
{
return pm_http_handle_close(pm_net_tcp_client_get_udata(c));
}
static int pm_http_recv_cb(pm_net_tcp_client_t *c)
{
return pm_http_handle_recv(pm_net_tcp_client_get_udata(c));
}
static int pm_http_accept_cb(pm_net_tcp_ctx_t *ctx, pm_net_tcp_client_t *c)
{
pm_http_ctx_t *hctx = pm_net_tcp_ctx_get_udata(ctx);
struct pm_http_client *hc = pm_http_alloc_client();
if (!c)
return -ENOMEM;
hc->ctx = hctx;
hc->recv_buf = pm_net_tcp_client_get_recv_buf(c);
hc->send_buf = pm_net_tcp_client_get_send_buf(c);
hc->nclient = c;
pm_net_tcp_client_set_udata(c, hc);
pm_net_tcp_client_set_recv_cb(c, &pm_http_recv_cb);
pm_net_tcp_client_set_close_cb(c, &pm_http_close_cb);
(void)ctx;
return 0;
}
static void __pm_http_ctx_run(pm_http_ctx_t *ctx, pm_net_tcp_ctx_t *nctx)
{
pm_net_tcp_ctx_set_udata(nctx, ctx);
pm_net_tcp_ctx_set_accept_cb(nctx, &pm_http_accept_cb);
pm_net_tcp_ctx_run(nctx);
}
#if PM_USE_SSL
static int pm_https_close_cb(pm_net_tcp_ssl_client_t *c)
{
return pm_http_handle_close(pm_net_tcp_ssl_client_get_udata(c));
}
static int pm_https_recv_cb(pm_net_tcp_ssl_client_t *c)
{
return pm_http_handle_recv(pm_net_tcp_ssl_client_get_udata(c));
}
static int pm_https_accept_cb(pm_net_tcp_ssl_ctx_t *ctx, pm_net_tcp_ssl_client_t *c)
{
pm_http_ctx_t *hctx = pm_net_tcp_ssl_ctx_get_udata(ctx);
struct pm_http_client *hc = pm_http_alloc_client();
if (!hc)
return -ENOMEM;
hc->ctx = hctx;
hc->use_ssl = true;
hc->recv_buf = pm_net_tcp_ssl_client_get_recv_buf(c);
hc->send_buf = pm_net_tcp_ssl_client_get_send_buf(c);
hc->nclient = c;
pm_net_tcp_ssl_client_set_udata(c, hc);
pm_net_tcp_ssl_client_set_recv_cb(c, &pm_https_recv_cb);
pm_net_tcp_ssl_client_set_close_cb(c, &pm_https_close_cb);
(void)ctx;
return 0;
}
static void __pm_http_ctx_run_ssl(pm_http_ctx_t *ctx, pm_net_tcp_ssl_ctx_t *nctx)
{
pm_net_tcp_ssl_ctx_set_udata(nctx, ctx);
pm_net_tcp_ssl_ctx_set_accept_cb(nctx, &pm_https_accept_cb);
pm_net_tcp_ssl_ctx_run(nctx);
}
#endif /* #if PM_USE_SSL */
#endif /* #if PM_USE_HTTP */
static void pm_web_handle_req(struct pm_http_req *req, struct pm_http_res *res,
void *arg)
{
return 0;
}
#include <getopt.h>
struct prog_arg {
int plain_port;
int ssl_port;
int nr_workers;
const char *ssl_cert_file;
const char *ssl_key_file;
};
static const struct option long_opts[] = {
{ "help", no_argument, NULL, 'h' },
{ "ssl-cert-file", required_argument, NULL, 'c' },
{ "ssl-key-file", required_argument, NULL, 'k' },
{ "plain-port", required_argument, NULL, 'p' },
{ "ssl-port", required_argument, NULL, 'Z' },
{ "nr-workers", required_argument, NULL, 'w' },
{ NULL, 0, NULL, 0 }
};
static const char short_opts[] = "hc:k:p:Z:w:";
static int prep_http_easy_arg(struct pm_http_easy_arg *arg, struct prog_arg *parg)
{
memset(arg, 0, sizeof(*arg));
arg->use_plain = (parg->plain_port > 0);
arg->use_ssl = (parg->ssl_port > 0);
arg->plain_port = parg->plain_port;
arg->ssl_port = parg->ssl_port;
arg->nr_workers = parg->nr_workers;
if (!arg->nr_workers)
arg->nr_workers = 4;
arg->cert_file = parg->ssl_cert_file;
arg->key_file = parg->ssl_key_file;
if (!arg->use_plain && !arg->use_ssl) {
printf("At least one of plain or SSL port must be specified\n");
return -EINVAL;
}
if (arg->use_ssl && (!arg->cert_file || !arg->key_file)) {
printf("SSL port requires both certificate and key files\n");
return -EINVAL;
}
return 0;
}
static int run_pm(struct prog_arg *parg)
{
struct pm_http_easy_arg arg;
pm_http_ctx_t *ctx;
int err;
err = prep_http_easy_arg(&arg, parg);
if (err)
return err;
setvbuf(stdout, NULL, _IONBF, 0);
err = pm_http_ctx_easy_init(&ctx, &arg);
if (err)
return err;
pm_http_ctx_set_req_cb(ctx, &pm_web_handle_req, &pm);
pm_http_ctx_run(ctx);
pm_http_ctx_wait(ctx);
pm_http_ctx_stop(ctx);
pm_http_ctx_destroy(ctx);
return 0;
}
static void show_help(const char *app)
{
printf("Usage: %s [OPTIONS]\n", app);
printf("Options:\n");
printf(" -h, --help Show this help message\n");
printf(" -c, --ssl-cert-file=FILE SSL certificate file\n");
printf(" -k, --ssl-key-file=FILE SSL key file\n");
printf(" -p, --plain-port=PORT Plain port number\n");
printf(" -Z, --ssl-port=PORT SSL port number\n");
printf(" -w, --nr-workers=NUM Number of workers\n");
}
static int parse_arg(int argc, char *argv[], struct prog_arg *parg)
{
int c;
memset(parg, 0, sizeof(*parg));
while (1) {
c = getopt_long(argc, argv, short_opts, long_opts, nullptr);
if (c < 0)
break;
switch (c) {
case 'h':
show_help(argv[0]);
return -ESHUTDOWN;
case 'c':
parg->ssl_cert_file = optarg;
break;
case 'k':
parg->ssl_key_file = optarg;
break;
case 'p':
parg->plain_port = atoi(optarg);
if (parg->plain_port < 0 || parg->plain_port > 65535) {
printf("Invalid plain port number: %d\n", parg->plain_port);
return -EINVAL;
}
break;
case 'Z':
parg->ssl_port = atoi(optarg);
if (parg->ssl_port < 0 || parg->ssl_port > 65535) {
printf("Invalid SSL port number: %d\n", parg->ssl_port);
return -EINVAL;
}
break;
case 'w':
parg->nr_workers = atoi(optarg);
if (parg->nr_workers < 0 || parg->nr_workers > 4096) {
printf("The number of workers must be in range [1, 4096]\n");
return -EINVAL;
}
break;
case '?':
show_help(argv[0]);
return -EINVAL;
}
}
return 0;
}
int main(int argc, char *argv[])
{
struct prog_arg parg;
int ret;
ret = parse_arg(argc, argv, &parg);
if (ret)
return -ret;
return -run_pm(&parg);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment