[Bugfix] Fix qwen3-tts create_causal_mask kwarg for transformers >=5.9.0#3786
[Bugfix] Fix qwen3-tts create_causal_mask kwarg for transformers >=5.9.0#3786Yadan-Wei wants to merge 5 commits into
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
linyueqian
left a comment
There was a problem hiding this comment.
Verified on an H20 against the real Qwen3-TTS-12Hz-0.6B-Base speech tokenizer, transformers 4.57.6 / 5.8.1 / 5.9.0.
The bug is real. transformers 5.9.0 (current PyPI latest) removed the deprecated input_embeds alias from create_causal_mask, and vllm 0.21.0 pins transformers!=5.0.*,...,!=5.5.0,>=4.56.0 with no upper bound, so a fresh install resolves to 5.9.0 and every qwen3-tts decode fails as reported.
But this patch does not fix it, and it regresses transformers 4.x. Two things changed in 5.9.0, not one: the input_embeds to inputs_embeds rename, and removal of cache_position from both create_causal_mask and create_sliding_window_causal_mask. mask_kwargs still passes cache_position, so after this rename the decode still raises TypeError: create_causal_mask() got an unexpected keyword argument 'cache_position' on 5.9.0. And transformers 4.56 to 4.57.x, also allowed by vllm 0.21.0, only accept input_embeds (singular), so renaming to inputs_embeds breaks 4.x.
mask_kwargs |
tfm 4.57.6 | tfm 5.8.1 | tfm 5.9.0 |
|---|---|---|---|
| current main | works | works | fails (input_embeds) |
| this PR | fails (inputs_embeds) |
works | fails (cache_position) |
| signature-filtered | works | works | works |
The matrix above is create_causal_mask binding plus end-to-end decode() on real Qwen3-TTS-12Hz-0.6B-Base weights. This PR trades a 5.9.0-only break for a 4.x break that still fails on 5.9.0. Requesting changes; the fix needs to be version-aware. See the inline comment.
| mask_kwargs = { | ||
| "config": self.config, | ||
| "input_embeds": inputs_embeds, | ||
| "inputs_embeds": inputs_embeds, |
There was a problem hiding this comment.
[blocking] One rename is not enough, and this also breaks transformers 4.x. transformers 5.9.0 removed two kwargs from create_causal_mask and create_sliding_window_causal_mask: the input_embeds alias and cache_position. mask_kwargs still feeds cache_position into both helpers (lines 576 and 580), so on 5.9.0 the decode still raises TypeError: ... unexpected keyword argument 'cache_position'. And transformers 4.56 to 4.57.x only accept input_embeds (singular), so the plural name fails there. No static dict works on both 4.x and 5.9.0; filter by the live signature instead:
# add `import inspect` at module top
mask_kwargs = {
"config": self.config,
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
def _mask_args(fn):
params = inspect.signature(fn).parameters
args = {k: v for k, v in mask_kwargs.items() if k in params}
if "inputs_embeds" not in params and "input_embeds" in params:
args["input_embeds"] = args.pop("inputs_embeds")
return args
causal_mask_mapping = {"full_attention": create_causal_mask(**_mask_args(create_causal_mask))}
if self.has_sliding_layers:
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(
**_mask_args(create_sliding_window_causal_mask)
)Verified end-to-end on real Qwen3-TTS weights: this passes on both transformers 4.57.6 and 5.9.0.
There was a problem hiding this comment.
Thanks for the careful matrix — you're right that the rename alone wasn't enough and that 4.x needs the singular form. Adopted the signature-filtered helper exactly as suggested in 77ee11b: added import inspect, kept the full mask_kwargs, and routed both create_causal_mask and create_sliding_window_causal_mask through _mask_args(fn) so each call only forwards kwargs the installed transformers version accepts (drops cache_position on 5.9.0, renames to input_embeds on 4.x). PTAL when you have a moment.
|
Pushed a one-line fixup as 5f3fdb7. The Verified end to end on Qwen3-TTS-12Hz-0.6B-Base: @Gaohan123 could you take another look, since you own this tokenizer model. The change is small but it sits on the codec decode path that runs for every request. |
|
fix dco please |
The Qwen3TTSTokenizer V2 forward path passes mask_kwargs to transformers.masking_utils.create_causal_mask with a key named "input_embeds" (singular). transformers renamed this kwarg to "inputs_embeds" in 5.5.1, kept "input_embeds" as a deprecated alias via @deprecate_kwarg, and removed the alias in 5.9.0 (released 2026-05-20, https://github.com/huggingface/transformers/releases/tag/v5.9.0). After the alias removal, qwen3-tts inference fails on first request: File ".../qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py", line 576 causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), ...} TypeError: create_causal_mask() got an unexpected keyword argument 'input_embeds' Rename the dict key to "inputs_embeds" so the unpacked kwargs match the current upstream signature. Every other reference to inputs_embeds in this file (including the function signature on line 532) already uses the plural form; line 568 was a stray typo. This restores compatibility with transformers >=5.9.0 while remaining compatible with 5.5.1..5.8.x (the deprecation alias path). Signed-off-by: Yadan Wei <yadanwei@amazon.com>
Reviewer pointed out two problems with the prior single-rename fix: - transformers 5.9.0 also dropped `cache_position` from create_causal_mask and create_sliding_window_causal_mask, so renaming alone still fails on 5.9.0 with `unexpected keyword argument 'cache_position'`. - transformers 4.56-4.57.x (also allowed by vllm 0.21.0) only accept the singular `input_embeds`, so the static rename regresses 4.x. Inspect each helper's live signature and forward only the kwargs it accepts; rename `inputs_embeds` -> `input_embeds` when only the singular form exists. Verified by reviewer end-to-end on Qwen3-TTS-12Hz-0.6B-Base weights against transformers 4.57.6, 5.8.1, and 5.9.0. Signed-off-by: Yadan Wei <yadanwei@amazon.com>
…mers 4.x The _mask_args helper popped "inputs_embeds" from the already-filtered args dict, but on transformers 4.x the parameter is the singular "input_embeds" so that key was never added to args, making the remap raise KeyError: 'inputs_embeds'. Read the value from mask_kwargs, which always holds it, instead. Verified end to end on Qwen3-TTS-12Hz-0.6B-Base: decode() passes on transformers 4.57.6 and 5.9.0. Signed-off-by: Yueqian Lin <linyueqian@outlook.com> Signed-off-by: Yadan Wei <yadanwei@amazon.com>
a767a11 to
44a0788
Compare
|
Summary
Qwen3TTSTokenizerV2DecoderTransformerModel.forwardconstructs amask_kwargsdict with key\"input_embeds\"(singular) and unpacks it intotransformers.masking_utils.create_causal_mask. transformers renamed this kwarg toinputs_embeds(plural) in 5.5.1, keptinput_embedsas a deprecated alias via@deprecate_kwarg, and removed the alias in 5.9.0 (released 2026-05-20).Result: every qwen3-tts request fails on the first forward with:
This is a single-character typo — the rest of the file (function signature on line 532, every other usage in this file) already uses the plural
inputs_embeds. Line 568 is a stray.Change
mask_kwargs = { \"config\": self.config, - \"input_embeds\": inputs_embeds, + \"inputs_embeds\": inputs_embeds, \"attention_mask\": attention_mask, \"cache_position\": cache_position, \"past_key_values\": past_key_values, \"position_ids\": position_ids, }Compatibility
Test plan
References
@deprecate_kwargintroduction: 5.5.1 insrc/transformers/masking_utils.pysrc/transformers/masking_utils.py(@deprecate_kwargdecorator removed)