Created
November 26, 2024 17:24
-
-
Save tazarov/71fe6f2e8d5947dde998e83ee9a57d0a to your computer and use it in GitHub Desktop.
Reproducing HNSW Knn search exception memory leak
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <stdexcept> | |
#if defined(__APPLE__) | |
#include <mach/mach.h> | |
#endif | |
#include <thread> | |
#include <atomic> | |
#include <stdlib.h> | |
typedef size_t labeltype; | |
typedef float dist_t; | |
template<class Function> | |
inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) { | |
if (numThreads <= 0) { | |
numThreads = std::thread::hardware_concurrency(); | |
} | |
if (numThreads == 1) { | |
for (size_t id = start; id < end; id++) { | |
fn(id, 0); | |
} | |
} else { | |
std::vector<std::thread> threads; | |
std::atomic<size_t> current(start); | |
// keep track of exceptions in threads | |
// https://stackoverflow.com/a/32428427/1713196 | |
std::exception_ptr lastException = nullptr; | |
std::mutex lastExceptMutex; | |
for (size_t threadId = 0; threadId < numThreads; ++threadId) { | |
threads.push_back(std::thread([&, threadId] { | |
while (true) { | |
size_t id = current.fetch_add(1); | |
if (id >= end) { | |
break; | |
} | |
try { | |
fn(id, threadId); | |
} catch (...) { | |
std::unique_lock<std::mutex> lastExcepLock(lastExceptMutex); | |
lastException = std::current_exception(); | |
/* | |
* This will work even when current is the largest value that | |
* size_t can fit, because fetch_add returns the previous value | |
* before the increment (what will result in overflow | |
* and produce 0 instead of current + 1). | |
*/ | |
current = end; | |
break; | |
} | |
} | |
})); | |
} | |
for (auto &thread: threads) { | |
thread.join(); | |
} | |
if (lastException) { | |
std::rethrow_exception(lastException); | |
} | |
} | |
} | |
// Simulate allocation and throwing an exception | |
void allocateAndThrow() { | |
const size_t allocationSize = 10 * 1024 * 1024; // Allocate 10 MB | |
int *data = new int[allocationSize / sizeof(int)]; | |
for (size_t i = 0; i < allocationSize / sizeof(int); ++i) { | |
data[i] = i; | |
} | |
auto rows = 1; | |
auto num_threads = std::thread::hardware_concurrency(); | |
if (rows <= num_threads * 4) { | |
num_threads = 1; | |
} | |
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { | |
throw std::runtime_error("Simulated failure!"); | |
} | |
); | |
delete[] data; // This line will never be reached | |
} | |
void allocateAndThrowFixed() { | |
const size_t allocationSize = 10 * 1024 * 1024; // Allocate 10 MB | |
int *data = new int[allocationSize / sizeof(int)]; | |
try { | |
for (size_t i = 0; i < allocationSize / sizeof(int); ++i) { | |
data[i] = i; | |
} | |
auto rows = 1; | |
auto num_threads = std::thread::hardware_concurrency(); | |
if (rows <= num_threads * 4) { | |
num_threads = 1; | |
} | |
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { | |
throw std::runtime_error("Simulated failure!"); | |
} | |
); | |
} catch (const std::exception &e) { | |
delete[] data; | |
} | |
} | |
void allocateAndThrowFixedUniqPtr() { | |
const size_t allocationSize = 10 * 1024 * 1024; // Allocate 10 MB | |
std::unique_ptr<int[]> data(new int[allocationSize / sizeof(int)]); | |
for (size_t i = 0; i < allocationSize / sizeof(int); ++i) { | |
data[i] = i; | |
} | |
throw std::runtime_error("Simulated failure!"); | |
} | |
void allocateAndThrowMoreRealistic(int rows, int k) { | |
labeltype *data_numpy_l; | |
dist_t *data_numpy_d; | |
data_numpy_l = new labeltype[rows * k]; | |
data_numpy_d = new dist_t[rows * k]; | |
auto num_threads = std::thread::hardware_concurrency(); | |
if (rows <= num_threads * 4) { | |
num_threads = 1; | |
} | |
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { | |
throw std::runtime_error("Simulated failure!"); | |
} | |
); | |
delete[] data_numpy_l; | |
delete[] data_numpy_d; | |
} | |
void allocateAndThrowMoreRealisticWithTryCatch(int rows, int k) { | |
labeltype *data_numpy_l; | |
dist_t *data_numpy_d; | |
try { | |
data_numpy_l = new labeltype[rows * k]; | |
data_numpy_d = new dist_t[rows * k]; | |
for (int i = 0; i < rows * k; ++i) { | |
data_numpy_l[i] = i; | |
data_numpy_d[i] = i; | |
} | |
auto num_threads = std::thread::hardware_concurrency(); | |
if (rows <= num_threads * 4) { | |
num_threads = 1; | |
} | |
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { | |
throw std::runtime_error("Simulated failure!"); | |
} | |
); | |
} catch (const std::exception &e) { | |
delete[] data_numpy_l; | |
delete[] data_numpy_d; | |
} | |
} | |
void allocateAndThrowMoreRealisticWithUniqPtr(int rows, int k) { | |
std::unique_ptr<labeltype[]> data_numpy_l(new labeltype[rows * k]); | |
std::unique_ptr<dist_t[]> data_numpy_d(new dist_t[rows * k]); | |
for (int i = 0; i < rows * k; ++i) { | |
data_numpy_l[i] = i; | |
data_numpy_d[i] = i; | |
} | |
auto num_threads = std::thread::hardware_concurrency(); | |
if (rows <= num_threads * 4) { | |
num_threads = 1; | |
} | |
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { | |
throw std::runtime_error("Simulated failure!"); | |
} | |
); | |
} | |
// Function to get memory usage (macOS-specific) | |
size_t getMemoryUsage() { | |
#if defined(__linux__) | |
// Linux: Use /proc/self/statm | |
FILE* file = fopen("/proc/self/statm", "r"); | |
if (!file) return 0; | |
long rss; | |
if (fscanf(file, "%*s %ld", &rss) != 1) { | |
fclose(file); | |
return 0; | |
} | |
fclose(file); | |
return rss * sysconf(_SC_PAGESIZE); // Convert pages to bytes | |
#elif defined(__APPLE__) | |
struct task_basic_info info; | |
mach_msg_type_number_t info_count = TASK_BASIC_INFO_COUNT; | |
if (task_info(mach_task_self(), TASK_BASIC_INFO, (task_info_t) &info, &info_count) != KERN_SUCCESS) { | |
return 0; | |
} | |
return info.resident_size; // Memory in bytes | |
#else | |
// Fallback for other platforms | |
return 0; // Implement OS-specific method if needed | |
#endif | |
} | |
int main() { | |
const int iterations = 1000000; // Number of iterations | |
size_t initialMemory = getMemoryUsage(); | |
std::cout << "Initial memory usage: " << initialMemory << " bytes" << std::endl; | |
for (int i = 0; i < iterations; ++i) { | |
try { | |
// allocateAndThrowMoreRealistic(1,10); // this will leak memory | |
// allocateAndThrowMoreRealisticWithTryCatch(1,10); // this won't leak memory | |
allocateAndThrowMoreRealisticWithUniqPtr(1, 10); // this won't leak memory | |
} catch (const std::exception &e) { | |
// Catch the exception, but don't clean up the allocated memory | |
} | |
// Periodically print memory usage | |
if (i % 10000 == 0) { | |
size_t currentMemory = getMemoryUsage(); | |
std::cout << "Iteration " << i << ", memory usage: " << currentMemory << " bytes" << std::endl; | |
} | |
} | |
size_t finalMemory = getMemoryUsage(); | |
std::cout << "Final memory usage: " << finalMemory << " bytes" << std::endl; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment