Skip to content

Pytorch triple backward#200

Draft
asglover wants to merge 4 commits into
mainfrom
pytorch-triple-backward
Draft

Pytorch triple backward#200
asglover wants to merge 4 commits into
mainfrom
pytorch-triple-backward

Conversation

@asglover
Copy link
Copy Markdown
Collaborator

@asglover asglover commented Jun 2, 2026

Adding triple backward support for higher order training to pytorch

@asglover asglover added the ci-ready Triggers CI checks for a pull request label Jun 2, 2026
@asglover
Copy link
Copy Markdown
Collaborator Author

asglover commented Jun 2, 2026

I'm going to run the full test suite tomorrow.
The changes are mostly about triple backwards, although my model was convinced that there was an error in the stream testing, and under closer inspection it looks like there was.
I'll promote this from Draft to regular PR when it's ready for review.

Copy link
Copy Markdown
Member

@vbharadwaj-bk vbharadwaj-bk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

95% looks good. I think the stream_test.py modifications are redundant, since each of the custom ops have already been tested. But otherwise good.

):
assert self.torch_op

in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO-someday: I wonder if we can combine all of these derivative functions into one to compact this file.

Comment thread tests/batch_test.py
@pytest.fixture(scope="class")
def problem(self, dtype, with_jax):
if with_jax:
pytest.skip("N/A for JAX")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO-someday: we could expand this test to include JAX. But not in this commit.

Comment thread tests/stream_test.py
return (X, Y, W, edge_index[0], edge_index[1])


@pytest.fixture
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I don't think we should have any modifications to stream_test.py. Because triple_backward is a composition of existing ops that all work fine with streams, I see no reason why their composition shouldn't pass stream tests. We need to test anything that's implemented as a custom op to make sure that the stream information is lowered correctly onto the kernel, but then any composition of those operators should be ok. Let's shrink the diff here.

Comment thread tests/conv_test.py
self.check_result(result, fieldname)


class TestTripleBackwardConvDirectOps:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remind me, what is the purpose of these DirectOps tests?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-ready Triggers CI checks for a pull request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants