Skip to content

Add Empirical#63

Open
gvcallen wants to merge 1 commit into
lockwo:mainfrom
gvcallen:empirical
Open

Add Empirical#63
gvcallen wants to merge 1 commit into
lockwo:mainfrom
gvcallen:empirical

Conversation

@gvcallen

@gvcallen gvcallen commented Apr 4, 2026

Copy link
Copy Markdown
Contributor

This PR adds AbstractEmpirical, Empirical and WeightedEmpirical distributions.

These distributions encapsulate a set of observed samples of a variable. The math was mirrored off tfp.distributions.Empirical with additional support for weighted samples.

@gvcallen gvcallen changed the title Add Empirical distributions Add Empirical Apr 4, 2026
Comment thread distreqx/distributions/_empirical.py Outdated
self.rtol == 0, self.atol, self.atol + self.rtol * jnp.abs(value)
)

def sample_and_log_prob(self, key: Key[Array, ""]) -> tuple[Array, Array]:

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for these functions (that do the same as the default approach), we can use the mixin approach to inheritance (e.g. inherit from AbstractSampleLogProbDistribution as well, this can be done for multiple classes)

Comment thread distreqx/distributions/_empirical.py
Comment thread tests/empirical_test.py Outdated
import jax

# Must be set before any JAX arrays are initialized
jax.config.update("jax_enable_x64", True)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if memory serves, this can be spotty when enabled in pytest files, if you need it, we can put in a conftest.py

Comment thread tests/empirical_test.py

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think making sure the numerics check out for non 0 values of empirical rtol/atol should have additional tests

gvcallen added a commit to gvcallen/distreqx that referenced this pull request Jul 1, 2026
- AbstractEmpirical now inherits AbstractSampleLogProbDistribution's
  default sample_and_log_prob instead of duplicating it.
- Move jax_enable_x64 config into tests/conftest.py (more reliable than
  per-file config, per lockwo's suggestion) and drop the now-redundant
  per-file config lines fork-wide.
- Add rtol/atol tolerance tests for Empirical.
Encapsulate a set of observed samples as a distribution, with support
for weighted samples. Referenced against TensorFlow Probability's
Empirical. AbstractEmpirical inherits AbstractSampleLogProbDistribution
for its default sample_and_log_prob implementation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants