Skip to content

Use pytorch2 optimized native attention#39

Open
attesaarela wants to merge 3 commits into
Liuhong99:mainfrom
attesaarela:use-pytorch2-optimized-native-attention
Open

Use pytorch2 optimized native attention#39
attesaarela wants to merge 3 commits into
Liuhong99:mainfrom
attesaarela:use-pytorch2-optimized-native-attention

Conversation

@attesaarela
Copy link
Copy Markdown
Contributor

Hi, here is a pull request for a small speedup where attention is computed using pytorch 2 function "torch.nn.functional.scaled_dot_product_attention" if available.

Makes the optimizer run about 10% faster according to a bit of testing I did

This optimization was essentially copied from a recent version of nanoGPT

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant