Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/aca_model/agent/labor_market.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import jax.numpy as jnp
import numpy as np
from lcm import categorical
from lcm.typing import (
ContinuousState,
Expand Down Expand Up @@ -39,12 +40,16 @@ class SpousalIncome:
married_has_inc: ScalarInt


HOURS_VALUES = jnp.array([0.0, 1000.0, 1500.0, 2000.0, 2500.0])
# Host array, not a module-level JAX array: a device array here would
# reserve the GPU memory pool at import time in every process that imports
# the model. It is converted to a device array at each indexing site, where
# the value folds into the surrounding compiled function.
HOURS_VALUES = np.array([0.0, 1000.0, 1500.0, 2000.0, 2500.0])


def working_hours_value(labor_supply: DiscreteAction) -> FloatND:
"""Map labor supply choice to annual hours worked."""
return HOURS_VALUES[labor_supply]
return jnp.asarray(HOURS_VALUES)[labor_supply]


def wage(
Expand Down Expand Up @@ -74,7 +79,7 @@ def income(

income = wage * hours^(1 + exp) * int^(-exp)
"""
hours = HOURS_VALUES[labor_supply]
hours = jnp.asarray(HOURS_VALUES)[labor_supply]
return jnp.where(
hours > 0.0,
wage
Expand Down
24 changes: 24 additions & 0 deletions tests/test_labor_market.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,35 @@

import jax.numpy as jnp
import numpy as np
import pytest

from aca_model.agent import labor_market
from aca_model.agent.labor_market import LaborSupply


def test_hours_values_is_host_array_so_import_allocates_no_device_memory() -> None:
"""`HOURS_VALUES` is a host (NumPy) array, not a device-pinned JAX array.

A module-level JAX array materializes on the default device the moment the
module is imported, reserving the GPU memory pool in every process that
imports the model — including the estimation orchestrator, which only
launches GPU worker ranks and must leave the devices free for them.
"""
assert isinstance(labor_market.HOURS_VALUES, np.ndarray)


@pytest.mark.parametrize(
("choice", "expected_hours"),
[(0, 0.0), (1, 1000.0), (2, 1500.0), (3, 2000.0), (4, 2500.0)],
)
def test_working_hours_value_maps_choice_to_annual_hours(
choice: int, expected_hours: float
) -> None:
"""Each labor-supply choice maps to its annual hours worked."""
result = labor_market.working_hours_value(jnp.asarray(choice, dtype=jnp.int32))
np.testing.assert_allclose(float(result), expected_hours)


def test_wage_combines_age_health_profile_with_residual() -> None:
"""`wage = exp(log_ft_wage_mean[period, good_health] + log_ft_wage_std * res)`."""
log_ft_wage_mean = jnp.array([[1.0, 2.0], [3.0, 4.0]]) # [period, good_health]
Expand Down
Loading