Skip to content

Instantly share code, notes, and snippets.

@tazarov
Created November 26, 2024 17:24
Show Gist options
  • Save tazarov/71fe6f2e8d5947dde998e83ee9a57d0a to your computer and use it in GitHub Desktop.
Save tazarov/71fe6f2e8d5947dde998e83ee9a57d0a to your computer and use it in GitHub Desktop.
Reproducing HNSW Knn search exception memory leak
#include <iostream>
#include <stdexcept>
#if defined(__APPLE__)
#include <mach/mach.h>
#endif
#include <thread>
#include <atomic>
#include <stdlib.h>
typedef size_t labeltype;
typedef float dist_t;
template<class Function>
inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) {
if (numThreads <= 0) {
numThreads = std::thread::hardware_concurrency();
}
if (numThreads == 1) {
for (size_t id = start; id < end; id++) {
fn(id, 0);
}
} else {
std::vector<std::thread> threads;
std::atomic<size_t> current(start);
// keep track of exceptions in threads
// https://stackoverflow.com/a/32428427/1713196
std::exception_ptr lastException = nullptr;
std::mutex lastExceptMutex;
for (size_t threadId = 0; threadId < numThreads; ++threadId) {
threads.push_back(std::thread([&, threadId] {
while (true) {
size_t id = current.fetch_add(1);
if (id >= end) {
break;
}
try {
fn(id, threadId);
} catch (...) {
std::unique_lock<std::mutex> lastExcepLock(lastExceptMutex);
lastException = std::current_exception();
/*
* This will work even when current is the largest value that
* size_t can fit, because fetch_add returns the previous value
* before the increment (what will result in overflow
* and produce 0 instead of current + 1).
*/
current = end;
break;
}
}
}));
}
for (auto &thread: threads) {
thread.join();
}
if (lastException) {
std::rethrow_exception(lastException);
}
}
}
// Simulate allocation and throwing an exception
void allocateAndThrow() {
const size_t allocationSize = 10 * 1024 * 1024; // Allocate 10 MB
int *data = new int[allocationSize / sizeof(int)];
for (size_t i = 0; i < allocationSize / sizeof(int); ++i) {
data[i] = i;
}
auto rows = 1;
auto num_threads = std::thread::hardware_concurrency();
if (rows <= num_threads * 4) {
num_threads = 1;
}
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
throw std::runtime_error("Simulated failure!");
}
);
delete[] data; // This line will never be reached
}
void allocateAndThrowFixed() {
const size_t allocationSize = 10 * 1024 * 1024; // Allocate 10 MB
int *data = new int[allocationSize / sizeof(int)];
try {
for (size_t i = 0; i < allocationSize / sizeof(int); ++i) {
data[i] = i;
}
auto rows = 1;
auto num_threads = std::thread::hardware_concurrency();
if (rows <= num_threads * 4) {
num_threads = 1;
}
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
throw std::runtime_error("Simulated failure!");
}
);
} catch (const std::exception &e) {
delete[] data;
}
}
void allocateAndThrowFixedUniqPtr() {
const size_t allocationSize = 10 * 1024 * 1024; // Allocate 10 MB
std::unique_ptr<int[]> data(new int[allocationSize / sizeof(int)]);
for (size_t i = 0; i < allocationSize / sizeof(int); ++i) {
data[i] = i;
}
throw std::runtime_error("Simulated failure!");
}
void allocateAndThrowMoreRealistic(int rows, int k) {
labeltype *data_numpy_l;
dist_t *data_numpy_d;
data_numpy_l = new labeltype[rows * k];
data_numpy_d = new dist_t[rows * k];
auto num_threads = std::thread::hardware_concurrency();
if (rows <= num_threads * 4) {
num_threads = 1;
}
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
throw std::runtime_error("Simulated failure!");
}
);
delete[] data_numpy_l;
delete[] data_numpy_d;
}
void allocateAndThrowMoreRealisticWithTryCatch(int rows, int k) {
labeltype *data_numpy_l;
dist_t *data_numpy_d;
try {
data_numpy_l = new labeltype[rows * k];
data_numpy_d = new dist_t[rows * k];
for (int i = 0; i < rows * k; ++i) {
data_numpy_l[i] = i;
data_numpy_d[i] = i;
}
auto num_threads = std::thread::hardware_concurrency();
if (rows <= num_threads * 4) {
num_threads = 1;
}
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
throw std::runtime_error("Simulated failure!");
}
);
} catch (const std::exception &e) {
delete[] data_numpy_l;
delete[] data_numpy_d;
}
}
void allocateAndThrowMoreRealisticWithUniqPtr(int rows, int k) {
std::unique_ptr<labeltype[]> data_numpy_l(new labeltype[rows * k]);
std::unique_ptr<dist_t[]> data_numpy_d(new dist_t[rows * k]);
for (int i = 0; i < rows * k; ++i) {
data_numpy_l[i] = i;
data_numpy_d[i] = i;
}
auto num_threads = std::thread::hardware_concurrency();
if (rows <= num_threads * 4) {
num_threads = 1;
}
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
throw std::runtime_error("Simulated failure!");
}
);
}
// Function to get memory usage (macOS-specific)
size_t getMemoryUsage() {
#if defined(__linux__)
// Linux: Use /proc/self/statm
FILE* file = fopen("/proc/self/statm", "r");
if (!file) return 0;
long rss;
if (fscanf(file, "%*s %ld", &rss) != 1) {
fclose(file);
return 0;
}
fclose(file);
return rss * sysconf(_SC_PAGESIZE); // Convert pages to bytes
#elif defined(__APPLE__)
struct task_basic_info info;
mach_msg_type_number_t info_count = TASK_BASIC_INFO_COUNT;
if (task_info(mach_task_self(), TASK_BASIC_INFO, (task_info_t) &info, &info_count) != KERN_SUCCESS) {
return 0;
}
return info.resident_size; // Memory in bytes
#else
// Fallback for other platforms
return 0; // Implement OS-specific method if needed
#endif
}
int main() {
const int iterations = 1000000; // Number of iterations
size_t initialMemory = getMemoryUsage();
std::cout << "Initial memory usage: " << initialMemory << " bytes" << std::endl;
for (int i = 0; i < iterations; ++i) {
try {
// allocateAndThrowMoreRealistic(1,10); // this will leak memory
// allocateAndThrowMoreRealisticWithTryCatch(1,10); // this won't leak memory
allocateAndThrowMoreRealisticWithUniqPtr(1, 10); // this won't leak memory
} catch (const std::exception &e) {
// Catch the exception, but don't clean up the allocated memory
}
// Periodically print memory usage
if (i % 10000 == 0) {
size_t currentMemory = getMemoryUsage();
std::cout << "Iteration " << i << ", memory usage: " << currentMemory << " bytes" << std::endl;
}
}
size_t finalMemory = getMemoryUsage();
std::cout << "Final memory usage: " << finalMemory << " bytes" << std::endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment