Last active
September 13, 2023 14:58
-
-
Save larytet/c306d470f7b032c796efad15dcd609a9 to your computer and use it in GitHub Desktop.
A simple hashtable for Linux kernel drivers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Lockfree is a set of lockfree containers for Linux and Linux kernel | |
* Copyright (C) <2017> Arkady Miasnikov | |
* | |
* This program is free software: you can redistribute it and/or modify | |
* it under the terms of the GNU General Public License as published by | |
* the Free Software Foundation, either version 3 of the License, or | |
* (at your option) any later version. | |
* | |
* This program is distributed in the hope that it will be useful, | |
* but WITHOUT ANY WARRANTY; without even the implied warranty of | |
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
* GNU General Public License for more details. | |
* | |
* You should have received a copy of the GNU General Public License | |
* along with this program. If not, see <http://www.gnu.org/licenses/>. | |
*/ | |
/** | |
* Implementation of lockfree linear probing hashtable | |
* The hashtable targets SystemTap environment where a typical key is thread ID | |
* The number of probes is limited by a constant. The index is not wrapping around, | |
* but instead the hashtable allocates enough memory to handle linear probing in the end | |
* of the table | |
* | |
* Limitation: a specific entry (a specific key) can be inserted and deleted by one thread. | |
* | |
* Performance: a core can make above 13M add&remove operations per second, cost of a | |
* single operation is under 20nano which is an equivalent of 50-100 opcodes. | |
*/ | |
#pragma once | |
#ifdef __KERNEL__ | |
# include "linux/vmalloc.h" | |
# include "linux/printk.h" | |
# define DEV_NAME "lockless" | |
# define PRINTF(s, ...) printk(KERN_ALERT DEV_NAME ": %s: " s "\n", __func__, __VA_ARGS__) | |
# define PRIu64 "llu" | |
#else | |
# include <stdlib.h> | |
# include <stdio.h> | |
# include <inttypes.h> | |
# define PRINTF(s, ...) printf("%s: " s "\n", __func__, __VA_ARGS__) | |
# define likely(x) __builtin_expect(!!(x), 1) // !!(x) will return 1 for any x != 0 | |
# define unlikely(x) __builtin_expect(!!(x), 0) | |
# define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0])) | |
#endif | |
#define __sync_access(x) (*(volatile __typeof__(*x) *) (x)) | |
/** | |
* Based on https://gist.github.com/badboy/6267743 | |
* http://burtleburtle.net/bob/hash/integer.html | |
* I replace the 'long key' by 'unsigned key' and unsigned Java right shifts | |
* by regular C/C++ right shifts | |
* I adopt the hash function for 22 (PID_MAX_LIMIT) bits integers | |
*/ | |
static uint32_t hash32shift(uint32_t key) | |
{ | |
key = ~key + (key << 10); // key = (key << 15) - key - 1; | |
key = key ^ (key >> 7); | |
key = key + (key << 1); | |
key = key ^ (key >> 2); | |
key = key * 187; // key = (key + (key << 3)) + (key << 11); | |
key = key ^ (key >> 11); | |
return key; | |
} | |
static uint32_t hash_none(uint32_t key) | |
{ | |
return key; | |
} | |
typedef struct | |
{ | |
uint64_t insert; | |
uint64_t remove; | |
uint64_t search; | |
uint64_t collision; | |
uint64_t insert_err; | |
uint64_t remove_err; | |
} hashtable_stat_t; | |
typedef struct | |
{ | |
const char *name; | |
size_t bits; | |
uint32_t (*hashfunction)(uint32_t); | |
size_t __size; | |
size_t __memory_size; | |
hashtable_stat_t __stat; | |
void *__table; | |
} hashtable_t; | |
static hashtable_t *hashtable_registry[64]; | |
static int hashtable_show(char *buf, size_t len) | |
{ | |
size_t i; | |
int rc; | |
size_t chars = 0; | |
rc = snprintf(buf+chars, len-chars, "\n%-25s %12s %12s %12s %12s %12s %12s %12s %12s %12s \n", | |
"Name", "Size", "Memory", "Ops", "Insert", "Remove", "Search", "Collision", "InsertErr", "RemoveErr"); | |
chars += rc; | |
for (i = 0;i < ARRAY_SIZE(hashtable_registry);i++) | |
{ | |
hashtable_t *hashtable = hashtable_registry[i]; | |
hashtable_stat_t *stat; | |
if (!hashtable) | |
continue; | |
stat = &hashtable->__stat; | |
rc = snprintf(buf+chars, len-chars, "%-25s %12zu %12zu %12" PRIu64 " %12" PRIu64 " %12" PRIu64 " %12" PRIu64 " %12" PRIu64 " %12" PRIu64 " %12" PRIu64 "\n", | |
hashtable->name, hashtable->__size, hashtable->__memory_size, | |
stat->insert+stat->remove+stat->search, stat->insert, stat->remove, stat->search, stat->collision, stat->insert_err, stat->remove_err | |
); | |
chars += rc; | |
} | |
return chars; | |
} | |
static void hashtable_registry_add(hashtable_t *table) | |
{ | |
size_t i; | |
for (i = 0;i < ARRAY_SIZE(hashtable_registry);i++) | |
{ | |
hashtable_t *registry = hashtable_registry[i]; | |
if (registry == table) | |
{ | |
PRINTF("Hashtable %s already registered", table->name); | |
break; | |
} | |
if (!registry) | |
{ | |
PRINTF("Register hashtable %s", table->name); | |
hashtable_registry[i] = table; | |
break; | |
} | |
} | |
} | |
static void hashtable_registry_remove(hashtable_t *table) | |
{ | |
size_t i; | |
for (i = 0;i < ARRAY_SIZE(hashtable_registry);i++) | |
{ | |
hashtable_t *registry = hashtable_registry[i]; | |
if (registry == table) | |
{ | |
PRINTF("Remove hashtable %s from the registry", table->name); | |
hashtable_registry[i] = NULL; | |
} | |
} | |
} | |
static inline size_t hashtable_get_index(const hashtable_t *hashtable, const uint32_t hash) \ | |
{ \ | |
return hash & (hashtable->__size - 1); \ | |
} \ | |
static void *hashtable_alloc(size_t size) | |
{ | |
#ifdef __KERNEL__ | |
void *p; | |
size = PAGE_ALIGN(size); | |
p = vmalloc(size); | |
if (p) | |
{ | |
unsigned long long adr = (unsigned long long) p; | |
while (size > 0) | |
{ | |
SetPageReserved(vmalloc_to_page((void *)adr)); | |
adr += PAGE_SIZE; | |
size -= PAGE_SIZE; | |
} | |
} | |
return p; | |
#else | |
return malloc(size); | |
#endif | |
} | |
static void hashtable_free(void *p, size_t size) | |
{ | |
#ifdef __KERNEL__ | |
if (p) | |
{ | |
unsigned long long adr = (unsigned long long) p; | |
while ((long) size > 0) { | |
ClearPageReserved(vmalloc_to_page((void *)adr)); | |
adr += PAGE_SIZE; | |
size -= PAGE_SIZE; | |
} | |
vfree(p); | |
} | |
#else | |
free(p); | |
#endif | |
} | |
static void hashtable_close(hashtable_t *hashtable) | |
{ | |
if (hashtable->__table) | |
{ | |
hashtable_free(hashtable->__table, hashtable->__memory_size); | |
} | |
else | |
{ | |
PRINTF("Failed to free null pointer for the hashtable %s", hashtable->name); | |
} | |
hashtable_registry_remove(hashtable); | |
} | |
#ifdef __KERNEL__ | |
# define HASHTABLE_CMPXCHG(key, val, new_val) cmpxchg(key, val, new_val) | |
#else | |
# define HASHTABLE_CMPXCHG(key, val, new_val) __sync_val_compare_and_swap(key, val, new_val) | |
#endif | |
#if 1 | |
#ifdef __KERNEL__ | |
# define HASHTABLE_BARRIER() barrier() | |
#else | |
# define HASHTABLE_BARRIER() __sync_synchronize() | |
#endif | |
#else | |
# define HASHTABLE_BARRIER() | |
#endif | |
#define HASHTABLE_SLOT_ADDR(hashtable, tokn, index) &(((hashtable_## tokn ##_slot_t*)hashtable->__table)[index]) | |
/** | |
* Illegal TID can be (PID_MAX_LIMIT+1) | |
* Illegal data is 0 for TID, -1 for FD, etc (this is optional) | |
*/ | |
#define DECLARE_HASHTABLE(tokn, data_type, max_tries, illegal_key, illegal_data) \ | |
\ | |
typedef struct \ | |
{ \ | |
volatile uint32_t key; \ | |
data_type data; \ | |
} hashtable_## tokn ## _slot_t; \ | |
\ | |
static void hashtable_## tokn ## _init_slot(hashtable_## tokn ## _slot_t *slot) \ | |
{ \ | |
slot->key = illegal_key; \ | |
slot->data = illegal_data; \ | |
} \ | |
\ | |
/** \ | |
* Calculate number of slots in the hash table \ | |
* I add max_tries on top to ensure that there are max_tries slots after the \ | |
* end of the table \ | |
*/ \ | |
static size_t hashtable_## tokn ##_memory_size(const int bits) \ | |
{ \ | |
size_t slots = (1 << bits) + max_tries; \ | |
return (sizeof(hashtable_## tokn ## _slot_t) * slots); \ | |
} \ | |
\ | |
static int hashtable_## tokn ##_init(hashtable_t *hashtable) \ | |
{ \ | |
size_t memory_size = hashtable_## tokn ## _memory_size(hashtable->bits); \ | |
void *p = hashtable_alloc(memory_size); \ | |
size_t i; \ | |
if (p) \ | |
{ \ | |
if (hashtable->hashfunction == NULL) \ | |
{ \ | |
hashtable->hashfunction = hash32shift; \ | |
} \ | |
hashtable->__size = (1 << hashtable->bits); \ | |
hashtable->__memory_size = memory_size; \ | |
hashtable->__table = p; \ | |
for (i = 0;i < hashtable->__size;i++) \ | |
{ \ | |
hashtable_## tokn ## _init_slot(HASHTABLE_SLOT_ADDR(hashtable, tokn, i)); \ | |
} \ | |
hashtable_registry_add(hashtable); \ | |
return 1; \ | |
} \ | |
PRINTF("Failed to allocate %zu for the hashtable %s", memory_size, hashtable->name); \ | |
return 0; \ | |
} \ | |
\ | |
/** \ | |
* Hash the key, get an index in the hashtable, try compare-and-set. \ | |
* If fails (not likely) try again with the next slot (linear probing) \ | |
* continue until success or max_tries is hit \ | |
*/ \ | |
static int hashtable_## tokn ##_insert(hashtable_t *hashtable, const uint32_t key, const data_type data) \ | |
{ \ | |
const uint32_t hash = hashtable->hashfunction(key); \ | |
const uint32_t index = hashtable_get_index(hashtable, hash); \ | |
/* I can do this for the last slot too - I allocated max_tries more slots */ \ | |
const uint32_t index_max = index+max_tries; \ | |
uint32_t i; \ | |
hashtable->__stat.insert++; \ | |
for (i = index;i < index_max;i++) \ | |
{ \ | |
hashtable_## tokn ##_slot_t *slot = HASHTABLE_SLOT_ADDR(hashtable, tokn, i); \ | |
uint32_t old_key = HASHTABLE_CMPXCHG(&slot->key, illegal_key, key); \ | |
if (likely(old_key == illegal_key)) /* Success */ \ | |
{ \ | |
slot->data = data; \ | |
return 1; \ | |
} \ | |
else \ | |
{ \ | |
hashtable->__stat.collision++; \ | |
} \ | |
} \ | |
\ | |
hashtable->__stat.insert_err++; \ | |
return 0; \ | |
} \ | |
\ | |
/** \ | |
* Hash the key, get an index in the hashtable, find the relevant entry, \ | |
* read the pointer, remove using atomic operation \ | |
* Only one context is allowed to remove a specific entry \ | |
*/ \ | |
static int hashtable_## tokn ##_remove(hashtable_t *hashtable, const uint32_t key, data_type *data) \ | |
{ \ | |
const uint32_t hash = hashtable->hashfunction(key); \ | |
const uint32_t index = hashtable_get_index(hashtable, hash); \ | |
/* I can do this for the last slot too - I allocated max_tries more slots */ \ | |
const uint32_t index_max = index+max_tries; \ | |
uint32_t i; \ | |
hashtable->__stat.remove++; \ | |
for (i = index;i < index_max;i++) \ | |
{ \ | |
hashtable_## tokn ##_slot_t *slot = HASHTABLE_SLOT_ADDR(hashtable, tokn, i); \ | |
uint32_t old_key = slot->key; \ | |
if (likely(old_key == key)) \ | |
{ \ | |
if (data) \ | |
{ \ | |
*data = slot->data; \ | |
} \ | |
__sync_access(&slot->data) = illegal_data; \ | |
HASHTABLE_BARRIER(); \ | |
__sync_access(&slot->key) = illegal_key; \ | |
return 1; \ | |
} \ | |
} \ | |
\ | |
hashtable->__stat.remove_err++; \ | |
return 0; \ | |
} \ | |
\ | |
/** \ | |
* Hash the key, get an index in the hashtable, find the relevant entry, \ | |
* read the pointer \ | |
*/ \ | |
static int hashtable_## tokn ##_find(hashtable_t *hashtable, const uint32_t key, data_type *data) \ | |
{ \ | |
const uint32_t hash = hashtable->hashfunction(key); \ | |
const uint32_t index = hashtable_get_index(hashtable, hash); \ | |
/* I can do this for the last slot too - I allocated max_tries more slots */ \ | |
const uint32_t index_max = index+max_tries; \ | |
uint32_t i; \ | |
hashtable->__stat.search++; \ | |
for (i = index;i < index_max;i++) \ | |
{ \ | |
hashtable_## tokn ##_slot_t *slot = HASHTABLE_SLOT_ADDR(hashtable, tokn, i); \ | |
uint32_t old_key = slot->key; \ | |
if (old_key == key) \ | |
{ \ | |
*data = slot->data; \ | |
return 1; \ | |
} \ | |
} \ | |
\ | |
return 0; \ | |
} \ | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment