This gist is out of date. Please see this gist instead.
When using Diffusers with xFormers enabled for Stable Diffusion, I noticed a problem where the pictures generated are slightly different (like a mistaken search) even when the seed value is fixed.
According to the xFormers team, the default backend, the Cutlass backend, does not guarantee deterministic behavior. (Incidentally, a patch has been merged into the main branch that adds a sentence to the documentation stating non-deterministic behavior and a warning if torch.use_deterministic_algorithms is set to enabled.)
In the same thread, I was informed that another backend of xFormers, Flash Attention, has deterministic behavior. So I wrote a patch to prioritize the Flash Attention backend.
The result of the test is shown in the figure above. Each row on the leftmost image is run#0, and the remaining image shows the difference from run#0. (There is less difference closer to black.) The seed value is fixed, and the generation is repeated. The cat's paw changes in the default (Cutlass) case, but when Flash Attention is prioritized, almost the same image is generated each time.
There is a reason why I used the word "almost." Flash Attention can be used for U-Net inference but not for VAE inference due to the size of the shape of the attention. For this reason, I have changed the code to fallback to the existing Cutlass backend when flash attention cannot use. However, using the Cutlass backend during VAE inference results in a very small (1 for a maximum luminance of 255 or so) difference, which is almost unrecognizable to humans.
https://github.com/takuma104/diffusers/tree/force_xformers_flash_attention
The essential part of this patch is the following:
If use_flash_attention
is True
, the op
argument of memory_efficient_attention()
is determined to prioritize FlashAttention. If the Flash Attention doesn't support the argument shape, type or so on, it fallbacks to the None
== default (Cutlass).
if use_flash_attention:
op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
fw, bw = op
if not fw.supports(xformers.ops.fmha.Inputs(query=query, key=key, value=value, attn_bias=attention_mask)):
logger.warning('Flash Attention is not availabe for the input arguments. Fallback to default xFormers\' backend.')
op = None
else:
op = None
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask, op=op)
by attached diffusers_sd_xformers_flash_attention.py
However, as I have written above, you may not be able to reproduce the same results for the default
one.
by attached diffusers_sd_xformers_flash_attention_profile.py
In my environment (RTX3060) the results were as follows. Using Flash Attention for Unet inference is slightly faster, and peak memory usage is same as default.
$ python diffusers_xformers_profile.py
A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'
Fetching 16 files: 100%|███████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 26163.30it/s]
default auto backend ---
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:04<00:00, 3.62it/s]
Peak memory use: 3849MB
flash attention backend ---
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:03<00:00, 4.09it/s]
Flash Attention is not availabe for the input arguments. Fallback to default xFormers' backend.
Peak memory use: 3849MB