Skip to content

Instantly share code, notes, and snippets.

@argv0
Created January 5, 2010 23:19
Show Gist options
  • Save argv0/269849 to your computer and use it in GitHub Desktop.
Save argv0/269849 to your computer and use it in GitHub Desktop.
#include <vector>
#include <erl_driver.h>
#include <ei.h>
#include "lru_item.hpp"
#include "lru_cache.hpp"
typedef std::vector<char> char_vec;
typedef lru_cache<char_vec, item_t> cache_t;
typedef enum _request_type {
REQ_INIT = 0,
REQ_GET,
REQ_PUT,
REQ_REMOVE,
REQ_ITEM_COUNT,
REQ_CAPACITY,
REQ_BYTES_USED,
REQ_INVALID = 255
} request_type;
typedef struct _from_emulator {
request_type type;
union {
struct {
uint32_t max_bytes;
} init;
struct {
uint32_t key_size;
const void *key;
uint32_t value_size;
const void *value;
} put;
struct {
uint32_t key_size;
const void *key;
} get ;
struct {
uint32_t key_size;
const void *key;
} remove;
};
} from_emulator;
char_vec make_charvec(const void *buf, int buflen) {
return char_vec((char *)buf, (char *)buf+buflen);
}
typedef struct _lru_driver_t {
ErlDrvPort port;
cache_t *cache;
} lru_driver_t;
#define deserialize_uint8(buf, buflen) \
({ \
uint8_t len; \
\
if (buflen < 1) \
{ \
goto ERROR; \
} \
\
memcpy (&len, buf, 1); \
buf += 1; \
buflen -= 1; \
\
len; \
})
#define deserialize_uint32(buf, buflen) \
({ \
uint32_t len; \
\
if (buflen < 4) \
{ \
goto ERROR; \
} \
\
memcpy (&len, buf, 4); \
buf += 4; \
buflen -= 4; \
\
len; \
})
#define deserialize_uint64(buf, buflen) \
({ \
uint64_t len; \
\
if (buflen < 8) \
{ \
goto ERROR; \
} \
\
memcpy (&len, buf, 8); \
buf += 8; \
buflen -= 8; \
\
len; \
})
#define deserialize_type(buf, buflen) ({ \
request_type res = REQ_INVALID; \
uint8_t val = deserialize_uint8(buf, buflen); \
switch (static_cast<request_type>(val)) { \
case REQ_INIT: \
case REQ_GET: \
case REQ_PUT: \
case REQ_REMOVE: \
case REQ_ITEM_COUNT: \
case REQ_CAPACITY: \
case REQ_BYTES_USED: \
case REQ_INVALID: \
res = static_cast<request_type>(val); \
break; \
} \
res; \
})
#define deserialize_bytes(buf, buflen, len) \
({ \
const void* ptr; \
\
if (buflen < (int) len) \
{ \
goto ERROR; \
} \
\
ptr = buf; \
buf += len; \
buflen -= len; \
ptr; \
})
#define deserialize_string(buf, buflen) \
({ \
uint32_t len = deserialize_uint32 (buf, buflen); \
deserialize_bytes (buf, buflen, len); \
})
static int
reply_prefixed_binary_rbuf (char** rbuf,
const void* prefix,
size_t prefix_len,
const void* data,
size_t len)
{
ErlDrvBinary* bin = driver_alloc_binary (prefix_len + len);
memcpy (bin->orig_bytes, prefix, prefix_len);
memcpy (bin->orig_bytes + prefix_len, data, len);
bin->orig_size = prefix_len + len;
*rbuf = (char *) bin;
return len;
}
static int
reply_error_rbuf (char** rbuf,
const char* msg)
{
return reply_prefixed_binary_rbuf (rbuf, "\001", 1, msg, strlen (msg));
}
static int
reply_empty_list_rbuf (char** rbuf)
{
*rbuf = NULL;
return 0;
}
static from_emulator
decode_from (const char *buf, int buflen) {
from_emulator from;
from.type = deserialize_type(buf, buflen);
switch (from.type) {
case REQ_INIT:
from.init.max_bytes = deserialize_uint32(buf, buflen);
break;
case REQ_GET:
from.get.key_size = deserialize_uint32(buf, buflen);
from.get.key = deserialize_bytes(buf, buflen, from.get.key_size);
break;
case REQ_PUT:
from.put.key_size = deserialize_uint32(buf, buflen);
from.put.key = deserialize_bytes(buf, buflen, from.put.key_size);
from.put.value_size = deserialize_uint32(buf, buflen);
from.put.value = deserialize_bytes(buf, buflen, from.put.value_size);
break;
case REQ_REMOVE:
from.remove.key_size = deserialize_uint32(buf, buflen);
from.remove.key = deserialize_bytes(buf, buflen, from.remove.key_size);
break;
case REQ_ITEM_COUNT:
break;
case REQ_CAPACITY:
break;
case REQ_BYTES_USED:
break;
case REQ_INVALID:
break;
}
return from;
ERROR:
from.type = REQ_INVALID;
return from;
}
extern "C"
lru_driver_t*
lru_driver_data_new (ErlDrvPort port) {
lru_driver_t* d = (lru_driver_t *)driver_alloc(sizeof(lru_driver_t));
memset (d, 0, sizeof (*d));
d->port = port;
set_port_control_flags (port, PORT_CONTROL_FLAG_BINARY);
return d;
}
extern "C"
void
lru_driver_data_free (lru_driver_t* d) {
driver_free (d);
}
extern "C"
int
lru_init() {
return 0;
}
extern "C"
ErlDrvData
lru_start(ErlDrvPort port, char *buf) {
return (ErlDrvData)lru_driver_data_new(port);
}
extern "C"
void
lru_stop(ErlDrvData handle) {
lru_driver_data_free((lru_driver_t *)handle);
}
static int handle_init(lru_driver_t *lru, from_emulator from, char **rbuf) {
lru->cache = new cache_t(from.init.max_bytes);
return reply_empty_list_rbuf(rbuf);
}
static int handle_get(lru_driver_t *lru, from_emulator from, char **rbuf) {
char_vec k = make_charvec(from.get.key, from.get.key_size);
item_t *r = lru->cache->fetch_pointer(k);
if (r) {
return reply_prefixed_binary_rbuf(rbuf, "\000", 1, r->data, r->size);
}
return reply_empty_list_rbuf(rbuf);
}
static int handle_put(lru_driver_t *lru, from_emulator from, char **rbuf) {
char_vec k = make_charvec(from.put.key, from.put.key_size);
item_t v;
v.size = from.put.value_size;
v.data = malloc(v.size);
memcpy(v.data, from.put.value, from.put.value_size);
lru->cache->insert(k, v);
return reply_empty_list_rbuf(rbuf);
}
static int handle_remove(lru_driver_t *lru, from_emulator from, char **rbuf) {
char_vec k = make_charvec(from.remove.key, from.remove.key_size);
lru->cache->remove(k);
return reply_empty_list_rbuf(rbuf);
}
static int handle_item_count(lru_driver_t *lru,
from_emulator from,
char **rbuf) {
uint64_t count = lru->cache->count();
return reply_prefixed_binary_rbuf(rbuf, "\000", 1, &count, sizeof(count));
}
static int handle_capacity(lru_driver_t *lru, from_emulator from, char **rbuf) {
uint64_t capacity = lru->cache->max_size();
return reply_prefixed_binary_rbuf(rbuf, "\000", 1, &capacity,
sizeof(capacity));
}
static int handle_bytes_used(lru_driver_t *lru,
from_emulator from,
char **rbuf)
{
uint64_t used = lru->cache->cur_size();
return reply_prefixed_binary_rbuf(rbuf, "\000", 1, &used, sizeof(used));
}
extern "C"
int lru_control (ErlDrvData handle,
unsigned int command,
char* buf,
int buflen,
char** rbuf,
int rlen) {
lru_driver_t* lru = (lru_driver_t*) handle;
from_emulator from = decode_from (buf, buflen);
int rv = 0;
switch(from.type) {
case REQ_GET:
rv = handle_get(lru, from, rbuf);
break;
case REQ_PUT:
rv = handle_put(lru, from, rbuf);
break;
case REQ_INIT:
rv = handle_init(lru, from, rbuf);
break;
case REQ_REMOVE:
rv = handle_remove(lru, from, rbuf);
break;
case REQ_ITEM_COUNT:
rv = handle_item_count(lru, from, rbuf);
break;
case REQ_CAPACITY:
rv = handle_capacity(lru, from, rbuf);
break;
case REQ_BYTES_USED:
rv = handle_bytes_used(lru, from, rbuf);
break;
case REQ_INVALID:
rv = reply_error_rbuf (rbuf, "invalid_request");
break;
}
return rv;
}
static ErlDrvEntry lru_driver_entry = {
lru_init, /* init */
lru_start, /* startup */
lru_stop, /* shutdown */
NULL, /* output */
NULL, /* ready_input */
NULL, /* ready_output */
"lrucache_drv", /* the name of the driver */
NULL, /* finish */
NULL, /* handle */
lru_control, /* control */
NULL, /* timeout */
NULL, /* process */
NULL, /* ready_async */
NULL, /* flush */
NULL, /* call */
NULL, /* event */
ERL_DRV_EXTENDED_MARKER, /* ERL_DRV_EXTENDED_MARKER */
ERL_DRV_EXTENDED_MAJOR_VERSION, /* ERL_DRV_EXTENDED_MAJOR_VERSION */
ERL_DRV_EXTENDED_MAJOR_VERSION, /* ERL_DRV_EXTENDED_MINOR_VERSION */
ERL_DRV_FLAG_USE_PORT_LOCKING /* ERL_DRV_FLAGs */
};
extern "C"
DRIVER_INIT (lrucache_drv) /* must match name in driver_entry */
{
return &lru_driver_entry;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment