This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| with torch.Stream(device='cuda') as s_cuda: | |
| a = torch.randn(10, 5, device='cuda') | |
| b = torch.randn(5, 10, device='cuda') | |
| c = torch.mm(a, b) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Here's the repro: | |
| import torch | |
| torch._dynamo.config.capture_scalar_outputs = True | |
| def fn(x, val_tensor): | |
| val = val_tensor.item() # Creates an unbacked float (fp64) | |
| scaled = val * 2.0 # fp64 * fp64 = fp64 computation | |
| return x * scaled # fp32 * fp64 -> needs downcast |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch._dynamo.testing import AotEagerAndRecordGraphs | |
| import torch.fx.traceback as fx_traceback | |
| def forward(x): | |
| with fx_traceback.annotate({"pp_stage": 0}): | |
| with fx_traceback.annotate({"fdsp_bucket": 0}): | |
| sin = torch.sin(x) | |
| sub = sin - 2 | |
| with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| File "/data/users/mlazos/pytorch/torch/_dynamo/convert_frame.py", line 791, in trace_frame | |
| tracer = InstructionTranslator( | |
| ^^^^^^^^^^^^^^^^^^^^^^ | |
| File "/data/users/mlazos/pytorch/torch/_dynamo/symbolic_convert.py", line 4461, in __init__ | |
| self.symbolic_stream_state = SymbolicStreamState() | |
| ^^^^^^^^^^^^^^^^^^^^^ | |
| File "/data/users/mlazos/pytorch/torch/_dynamo/variables/streams.py", line 76, in __init__ | |
| if torch.accelerator.is_available(): | |
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
| File "/data/users/mlazos/pytorch/torch/accelerator/__init__.py", line 90, in is_available |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def test_get_current_stream_return(self): | |
| def fn(x, s): | |
| with s: | |
| s0 = torch.cuda.current_stream() | |
| return x, s0 | |
| s_inp = torch.Stream(device="cuda") | |
| inp = (torch.ones(2, 2) + 1, s_inp) | |
| fn_opt = torch.compile(fn, fullgraph=True) | |
| _, s0 = fn_opt(*inp) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def test_stream_enter_exit(self): | |
| def fn(x, y): | |
| s2 = torch.cuda.Stream() | |
| s1 = torch.cuda.Stream() | |
| with s2: | |
| z1 = torch.add(x, y) | |
| with s1: | |
| z = torch.add(x, y) | |
| y = z + 2 + z1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def llama_shapes(): | |
| # batch sizes * seq lengths | |
| BS = [2**i for i in range(4, 17)] | |
| #BS = [2**i for i in range(16, 17)] | |
| # attn: wqkv, wo; ffn: w13, w2 | |
| KN = [ | |
| (4096, 12288), | |
| (4096, 4096), | |
| (4096, 22016), | |
| (11008, 4096), |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ''' | |
| Online Python Interpreter. | |
| Code, Compile, Run and Debug python program online. | |
| Write your code in this editor and press "Run" button to execute it. | |
| ''' | |
| from dataclasses import dataclass | |
| from collections import deque | |
| from copy import deepcopy |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ''' | |
| Online Python Interpreter. | |
| Code, Compile, Run and Debug python program online. | |
| Write your code in this editor and press "Run" button to execute it. | |
| ''' | |
| from dataclasses import dataclass | |
| from copy import deepcopy | |
| # blocks |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| 0 MAKE_CELL 0 (self) | |
| 2 MAKE_CELL 12 (kwarg_keys) | |
| 4 RESUME 0 | |
| 6 LOAD_GLOBAL 21 (NULL + __import_torch_dot__dynamo_dot_utils) | |
| 16 LOAD_ATTR 22 (store_user_object_weakrefs) | |
| 36 COPY 1 | |
| 38 STORE_FAST 5 (tmp_0) | |
| 40 LOAD_FAST 0 (self) | |
| 42 LOAD_ATTR 24 (_modules) | |
| 62 LOAD_CONST 2 ('layers') |
NewerOlder