Skip to content

Fix to_numpy() crash on bfloat16 tensors by upcasting to float32#1346

Merged
jlarson4 merged 1 commit into
TransformerLensOrg:devfrom
robbiebusinessacc:contrib/to-numpy-bfloat16
May 29, 2026
Merged

Fix to_numpy() crash on bfloat16 tensors by upcasting to float32#1346
jlarson4 merged 1 commit into
TransformerLensOrg:devfrom
robbiebusinessacc:contrib/to-numpy-bfloat16

Conversation

@robbiebusinessacc
Copy link
Copy Markdown

transformer_lens.utilities.tensors.to_numpy (re-exported as
transformer_lens.utils.to_numpy) raises TypeError: Got unsupported ScalarType BFloat16 when passed a bfloat16 tensor, because NumPy has no
bfloat16 dtype and torch.Tensor.numpy() cannot convert it directly.

bfloat16 is common in TransformerLens since many pretrained models load in
reduced precision, and to_numpy is used in several utility paths (e.g.
utilities/slice.py), so this is easy to hit.

Fix: in the torch.Tensor / nn.Parameter branch, detach + move to CPU,
then upcast to float32 when the dtype is bfloat16 before calling
.numpy(). All other dtypes (float32, float16, int, etc.) are unchanged.

Tests: added a TestToNumpy class covering bfloat16 tensors and
nn.Parameter, float32/float16/int passthrough, numpy-array identity,
list/tuple, scalars, and the invalid-type ValueError. Reverting the fix
makes the two bfloat16 tests fail with the original TypeError, confirming
they're meaningful.

No existing tracking issue.

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality
    to not work as expected)
  • This change requires a documentation update

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature
    works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect
    backward compatibility

NumPy has no bfloat16 dtype, so calling .numpy() on a bfloat16 tensor
raises TypeError. Detach/move to CPU, upcast bfloat16 to float32, then
convert. bfloat16 is common since many pretrained models load in reduced
precision. Adds a TestToNumpy class covering bfloat16, float32/float16/int
passthrough, numpy/list/tuple/scalar inputs, and the invalid-type error.
@jlarson4 jlarson4 merged commit 08c9ef9 into TransformerLensOrg:dev May 29, 2026
24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants