From 979df58ec81e554cff01b33b8b7b2a13c3c545ad Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 21 Jun 2026 15:41:36 +0200 Subject: [PATCH] Make HOURS_VALUES a host array to avoid import-time GPU preallocation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The module-level `HOURS_VALUES = jnp.array(...)` materialized on the default device at import. With XLA_PYTHON_CLIENT_PREALLOCATE=true that first array op reserves 95% of device 0 in every process that imports the model — including the MSM estimation's pytask orchestrator, which only `srun`s GPU ranks and must leave the devices free for them. The orchestrator thus starved the rank's pool reservation, surfacing as the device-0 OOM. Make HOURS_VALUES a host (NumPy) array, converted to JAX at the indexing sites where the value folds into the surrounding compiled function. No numerical change. Co-Authored-By: Claude Opus 4.8 --- src/aca_model/agent/labor_market.py | 11 ++++++++--- tests/test_labor_market.py | 24 ++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/aca_model/agent/labor_market.py b/src/aca_model/agent/labor_market.py index 1421e9e..5665d92 100644 --- a/src/aca_model/agent/labor_market.py +++ b/src/aca_model/agent/labor_market.py @@ -4,6 +4,7 @@ """ import jax.numpy as jnp +import numpy as np from lcm import categorical from lcm.typing import ( ContinuousState, @@ -39,12 +40,16 @@ class SpousalIncome: married_has_inc: ScalarInt -HOURS_VALUES = jnp.array([0.0, 1000.0, 1500.0, 2000.0, 2500.0]) +# Host array, not a module-level JAX array: a device array here would +# reserve the GPU memory pool at import time in every process that imports +# the model. It is converted to a device array at each indexing site, where +# the value folds into the surrounding compiled function. +HOURS_VALUES = np.array([0.0, 1000.0, 1500.0, 2000.0, 2500.0]) def working_hours_value(labor_supply: DiscreteAction) -> FloatND: """Map labor supply choice to annual hours worked.""" - return HOURS_VALUES[labor_supply] + return jnp.asarray(HOURS_VALUES)[labor_supply] def wage( @@ -74,7 +79,7 @@ def income( income = wage * hours^(1 + exp) * int^(-exp) """ - hours = HOURS_VALUES[labor_supply] + hours = jnp.asarray(HOURS_VALUES)[labor_supply] return jnp.where( hours > 0.0, wage diff --git a/tests/test_labor_market.py b/tests/test_labor_market.py index 18dcaa2..f26992a 100644 --- a/tests/test_labor_market.py +++ b/tests/test_labor_market.py @@ -2,11 +2,35 @@ import jax.numpy as jnp import numpy as np +import pytest from aca_model.agent import labor_market from aca_model.agent.labor_market import LaborSupply +def test_hours_values_is_host_array_so_import_allocates_no_device_memory() -> None: + """`HOURS_VALUES` is a host (NumPy) array, not a device-pinned JAX array. + + A module-level JAX array materializes on the default device the moment the + module is imported, reserving the GPU memory pool in every process that + imports the model — including the estimation orchestrator, which only + launches GPU worker ranks and must leave the devices free for them. + """ + assert isinstance(labor_market.HOURS_VALUES, np.ndarray) + + +@pytest.mark.parametrize( + ("choice", "expected_hours"), + [(0, 0.0), (1, 1000.0), (2, 1500.0), (3, 2000.0), (4, 2500.0)], +) +def test_working_hours_value_maps_choice_to_annual_hours( + choice: int, expected_hours: float +) -> None: + """Each labor-supply choice maps to its annual hours worked.""" + result = labor_market.working_hours_value(jnp.asarray(choice, dtype=jnp.int32)) + np.testing.assert_allclose(float(result), expected_hours) + + def test_wage_combines_age_health_profile_with_residual() -> None: """`wage = exp(log_ft_wage_mean[period, good_health] + log_ft_wage_std * res)`.""" log_ft_wage_mean = jnp.array([[1.0, 2.0], [3.0, 4.0]]) # [period, good_health]