Skip to content

Instantly share code, notes, and snippets.

@bityob
Created December 25, 2023 23:47
Show Gist options
  • Save bityob/cab6fa78ef6382d1f55082a60dd6fc62 to your computer and use it in GitHub Desktop.
Save bityob/cab6fa78ef6382d1f55082a60dd6fc62 to your computer and use it in GitHub Desktop.
Run torch with multiprocessing on Jupyter Notebook (aka IPython, Jupyter Lab etc.)
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "62ff665d-63e1-472e-bace-deb2c47a7c8f",
"metadata": {},
"outputs": [],
"source": [
"# Source: https://github.com/mszhanyi/pymultiprocessdemo/blob/main/demos/mycellmagic.py\n",
"from IPython.core.magic import register_cell_magic\n",
"\n",
"@register_cell_magic\n",
"def save2file(line, cell):\n",
" 'save python code block to a file'\n",
" with open(line, 'wt') as fd:\n",
" fd.write(cell)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "edac3158-4fcc-4f21-9ba3-ba1ccbd03057",
"metadata": {},
"outputs": [],
"source": [
"%%save2file tmp.py\n",
"print(\"Loading setup\")\n",
"import torch\n",
"import torchaudio\n",
"import soundfile\n",
"from torchaudio.utils import download_asset\n",
"import torch.multiprocessing as multiprocessing\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H\n",
"model = bundle.get_model().to(device)\n",
"\n",
"class GreedyCTCDecoder(torch.nn.Module):\n",
" def __init__(self, labels, blank=0):\n",
" super().__init__()\n",
" self.labels = labels\n",
" self.blank = blank\n",
"\n",
" def forward(self, emission: torch.Tensor) -> str:\n",
" indices = torch.argmax(emission, dim=-1) # [num_seq,]\n",
" indices = torch.unique_consecutive(indices, dim=-1)\n",
" indices = [i for i in indices if i != self.blank]\n",
" return \"\".join([self.labels[i] for i in indices])\n",
"\n",
"\n",
"decoder = GreedyCTCDecoder(labels=bundle.get_labels())\n",
"\n",
"\n",
"def flow(file_object):\n",
" # Source: torchaudio doc\n",
" waveform, sample_rate = torchaudio.load(file_object)\n",
" waveform = waveform.to(device)\n",
" if sample_rate != bundle.sample_rate:\n",
" waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)\n",
" with torch.inference_mode():\n",
" features, _ = model.extract_features(waveform)\n",
" with torch.inference_mode():\n",
" emission, _ = model(waveform)\n",
" # decoder = GreedyCTCDecoder(labels=bundle.get_labels())\n",
" transcript = decoder(emission[0])\n",
" print(file_object.split(\"\\\\\")[-1], transcript)\n",
"\n",
"print(\"Done loading module\")\n",
"\n",
"def run():\n",
" files = [\n",
" download_asset(\"tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav\"),\n",
" download_asset(\"tutorial-assets/steam-train-whistle-daniel_simon.wav\"),\n",
" download_asset(\"tutorial-assets/Lab41-SRI-VOiCES-rm1-impulse-mc01-stu-clo-8000hz.wav\"),\n",
" download_asset(\"tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042-8000hz.wav\"),\n",
" download_asset(\"tutorial-assets/Lab41-SRI-VOiCES-rm1-babb-mc01-stu-clo-8000hz.wav\"),\n",
" ]\n",
" ctx = multiprocessing.get_context(\"spawn\")\n",
" with ctx.Pool(5) as p:\n",
" for _ in p.map(flow, files):\n",
" pass\n",
"\n",
"if __name__ == '__main__':\n",
" # Added `__spec__` to fix `AttributeError: module '__main__' has no attribute '__spec__' on running the %run cell multiple times \n",
" # See too: https://stackoverflow.com/questions/45720153/python-multiprocessing-error-attributeerror-module-main-has-no-attribute\n",
" __spec__ = \"main\"\n",
" import time\n",
" start = time.time()\n",
" print(\"Starting...\")\n",
" multiprocessing.freeze_support()\n",
" print(\"Done freeze support, running pool...\")\n",
"\n",
" torch.random.manual_seed(0)\n",
"\n",
" run()\n",
" end = time.time()\n",
" print(\"Done all on {} seconds\".format(end - start))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "97277d42-cdfd-412c-baf6-74e19cbb150c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading setup\n",
"Done loading module\n",
"Starting...\n",
"Done freeze support, running pool...\n",
"Done all on 12.184009552001953 seconds\n"
]
}
],
"source": [
"%run tmp.py"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "python3.11",
"language": "python",
"name": "python3.11"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
@bityob
Copy link
Author

bityob commented Dec 25, 2023

References:

GitHub Issues

pytorch/pytorch#17680
pytorch/pytorch#62943
ipython/ipython#10894

StackOverFlow

https://stackoverflow.com/questions/41385708/multiprocessing-example-giving-attributeerror

Demo Repository regarding issues of using multiprocessing and IPython/Jupyter

https://github.com/mszhanyi/pymultiprocessdemo

Also liked the https://pypi.org/project/multiprocess/ fork, but can't use it with pytorch since they have their own multiprocesses fork.

The %%save2file tmp.py and %run tmp.py magic usages, are the most practical solutions I have found for this.

It works just fine now from inside my notebook

@IliasAarab
Copy link

Smart! What about rerunning the last cell without restarting the kernel, is this possible?

@bityob
Copy link
Author

bityob commented Jun 17, 2024

@IliasAarab

Not sure I understand your question.

The last cell doesn't need kernel restart anyway, since it's run on seperate python process with the %run magic.

So you can update the second cell (with the %%save2file tmp.py in top) and then run the third (last) cell and it will run the script on fresh python.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment