Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 234 additions & 21 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)

def fuse_qkv_projections(self):
def fuse_qkv_projections(self, inplace: bool = False):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
Expand All @@ -106,7 +106,7 @@ def fuse_qkv_projections(self):

for module in self.modules():
if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion:
module.fuse_projections()
module.fuse_projections(inplace=inplace)

def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
Expand All @@ -117,11 +117,28 @@ def unfuse_qkv_projections(self):
if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion:
module.unfuse_projections()

def restore_checkpoint_fusion_state(self, inplace: bool = False):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide an example of where / how this is used?

My mental model says:

  • Load the pipeline
  • Call the fuse_qkv_projections() method on pipe.transformer.
  • Then, before loading we call this method.
  • And then load the LoRA weights?

"""
Restores the QKV fusion state back to that of the original model checkpoint (unlike `fuse_qkv_projections`,
which will fuse all eligible projections). This can be undone by `unfuse_qkv_projections`. The original
checkpoint fusion info is held on each `AttentionModuleMixin` module in the _native_fused_projections
attribute.

> [!WARNING] > This API is 🧪 experimental.
"""
for module in self.modules():
if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion:
if module._native_fused_projections is True:
module.fuse_projections(inplace=inplace)
elif module._native_fused_projections is False:
module.unfuse_projections()


class AttentionModuleMixin:
_default_processor_cls = None
_available_processors = []
_supports_qkv_fusion = True
_native_fused_projections = None
fused_projections = False

def set_processor(self, processor: AttentionProcessor) -> None:
Expand Down Expand Up @@ -244,11 +261,34 @@ def set_use_memory_efficient_attention_xformers(

self.set_attention_backend("xformers")

@staticmethod
def _has_active_lora(module: nn.Module) -> bool:
"""Checks for the presence of PEFT-style LoRA modules without needing to import `peft`."""
return any("lora_A" in name or "lora_B" in name for name, _ in module.named_modules())

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better way to detect it is:

for name, mod in module.named_modules():
    if isinstance(mod, BaseTunerLayer): ...

This way, we can cater towards any non-LoRA adapters in the future.


@torch.no_grad()
def fuse_projections(self):
def fuse_projections(self, inplace: bool = False):
"""
Fuse the query, key, and value projections into a single projection for efficiency.
"""
# Do not fuse if LoRA adapters are active on the Q,K,V projections.
possible_qkv_modules = [
("to_q", getattr(self, "to_q", None)),
("to_k", getattr(self, "to_k", None)),
("to_v", getattr(self, "to_v", None)),
("add_q_proj", getattr(self, "add_q_proj", None)),
("add_k_proj", getattr(self, "add_k_proj", None)),
("add_v_proj", getattr(self, "add_v_proj", None)),
]
active_lora_modules = [
name for name, mod in possible_qkv_modules if mod is not None and self._has_active_lora(mod)
]
if active_lora_modules:
raise ValueError(
f"Cannot fuse QKV projections: LoRA adapters are active on {active_lora_modules}. "
"Please detach the LoRA or call `merge_and_unload()` to merge LoRA weights first."
)

# Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections in Flux2
# single stream blocks are always fused)
if not self._supports_qkv_fusion:
Expand All @@ -275,6 +315,16 @@ def fuse_projections(self):
if hasattr(self, "use_bias") and self.use_bias:
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
self.to_kv.bias.copy_(concatenated_bias)

if inplace:
# Keep the necessary K,V dims so that the individual projections can be reconstructed.
self._qkv_split_dims = (
self.to_k.weight.shape[0],
self.to_v.weight.shape[0],
self.to_k.weight.shape[1],
)
delattr(self, "to_k")
delattr(self, "to_v")
Comment on lines +326 to +327

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it free the memory? If not, what is the purpose of deleting these attributes? Also, from what I understand this and some of the other refactors introduced in this PR aren't particularly for LoRA-awareness?

else:
# Fuse self-attention projections
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
Expand All @@ -287,27 +337,68 @@ def fuse_projections(self):
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
self.to_qkv.bias.copy_(concatenated_bias)

if inplace:
# Keep the necessary Q,K,V dims so that the individual projections can be reconstructed.
self._qkv_split_dims = (
self.to_q.weight.shape[0],
self.to_k.weight.shape[0],
self.to_v.weight.shape[0],
self.to_q.weight.shape[1],
)
delattr(self, "to_q")
delattr(self, "to_k")
delattr(self, "to_v")
Comment on lines +348 to +350

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.


# Handle added projections for models like SD3, Flux, etc.
if (
getattr(self, "add_q_proj", None) is not None
and getattr(self, "add_k_proj", None) is not None
and getattr(self, "add_v_proj", None) is not None
):
concatenated_weights = torch.cat(
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
if getattr(self, "add_k_proj", None) is not None and getattr(self, "add_v_proj", None) is not None:
if getattr(self, "add_q_proj", None) is not None:
# Added Self Attention (e.g. Flux)
concatenated_weights = torch.cat(
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]

self.to_added_qkv = nn.Linear(
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
)
self.to_added_qkv.weight.copy_(concatenated_weights)
if self.added_proj_bias:
concatenated_bias = torch.cat(
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
self.to_added_qkv = nn.Linear(
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
)
self.to_added_qkv.bias.copy_(concatenated_bias)
self.to_added_qkv.weight.copy_(concatenated_weights)
if self.added_proj_bias:
concatenated_bias = torch.cat(
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
)
self.to_added_qkv.bias.copy_(concatenated_bias)

if inplace:
self._added_qkv_split_dims = (
self.add_q_proj.weight.shape[0],
self.add_k_proj.weight.shape[0],
self.add_v_proj.weight.shape[0],
self.add_q_proj.weight.shape[1],
)
delattr(self, "add_q_proj")
delattr(self, "add_k_proj")
delattr(self, "add_v_proj")
else:
# Added Cross Attention (e.g. Wan)
concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]

self.to_added_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_added_kv.weight.copy_(concatenated_weights)
if hasattr(self, "use_bias") and self.use_bias:
concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
self.to_added_kv.bias.copy_(concatenated_bias)

if inplace:
self._added_qkv_split_dims = (
self.add_k_proj.weight.shape[0],
self.add_v_proj.weight.shape[0],
self.add_k_proj.weight.shape[1],
)
delattr(self, "add_k_proj")
delattr(self, "add_v_proj")

self.fused_projections = True

Expand All @@ -316,6 +407,22 @@ def unfuse_projections(self):
"""
Unfuse the query, key, and value projections back to separate projections.
"""
# Do not unfuse if LoRA adapters are active on the Q,K,V projections.
possible_fused_modules = [
("to_qkv", getattr(self, "to_qkv", None)),
("to_kv", getattr(self, "to_kv", None)),
("to_added_qkv", getattr(self, "to_added_qkv", None)),
("to_added_kv", getattr(self, "to_added_kv", None)),
]
active_lora_modules = [
name for name, mod in possible_fused_modules if mod is not None and self._has_active_lora(mod)
]
if active_lora_modules:
raise ValueError(
f"Cannot unfuse QKV projections: LoRA adapters are active on {active_lora_modules}. "
"Please detach the LoRA or call `merge_and_unload()` to merge LoRA weights first."
)

# Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections in Flux2
# single stream blocks are always fused)
if not self._supports_qkv_fusion:
Expand All @@ -327,16 +434,122 @@ def unfuse_projections(self):

# Remove fused projection layers
if hasattr(self, "to_qkv"):
if not hasattr(self, "to_q"):
# QKV fused in-place, need to reconstruct the individual Q,K,V projections
has_bias = self.to_qkv.bias is not None
d_q, d_k, d_v, d_in = self._qkv_split_dims
self.to_q = nn.Linear(d_in, d_q, bias=has_bias)
self.to_k = nn.Linear(d_in, d_k, bias=has_bias)
self.to_v = nn.Linear(d_in, d_v, bias=has_bias)
# Avoid copying by using a view which shares storage with the fused projection
self.to_q.weight = nn.Parameter(self.to_qkv.weight[:d_q])
self.to_k.weight = nn.Parameter(self.to_qkv.weight[d_q : d_q + d_k])
self.to_v.weight = nn.Parameter(self.to_qkv.weight[d_q + d_k :])
if has_bias:
self.to_q.bias = nn.Parameter(self.to_qkv.bias[:d_q])
self.to_k.bias = nn.Parameter(self.to_qkv.bias[d_q : d_q + d_k])
self.to_v.bias = nn.Parameter(self.to_qkv.bias[d_q + d_k :])
delattr(self, "to_qkv")

if hasattr(self, "to_kv"):
if not hasattr(self, "to_k"):
has_bias = self.to_kv.bias is not None
d_k, d_v, d_in = self._qkv_split_dims
self.to_k = nn.Linear(d_in, d_k, bias=has_bias)
self.to_v = nn.Linear(d_in, d_v, bias=has_bias)
self.to_k.weight = nn.Parameter(self.to_kv.weight[:d_k])
self.to_v.weight = nn.Parameter(self.to_kv.weight[d_k:])
if has_bias:
self.to_k.bias = nn.Parameter(self.to_kv.bias[:d_k])
self.to_v.bias = nn.Parameter(self.to_kv.bias[d_k:])
delattr(self, "to_kv")

if hasattr(self, "to_added_qkv"):
if not hasattr(self, "add_q_proj"):
has_bias = self.to_added_qkv.bias is not None
d_q, d_k, d_v, d_in = self._added_qkv_split_dims
self.add_q_proj = nn.Linear(d_in, d_q, bias=has_bias)
self.add_k_proj = nn.Linear(d_in, d_k, bias=has_bias)
self.add_v_proj = nn.Linear(d_in, d_v, bias=has_bias)
# Avoid copying by using a view which shares storage with the fused projection
self.add_q_proj.weight = nn.Parameter(self.to_added_qkv.weight[:d_q])
self.add_k_proj.weight = nn.Parameter(self.to_added_qkv.weight[d_q : d_q + d_k])
self.add_v_proj.weight = nn.Parameter(self.to_added_qkv.weight[d_q + d_k :])
if has_bias:
self.add_q_proj.bias = nn.Parameter(self.to_added_qkv.bias[:d_q])
self.add_k_proj.bias = nn.Parameter(self.to_added_qkv.bias[d_q : d_q + d_k])
self.add_v_proj.bias = nn.Parameter(self.to_added_qkv.bias[d_q + d_k :])
delattr(self, "to_added_qkv")

if hasattr(self, "to_added_kv"):
if not hasattr(self, "add_k_proj"):
has_bias = self.to_added_kv.bias is not None
d_k, d_v, d_in = self._added_qkv_split_dims
self.add_k_proj = nn.Linear(d_in, d_k, bias=has_bias)
self.add_v_proj = nn.Linear(d_in, d_v, bias=has_bias)
self.add_k_proj.weight = nn.Parameter(self.to_added_kv.weight[:d_k])
self.add_v_proj.weight = nn.Parameter(self.to_added_kv.weight[d_k:])
if has_bias:
self.add_k_proj.bias = nn.Parameter(self.to_added_kv.bias[:d_k])
self.add_v_proj.bias = nn.Parameter(self.to_added_kv.bias[d_k:])
delattr(self, "to_added_kv")

if hasattr(self, "_qkv_split_dims"):
delattr(self, "_qkv_split_dims")
if hasattr(self, "_added_qkv_split_dims"):
delattr(self, "_added_qkv_split_dims")
self.fused_projections = False

def get_qkv(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Get the query, key, and value from the Q,K,V projections, handling both the split and fused cases.
"""
if self.fused_projections:
if hasattr(self, "to_kv"):
query = self.to_q(hidden_states)
key, value = self.to_kv(encoder_hidden_states).chunk(2, dim=-1)
elif hasattr(self, "to_qkv"):
query, key, value = self.to_qkv(hidden_states).chunk(3, dim=-1)
else:
raise RuntimeError("Cannot find fused self-attn proj `to_qkv` or cross-attn proj `to_kv`.")
else:
query = self.to_q(hidden_states)
kv_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(kv_states)
value = self.to_v(kv_states)
return query, key, value

def get_added_qkv(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Get the added query, key, and value from added Q,K,V projections (for example, second stream projections in a
MM-DiT-style model like Flux). Note that for models with only `add_k_proj`/`add_v_proj` such as Wan, Q comes
from the normal `to_q` projection.
"""
if self.fused_projections:
if hasattr(self, "to_added_kv"):
query = self.to_q(hidden_states)
key, value = self.to_added_kv(encoder_hidden_states).chunk(2, dim=-1)
elif hasattr(self, "to_added_qkv"):
query, key, value = self.to_added_qkv(hidden_states).chunk(3, dim=-1)
else:
raise RuntimeError(
"Cannot find added fused self-attn proj `to_added_qkv` or cross-attn proj `to_added_kv`."
)
else:
query = self.add_q_proj(hidden_states) if hasattr(self, "add_q_proj") else self.to_q(hidden_states)
kv_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.add_k_proj(kv_states)
value = self.add_v_proj(kv_states)
return query, key, value

def set_attention_slice(self, slice_size: int) -> None:
"""
Set the slice size for attention computation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def __init__(

self.set_processor(processor)

def fuse_projections(self):
def fuse_projections(self, inplace: bool = False):
if getattr(self, "fused_projections", False):
return

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_helios.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def __init__(
self.history_scale_mode = history_scale_mode
self.max_scale = 10.0

def fuse_projections(self):
def fuse_projections(self, inplace: bool = False):
if getattr(self, "fused_projections", False):
return

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def __init__(

self.set_processor(processor)

def fuse_projections(self):
def fuse_projections(self, inplace: bool = False):
if getattr(self, "fused_projections", False):
return

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def __init__(

self.set_processor(processor)

def fuse_projections(self):
def fuse_projections(self, inplace: bool = False):
if getattr(self, "fused_projections", False):
return

Expand Down
Loading
Loading