diff --git a/src/aca_model/agent/labor_market.py b/src/aca_model/agent/labor_market.py index 1421e9e..5665d92 100644 --- a/src/aca_model/agent/labor_market.py +++ b/src/aca_model/agent/labor_market.py @@ -4,6 +4,7 @@ """ import jax.numpy as jnp +import numpy as np from lcm import categorical from lcm.typing import ( ContinuousState, @@ -39,12 +40,16 @@ class SpousalIncome: married_has_inc: ScalarInt -HOURS_VALUES = jnp.array([0.0, 1000.0, 1500.0, 2000.0, 2500.0]) +# Host array, not a module-level JAX array: a device array here would +# reserve the GPU memory pool at import time in every process that imports +# the model. It is converted to a device array at each indexing site, where +# the value folds into the surrounding compiled function. +HOURS_VALUES = np.array([0.0, 1000.0, 1500.0, 2000.0, 2500.0]) def working_hours_value(labor_supply: DiscreteAction) -> FloatND: """Map labor supply choice to annual hours worked.""" - return HOURS_VALUES[labor_supply] + return jnp.asarray(HOURS_VALUES)[labor_supply] def wage( @@ -74,7 +79,7 @@ def income( income = wage * hours^(1 + exp) * int^(-exp) """ - hours = HOURS_VALUES[labor_supply] + hours = jnp.asarray(HOURS_VALUES)[labor_supply] return jnp.where( hours > 0.0, wage diff --git a/tests/test_labor_market.py b/tests/test_labor_market.py index 18dcaa2..f26992a 100644 --- a/tests/test_labor_market.py +++ b/tests/test_labor_market.py @@ -2,11 +2,35 @@ import jax.numpy as jnp import numpy as np +import pytest from aca_model.agent import labor_market from aca_model.agent.labor_market import LaborSupply +def test_hours_values_is_host_array_so_import_allocates_no_device_memory() -> None: + """`HOURS_VALUES` is a host (NumPy) array, not a device-pinned JAX array. + + A module-level JAX array materializes on the default device the moment the + module is imported, reserving the GPU memory pool in every process that + imports the model — including the estimation orchestrator, which only + launches GPU worker ranks and must leave the devices free for them. + """ + assert isinstance(labor_market.HOURS_VALUES, np.ndarray) + + +@pytest.mark.parametrize( + ("choice", "expected_hours"), + [(0, 0.0), (1, 1000.0), (2, 1500.0), (3, 2000.0), (4, 2500.0)], +) +def test_working_hours_value_maps_choice_to_annual_hours( + choice: int, expected_hours: float +) -> None: + """Each labor-supply choice maps to its annual hours worked.""" + result = labor_market.working_hours_value(jnp.asarray(choice, dtype=jnp.int32)) + np.testing.assert_allclose(float(result), expected_hours) + + def test_wage_combines_age_health_profile_with_residual() -> None: """`wage = exp(log_ft_wage_mean[period, good_health] + log_ft_wage_std * res)`.""" log_ft_wage_mean = jnp.array([[1.0, 2.0], [3.0, 4.0]]) # [period, good_health]