Assumptions:
- Head node is firewalled but worker nodes are not.
- Nodes can ssh into each other.
On the head node, run the following commands:
ray stop
# Torch version: 2.1.0.dev20230403+cu117 | |
# Cuda: 11.7 | |
# Issue summary: | |
# Python's SDPA function is a means to use flash attention. This function doesn't work on sm_86 under some scenarios: | |
# - if we use bs=1, there's no issue (for most sequence lengths. Found it erroring for seq len 3 though) | |
# - if we use bs>1, then the module throws an error, during loss.backward() | |
# - these both happen when head_dim > 64. In this repro, we're using codegen-2B, which has head_dim=80. | |
# | |
# See this for error log: https://pastebin.com/t2Xdyb0d | |
# |
0x9d58bAe70c30213A275791Fe3bFf4f3940Bf57E7 |