Last active
July 14, 2022 21:51
-
-
Save ed1d1a8d/424e5bc83325c93037cfe2de9e457a68 to your computer and use it in GitHub Desktop.
ffcv-tqdm-thread-leak
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.utils.data | |
import torchvision | |
from ffcv.fields import IntField, RGBImageField | |
from ffcv.writer import DatasetWriter | |
ds = torch.utils.data.Subset( | |
dataset=torchvision.datasets.CIFAR10( | |
"/var/tmp", train=False, download=True | |
), | |
indices=range(64), | |
) | |
writer = DatasetWriter( | |
"test.beton", | |
{"image": RGBImageField(write_mode="raw"), "label": IntField()}, | |
num_workers=4, | |
) | |
writer.from_indexed_dataset(ds) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import psutil | |
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder | |
from ffcv.loader import Loader | |
from ffcv.transforms import ToDevice, ToTensor | |
from tqdm.auto import tqdm | |
CUR_PROCESS = psutil.Process(os.getpid()) | |
MAX_THREADS: int = 0 | |
def print_max_threads_encountered(idx: int): | |
global MAX_THREADS | |
cur_threads = CUR_PROCESS.num_threads() | |
MAX_THREADS = max(cur_threads, MAX_THREADS) | |
print(f"Max threads: {MAX_THREADS}; cur threads: {cur_threads}; idx={idx}") | |
def get_loader( | |
batch_size: int, | |
num_workers: int, | |
device: str = "cpu", # BUG occurs on "cpu" or "cuda"! | |
) -> Loader: | |
label_pipeline = [ | |
IntDecoder(), | |
ToTensor(), | |
ToDevice(device), | |
] | |
image_pipeline = [ | |
SimpleRGBImageDecoder(), | |
ToTensor(), | |
ToDevice(device), | |
] | |
return Loader( | |
"test.beton", | |
batch_size=batch_size, | |
num_workers=num_workers, | |
os_cache=True, # BUG occurs with os_cache = False or True | |
pipelines={"image": image_pipeline, "label": label_pipeline}, | |
) | |
def main(): | |
loader = get_loader(batch_size=4, num_workers=4) | |
cnt: int = 0 | |
while True: | |
# This has a thread leak | |
for _ in tqdm(loader): | |
pass | |
# This also has a thread leak | |
# with tqdm(loader) as pbar: | |
# for _ in pbar: | |
# pass | |
# Without tqdm, there is no thread leak! | |
# for _ in loader: | |
# pass | |
# Manual tqdm is also okay! | |
# with tqdm(total=len(loader)) as pbar: | |
# for _ in loader: | |
# pbar.update(1) | |
cnt += 1 | |
print_max_threads_encountered(idx=cnt) | |
if __name__ == "__main__": | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ffcv==0.0.3 | |
torch==1.12.0 | |
torchvision==0.13.0 | |
tqdm==4.64.0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
You can either run the gen_beton.py script to generate this file, or you can download it from | |
https://drive.google.com/file/d/1d_XT5CAG9MxZ8gIao5qOxNVbF8HNuXfg/view?usp=sharing |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment