-
Notifications
You must be signed in to change notification settings - Fork 6k
Chroma as a FLUX.1 variant #11566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Chroma as a FLUX.1 variant #11566
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -191,6 +191,7 @@ def __init__( | |
transformer: FluxTransformer2DModel, | ||
image_encoder: CLIPVisionModelWithProjection = None, | ||
feature_extractor: CLIPImageProcessor = None, | ||
variant: str = "flux", | ||
): | ||
super().__init__() | ||
|
||
|
@@ -213,6 +214,17 @@ def __init__( | |
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 | ||
) | ||
self.default_sample_size = 128 | ||
if variant not in {"flux", "chroma"}: | ||
raise ValueError("`variant` must be `'flux' or `'chroma'`.") | ||
|
||
self.variant = variant | ||
|
||
def _get_chroma_attn_mask(self, length: torch.Tensor, max_sequence_length: int) -> torch.Tensor: | ||
attention_mask = torch.zeros((length.shape[0], max_sequence_length), dtype=torch.bool, device=length.device) | ||
for i, n_tokens in enumerate(length): | ||
n_tokens = torch.max(n_tokens + 1, max_sequence_length) | ||
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does Chroma support tokens beyond the max length for T5? Wouldn't this operation result in a mask that is 512 tokens in length with all True/1 for n_tokens < max_sequence_length? Also is it not possible to use the attention mask returned by the tokenizer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC, Chroma needs an attention mask that's equivalent to The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
though, the discussion here suggests the extra pad token is a mistake in the comfyui implementation so There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training seems to suggest otherwise. |
||
attention_mask[i, :n_tokens] = True | ||
return attention_mask | ||
|
||
def _get_t5_prompt_embeds( | ||
self, | ||
|
@@ -236,7 +248,7 @@ def _get_t5_prompt_embeds( | |
padding="max_length", | ||
max_length=max_sequence_length, | ||
truncation=True, | ||
return_length=False, | ||
return_length=(self.variant == "chroma"), | ||
return_overflowing_tokens=False, | ||
return_tensors="pt", | ||
) | ||
|
@@ -250,7 +262,15 @@ def _get_t5_prompt_embeds( | |
f" {max_sequence_length} tokens: {removed_text}" | ||
) | ||
|
||
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] | ||
prompt_embeds = self.text_encoder_2( | ||
text_input_ids.to(device), | ||
output_hidden_states=False, | ||
attention_mask=( | ||
self._get_chroma_attn_mask(text_inputs.length, max_sequence_length).to(device) | ||
if self.variant == "chroma" | ||
else None | ||
), | ||
)[0] | ||
|
||
dtype = self.text_encoder_2.dtype | ||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this not be 3072? am I missing something? (currently computes to 344, but this doesn't fit
distilled_guidance_layer.out_proj
and comfy sets it to 3072) mod_index_length/mod_proj.shape[0] should be 344 though