Skip to content

Flash/Sage varlen does not work with torch.compile #11957

@a-r-r-o-w

Description

@a-r-r-o-w

@a-r-r-o-w @DN6 @tolgacangoz
I got a bit excited about this PR and wanted to give it a go. I love the syntax, both the setter function and the context, great work!

I wanted to also see if it would still compile but got the following error logs:

[t+28s648ms]   0%|          | 0/30 [00:00<?, ?it/s]/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:1601: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
[t+28s650ms]   torch._dynamo.utils.warn_once(msg)
[t+28s666ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0] Graph break from `Tensor.item()`, consider setting:
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     torch._dynamo.config.capture_scalar_outputs = True
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0] or:
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0] to include these operations in the captured graph.
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0] Graph break: from user code at:
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/transformers/transformer_flux.py", line 733, in forward
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     encoder_hidden_states, hidden_states = block(
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/transformers/transformer_flux.py", line 456, in forward
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     attention_outputs = self.attn(
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/transformers/transformer_flux.py", line 343, in forward
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/transformers/transformer_flux.py", line 117, in __call__
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     hidden_states = dispatch_attention_fn(
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/attention_dispatch.py", line 241, in dispatch_attention_fn
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     return backend_fn(**kwargs)
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/attention_dispatch.py", line 962, in _sage_varlen_attention
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     _prepare_for_flash_attn_or_sage_varlen(
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/attention_dispatch.py", line 351, in _prepare_for_flash_attn_or_sage_varlen
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     return _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device)
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/_dynamo/polyfills/__init__.py", line 253, in getattr_and_trace
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     return fn(*args[2:], **kwargs)
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/attention_dispatch.py", line 321, in _prepare_for_flash_attn_or_sage_varlen_without_mask
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     max_seqlen_q = seqlens_q.max().item()
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]
[t+28s669ms]   0%|          | 0/30 [00:00<?, ?it/s]
[t+28s670ms] [ERROR] Traceback (most recent call last):
[t+28s671ms]   File "/server/tasks.py", line 50, in run_task
[t+28s671ms]     output = await result
[t+28s671ms]              ^^^^^^^^^^^^
[t+28s671ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/src/inference.py", line 248, in run
[t+28s671ms]     result = self.pipeline(
[t+28s671ms]              ^^^^^^^^^^^^^^
[t+28s671ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[t+28s671ms]     return func(*args, **kwargs)
[t+28s671ms]            ^^^^^^^^^^^^^^^^^^^^^
[t+28s671ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_kontext.py", line 1063, in __call__
[t+28s671ms]     noise_pred = self.transformer(
[t+28s671ms]                  ^^^^^^^^^^^^^^^^^
[t+28s671ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 411, in __call__
[t+28s671ms]     return super().__call__(*args, **kwargs)
[t+28s671ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+28s671ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[t+28s671ms]     return self._call_impl(*args, **kwargs)
[t+28s671ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+28s671ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[t+28s671ms]     return forward_call(*args, **kwargs)
[t+28s671ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+28s672ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 812, in compile_wrapper
[t+28s672ms]     raise e.with_traceback(None) from e.__cause__  # User compiler error
[t+28s672ms]     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+28s672ms] torch._dynamo.exc.Unsupported: Unsupported Tensor.item() call with capture_scalar_outputs=False
[t+28s672ms]   Explanation: Dynamo does not support tracing `Tensor.item()` with config.capture_scalar_outputs=False.

compiling with both

self.pipeline.transformer.compile_repeated_blocks(fullgraph=True)

and

self.pipeline.transformer.to(memory_format=torch.channels_last)
self.pipeline.transformer = torch.compile(
  self.pipeline.transformer, mode="max-autotune", fullgraph=True
)

yields the same result

after

self.pipeline.transformer.set_attention_backend("sage_varlen")

Originally posted by @okaris in #11916 (comment)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions