Skip to content

Add RationalQuadraticSpline#71

Open
gvcallen wants to merge 1 commit into
lockwo:mainfrom
gvcallen:rqs
Open

Add RationalQuadraticSpline#71
gvcallen wants to merge 1 commit into
lockwo:mainfrom
gvcallen:rqs

Conversation

@gvcallen

@gvcallen gvcallen commented Apr 4, 2026

Copy link
Copy Markdown
Contributor

Adds the rational quadratic spline bijector, very useful for complex normalizing flows. Reference distrax code is here

@gvcallen gvcallen changed the title Add RationalQuadraticSpline bijector Add RationalQuadraticSpline Apr 4, 2026
bin_slope = bin_height / bin_width

z = (x - x_pos_bin[0]) / bin_width
z = jnp.clip(z, 0.0, 1.0)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function has some good comments in the distrax codebase, we should copy them over

"""A rational-quadratic spline bijector.

Implements the spline bijector introduced by:
> Durkan et al., Neural Spline Flows, https://arxiv.org/abs/1906.04032, 2019.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can do the special bib notation here for mkdocs (e.g. ??? cite "References")

):
"""Initializes a RationalQuadraticSpline bijector.

Args:

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should match docstring style (e.g. **Arguments**)

unnormalized_bin_heights = params[
..., self.num_bins : 2 * self.num_bins # noqa: E203
]
unnormalized_knot_slopes = params[..., 2 * self.num_bins :] # noqa: E203

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I fixed the black/flake interaction for these, does it error if you remove the # noqa: E203 ?

from distreqx.bijectors import RationalQuadraticSpline


class RationalQuadraticSplineTest(TestCase):

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a few more low hanging fruit from the distrax tests, e.g. https://github.com/google-deepmind/distrax/blob/main/distrax/_src/bijectors/rational_quadratic_spline_test.py#L153 that are worth porting over

gvcallen added a commit to gvcallen/distreqx that referenced this pull request Jul 1, 2026
- Port distrax's reference comments explaining the spline algorithm and
  the numerically-stable quadratic root solver.
- Add a References section using mkdocs-material's ??? cite block, and
  full **Arguments:** docs for __init__ matching the codebase's docstring
  style.
- Remove the one # noqa: E203 that isn't actually needed (verified with
  flake8 directly); keep the two that are.
- Port a few more tests from distrax: constructor validation, all
  boundary_slopes options, and monotonicity.
Implements the rational-quadratic spline bijector introduced by Durkan
et al. (Neural Spline Flows, 2019), ported from distrax's reference
implementation. Very useful for building complex normalizing flows.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants