Skip to content

Add reinterpreted_batch_ndims to Independent#75

Open
gvcallen wants to merge 1 commit into
lockwo:mainfrom
gvcallen:multivariate_independent
Open

Add reinterpreted_batch_ndims to Independent#75
gvcallen wants to merge 1 commit into
lockwo:mainfrom
gvcallen:multivariate_independent

Conversation

@gvcallen

@gvcallen gvcallen commented Apr 6, 2026

Copy link
Copy Markdown
Contributor

Adds reinterpreted_batch_ndims to Independent, to support distributions that have been batched via eqx.filter_vmap (e.g., MultivariateNormalTri) that don't support handle batched arrays natively.

Although distreqx has expressly dropped the concept of batch dimensions throughout the codebase (which I strongly agree was the right choice), I do believe Independent still requires a reinterpreted_batch_ndims options, with motivation in the example below.

Without this PR, passing a vmapped distribution that doesn't natively support broadcasing to Independent crashes on methods like log_prob, because the inner computations weren't being appropriately mapped. This change allows users to explicitly define the number of mapped axes, letting Independent correctly unroll the vmap layers and reduce the results to a single scalar. The previous behaviour is left unaffected with the default of reinterpreted_batch_ndims = 0.

locs = jnp.zeros((20, 10, 3))
scales_tri = jnp.stack([jnp.tri(3)]*20*10, axis=0).reshape(20, 10, 3, 3)
xs = jnp.ones((20, 10, 3))

# Create a batch of mvn's
mvns = eqx.filter_vmap(eqx.filter_vmap(dist.MultivariateNormalTri))(locs, scales_tri)

# Calling log_prob (expectedly) does not work because MultivariateNormalTri assumes k x k
# log_prob = mvn.log_prob(locs) # error

# We can manually vmap the computation, but then we get a vector of log-probs.
# We could add a sum, but it is better to encapsulate this in a single distribution
def simple_log_prob(d, x): return d.log_prob(x)
log_probs = eqx.filter_vmap(eqx.filter_vmap(simple_log_prob))(mvns, xs)
assert log_probs.shape == (20, 10,)

# If we attempt to use "Independent" to reinterpret the batch dims, we get a crash,
# because the computation dimensions are not being reinterpreted
reinterp_mvn = dist.Independent(mvns)
# reinterp_mvn.log_prob(xs) # error

# With the propose changes, we can successfully encapsulate the independent MVNs.
# We could also pass e.g. 1 here and manually vmap the last dim if desired.
reinterp_mvn = dist.Independent(mvns, reinterpreted_batch_ndims=2)
log_prob = reinterp_mvn.log_prob(xs)
assert jnp.isscalar(log_prob)

Comment thread distreqx/distributions/_independent.py Outdated
- `reinterpreted_batch_ndims`: Number of batch dimensions to reinterpret
as event dimensions. Defaults to 0, which preserves standard broadcasting
behavior for natively batched distributions (e.g., `Normal`).
**Note:** If you are passing a distribution that does not natively broadcast

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this note render correctly? I think we can make nice note interface. Here is an AI summary of the interfaces

Solid block (!!!) — always visible:                                                                                                
  !!! note        
                                                                                                                                     
      Body must be indented 4 spaces and separated from the
      `!!!` line by a blank line.                                                                                                    
                                 
  Collapsible block (???) — collapsed by default; user clicks to open:                                                               
  ??? note                                                                                                                           
          
      Same indentation rules.                                                                                                        
                                                                                                                                     
  Open-by-default collapsible (???+):
  ???+ warning                                                                                                                       
                  
      Starts expanded but is still collapsible.                                                                                      
                                               
  Custom title — quoted string after the type:                                                                                       
  !!! warning "Bijectors are applied in reverse order"                                                                               
                                                      
      Given a sequence `[f, g]`, the `Chain` bijector computes `f(g(x))`... 

samples, log_prob = self.distribution.sample_and_log_prob(key)
log_prob = _reduce_helper(log_prob)
return samples, log_prob
if self.reinterpreted_batch_ndims == 0:

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we should have a specific flag that goes to the reduce helper, these means it should reduce over all dimensions basically, should it be instead a required argument (that then becomes all axis rather than just 0)? Just to simplify code?

Comment thread distreqx/distributions/_independent.py Outdated
total_batches = math.prod(bshape)
keys = jax.random.split(key, total_batches).reshape(*bshape)

def _single_sample_and_log_prob(d, k):

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this patterns comes up several times, I think if we make a vmap helper it could reduce LoC? e.g.

  def _vmap_method(self, fn,):                                                                                                 
      for _ in range(self.reinterpreted_batch_ndims):                                                                              
          fn = eqx.filter_vmap(fn)                                                                                                   
      return fn(self.distribution,)


  def _vmap_and_sum(self, fn,):                                                                                                
      out = self._vmap_method(fn,)                                                                                           
      return jnp.sum(out, axis=tuple(range(self.reinterpreted_batch_ndims)))     

then each of these can just call vmap and sum

Comment thread distreqx/distributions/_independent.py Outdated
# on the raw, un-wrapped vmapped inner distributions.
d1_rndims = dist1.reinterpreted_batch_ndims
d2_rndims = dist2.reinterpreted_batch_ndims
p_base_shape = dist1.event_shape[d1_rndims:] # fmt: skip

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is fmt skip for here?

Comment thread distreqx/distributions/_independent.py Outdated
# Safely extract the base event shapes without triggering unsafe traces
# on the raw, un-wrapped vmapped inner distributions.
d1_rndims = dist1.reinterpreted_batch_ndims
d2_rndims = dist2.reinterpreted_batch_ndims

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if these don't match this has an asymmetric issue? should we enforce this eto be the same

Comment thread tests/independent_test.py Outdated
)
return batch_mvn, locs, scales_tri

# --- 1. Basic & Legacy Tests ---

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

un-needed comments

Comment thread tests/independent_test.py Outdated
self.assertIsInstance(model, Independent)

def test_legacy_broadcasting_behavior(self):
"""Tests the reinterpreted_batch_ndims=0 fallback for standard distributions."""

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

personally, I don't think we need to have much backwards compatible, we can break things if the new API is better.

Comment thread tests/independent_test.py
def assertion_fn(self, rtol=1e-5):
return lambda x, y: np.testing.assert_allclose(x, y, rtol=rtol)

def _create_vmapped_mvn(self, M=20, N=10, D=3):

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should check across multiple distributions, e.g. bernoulli/normal/etc to make sure nothing unexpected happens

gvcallen added a commit to gvcallen/distreqx that referenced this pull request Jul 1, 2026
- Convert the plain-text note in __init__'s docstring to a proper
  !!! note admonition block.
- Add _vmap_method/_vmap_and_sum helpers to de-duplicate the
  vmap-N-times-or-call-directly pattern repeated across sample,
  sample_and_log_prob, log_prob, entropy, log_cdf, mean, median,
  variance, stddev, mode, and kl_divergence.
- Remove the two # fmt: skip comments in kl_divergence (verified with
  black that neither line is actually reformatted without them).
- Enforce that both sides of kl_divergence share the same
  reinterpreted_batch_ndims, since a mismatch would silently sum over
  the wrong axes on one side.
- Drop unnecessary section-header comments in the test file and the
  legacy-broadcasting-behavior test (lockwo doesn't prioritize backwards
  compatibility here).
- Add coverage for natively-broadcasting distributions other than Normal
  (Bernoulli) under reinterpretation, and for the new
  reinterpreted_batch_ndims mismatch validation.
Lets Independent reinterpret one or more leading batch dimensions as
event dimensions, needed for distributions batched via eqx.filter_vmap
that don't natively broadcast (e.g. MultivariateNormalTri). Every
method is vmapped the requested number of times via shared
_vmap_method/_vmap_and_sum helpers rather than duplicating the
vmap-loop boilerplate per method.
@gvcallen gvcallen force-pushed the multivariate_independent branch from 3320770 to 8d5ed7b Compare July 1, 2026 21:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants