Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 53 additions & 11 deletions fastdeploy/model_executor/layers/attention/dsa_attention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,29 @@ def forward_mixed(
"""
Mixed模式的前向传播
"""
if not forward_meta.max_len_tensor_cpu[1] and not forward_meta.max_len_tensor_cpu[2]:
return None

latent_cache = forward_meta.caches[2 * layer.layer_id] if hasattr(forward_meta, "caches") else None
res = DSAAttentionBackend.forward_static(
q, v, compressed_kv, k_pe, forward_meta.caches[2 * layer.layer_id], forward_meta, self.attn_softmax_scale
)
return res

@staticmethod
def forward_static(
q: paddle.Tensor,
indexer_topk: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
latent_cache: paddle.Tensor,
forward_meta: ForwardMeta,
attn_softmax_scale: float,
) -> paddle.Tensor:

assert len(q.shape) == 3
assert len(compressed_kv.shape) == 2
assert len(k_pe.shape) == 3
assert len(latent_cache.shape) == 4

if current_platform.is_cuda():
import flash_mla
Expand All @@ -352,43 +373,64 @@ def forward_mixed(
"fp8_ds_mla",
)

assert len(q.shape) == 3
q_num_heads = q.shape[1]
ceil64_num_heads = (q_num_heads + 63) // 64 * 64

fmha_out_prefill = None
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
if ceil64_num_heads != q_num_heads:
new_q = paddle.empty([q.shape[0], ceil64_num_heads, q.shape[2]], dtype=q.dtype)
new_q[:, :q_num_heads, :] = q
else:
new_q = q

kv = paddle.concat([compressed_kv.unsqueeze(1), k_pe], axis=-1)
fmha_out_prefill, _, __ = flash_mla.flash_mla_sparse_fwd(
q, # q_input.contiguous(),
k, # kv.unsqueeze(1),
v, # indexer_top_k.unsqueeze(1),
sm_scale=self.attn_softmax_scale,
new_q, # q_input.contiguous(),
kv, # kv.unsqueeze(1),
indexer_topk, # indexer_top_k.unsqueeze(1),
sm_scale=attn_softmax_scale,
)

assert len(fmha_out_prefill.shape) == 3
fmha_out_prefill = fmha_out_prefill[:, :q_num_heads, :].contiguous()

# Decode
# if k is None:
if forward_meta.max_len_tensor_cpu[2]: # max_enc_len_this_time
if forward_meta.max_len_tensor_cpu[2]:

tile_scheduler_metadata, _ = flash_mla.get_mla_metadata()
new_cache_shape = latent_cache.shape
assert new_cache_shape[1] == 1
new_cache_shape[1], new_cache_shape[2] = new_cache_shape[2], new_cache_shape[1]

if ceil64_num_heads != q_num_heads:
new_q = paddle.empty([q.shape[0], ceil64_num_heads, q.shape[2]], dtype=q.dtype)
new_q[:, :q_num_heads, :] = q
else:
new_q = q

fmha_out_decode, _ = flash_mla.flash_mla_with_kvcache(
q.unsqueeze(1).contiguous(),
new_q.unsqueeze(1).contiguous(),
latent_cache.view(new_cache_shape),
None, # forward_meta.block_tables,
None, # cache_seqlens
512, # self.qk_nope_head_dim,
tile_scheduler_metadata,
None, # num_splits,
self.attn_softmax_scale,
attn_softmax_scale,
False, # casual
True, # is_fp8_kvcache
v, # indices,
indexer_topk, # indices,
None, # t.attn_sink,
None, # extra_k_cache,
None, # extra_indices_in_kvcache: Optional[torch.Tensor] = None,
None, # topk_length: Optional[torch.Tensor] = None,
None, # extra_topk_length: Optional[torch.Tensor] = None
)

fmha_out_decode = fmha_out_decode[:, :, :q_num_heads, :].contiguous()

if fmha_out_prefill is not None:

from fastdeploy.model_executor.ops.gpu import (
Expand All @@ -402,7 +444,7 @@ def forward_mixed(
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
self.num_heads * 4,
q_num_heads * 4,
128,
1,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,8 @@ def __init__(
logger.info(
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
)
# swa config
self.window_attn_skip_freq = getattr(fd_config.model_config, "window_attn_skip_freq", None)

def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
Expand Down Expand Up @@ -618,8 +620,13 @@ def get_kv_cache_shape(
"""
Calculate kv cache shape for MLA
"""
key_cache_shape = [max_num_blocks, 1, self.block_size, self.kv_lora_rank + self.qk_rope_head_dim]
layer_id = self.layer_id

This comment was marked as outdated.

value_cache_shape = []
if self.window_attn_skip_freq is not None and self.window_attn_skip_freq[layer_id] == 1:
fp8_key_cahe_dim = self.kv_lora_rank + 4 * (self.kv_lora_rank // 128) + 2 * self.qk_rope_head_dim
key_cache_shape = [max_num_blocks, 1, self.block_size, fp8_key_cahe_dim]
else:
key_cache_shape = [max_num_blocks, 1, self.block_size, self.kv_lora_rank + self.qk_rope_head_dim]
return key_cache_shape, value_cache_shape

def create_kv_cache(
Expand Down
Loading
Loading