Skip to content

[JAX] score_mod attention to support "enforce_precompiled" flag in cuDNN graph deserialization #3046

@jberchtold-nvidia

Description

@jberchtold-nvidia

Once NVIDIA/cudnn-frontend#254 is completed, we should support this in TE's cuDNN usage with score_mod to give clearer error messages instead of potentially incorrect values or less clear errors in cases where the cuDNN Python FE and cuDNN C++ FE requires the cuDNN C++ FE to recompile the kernels.

Metadata

Metadata

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions