diff --git a/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py b/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py index 1bc8c8e8dcd..a416868b0a0 100644 --- a/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py @@ -335,8 +335,29 @@ def forward_mixed( """ Mixed模式的前向传播 """ + if not forward_meta.max_len_tensor_cpu[1] and not forward_meta.max_len_tensor_cpu[2]: + return None - latent_cache = forward_meta.caches[2 * layer.layer_id] if hasattr(forward_meta, "caches") else None + res = DSAAttentionBackend.forward_static( + q, v, compressed_kv, k_pe, forward_meta.caches[2 * layer.layer_id], forward_meta, self.attn_softmax_scale + ) + return res + + @staticmethod + def forward_static( + q: paddle.Tensor, + indexer_topk: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, + latent_cache: paddle.Tensor, + forward_meta: ForwardMeta, + attn_softmax_scale: float, + ) -> paddle.Tensor: + + assert len(q.shape) == 3 + assert len(compressed_kv.shape) == 2 + assert len(k_pe.shape) == 3 + assert len(latent_cache.shape) == 4 if current_platform.is_cuda(): import flash_mla @@ -352,36 +373,55 @@ def forward_mixed( "fp8_ds_mla", ) + assert len(q.shape) == 3 + q_num_heads = q.shape[1] + ceil64_num_heads = (q_num_heads + 63) // 64 * 64 + fmha_out_prefill = None if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time + if ceil64_num_heads != q_num_heads: + new_q = paddle.empty([q.shape[0], ceil64_num_heads, q.shape[2]], dtype=q.dtype) + new_q[:, :q_num_heads, :] = q + else: + new_q = q + kv = paddle.concat([compressed_kv.unsqueeze(1), k_pe], axis=-1) fmha_out_prefill, _, __ = flash_mla.flash_mla_sparse_fwd( - q, # q_input.contiguous(), - k, # kv.unsqueeze(1), - v, # indexer_top_k.unsqueeze(1), - sm_scale=self.attn_softmax_scale, + new_q, # q_input.contiguous(), + kv, # kv.unsqueeze(1), + indexer_topk, # indexer_top_k.unsqueeze(1), + sm_scale=attn_softmax_scale, ) + assert len(fmha_out_prefill.shape) == 3 + fmha_out_prefill = fmha_out_prefill[:, :q_num_heads, :].contiguous() + # Decode - # if k is None: - if forward_meta.max_len_tensor_cpu[2]: # max_enc_len_this_time + if forward_meta.max_len_tensor_cpu[2]: tile_scheduler_metadata, _ = flash_mla.get_mla_metadata() new_cache_shape = latent_cache.shape assert new_cache_shape[1] == 1 new_cache_shape[1], new_cache_shape[2] = new_cache_shape[2], new_cache_shape[1] + + if ceil64_num_heads != q_num_heads: + new_q = paddle.empty([q.shape[0], ceil64_num_heads, q.shape[2]], dtype=q.dtype) + new_q[:, :q_num_heads, :] = q + else: + new_q = q + fmha_out_decode, _ = flash_mla.flash_mla_with_kvcache( - q.unsqueeze(1).contiguous(), + new_q.unsqueeze(1).contiguous(), latent_cache.view(new_cache_shape), None, # forward_meta.block_tables, None, # cache_seqlens 512, # self.qk_nope_head_dim, tile_scheduler_metadata, None, # num_splits, - self.attn_softmax_scale, + attn_softmax_scale, False, # casual True, # is_fp8_kvcache - v, # indices, + indexer_topk, # indices, None, # t.attn_sink, None, # extra_k_cache, None, # extra_indices_in_kvcache: Optional[torch.Tensor] = None, @@ -389,6 +429,8 @@ def forward_mixed( None, # extra_topk_length: Optional[torch.Tensor] = None ) + fmha_out_decode = fmha_out_decode[:, :, :q_num_heads, :].contiguous() + if fmha_out_prefill is not None: from fastdeploy.model_executor.ops.gpu import ( @@ -402,7 +444,7 @@ def forward_mixed( forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, forward_meta.cu_seqlens_q, - self.num_heads * 4, + q_num_heads * 4, 128, 1, ) diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index ba1ef6fab0c..855c34c8343 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -542,6 +542,8 @@ def __init__( logger.info( "The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead." ) + # swa config + self.window_attn_skip_freq = getattr(fd_config.model_config, "window_attn_skip_freq", None) def init_attention_metadata(self, forward_meta: ForwardMeta): """Initialize attention metadata hence all layers in the forward pass can reuse it.""" @@ -618,8 +620,13 @@ def get_kv_cache_shape( """ Calculate kv cache shape for MLA """ - key_cache_shape = [max_num_blocks, 1, self.block_size, self.kv_lora_rank + self.qk_rope_head_dim] + layer_id = self.layer_id value_cache_shape = [] + if self.window_attn_skip_freq is not None and self.window_attn_skip_freq[layer_id] == 1: + fp8_key_cahe_dim = self.kv_lora_rank + 4 * (self.kv_lora_rank // 128) + 2 * self.qk_rope_head_dim + key_cache_shape = [max_num_blocks, 1, self.block_size, fp8_key_cahe_dim] + else: + key_cache_shape = [max_num_blocks, 1, self.block_size, self.kv_lora_rank + self.qk_rope_head_dim] return key_cache_shape, value_cache_shape def create_kv_cache( diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 1a89d6a756e..5d04e4e41df 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -72,6 +72,78 @@ ) +import triton +import triton.language as tl + + +@enable_compat_on_triton_kernel +@triton.jit +def get_swa_indexer_top_k_kernel( + indexer_top_k, + block_tables, + cu_seqlens_q, + seq_lens_encoder, + seq_lens_decoder, + batch_id_per_token, + max_page_per_seq: tl.constexpr, + window_size: tl.constexpr, + page_size: tl.constexpr, +): + token_id = tl.program_id(0) + + indexer_top_k += token_id * window_size + + batch_id = tl.load(batch_id_per_token + token_id) + if batch_id < 0: + return + + block_tables += batch_id * max_page_per_seq + + kv_len = tl.load(seq_lens_decoder + batch_id) + encoder_len = tl.load(seq_lens_encoder + batch_id) + cu_q_len = tl.load(cu_seqlens_q + batch_id) + token_id_in_this_batch = token_id - cu_q_len + kv_len + + valid_window_size = min(token_id_in_this_batch + 1, window_size) + + for idx in range(token_id_in_this_batch, token_id_in_this_batch - valid_window_size, -1): + if encoder_len > 0: + # encoder case. + tmp = cu_q_len + idx + tl.store(indexer_top_k + token_id_in_this_batch - idx, tmp) + else: + tmp = tl.load(block_tables + idx // page_size) + tmp = tmp * page_size + idx % page_size + tl.store(indexer_top_k + token_id_in_this_batch - idx, tmp) + + +def get_swa_indexer_top_k( + indexer_top_k, + block_tables, + cu_seqlens_q, + seq_lens_encoder, + seq_lens_decoder, + batch_id_per_token, +): + assert indexer_top_k.ndim == 3 + assert indexer_top_k.shape[1] == 1 + + token_num = indexer_top_k.shape[0] + grid = (token_num,) + + get_swa_indexer_top_k_kernel[grid]( + indexer_top_k, + block_tables, + cu_seqlens_q, + seq_lens_encoder, + seq_lens_decoder, + batch_id_per_token, + max_page_per_seq=block_tables.shape[1], + window_size=indexer_top_k.shape[2], + page_size=64, + ) + + class DeepSeekV3MLP(nn.Layer): """ DeepSeekV3MLP, for Dense FFN and Shared Experts Layer. @@ -226,6 +298,10 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None self.q_lora_rank = fd_config.model_config.q_lora_rank self.kv_lora_rank = fd_config.model_config.kv_lora_rank + # swa + self.swa_layer_list = getattr(fd_config.model_config, "window_attn_skip_freq", None) + self.sliding_window = getattr(fd_config.model_config, "sliding_window", 0) + self.attn_softmax_scale = self.qk_head_dim**-0.5 if fd_config.model_config.model_type == "glm_moe_dsa": @@ -361,6 +437,58 @@ def yarn_get_mscale(scale=1, mscale=1): return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 + def forward_swa_static( + self, + forward_meta: ForwardMeta, + query_nope: paddle.Tensor, + query_pe: paddle.Tensor, + compressed_kv: paddle.Tensor, + key_pe: paddle.Tensor, + ): + """MLA static attention with sliding window indexer.""" + q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2]) + + q_input = paddle.concat([q_nope_out, query_pe], axis=-1) + q_input.reshape_( + [ + -1, + self.num_attention_heads_tp, + self.kv_lora_rank + self.qk_rope_head_dim, + ] + ) + + self.index_topk = self.sliding_window + indexer_top_k = paddle.full([q_input.shape[0], 1, self.index_topk], -1, dtype="int32") + + get_swa_indexer_top_k( + indexer_top_k, + forward_meta.block_tables, + forward_meta.cu_seqlens_q, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.batch_id_per_token, + ) + + from fastdeploy.model_executor.layers.attention import DSAAttentionBackend + + fmqa_out = DSAAttentionBackend.forward_static( + q=q_input.contiguous(), + indexer_topk=indexer_top_k, + compressed_kv=compressed_kv, + k_pe=key_pe, + latent_cache=forward_meta.caches[self.layer_id], + forward_meta=forward_meta, + attn_softmax_scale=self.attn_softmax_scale, + ) + + fmqa_out = fmqa_out.reshape_([-1, self.num_attention_heads_tp, self.kv_lora_rank]).transpose([1, 0, 2]) + + return ( + self.kv_b_proj_bmm(fmqa_out, proj_type="v") + .transpose([1, 0, 2]) + .reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) + ) + def forward( self, forward_meta: ForwardMeta, @@ -398,142 +526,156 @@ def forward( need_do_prefill = forward_meta.max_len_tensor_cpu[1] > 0 need_do_decode = forward_meta.max_len_tensor_cpu[2] > 0 - if need_do_prefill: - # Handle prefix cache: read cached latent from paged cache and interleave - # with the new-token latent in a single fused kernel call. - full_compressed_kv = compressed_kv - full_k_pe = key_pe.squeeze(1) - if self.enable_chunked_prefill or self.enable_prefix_caching: - - full_compressed_kv, full_k_pe = fused_read_cache_and_interleave( - forward_meta.caches[self.layer_id], - forward_meta.block_tables, - compressed_kv, - key_pe.squeeze(1), - forward_meta.cu_seqlens_k, - forward_meta.cu_seqlens_q, - self.kv_lora_rank, - self.qk_rope_head_dim, - self.block_size, - ) + window_attn_skip_freq = getattr(self.fd_config.model_config, "window_attn_skip_freq", None) - # Project latent KV to full key and value - key_value = self.kv_b_proj(full_compressed_kv) - key_value.reshape_( - [ - -1, - self.num_attention_heads_tp, - self.qk_nope_head_dim + self.v_head_dim, - ] - ) - key_nope, value = key_value.split([self.qk_nope_head_dim, self.v_head_dim], axis=-1) - - query[..., self.qk_nope_head_dim :] = query_pe - key = paddle.empty([full_k_pe.shape[0], self.num_attention_heads_tp, self.qk_head_dim], dtype=query.dtype) - key[..., : self.qk_nope_head_dim] = key_nope - key[..., self.qk_nope_head_dim :] = full_k_pe.unsqueeze(1) - if self.qk_head_dim - self.v_head_dim != 0: - value = paddle.nn.functional.pad(value, [0, self.qk_head_dim - self.v_head_dim], value=0) - - fmha_out = self.mla_attn( - q=query, - k=key, - v=value, - qkv=None, - compressed_kv=compressed_kv, # Pass original (new only) for cache writing - k_pe=key_pe, # Pass original (new only) for cache writing + if self.sliding_window > 0 and window_attn_skip_freq is not None and window_attn_skip_freq[self.layer_id] == 1: + attn_out = self.forward_swa_static( forward_meta=forward_meta, + query_nope=query_nope, + query_pe=query_pe, + compressed_kv=compressed_kv, + key_pe=key_pe, ) - - fmha_out.reshape_([-1, self.num_attention_heads_tp, self.qk_head_dim]) - fmha_out = fmha_out[:, :, : self.v_head_dim] - fmha_out.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) - attn_out = fmha_out - - if need_do_decode: # max_dec_len_this_time - - if int(os.getenv("USE_FLASH_MLA", "0")) == 0 and self.prop.major == 9: - pass - else: - from fastdeploy.model_executor.layers.attention.mla_attention_backend import ( - extract_decoder_token_from_q, - insert_decoder_result_back, + else: + if need_do_prefill: + # Handle prefix cache: read cached latent from paged cache and interleave + # with the new-token latent in a single fused kernel call. + full_compressed_kv = compressed_kv + full_k_pe = key_pe.squeeze(1) + if self.enable_chunked_prefill or self.enable_prefix_caching: + + full_compressed_kv, full_k_pe = fused_read_cache_and_interleave( + forward_meta.caches[self.layer_id], + forward_meta.block_tables, + compressed_kv, + key_pe.squeeze(1), + forward_meta.cu_seqlens_k, + forward_meta.cu_seqlens_q, + self.kv_lora_rank, + self.qk_rope_head_dim, + self.block_size, + ) + + # Project latent KV to full key and value + key_value = self.kv_b_proj(full_compressed_kv) + key_value.reshape_( + [ + -1, + self.num_attention_heads_tp, + self.qk_nope_head_dim + self.v_head_dim, + ] ) + key_nope, value = key_value.split([self.qk_nope_head_dim, self.v_head_dim], axis=-1) - decoder_query_nope, cache_seqlens = extract_decoder_token_from_q( - query_nope.reshape([0, -1]), - forward_meta.cu_seqlens_q, - forward_meta.seq_lens_encoder, - forward_meta.seq_lens_decoder, + query[..., self.qk_nope_head_dim :] = query_pe + key = paddle.empty( + [full_k_pe.shape[0], self.num_attention_heads_tp, self.qk_head_dim], dtype=query.dtype ) - - decoder_query_pe, cache_seqlens = extract_decoder_token_from_q( - query_pe.reshape([0, -1]), - forward_meta.cu_seqlens_q, - forward_meta.seq_lens_encoder, - forward_meta.seq_lens_decoder, + key[..., : self.qk_nope_head_dim] = key_nope + key[..., self.qk_nope_head_dim :] = full_k_pe.unsqueeze(1) + if self.qk_head_dim - self.v_head_dim != 0: + value = paddle.nn.functional.pad(value, [0, self.qk_head_dim - self.v_head_dim], value=0) + + fmha_out = self.mla_attn( + q=query, + k=key, + v=value, + qkv=None, + compressed_kv=compressed_kv, # Pass original (new only) for cache writing + k_pe=key_pe, # Pass original (new only) for cache writing + forward_meta=forward_meta, ) - assert decoder_query_nope.shape[0] == forward_meta.seq_lens_encoder.shape[0] - assert decoder_query_pe.shape[0] == forward_meta.seq_lens_encoder.shape[0] - - forward_meta.cache_seqlens = cache_seqlens - query_nope = decoder_query_nope.reshape([0, -1, self.qk_nope_head_dim]) - query_pe = decoder_query_pe.reshape([0, -1, self.qk_rope_head_dim]) + fmha_out.reshape_([-1, self.num_attention_heads_tp, self.qk_head_dim]) + fmha_out = fmha_out[:, :, : self.v_head_dim] + fmha_out.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) + attn_out = fmha_out - q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2]) - - q_input = paddle.concat([q_nope_out, query_pe], axis=-1) - q_input.reshape_( - [ - -1, - self.num_attention_heads_tp * (self.kv_lora_rank + self.qk_rope_head_dim), - ] - ) + if need_do_decode: # max_dec_len_this_time - fmqa_out = self.mla_attn( - q=q_input, - k=None, - v=None, - qkv=None, - compressed_kv=compressed_kv, - k_pe=key_pe, - forward_meta=forward_meta, - ) + if int(os.getenv("USE_FLASH_MLA", "0")) == 0 and self.prop.major == 9: + pass + else: + from fastdeploy.model_executor.layers.attention.mla_attention_backend import ( + extract_decoder_token_from_q, + insert_decoder_result_back, + ) + + decoder_query_nope, cache_seqlens = extract_decoder_token_from_q( + query_nope.reshape([0, -1]), + forward_meta.cu_seqlens_q, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + ) + + decoder_query_pe, cache_seqlens = extract_decoder_token_from_q( + query_pe.reshape([0, -1]), + forward_meta.cu_seqlens_q, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + ) + assert decoder_query_nope.shape[0] == forward_meta.seq_lens_encoder.shape[0] + assert decoder_query_pe.shape[0] == forward_meta.seq_lens_encoder.shape[0] + + forward_meta.cache_seqlens = cache_seqlens + + query_nope = decoder_query_nope.reshape([0, -1, self.qk_nope_head_dim]) + query_pe = decoder_query_pe.reshape([0, -1, self.qk_rope_head_dim]) + + q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2]) + + q_input = paddle.concat([q_nope_out, query_pe], axis=-1) + q_input.reshape_( + [ + -1, + self.num_attention_heads_tp * (self.kv_lora_rank + self.qk_rope_head_dim), + ] + ) - fmqa_out = fmqa_out.reshape_([-1, self.num_attention_heads_tp, self.kv_lora_rank]).transpose([1, 0, 2]) + fmqa_out = self.mla_attn( + q=q_input, + k=None, + v=None, + qkv=None, + compressed_kv=compressed_kv, + k_pe=key_pe, + forward_meta=forward_meta, + ) - fmqa_out = ( - self.kv_b_proj_bmm(fmqa_out, proj_type="v") - .transpose([1, 0, 2]) - .reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) - ) + fmqa_out = fmqa_out.reshape_([-1, self.num_attention_heads_tp, self.kv_lora_rank]).transpose([1, 0, 2]) - if int(os.getenv("USE_FLASH_MLA", "0")) == 0 and self.prop.major == 9: - pass - else: - fmqa_out = insert_decoder_result_back( - fmqa_out.reshape([0, 1, self.num_attention_heads_tp, self.v_head_dim]), - forward_meta.cu_seqlens_q, - forward_meta.seq_lens_encoder, - forward_meta.seq_lens_decoder, - q_total_token_num, + fmqa_out = ( + self.kv_b_proj_bmm(fmqa_out, proj_type="v") + .transpose([1, 0, 2]) + .reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) ) - if need_do_prefill: - merge_prefill_decode_output( - attn_out, - fmqa_out, - forward_meta.seq_lens_encoder, - forward_meta.seq_lens_decoder, - forward_meta.seq_lens_this_time, - forward_meta.cu_seqlens_q, - self.num_attention_heads_tp, - self.v_head_dim, - 1, - ) - else: - attn_out = fmqa_out + if int(os.getenv("USE_FLASH_MLA", "0")) == 0 and self.prop.major == 9: + pass + else: + fmqa_out = insert_decoder_result_back( + fmqa_out.reshape([0, 1, self.num_attention_heads_tp, self.v_head_dim]), + forward_meta.cu_seqlens_q, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + q_total_token_num, + ) + + if need_do_prefill: + merge_prefill_decode_output( + attn_out, + fmqa_out, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.cu_seqlens_q, + self.num_attention_heads_tp, + self.v_head_dim, + 1, + ) + else: + attn_out = fmqa_out + if self.use_gated_attn: gated_attn_act = getattr(self.fd_config.model_config, "gated_attn_act", "sigmoid") if gated_attn_act == "sigmoid": @@ -547,7 +689,6 @@ def forward( import triton -import triton.language as tl @enable_compat_on_triton_kernel @@ -894,12 +1035,12 @@ def forward( q_input = paddle.concat([q_nope_out.transpose([1, 0, 2]).contiguous(), query_pe], axis=-1) compressed_kv = self.kv_a_layernorm(compressed_kv)[0] - kv = paddle.concat([compressed_kv, key_pe.squeeze(1)], axis=-1) + # kv = paddle.concat([compressed_kv, key_pe.squeeze(1)], axis=-1) # dsa attention fmha_out = self.dsa_attn( q=q_input.contiguous(), - k=kv.unsqueeze(1).contiguous(), + k=None, # kv.unsqueeze(1).contiguous(), v=indexer_top_k.unsqueeze(1).contiguous(), qkv=None, compressed_kv=compressed_kv, diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index b51415dbb68..a8df36acda7 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -308,6 +308,8 @@ def __init__( if self.enable_overlap_schedule: logger.info("Using overlap schedule") self.current_launch_token_num = 0 + # swa config + self.window_attn_skip_freq = getattr(self.fd_config.model_config, "window_attn_skip_freq", None) def _async_output_busy_loop(self): """Entrypoint for the thread which handles outputs asynchronously.""" @@ -1598,7 +1600,8 @@ def initialize_kv_cache(self, profile: bool = False) -> None: key_cache_shapes = [] value_cache_shapes = [] indexer_cache_shapes = [] - for attn_backend in self.attn_backends: + for layer_id, attn_backend in enumerate(self.attn_backends): + attn_backend.layer_id = layer_id kv_cache_shape = attn_backend.get_kv_cache_shape( max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type ) @@ -1648,6 +1651,23 @@ def initialize_kv_cache(self, profile: bool = False) -> None: logger.info( f"..creating kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}, indexer:{indexer_cache_shape}" ) + # swa mla cache type + if self.mla_cache and self.window_attn_skip_freq is not None and self.window_attn_skip_freq[i] == 1: + cache_type = "uint8" + kv_cache_quant_type = "uint8" + else: + # Get kv cache dtype + cache_type = self.model_config.dtype + kv_cache_quant_type = None + + if ( + self.quant_config + and hasattr(self.quant_config, "kv_cache_quant_type") + and self.quant_config.kv_cache_quant_type is not None + ): + cache_type = "uint8" + kv_cache_quant_type = self.quant_config.kv_cache_quant_type + key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_type) set_data_ipc(key_cache, key_cache_name) self.cache_kvs_map[key_cache_name] = key_cache diff --git a/tests/layers/test_dsa_attention_backend.py b/tests/layers/test_dsa_attention_backend.py index 48553643593..b3533bfeab6 100644 --- a/tests/layers/test_dsa_attention_backend.py +++ b/tests/layers/test_dsa_attention_backend.py @@ -631,7 +631,7 @@ def test_forward_mixed_decode_only(self, mock_abs, mock_randn, mock_init_rank, m mock_flash_mla = MagicMock() mock_flash_mla.get_mla_metadata.return_value = ("tile_meta", None) - mock_flash_mla.flash_mla_with_kvcache.return_value = ("decode_output", None) + mock_flash_mla.flash_mla_with_kvcache.return_value = (paddle.zeros([2, 1, 16, 512], dtype="float32"), None) mock_dsk_write = MagicMock() gpu_module = MagicMock() @@ -648,17 +648,17 @@ def test_forward_mixed_decode_only(self, mock_abs, mock_randn, mock_init_rank, m }, ): result = backend.forward_mixed( - q=MagicMock(), + q=paddle.zeros([2, 16, 192], dtype="float32"), k=None, - v=MagicMock(), + v=paddle.zeros([2, 1, 8], dtype="int32"), qkv=None, - compressed_kv=MagicMock(), - k_pe=MagicMock(), + compressed_kv=paddle.zeros([2, 512], dtype="float32"), + k_pe=paddle.zeros([2, 1, 64], dtype="float32"), layer=layer, forward_meta=forward_meta, ) - self.assertEqual(result, "decode_output") + self.assertEqual(result.shape, [2, 1, 16, 512]) @patch("fastdeploy.model_executor.layers.attention.dsa_attention_backend.current_platform") @patch( @@ -720,9 +720,9 @@ def test_forward_mixed_both_prefill_and_decode(self, mock_abs, mock_randn, mock_ mock_abs.return_value.max.return_value = scale_mock mock_flash_mla = MagicMock() - mock_flash_mla.flash_mla_sparse_fwd.return_value = ("prefill_out", None, None) + mock_flash_mla.flash_mla_sparse_fwd.return_value = (paddle.zeros([2, 64, 512], dtype="float32"), None, None) mock_flash_mla.get_mla_metadata.return_value = ("tile_meta", None) - mock_flash_mla.flash_mla_with_kvcache.return_value = ("decode_out", None) + mock_flash_mla.flash_mla_with_kvcache.return_value = (paddle.zeros([2, 1, 16, 512], dtype="float32"), None) mock_dsk_write = MagicMock() mock_merge = MagicMock() @@ -741,18 +741,18 @@ def test_forward_mixed_both_prefill_and_decode(self, mock_abs, mock_randn, mock_ }, ): result = backend.forward_mixed( - q=MagicMock(), + q=paddle.zeros([2, 16, 192], dtype="float32"), k=MagicMock(), - v=MagicMock(), + v=paddle.zeros([2, 1, 8], dtype="int32"), qkv=None, - compressed_kv=MagicMock(), - k_pe=MagicMock(), + compressed_kv=paddle.zeros([2, 512], dtype="float32"), + k_pe=paddle.zeros([2, 1, 64], dtype="float32"), layer=layer, forward_meta=forward_meta, ) # When both prefill and decode, returns fmha_out_prefill after merge - self.assertEqual(result, "prefill_out") + self.assertEqual(result.shape, [2, 16, 512]) mock_merge.assert_called_once() @patch("fastdeploy.model_executor.layers.attention.dsa_attention_backend.current_platform") diff --git a/tests/layers/test_mla_attention_kv_cache.py b/tests/layers/test_mla_attention_kv_cache.py index 4bbb09b4334..d3e15eb8527 100644 --- a/tests/layers/test_mla_attention_kv_cache.py +++ b/tests/layers/test_mla_attention_kv_cache.py @@ -29,6 +29,8 @@ def _make_mla_backend(block_size=64, kv_lora_rank=512, qk_rope_head_dim=64): backend.block_size = block_size backend.kv_lora_rank = kv_lora_rank backend.qk_rope_head_dim = qk_rope_head_dim + backend.layer_id = 0 + backend.window_attn_skip_freq = None return backend