Skip to content

[OP] Support MLA sliding-window attention#8060

Merged
EmmonsCurse merged 5 commits into
PaddlePaddle:developfrom
chang-wenbin:mla-swa1
Jun 17, 2026
Merged

[OP] Support MLA sliding-window attention#8060
EmmonsCurse merged 5 commits into
PaddlePaddle:developfrom
chang-wenbin:mla-swa1

Conversation

@chang-wenbin

@chang-wenbin chang-wenbin commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator

Motivation

为 DeepSeek V3 MLA attention 增加 sliding-window attention (SWA) 支持,并复用 DSA/FlashMLA static attention 路径以降低 SWA 层 KV cache 开销。

Modifications

  • 在 DeepSeek V3 MLA attention 中新增 SWA indexer 生成 kernel 和 forward_swa_static 路径。
  • DSAAttentionBackend 的 mixed attention 逻辑拆出 forward_static,供 MLA SWA 路径复用。
  • 根据 window_attn_skip_freq 为 MLA SWA 层使用 packed fp8 key cache shape,并在 GPU cache 初始化中为 SWA 层创建 uint8 cache。
  • 更新 KV cache 理论显存估算,删除 test_dsa_attention_backend.py 并保留 MLA KV cache 单测。

Usage or Command

N/A

Accuracy Tests

N/A

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@codecov-commenter

codecov-commenter commented Jun 16, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 15.66265% with 140 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@cbb0811). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/model_executor/models/deepseek_v3.py 8.13% 79 Missing ⚠️
...executor/layers/attention/dsa_attention_backend.py 15.00% 51 Missing ⚠️
fastdeploy/worker/gpu_model_runner.py 57.14% 3 Missing and 3 partials ⚠️
...executor/layers/attention/mla_attention_backend.py 33.33% 3 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #8060   +/-   ##
==========================================
  Coverage           ?   67.24%           
==========================================
  Files              ?      475           
  Lines              ?    66867           
  Branches           ?    10312           
==========================================
  Hits               ?    44968           
  Misses             ?    19024           
  Partials           ?     2875           
Flag Coverage Δ
GPU 77.26% <15.66%> (?)
XPU 6.96% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

PaddlePaddle-bot

This comment was marked as outdated.

@chang-wenbin chang-wenbin changed the title Mla swa1 [OP] Support MLA sliding-window attention Jun 17, 2026
@PaddlePaddle-bot

Copy link
Copy Markdown

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-06-17 14:04:55

CI报告基于以下代码生成(30分钟更新一次):
PR commit: 38424bf | Merge base: cbb0811 (branch: develop)


1 Required任务 : 9/10 通过

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
41(0) 41 36 5 0 0 0
任务 错误类型 置信度 日志
Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage PR问题 Job

2 失败详情

🔴 Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage — PR问题(置信度: 高)

错误类型: PR问题 | 置信度: 高
分析器: ci_analyze_unittest_fastdeploy
失败用例: 覆盖率阈值检查失败

用例 错误摘要
Verify Code Coverage Threshold (80%) 单测通过,但 diff coverage 仅 13%,低于 80% 阈值,覆盖率步骤以 exit code 9 失败

关键日志:

TEST_EXIT_CODE: 0
COVERAGE_EXIT_CODE: 9
All tests passed
Coverage generation failed (exit code 9)
fastdeploy/model_executor/layers/attention/mla_attention_backend.py: 50.0%, violations [546, 626, 627]
fastdeploy/model_executor/layers/attention/dsa_attention_backend.py: 4.878%, 39 条 violation
fastdeploy/model_executor/models/deepseek_v3.py: 8.140%, 79 条 violation
fastdeploy/worker/gpu_model_runner.py: 53.846%, violations [1652,1653,1664,1665,3065,3066]
total_num_violations: 127, total_percent_covered: 13, num_changed_lines: 410
  • 根因摘要: Diff coverage 13% 未达 80%

本次失败不是单测断言失败或 CI 环境异常:Run FastDeploy Unit Tests and Coverage 步骤完成后 TEST_EXIT_CODE=0,随后覆盖率校验步骤因 COVERAGE_EXIT_CODE=9 失败。覆盖率报告显示 PR 新增/修改的 410 行中有 127 行未覆盖,总 diff coverage 只有 13%,低于 required job 的 80% 阈值。

未覆盖代码集中在本 PR 的核心变更:deepseek_v3.py 新增 get_swa_indexer_top_k_kernel/get_swa_indexer_top_kforward_swa_static 路径,dsa_attention_backend.py 抽出的 DSAAttentionBackend.forward_static prefill/decode/mixed 分支,mla_attention_backend.py 的 SWA cache shape 分支,以及 gpu_model_runner.py 中 SWA MLA cache dtype 和显存估算逻辑。PR 同时删除了 tests/layers/test_dsa_attention_backend.py 843 行测试,仅在 test_mla_attention_kv_cache.py 中补了 layer_id/window_attn_skip_freq=None 两行,未覆盖这些新增 SWA/DSA static 分支。

修复建议:

  1. tests/layers/test_mla_attention_kv_cache.py 或新增测试中覆盖 MLAAttentionBackend.get_kv_cache_shapewindow_attn_skip_freq[layer_id] == 1 分支,断言 fp8 key cache shape 为 kv_lora_rank + 4 * (kv_lora_rank // 128) + 2 * qk_rope_head_dim
  2. 补充 DSAAttentionBackend.forward_static 的 mock 单测,覆盖 prefill、decode、prefill+decode、head 数补齐到 64、insert_decoder_result_back 回填等分支,替代被删除的 tests/layers/test_dsa_attention_backend.py 覆盖面。
  3. DeepseekV3MLAAttention.forward_swa_staticget_swa_indexer_top_k 增加最小张量/mock 测试,至少覆盖 encoder 与 decoder 两类 indexer 生成,以及 window_attn_skip_freq[self.layer_id] == 1forward 路由到 SWA static 路径。
  4. GPUModelRunner cache 初始化和显存估算补充单测,覆盖 SWA 层 cache_type="uint8"、非 SWA 层保持模型 dtype、以及 window_attn_skip_freq 混合层数时的 required memory 计算。

关联变更: fastdeploy/model_executor/models/deepseek_v3.pyfastdeploy/model_executor/layers/attention/dsa_attention_backend.pyfastdeploy/model_executor/layers/attention/mla_attention_backend.pyfastdeploy/worker/gpu_model_runner.pytests/layers/test_dsa_attention_backend.pytests/layers/test_mla_attention_kv_cache.py

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-06-17 14:42:51

📋 Review 摘要

PR 概述:为 DeepSeek V3 MLA attention 增加 SWA static attention 路径,并调整 MLA/SWA KV cache 形状与 dtype。
变更范围deepseek_v3.pylayers/attention/*gpu_model_runner.py、相关 KV cache 单测。
影响面 Tag[Models] [OP] [KVCache]

问题

级别 文件 概述
🔴 Bug fastdeploy/model_executor/models/deepseek_v3.py:143 SWA indexer 硬编码 64 作为 page size,非默认 --block-size 下会读错 KV cache 位置

历史 Findings 修复情况

Finding 问题 状态
F1 SWA prefill 路径没有读取 prefix/chunked prefill 的历史 KV。 ⚠️ 仍存在
F2 layer_id 和 SWA uint8 cache dtype 只在非 V1 cache 初始化路径生效。 ⚠️ 仍存在

📝 PR 规范检查

已修复:标题已包含官方 [OP] Tag,PR 描述也已补齐 Motivation / Modifications / Usage or Command / Accuracy Tests / Checklist 结构。

总体评价

当前实现仍有一个会导致非默认 block size 下 SWA decode 访问错误 KV cache 的阻塞问题,需要在合入前修复。另有两条历史 finding 在当前 diff 中仍未解决,本轮未重复发 inline comment。

batch_id_per_token,
max_page_per_seq=block_tables.shape[1],
window_size=indexer_top_k.shape[2],
page_size=64,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug 这里把 SWA indexer 的 page_size 固定为 64,但 KV block size 是可配置的 cache_config.block_size

当前 block_tables 的页号按实际 block size 分配,其他 cache 读写路径也使用 self.block_size。当用户通过 --block-size 配置为非 64 时,这里的 idx // 64idx % 64 会生成错误的物理 token id,decode SWA 会从错误 cache 位置取 KV,直接导致 attention 结果错误。

建议把 block_size 作为参数从 DeepseekV3MLAAttention.forward_swa_static() 传入 get_swa_indexer_top_k(),并在 kernel launch 中使用该值作为 page_size,不要硬编码 64。

@EmmonsCurse EmmonsCurse left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM~ Skip coverage check as it mainly relies on end-to-end tests.

@EmmonsCurse EmmonsCurse merged commit 6076add into PaddlePaddle:develop Jun 17, 2026
70 of 75 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants