Skip to content

Instantly share code, notes, and snippets.

@thomasahle
Created January 23, 2025 11:56
Show Gist options
  • Save thomasahle/f5e8c38c8adfaa5a7da58ba46c21d7b0 to your computer and use it in GitHub Desktop.
Save thomasahle/f5e8c38c8adfaa5a7da58ba46c21d7b0 to your computer and use it in GitHub Desktop.
1: def _generated_forward(i:int, k:int, _var_5305017088:torch.Tensor, _var_5305017136:torch.Tensor, _var_5305023376:torch.Tensor) -> tuple[torch.Tensor]:
2: sum_ = 0 # (i, k)
3: # Product of 2 tensors
4: var_x = _var_5305023376 # (i, j)
5: var_w = _var_5305017088 # (j, k)
6: # var_x: (i, j)
7: # var_w: (j, k)
8: prod_ = ctg.array_contract(
9: arrays=[var_x, var_w],
10: inputs=[('i', 'j'), ('j', 'k')],
11: output=('i', 'k'),
12: optimize='auto'
13: )
14: sum_ += prod_
15: var_b = _var_5305017136 # (i, k)
16: sum_ += var_b
17: fn_relu = torch.relu(sum_) # (i, k)
18: # Product of 4 tensors
19: fn_gt0 = (sum_ >= 0).float() # (i_0, k_0)
20: indices = torch.arange(k, dtype=torch.int64)
21: values = torch.ones(k, dtype=torch.float32)
22: delta_ = torch.sparse_csr_tensor(
23: crow_indices=torch.arange(k + 1, dtype=torch.int64),
24: col_indices=indices,
25: values=values,
26: size=(k, k, k)
27: )
28: indices = torch.arange(i, dtype=torch.int64)
29: values = torch.ones(i, dtype=torch.float32)
30: delta__1 = torch.sparse_csr_tensor(
31: crow_indices=torch.arange(i + 1, dtype=torch.int64),
32: col_indices=indices,
33: values=values,
34: size=(i, i, i)
35: )
36: # fn_gt0: (i_0, k_0)
37: # var_w: (j, k_1)
38: # delta_: (k, k_0, k_1)
39: # delta__1: (i_, i, i_0)
40: prod__1 = ctg.array_contract(
41: arrays=[fn_gt0, var_w, delta_, delta__1],
42: inputs=[('i_0', 'k_0'), ('j', 'k_1'), ('k', 'k_0', 'k_1'), ('i_', 'i', 'i_0')],
43: output=('j', 'k', 'i_', 'i'),
44: optimize='auto'
45: )
46: return fn_relu, prod__1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment