-
Notifications
You must be signed in to change notification settings - Fork 6.4k
[core] reuse AttentionMixin
for compatible classes
#12463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
from ...models.attention_processor import ( | ||
ADDED_KV_ATTENTION_PROCESSORS, | ||
CROSS_ATTENTION_PROCESSORS, | ||
AttentionProcessor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that many of the tests in AudioLDM2PipelineFastTests
currently fail in the CI with the following error, e.g.:
FAILED tests/pipelines/audioldm2/test_audioldm2.py::AudioLDM2PipelineFastTests::test_inference_batch_consistent - AttributeError: type object 'ClapConfig' has no attribute 'from_text_audio_configs'
This method is called in the AudioLDM2PipelineFastTests.get_dummy_components
:
diffusers/tests/pipelines/audioldm2/test_audioldm2.py
Lines 141 to 145 in fa468c5
text_encoder_config = ClapConfig.from_text_audio_configs( | |
text_config=text_branch_config, | |
audio_config=audio_branch_config, | |
projection_dim=16, | |
) |
It looks like the ClapConfig.from_text_audio_configs
method exists in transformers==4.57.0
but has been removed in main
. Given that this method will be deprecated, should we replace this call with something like
class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
...
def get_dummy_components(self):
...
text_encoder_config = ClapConfig(
text_config=text_branch_config,
audio_config=audio_branch_config,
projection_dim=16,
)
...
...
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, for the following tests fail due to CLIPFeatureExtractor
being removed:
FAILED tests/pipelines/test_pipelines.py::DownloadTests::test_download_bin_only_variant_exists_for_model - AttributeError: module transformers has no attribute CLIPFeatureExtractor
FAILED tests/pipelines/test_pipelines.py::DownloadTests::test_download_bin_variant_does_not_exist_for_model - AttributeError: module transformers has no attribute CLIPFeatureExtractor
FAILED tests/pipelines/test_pipelines.py::PipelineFastTests::test_wrong_model - AttributeError: module transformers has no attribute CLIPFeatureExtractor
Should we replace the calls to CLIPFeatureExtractor
with CLIPImageProcessor
, or do you think that should be separated into a new PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For #12463 (comment) see #12455
nit: Unrelated to this PR. Prefer discussing these separately.
for name, module in self.named_children(): | ||
fn_recursive_attn_processor(name, module, processor) | ||
|
||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it's out of the scope for this PR, but I see that a lot of models additionally have a set_default_attn_processor
method, usually # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
. Do you think it makes sense to add this method to AttentionMixin
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, not yet since AttentionMixin
is fairly agnostic to the model-type but set_default_attn_processor
relies on some custom attention processor types. For UNet2DConditionModel, we have:
diffusers/src/diffusers/models/unets/unet_2d_condition.py
Lines 762 to 769 in fa468c5
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnAddedKVProcessor() | |
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnProcessor() | |
else: | |
raise ValueError( | |
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" | |
) |
However, for AutoencoderKL Temporal Decoder:
diffusers/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py
Lines 269 to 274 in fa468c5
if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnProcessor() | |
else: | |
raise ValueError( | |
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" | |
) |
I'd be down to the refactoring, though. Cc: @DN6
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! I think AuraFlowTransformer2DModel
and AudioLDM2UNet2DConditionModel
have their attn_processor
/set_attn_processor
methods deleted but are missing the corresponding change to inherit from AttentionMixin
.
Thanks for those catches, @dg845. Should have been fixed by now. |
LGTM :) |
What does this PR do?
Many models use
"# Copied from ..."
implementations ofattn_processors
andset_attn_processor
. They are basically the same as what we have implemented indiffusers/src/diffusers/models/attention.py
Line 39 in 693d8a3
This PR makes those models inherit from
AttentionMixin
and removes the copied-over implementations.I decided to leave
fuse_qkv_projections
andunfuse_qkv_projections
out of this PR because some models don't have attention processors implemented in a way that would make this seamless. But the methods removed in this PR should be very harmless.