diff --git a/CodeEntropy/config/argparse.py b/CodeEntropy/config/argparse.py index 72670bc6..484dd6e8 100644 --- a/CodeEntropy/config/argparse.py +++ b/CodeEntropy/config/argparse.py @@ -151,6 +151,96 @@ class ArgSpec: help="Type of neighbor search to use." "Default RAD; grid search is also available", ), + "parallel_frames": ArgSpec( + type=bool, + default=False, + help="Execute frame-local covariance calculations in parallel using Dask.", + ), + "use_dask": ArgSpec( + type=bool, + default=False, + help="Enable local Dask frame parallelism.", + ), + "dask_workers": ArgSpec( + type=int, + default=None, + help="Number of local Dask worker processes.", + ), + "dask_threads_per_worker": ArgSpec( + type=int, + default=1, + help="Threads per local Dask worker. Use 1 for MDAnalysis trajectory safety.", + ), + "hpc": ArgSpec( + type=bool, + default=False, + help="Use a SLURM-backed Dask cluster for parallel frame execution.", + ), + "hpc_account": ArgSpec( + type=str, + default=None, + help="SLURM account/project code.", + ), + "hpc_qos": ArgSpec( + type=str, + default=None, + help="Optional SLURM QoS.", + ), + "hpc_constraint": ArgSpec( + type=str, + default=None, + help="Optional SLURM node constraint.", + ), + "hpc_queue": ArgSpec( + type=str, + default=None, + help="SLURM partition/queue.", + ), + "hpc_cores": ArgSpec( + type=int, + default=1, + help="Number of CPU cores per Dask worker job.", + ), + "hpc_processes": ArgSpec( + type=int, + default=1, + help="Number of Dask worker processes per SLURM job.", + ), + "hpc_memory": ArgSpec( + type=str, + default="4GB", + help="Memory requested per Dask worker job.", + ), + "hpc_walltime": ArgSpec( + type=str, + default="01:00:00", + help="Walltime for each Dask worker job, formatted as HH:MM:SS.", + ), + "hpc_nodes": ArgSpec( + type=int, + default=1, + help="Number of SLURM Dask worker jobs to launch.", + ), + "submit": ArgSpec( + type=bool, + default=False, + help="Submit a master SLURM job instead of running immediately.", + ), + "conda_path": ArgSpec( + type=str, + default="conda", + help="Path to conda executable used by SLURM worker prologue.", + ), + "conda_exec": ArgSpec( + type=str, + default="conda", + help="Conda-compatible executable to use, usually conda or mamba.", + ), + "conda_env": ArgSpec( + type=str, + default=None, + help="Conda environment name to activate on Dask workers.", + ), } @@ -385,6 +475,7 @@ def validate_inputs(self, u: Any, args: argparse.Namespace) -> None: self._check_input_bin_width(args) self._check_input_temperature(args) self._check_input_force_partitioning(args) + self._check_parallel_frame_options(args) @staticmethod def _check_input_start(u: Any, args: argparse.Namespace) -> None: @@ -443,3 +534,50 @@ def _check_input_force_partitioning(self, args: argparse.Namespace) -> None: args.force_partitioning, default_value, ) + + @staticmethod + def _check_parallel_frame_options(args: argparse.Namespace) -> None: + """Validate local Dask, HPC Dask, and submit-related options.""" + dask_workers = getattr(args, "dask_workers", None) + if dask_workers is not None and dask_workers < 1: + raise ValueError("'dask_workers' must be at least 1 if provided.") + + dask_threads = getattr(args, "dask_threads_per_worker", 1) + if dask_threads < 1: + raise ValueError("'dask_threads_per_worker' must be at least 1.") + + using_hpc = bool(getattr(args, "hpc", False)) + submitting = bool(getattr(args, "submit", False)) + + if submitting and not using_hpc: + raise ValueError("'submit' requires 'hpc' to be enabled.") + + if not using_hpc and not submitting: + return + + if not getattr(args, "hpc_queue", None): + raise ValueError("'hpc_queue' must be provided when using HPC Dask.") + + if getattr(args, "hpc_nodes", 1) < 1: + raise ValueError("'hpc_nodes' must be at least 1.") + + if getattr(args, "hpc_cores", 1) < 1: + raise ValueError("'hpc_cores' must be at least 1.") + + if getattr(args, "hpc_processes", 1) < 1: + raise ValueError("'hpc_processes' must be at least 1.") + + if not getattr(args, "hpc_memory", None): + raise ValueError("'hpc_memory' must be provided when using HPC Dask.") + + if not getattr(args, "hpc_walltime", None): + raise ValueError("'hpc_walltime' must be provided when using HPC Dask.") + + if not getattr(args, "conda_env", None): + raise ValueError("'conda_env' must be provided when using HPC Dask.") + + if not getattr(args, "conda_path", None): + raise ValueError("'conda_path' must be provided when using HPC Dask.") + + if not getattr(args, "conda_exec", None): + raise ValueError("'conda_exec' must be provided when using HPC Dask.") diff --git a/CodeEntropy/config/runtime.py b/CodeEntropy/config/runtime.py index 6a00d04a..1c76c857 100644 --- a/CodeEntropy/config/runtime.py +++ b/CodeEntropy/config/runtime.py @@ -33,6 +33,7 @@ from rich.text import Text from CodeEntropy.config.argparse import ConfigResolver +from CodeEntropy.core.dask_clusters import HPCDaskManager from CodeEntropy.core.logging import LoggingConfig from CodeEntropy.entropy.workflow import EntropyWorkflow from CodeEntropy.levels.dihedrals import ConformationStateBuilder @@ -223,8 +224,9 @@ def run_entropy_workflow(self) -> None: This method: - Sets up logging and prints the splash screen - - Loads YAML config from CWD and parses CLI args + - Loads YAML configuration from CWD and parses CLI args - Merges args with YAML per-run config + - Optionally submits a master SLURM job and exits - Builds the MDAnalysis Universe (with optional force merging) - Validates user parameters - Constructs dependencies and executes EntropyWorkflow @@ -266,6 +268,11 @@ def run_entropy_workflow(self) -> None: self._validate_required_args(args) + if getattr(args, "submit", False): + self._config_manager._check_parallel_frame_options(args) + HPCDaskManager(args).submit_master() + return + self.print_args_table(args) universe_operations = UniverseOperations() diff --git a/CodeEntropy/core/dask_clusters.py b/CodeEntropy/core/dask_clusters.py new file mode 100644 index 00000000..f5b92a70 --- /dev/null +++ b/CodeEntropy/core/dask_clusters.py @@ -0,0 +1,193 @@ +""" +Helpers for setting up Dask clusters on HPC using SLURM. +""" + +import os +import subprocess +import sys + +import psutil +from dask.distributed import Client +from dask_jobqueue import SLURMCluster + + +class HPCDaskManager: + """ + Manage SLURM-backed Dask clusters and submission utilities for HPC environments. + """ + + def __init__(self, args): + """ + Initialise HPCDaskManager with runtime arguments. + + Args: + args: Parsed CLI arguments containing HPC and conda configuration. + """ + self.args = args + + def check_slurm_env(self) -> None: + """ + Remove SLURM_CPU_BIND from environment if present. + + Some HPC systems require this variable to be unset for correct CPU binding. + """ + if "SLURM_CPU_BIND" in os.environ: + os.environ.pop("SLURM_CPU_BIND") + + def system_network_interface(self) -> str: + """ + Get best candidate for HPC network interface from commonly known ones. + + This deliberately follows the WaterEntropy behaviour and only selects from + known HPC-safe interfaces. It avoids selecting arbitrary interfaces such as + eno1, which may exist on the master node but not on worker nodes. + """ + hpc_nics = ["bond0", "ib0", "hsn0", "eth0"] + interfaces = list(psutil.net_if_addrs().keys()) + + for iface in hpc_nics: + if iface in interfaces: + return iface + + raise RuntimeError( + "Could not find a known HPC network interface. " + f"Available interfaces: {interfaces}. " + "Expected one of: bond0, ib0, hsn0, eth0." + ) + + def slurm_directives(self) -> tuple[list[str], list[str]]: + """ + Process extra SLURM directives and directives to be skipped. + + Returns: + Tuple containing extra directives and skipped directives. + """ + args = self.args + + extra: list[str] = [] + + if args.hpc_account: + extra.append(f'--account="{args.hpc_account}"') + if args.hpc_qos: + extra.append(f'--qos="{args.hpc_qos}"') + if args.hpc_constraint: + extra.append(f'--constraint="{args.hpc_constraint}"') + + skip = ["--mem"] + + return extra, skip + + def slurm_prologues(self) -> list[str]: + """ + Process environment setup commands for the SLURM worker job script. + + Returns: + List of shell commands executed before the Dask worker starts. + """ + args = self.args + prologue: list[str] = [] + + for module_name in getattr(args, "hpc_modules", None) or []: + prologue.append(f"module load {module_name}") + + prologue.append(f'eval "$({args.conda_path} shell.bash hook)"') + + if args.conda_exec == "mamba": + prologue.append(f'eval "$({args.conda_exec} shell hook --shell bash)"') + + prologue.append(f"{args.conda_exec} activate {args.conda_env}") + prologue.append("export SLURM_CPU_FREQ_REQ=2250000") + + return prologue + + def configure_cluster(self) -> Client: + """ + Configure a SLURM-backed Dask cluster. + + Returns: + Dask distributed client connected to the SLURMCluster. + """ + args = self.args + + extra, skip = self.slurm_directives() + prologue = self.slurm_prologues() + iface = self.system_network_interface() + + self.check_slurm_env() + + cluster = SLURMCluster( + cores=args.hpc_cores, + processes=args.hpc_processes, + memory=args.hpc_memory, + queue=args.hpc_queue, + job_directives_skip=skip, + job_extra_directives=extra, + python="srun python", + walltime=args.hpc_walltime, + shebang="#!/bin/bash --login", + local_directory="$PWD", + interface=iface, + job_script_prologue=prologue, + ) + + cluster.scale(jobs=args.hpc_nodes) + + client = Client(cluster) + + with open("dask-cluster-submit.sh", "w", encoding="utf-8") as f: + f.write(cluster.job_script()) + + return client + + def submit_master(self) -> None: + """ + Submit a SLURM job that runs the master CodeEntropy process. + + This generates a temporary SLURM script and submits it via sbatch. + """ + cli = list(sys.argv[1:]) + if "--submit" in cli: + idx = cli.index("--submit") + cli.pop(idx) + + if idx < len(cli) and str(cli[idx]).lower() in {"true", "false"}: + cli.pop(idx) + + script_name = "CodeEntropy-master-submit.sh" + + with open(script_name, "w", encoding="utf-8") as f: + f.write("#!/bin/bash --login\n\n") + f.write("#SBATCH --job-name=codeentropy-master\n") + f.write("#SBATCH --nodes=1\n") + f.write("#SBATCH --ntasks=1\n") + f.write("#SBATCH --cpus-per-task=2\n") + f.write(f"#SBATCH --time={self.args.hpc_walltime}\n") + + if self.args.hpc_account: + f.write(f"#SBATCH --account={self.args.hpc_account}\n") + + f.write(f"#SBATCH --partition={self.args.hpc_queue}\n") + + if self.args.hpc_qos: + f.write(f"#SBATCH --qos={self.args.hpc_qos}\n") + + f.write("\n") + + for module_name in getattr(self.args, "hpc_modules", None) or []: + f.write(f"module load {module_name}\n") + + f.write(f'eval "$({self.args.conda_path} shell.bash hook)"\n') + + if self.args.conda_exec == "mamba": + f.write(f'eval "$({self.args.conda_exec} shell hook --shell bash)"\n') + + f.write(f"{self.args.conda_exec} activate {self.args.conda_env}\n\n") + f.write(f"srun CodeEntropy {' '.join(cli)}") + + self.check_slurm_env() + + try: + result = subprocess.check_output(["bash", "-c", f"sbatch {script_name}"]) + print(result.decode("utf-8")) + except subprocess.CalledProcessError as e: + print(e.output) diff --git a/CodeEntropy/entropy/workflow.py b/CodeEntropy/entropy/workflow.py index 2adf3346..f8ddd0a8 100644 --- a/CodeEntropy/entropy/workflow.py +++ b/CodeEntropy/entropy/workflow.py @@ -22,6 +22,7 @@ import pandas as pd +from CodeEntropy.core.dask_clusters import HPCDaskManager from CodeEntropy.core.logging import LoggingConfig from CodeEntropy.entropy.graph import EntropyGraph from CodeEntropy.entropy.water import WaterEntropy @@ -116,13 +117,61 @@ def execute(self) -> None: frame_selection=frame_selection, ) - with self._reporter.progress(transient=False) as p: - self._run_level_dag(shared_data, progress=p) - self._run_entropy_graph(shared_data, progress=p) + self._configure_parallel_frame_execution(shared_data) + + try: + with self._reporter.progress(transient=False) as p: + self._run_level_dag(shared_data, progress=p) + self._run_entropy_graph(shared_data, progress=p) + finally: + client = shared_data.get("dask_client") + if client is not None: + client.close() self._finalize_molecule_results() self._reporter.log_tables() + def _configure_parallel_frame_execution(self, shared_data: SharedData) -> None: + """Attach a Dask client to shared_data if parallel frames are requested. + + Supports: + - Local Dask via --parallel_frames true / --use_dask true + - SLURM-backed Dask via --hpc true + """ + use_parallel = bool( + getattr(self._args, "parallel_frames", False) + or getattr(self._args, "use_dask", False) + or getattr(self._args, "hpc", False) + ) + + if not use_parallel: + return + + if "dask_client" in shared_data: + shared_data["parallel_frames"] = True + return + + if getattr(self._args, "hpc", False): + client = HPCDaskManager(self._args).configure_cluster() + shared_data["dask_client"] = client + shared_data["parallel_frames"] = True + return + + try: + from dask.distributed import Client + except ImportError as exc: + raise RuntimeError( + "Parallel frame execution was requested, but dask.distributed " + "is not installed." + ) from exc + + shared_data["dask_client"] = Client( + processes=True, + n_workers=getattr(self._args, "dask_workers", None), + threads_per_worker=getattr(self._args, "dask_threads_per_worker", 1), + ) + shared_data["parallel_frames"] = True + def _build_frame_selection(self) -> FrameSelection: """Build the workflow frame selection. diff --git a/CodeEntropy/levels/level_dag.py b/CodeEntropy/levels/level_dag.py index 7409bcf2..0c463ced 100644 --- a/CodeEntropy/levels/level_dag.py +++ b/CodeEntropy/levels/level_dag.py @@ -35,6 +35,41 @@ logger = logging.getLogger(__name__) +_FRAME_WORKER_EXCLUDED_SHARED_KEYS = { + "force_covariances", + "torque_covariances", + "forcetorque_covariances", + "frame_counts", + "forcetorque_counts", + "force_torque_stats", + "force_torque_counts", + "n_frames", + "entropy_manager", + "run_manager", + "reporter", + "dask_client", +} + + +def _execute_frame_worker( + shared_data: dict[str, Any], + frame_index: int, + universe_operations: Any | None = None, +) -> tuple[int, Any]: + """Execute one frame on a Dask worker. + + Args: + shared_data: Worker-local shared calculation inputs. + frame_index: Frame index to process. + universe_operations: Optional universe operations adapter. + + Returns: + Tuple of frame index and frame-local covariance output. + """ + frame_dag = FrameGraph(universe_operations=universe_operations).build() + return int(frame_index), frame_dag.execute_frame(shared_data, int(frame_index)) + + class LevelDAG: """Execute hierarchy detection, per-frame covariance calculation, and reduction. @@ -170,6 +205,10 @@ def _run_frame_stage( indices to process and reduces each frame-local output into shared accumulators. + If ``shared_data["dask_client"]`` exists and parallel frame execution is + enabled, frame-local outputs are computed on Dask workers and reduced in + the parent process. + Args: shared_data: Shared data dictionary. Must contain ``frame_source``. progress: Optional progress sink. @@ -192,6 +231,19 @@ def _run_frame_stage( title="Initializing", ) + client = shared_data.get("dask_client") + parallel_frames = bool(shared_data.get("parallel_frames", client is not None)) + + if parallel_frames and client is not None and len(frame_indices) > 1: + self._run_frame_stage_dask( + shared_data, + frame_indices=frame_indices, + client=client, + progress=progress, + task=task, + ) + return + for frame_index in frame_indices: if progress is not None and task is not None: progress.update(task, title=f"Frame {frame_index}") @@ -206,6 +258,85 @@ def _run_frame_stage( if progress is not None and task is not None: progress.advance(task) + @staticmethod + def _make_frame_worker_shared_data(shared_data: dict[str, Any]) -> dict[str, Any]: + """Return the subset of shared data required by frame workers. + + Reduction accumulators and parent orchestration/reporting objects are + intentionally excluded because workers should only compute frame-local + outputs. + """ + return { + key: value + for key, value in shared_data.items() + if key not in _FRAME_WORKER_EXCLUDED_SHARED_KEYS + } + + def _run_frame_stage_dask( + self, + shared_data: dict[str, Any], + *, + frame_indices: list[int], + client: Any, + progress: _RichProgressSink | None = None, + task: TaskID | None = None, + ) -> None: + """Execute frame-local DAG tasks in parallel using Dask. + + Workers return frame-local covariance payloads. The parent process performs + all reductions into the shared accumulators. + + Important: + Do not scatter/broadcast worker_shared. It contains stateful objects + such as frame_source / universe trajectory state. Broadcasting can reuse + mutable state across tasks on the same worker and make frames interfere + with one another. + """ + try: + from distributed import as_completed + except ImportError as exc: + raise RuntimeError( + "Parallel frame execution requires dask.distributed to be installed." + ) from exc + + worker_shared = self._make_frame_worker_shared_data(shared_data) + + futures = [ + client.submit( + _execute_frame_worker, + worker_shared, + frame_index, + self._universe_operations, + pure=False, + ) + for frame_index in frame_indices + ] + + completed = 0 + + try: + for future in as_completed(futures): + frame_index, frame_out = future.result() + completed += 1 + + if progress is not None and task is not None: + progress.update(task, title=f"Frame {frame_index}") + + self._reduce_one_frame(shared_data, frame_out) + + if progress is not None and task is not None: + progress.advance(task) + + if completed != len(frame_indices): + raise RuntimeError( + f"Parallel frame execution completed {completed} frames, " + f"but expected {len(frame_indices)}." + ) + + except Exception: + client.cancel(futures) + raise + @staticmethod def _incremental_mean(old: Any, new: Any, n: int) -> Any: """Compute an incremental mean. diff --git a/docs/getting_started.rst b/docs/getting_started.rst index ff9ef211..8fee0078 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -186,9 +186,10 @@ The ``top_traj_file`` argument is required; other arguments have default values. - Enable verbose output. - ``False`` - ``bool`` - * - ``--outfile`` - - Name of the JSON output file to write results to (filename only). Defaults to ``outfile.json``. - - ``outfile.json`` + * - ``--output_file`` + - Name of the JSON output file to write results to (filename only). Defaults to + ``output_file.json``. + - ``output_file.json`` - ``str`` * - ``--force_partitioning`` - Factor for partitioning forces when there are weak correlations. @@ -202,22 +203,285 @@ The ``top_traj_file`` argument is required; other arguments have default values. - How to group molecules for averaging. - ``molecules`` - ``str`` - * - ``--kcal_force_units`` - - Set input units as kcal/mol - - ``False`` - - ``bool`` * - ``--combined_forcetorque`` - - Use the combined force-torque covariance matrix for the highest level to match the 2019 paper + - Use the combined force-torque covariance matrix for the highest level to match the + 2019 paper. - ``True`` - ``bool`` * - ``--customised_axes`` - - Use custom bonded axes to get COM, MOI and PA that match the 2019 paper + - Use custom bonded axes to get COM, MOI and PA that match the 2019 paper. - ``True`` - ``bool`` * - ``--search_type`` - - Method for finding neighbouring molecules + - Method for finding neighbouring molecules. - ``RAD`` - ``str`` + * - ``--parallel_frames`` + - Execute frame-local covariance calculations in parallel. When enabled, frame-level + work is submitted to Dask and reduced in the parent process. + - ``False`` + - ``bool`` + * - ``--use_dask`` + - Enable local Dask frame parallelism. This is useful for running frame-level work + across local worker processes. + - ``False`` + - ``bool`` + * - ``--dask_workers`` + - Number of local Dask worker processes to use for parallel frame execution. If unset, + Dask chooses a default. + - ``None`` + - ``int`` + * - ``--dask_threads_per_worker`` + - Number of threads per local Dask worker. ``1`` is recommended for trajectory safety + with MDAnalysis. + - ``1`` + - ``int`` + * - ``--hpc`` + - Use a SLURM-backed Dask cluster for parallel frame execution. + - ``False`` + - ``bool`` + * - ``--submit`` + - Submit a master SLURM job and exit instead of running immediately in the current + process. This is intended for HPC batch submission. + - ``False`` + - ``bool`` + * - ``--hpc_queue`` + - SLURM partition or queue to use for Dask worker jobs. + - ``None`` + - ``str`` + * - ``--hpc_nodes`` + - Number of SLURM Dask worker jobs to launch. + - ``1`` + - ``int`` + * - ``--hpc_cores`` + - Number of CPU cores requested per Dask worker job. + - ``1`` + - ``int`` + * - ``--hpc_processes`` + - Number of Dask worker processes per SLURM job. + - ``1`` + - ``int`` + * - ``--hpc_memory`` + - Memory requested per Dask worker job, for example ``4GB`` or ``16GB``. + - ``4GB`` + - ``str`` + * - ``--hpc_walltime`` + - Walltime requested for each Dask worker job, formatted as ``HH:MM:SS``. + - ``01:00:00`` + - ``str`` + * - ``--hpc_account`` + - Optional SLURM account or project code. + - ``None`` + - ``str`` + * - ``--hpc_qos`` + - Optional SLURM QoS value. + - ``None`` + - ``str`` + * - ``--hpc_constraint`` + - Optional SLURM node constraint. + - ``None`` + - ``str`` + * - ``--conda_path`` + - Path to the conda executable used in the SLURM worker prologue. + - ``conda`` + - ``str`` + * - ``--conda_exec`` + - Conda-compatible executable to use for environment activation, usually ``conda`` or + ``mamba``. + - ``conda`` + - ``str`` + * - ``--conda_env`` + - Conda environment name to activate on SLURM workers. + - ``None`` + - ``str`` + +Parallel Frame Execution +------------------------ + +CodeEntropy can optionally process trajectory frames in parallel using Dask. This is +most useful for larger trajectories where the frame-local covariance calculations are +one of the slowest parts of the workflow. + +The parallel implementation works as a map/reduce workflow: + +* each Dask worker processes one frame at a time; +* each worker returns a frame-local covariance result; +* the parent process reduces those frame-local results into the final running + covariance averages; +* the entropy graph runs after frame reduction has completed. + +This means workers do not directly modify the shared covariance accumulators. The +parent process remains responsible for reduction, which keeps the parallel execution +consistent with the sequential workflow. + +Local Dask Execution +^^^^^^^^^^^^^^^^^^^^ + +For local workstation or laptop use, enable ``parallel_frames`` and ``use_dask`` in +``config.yaml``: + +.. code-block:: yaml + + --- + + run1: + top_traj_file: ["md_A4_dna.tpr", "md_A4_dna_xf.trr"] + selection_string: "all" + start: 0 + end: 100 + step: 1 + + parallel_frames: true + use_dask: true + dask_workers: 4 + dask_threads_per_worker: 1 + +The recommended value for ``dask_threads_per_worker`` is ``1``. This keeps each worker +process independent and avoids thread-safety issues when reading trajectory data. + +The same run can also be started from the command line: + +.. code-block:: bash + + CodeEntropy \ + --parallel_frames true \ + --use_dask true \ + --dask_workers 4 \ + --dask_threads_per_worker 1 + +For very small systems or short trajectories, local Dask may not be faster than the +sequential path because there is overhead in starting workers and transferring frame +data. It is best suited to larger calculations with many frames. + +SLURM / HPC Dask Execution +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +On a SLURM-based HPC system, CodeEntropy can create a Dask cluster using SLURM worker +jobs. This is enabled with ``hpc: true``. + +Example ``config.yaml``: + +.. code-block:: yaml + + --- + + run1: + top_traj_file: ["1AKI_prod_new.tpr", "1AKI_prod_new.trr"] + selection_string: "all" + start: 0 + end: 500 + step: 1 + + parallel_frames: true + hpc: true + + hpc_queue: standard + hpc_nodes: 4 + hpc_cores: 8 + hpc_processes: 1 + hpc_memory: 16GB + hpc_walltime: "02:00:00" + + hpc_account: null + hpc_qos: null + hpc_constraint: null + + conda_path: conda + conda_exec: conda + conda_env: codeentropy + +The important HPC options are: + +* ``hpc_queue``: SLURM partition or queue. +* ``hpc_nodes``: number of Dask worker jobs to launch. +* ``hpc_cores``: number of CPU cores requested per Dask worker job. +* ``hpc_processes``: number of Dask worker processes per SLURM job. +* ``hpc_memory``: memory requested per Dask worker job. +* ``hpc_walltime``: walltime requested for each worker job. +* ``conda_env``: environment to activate on the worker jobs. + +Submitting a Master SLURM Job +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If you want CodeEntropy to submit a master SLURM job and then exit, set +``submit: true`` as well as ``hpc: true``: + +.. code-block:: yaml + + --- + + run1: + top_traj_file: ["1AKI_prod.tpr", "1AKI_prod.trr"] + selection_string: "all" + start: 0 + end: 500 + step: 1 + + submit: true + parallel_frames: true + hpc: true + + hpc_queue: standard + hpc_nodes: 4 + hpc_cores: 8 + hpc_processes: 1 + hpc_memory: 16GB + hpc_walltime: "02:00:00" + + hpc_account: null + hpc_qos: null + hpc_constraint: null + + conda_path: conda + conda_exec: conda + conda_env: codeentropy + +Run CodeEntropy from the working directory containing ``config.yaml``: + +.. code-block:: bash + + CodeEntropy + +In submit mode, CodeEntropy writes and submits a master SLURM script, then exits from +the current process. The submitted master job starts CodeEntropy again on the cluster, +where the SLURM-backed Dask workers are then launched. + +Choosing a Parallel Mode +^^^^^^^^^^^^^^^^^^^^^^^^ + +Use sequential execution for small tests and debugging: + +.. code-block:: yaml + + parallel_frames: false + use_dask: false + hpc: false + submit: false + +Use local Dask when running on a workstation: + +.. code-block:: yaml + + parallel_frames: true + use_dask: true + dask_workers: 4 + dask_threads_per_worker: 1 + +Use HPC Dask when running inside an allocated HPC session or batch job: + +.. code-block:: yaml + + parallel_frames: true + hpc: true + +Use submit mode when you want CodeEntropy to create and submit the master SLURM job +for you: + +.. code-block:: yaml + + submit: true + parallel_frames: true + hpc: true Averaging --------- diff --git a/pyproject.toml b/pyproject.toml index 6515a790..2fcd96d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ "requests>=2.32,<3.0", "rdkit>=2025.9.5", "numba>=0.65.0,<0.70", + "dask-jobqueue>=0.9.0, <0.10", ] [project.urls] diff --git a/tests/unit/CodeEntropy/config/argparse/conftest.py b/tests/unit/CodeEntropy/config/argparse/conftest.py index 7ad0a4bc..adcabf07 100644 --- a/tests/unit/CodeEntropy/config/argparse/conftest.py +++ b/tests/unit/CodeEntropy/config/argparse/conftest.py @@ -43,6 +43,32 @@ def _make(**overrides): return _make +@pytest.fixture() +def make_valid_hpc_args(make_args): + """Factory to build a valid HPC/Dask args object for validation tests.""" + + def _make(**overrides): + base = dict( + dask_workers=None, + dask_threads_per_worker=1, + hpc=True, + submit=False, + hpc_queue="standard", + hpc_nodes=1, + hpc_cores=1, + hpc_processes=1, + hpc_memory="4GB", + hpc_walltime="01:00:00", + conda_env="codeentropy", + conda_path="conda", + conda_exec="conda", + ) + base.update(overrides) + return make_args(**base) + + return _make + + @pytest.fixture() def empty_cli_args(resolver): """Argparse Namespace with all parser defaults.""" diff --git a/tests/unit/CodeEntropy/config/argparse/test_argparse_validate_inputs.py b/tests/unit/CodeEntropy/config/argparse/test_argparse_validate_inputs.py index 1769a9c3..73ca005d 100644 --- a/tests/unit/CodeEntropy/config/argparse/test_argparse_validate_inputs.py +++ b/tests/unit/CodeEntropy/config/argparse/test_argparse_validate_inputs.py @@ -5,6 +5,7 @@ def test_validate_inputs_valid_does_not_raise(resolver, dummy_universe, make_args): args = make_args() + resolver.validate_inputs(dummy_universe, args) @@ -57,3 +58,218 @@ def test_check_input_force_partitioning_non_default_logs_warning( resolver._check_input_force_partitioning(args) assert "differs from the default" in caplog.text + + +def test_check_parallel_frame_options_valid_local_dask_does_not_raise( + resolver, make_args +): + args = make_args( + dask_workers=2, + dask_threads_per_worker=1, + hpc=False, + submit=False, + ) + + resolver._check_parallel_frame_options(args) + + +def test_check_parallel_frame_options_allows_dask_workers_none(resolver, make_args): + args = make_args( + dask_workers=None, + dask_threads_per_worker=1, + hpc=False, + submit=False, + ) + + resolver._check_parallel_frame_options(args) + + +def test_check_parallel_frame_options_valid_hpc_does_not_raise( + resolver, make_valid_hpc_args +): + args = make_valid_hpc_args() + + resolver._check_parallel_frame_options(args) + + +def test_check_parallel_frame_options_valid_hpc_submit_does_not_raise( + resolver, make_valid_hpc_args +): + args = make_valid_hpc_args(submit=True) + + resolver._check_parallel_frame_options(args) + + +def test_check_parallel_frame_options_raises_when_dask_workers_less_than_one( + resolver, make_args +): + args = make_args( + dask_workers=0, + dask_threads_per_worker=1, + hpc=False, + submit=False, + ) + + with pytest.raises( + ValueError, + match="'dask_workers' must be at least 1 if provided.", + ): + resolver._check_parallel_frame_options(args) + + +def test_check_parallel_frame_options_raises_when_dask_threads_less_than_one( + resolver, make_args +): + args = make_args( + dask_workers=None, + dask_threads_per_worker=0, + hpc=False, + submit=False, + ) + + with pytest.raises( + ValueError, + match="'dask_threads_per_worker' must be at least 1.", + ): + resolver._check_parallel_frame_options(args) + + +def test_check_parallel_frame_options_raises_when_submit_without_hpc( + resolver, make_valid_hpc_args +): + args = make_valid_hpc_args( + hpc=False, + submit=True, + ) + + with pytest.raises( + ValueError, + match="'submit' requires 'hpc' to be enabled.", + ): + resolver._check_parallel_frame_options(args) + + +def test_check_parallel_frame_options_raises_when_hpc_queue_missing( + resolver, make_valid_hpc_args +): + args = make_valid_hpc_args( + hpc_queue=None, + ) + + with pytest.raises( + ValueError, + match="'hpc_queue' must be provided when using HPC Dask.", + ): + resolver._check_parallel_frame_options(args) + + +def test_check_parallel_frame_options_raises_when_hpc_nodes_less_than_one( + resolver, make_valid_hpc_args +): + args = make_valid_hpc_args( + hpc_nodes=0, + ) + + with pytest.raises( + ValueError, + match="'hpc_nodes' must be at least 1.", + ): + resolver._check_parallel_frame_options(args) + + +def test_check_parallel_frame_options_raises_when_hpc_cores_less_than_one( + resolver, make_valid_hpc_args +): + args = make_valid_hpc_args( + hpc_cores=0, + ) + + with pytest.raises( + ValueError, + match="'hpc_cores' must be at least 1.", + ): + resolver._check_parallel_frame_options(args) + + +def test_check_parallel_frame_options_raises_when_hpc_processes_less_than_one( + resolver, make_valid_hpc_args +): + args = make_valid_hpc_args( + hpc_processes=0, + ) + + with pytest.raises( + ValueError, + match="'hpc_processes' must be at least 1.", + ): + resolver._check_parallel_frame_options(args) + + +def test_check_parallel_frame_options_raises_when_hpc_memory_missing( + resolver, make_valid_hpc_args +): + args = make_valid_hpc_args( + hpc_memory=None, + ) + + with pytest.raises( + ValueError, + match="'hpc_memory' must be provided when using HPC Dask.", + ): + resolver._check_parallel_frame_options(args) + + +def test_check_parallel_frame_options_raises_when_hpc_walltime_missing( + resolver, make_valid_hpc_args +): + args = make_valid_hpc_args( + hpc_walltime=None, + ) + + with pytest.raises( + ValueError, + match="'hpc_walltime' must be provided when using HPC Dask.", + ): + resolver._check_parallel_frame_options(args) + + +def test_check_parallel_frame_options_raises_when_conda_env_missing( + resolver, make_valid_hpc_args +): + args = make_valid_hpc_args( + conda_env=None, + ) + + with pytest.raises( + ValueError, + match="'conda_env' must be provided when using HPC Dask.", + ): + resolver._check_parallel_frame_options(args) + + +def test_check_parallel_frame_options_raises_when_conda_path_missing( + resolver, make_valid_hpc_args +): + args = make_valid_hpc_args( + conda_path=None, + ) + + with pytest.raises( + ValueError, + match="'conda_path' must be provided when using HPC Dask.", + ): + resolver._check_parallel_frame_options(args) + + +def test_check_parallel_frame_options_raises_when_conda_exec_missing( + resolver, make_valid_hpc_args +): + args = make_valid_hpc_args( + conda_exec=None, + ) + + with pytest.raises( + ValueError, + match="'conda_exec' must be provided when using HPC Dask.", + ): + resolver._check_parallel_frame_options(args) diff --git a/tests/unit/CodeEntropy/config/runtime/test_run_entropy_workflow_branches.py b/tests/unit/CodeEntropy/config/runtime/test_run_entropy_workflow_branches.py index 9ef6f761..55454cd5 100644 --- a/tests/unit/CodeEntropy/config/runtime/test_run_entropy_workflow_branches.py +++ b/tests/unit/CodeEntropy/config/runtime/test_run_entropy_workflow_branches.py @@ -156,3 +156,45 @@ class BadArgs: "Run arguments at failure could not be serialized" in str(call.args[0]) for call in error_spy.call_args_list ) + + +def test_run_entropy_workflow_submit_calls_submit_master_and_returns(runner): + runner._logging_config = MagicMock() + runner._config_manager = MagicMock() + runner._reporter = MagicMock() + runner.show_splash = MagicMock() + runner.print_args_table = MagicMock() + runner._build_universe = MagicMock() + + run_logger = MagicMock() + runner._logging_config.configure.return_value = run_logger + + runner._config_manager.load_config.return_value = {"run1": {}} + + args = SimpleNamespace( + output_file="out.json", + verbose=False, + submit=True, + top_traj_file=["topology.tpr", "trajectory.trr"], + selection_string="all", + force_file=None, + file_format=None, + kcal_force_units=False, + ) + + parser = MagicMock() + parser.parse_known_args.return_value = (args, []) + runner._config_manager.build_parser.return_value = parser + runner._config_manager.resolve.return_value = args + runner._config_manager.validate_inputs = MagicMock() + + with patch("CodeEntropy.config.runtime.HPCDaskManager") as HPCDaskManagerCls: + runner.run_entropy_workflow() + + HPCDaskManagerCls.assert_called_once_with(args) + HPCDaskManagerCls.return_value.submit_master.assert_called_once() + + runner.print_args_table.assert_not_called() + runner._build_universe.assert_not_called() + runner._config_manager.validate_inputs.assert_not_called() + runner._logging_config.export_console.assert_not_called() diff --git a/tests/unit/CodeEntropy/core/dask_clusters/test_dask_clusters.py b/tests/unit/CodeEntropy/core/dask_clusters/test_dask_clusters.py new file mode 100644 index 00000000..fe95684e --- /dev/null +++ b/tests/unit/CodeEntropy/core/dask_clusters/test_dask_clusters.py @@ -0,0 +1,507 @@ +"""Tests for CodeEntropy HPC/Dask SLURM cluster helpers.""" + +import argparse +import os +import subprocess +import sys +from unittest import mock + +import pytest + +from CodeEntropy.core.dask_clusters import HPCDaskManager + + +def args_helper(args_list): + """Build test args for HPCDaskManager.""" + parser = argparse.ArgumentParser() + + parser.add_argument("--hpc-account", type=str, default="") + parser.add_argument("--hpc-constraint", type=str, default="") + parser.add_argument("--hpc-qos", type=str, default="") + parser.add_argument("--hpc-queue", type=str, default="standard") + parser.add_argument("--hpc-cores", type=int, default=20) + parser.add_argument("--hpc-memory", type=str, default="16GB") + parser.add_argument("--hpc-nodes", type=int, default=4) + parser.add_argument("--hpc-processes", type=int, default=20) + parser.add_argument("--hpc-walltime", type=str, default="24:00:00") + parser.add_argument("--hpc-modules", nargs="+", default=None) + + parser.add_argument("--conda-env", type=str, default="codeentropy") + parser.add_argument("--conda-exec", type=str, default="conda") + parser.add_argument("--conda-path", type=str, default="/path/to/conda") + + return parser.parse_args(args_list) + + +def test_check_slurm_env_removes_cpu_bind(): + args = args_helper([]) + manager = HPCDaskManager(args) + + os.environ["SLURM_CPU_BIND"] = "1" + assert os.environ["SLURM_CPU_BIND"] == "1" + + manager.check_slurm_env() + + assert "SLURM_CPU_BIND" not in os.environ + + +def test_slurm_directives_account(): + args = args_helper(["--hpc-account", "c01"]) + manager = HPCDaskManager(args) + + extra, skip = manager.slurm_directives() + + assert extra == ['--account="c01"'] + assert skip == ["--mem"] + + +def test_slurm_directives_constraint(): + args = args_helper(["--hpc-constraint", "intel25"]) + manager = HPCDaskManager(args) + + extra, _skip = manager.slurm_directives() + + assert extra == ['--constraint="intel25"'] + + +def test_slurm_directives_qos(): + args = args_helper(["--hpc-qos", "standard"]) + manager = HPCDaskManager(args) + + extra, _skip = manager.slurm_directives() + + assert extra == ['--qos="standard"'] + + +def test_slurm_directives_all(): + args = args_helper( + [ + "--hpc-account", + "c01", + "--hpc-qos", + "standard", + "--hpc-constraint", + "intel25", + ] + ) + manager = HPCDaskManager(args) + + extra, skip = manager.slurm_directives() + + assert extra == [ + '--account="c01"', + '--qos="standard"', + '--constraint="intel25"', + ] + assert skip == ["--mem"] + + +def test_slurm_prologues_conda(): + args = args_helper( + [ + "--conda-env", + "codeentropy", + "--conda-exec", + "conda", + "--conda-path", + "/path/to/conda", + ] + ) + manager = HPCDaskManager(args) + + prologue = manager.slurm_prologues() + + assert prologue == [ + 'eval "$(/path/to/conda shell.bash hook)"', + "conda activate codeentropy", + "export SLURM_CPU_FREQ_REQ=2250000", + ] + + +def test_slurm_prologues_mamba(): + args = args_helper( + [ + "--conda-env", + "codeentropy", + "--conda-exec", + "mamba", + "--conda-path", + "/path/to/conda", + ] + ) + manager = HPCDaskManager(args) + + prologue = manager.slurm_prologues() + + assert prologue == [ + 'eval "$(/path/to/conda shell.bash hook)"', + 'eval "$(mamba shell hook --shell bash)"', + "mamba activate codeentropy", + "export SLURM_CPU_FREQ_REQ=2250000", + ] + + +def test_slurm_prologues_includes_hpc_modules(): + args = args_helper( + [ + "--hpc-modules", + "apps/binapps/conda/miniforge3/25.9.1", + "gcc/12.2.0", + "--conda-env", + "codeentropy", + "--conda-exec", + "conda", + "--conda-path", + "/path/to/conda", + ] + ) + manager = HPCDaskManager(args) + + prologue = manager.slurm_prologues() + + assert prologue == [ + "module load apps/binapps/conda/miniforge3/25.9.1", + "module load gcc/12.2.0", + 'eval "$(/path/to/conda shell.bash hook)"', + "conda activate codeentropy", + "export SLURM_CPU_FREQ_REQ=2250000", + ] + + +@mock.patch("psutil.net_if_addrs") +def test_system_network_interface_prefers_bond0(net_if_addrs): + net_if_addrs.return_value = {"bond0": [], "ib0": [], "eth0": []} + + args = args_helper([]) + manager = HPCDaskManager(args) + + assert manager.system_network_interface() == "bond0" + + +@mock.patch("psutil.net_if_addrs") +def test_system_network_interface_prefers_ib0_when_bond0_missing(net_if_addrs): + net_if_addrs.return_value = {"ib0": [], "eth0": []} + + args = args_helper([]) + manager = HPCDaskManager(args) + + assert manager.system_network_interface() == "ib0" + + +@mock.patch("psutil.net_if_addrs") +def test_system_network_interface_prefers_hsn0_when_bond0_and_ib0_missing( + net_if_addrs, +): + net_if_addrs.return_value = {"hsn0": [], "eth0": []} + + args = args_helper([]) + manager = HPCDaskManager(args) + + assert manager.system_network_interface() == "hsn0" + + +@mock.patch("psutil.net_if_addrs") +def test_system_network_interface_prefers_eth0_when_only_eth0_known_interface( + net_if_addrs, +): + net_if_addrs.return_value = {"eth0": [], "eno1": []} + + args = args_helper([]) + manager = HPCDaskManager(args) + + assert manager.system_network_interface() == "eth0" + + +@mock.patch("psutil.net_if_addrs") +def test_system_network_interface_raises_without_known_hpc_interface(net_if_addrs): + net_if_addrs.return_value = {"lo": [], "docker0": [], "eno1": []} + + args = args_helper([]) + manager = HPCDaskManager(args) + + with pytest.raises(RuntimeError, match="Could not find a known HPC network"): + manager.system_network_interface() + + +@mock.patch("subprocess.check_output") +def test_submit_master_writes_expected_script_conda(check_output): + check_output.return_value = b"Submitted batch job 12345\n" + + args = args_helper( + [ + "--conda-env", + "codeentropy", + "--conda-exec", + "conda", + "--conda-path", + "/path/to/conda", + "--hpc-account", + "c01-bio", + "--hpc-qos", + "standard", + "--hpc-queue", + "standard", + "--hpc-walltime", + "24:00:00", + ] + ) + manager = HPCDaskManager(args) + + cli = [ + "CodeEntropy", + "--top_traj_file", + "topology.tpr", + "trajectory.trr", + "--start", + "0", + "--end", + "512", + "--step", + "1", + "--hpc", + "true", + "--hpc_nodes", + "4", + "--submit", + "true", + ] + + with mock.patch.object(sys, "argv", cli): + manager.submit_master() + + with open("CodeEntropy-master-submit.sh", encoding="utf-8") as file: + script = file.read() + + assert "#SBATCH --job-name=codeentropy-master" in script + assert "#SBATCH --nodes=1" in script + assert "#SBATCH --ntasks=1" in script + assert "#SBATCH --cpus-per-task=2" in script + assert "#SBATCH --time=24:00:00" in script + assert "#SBATCH --account=c01-bio" in script + assert "#SBATCH --partition=standard" in script + assert "#SBATCH --qos=standard" in script + assert 'eval "$(/path/to/conda shell.bash hook)"' in script + assert "conda activate codeentropy" in script + assert "srun CodeEntropy" in script + assert "--submit" not in script + assert " --submit " not in script + assert not script.rstrip().endswith(" true") + + os.remove("CodeEntropy-master-submit.sh") + + +@mock.patch("subprocess.check_output") +def test_submit_master_writes_expected_script_mamba(check_output): + check_output.return_value = b"Submitted batch job 12345\n" + + args = args_helper( + [ + "--conda-env", + "codeentropy", + "--conda-exec", + "mamba", + "--conda-path", + "/path/to/conda", + "--hpc-account", + "c01-bio", + "--hpc-qos", + "standard", + "--hpc-queue", + "standard", + ] + ) + manager = HPCDaskManager(args) + + cli = [ + "CodeEntropy", + "--top_traj_file", + "topology.tpr", + "trajectory.trr", + "--hpc", + "true", + "--submit", + "true", + ] + + with mock.patch.object(sys, "argv", cli): + manager.submit_master() + + with open("CodeEntropy-master-submit.sh", encoding="utf-8") as file: + script = file.read() + + assert 'eval "$(/path/to/conda shell.bash hook)"' in script + assert 'eval "$(mamba shell hook --shell bash)"' in script + assert "mamba activate codeentropy" in script + assert "srun CodeEntropy" in script + assert "--submit" not in script + + os.remove("CodeEntropy-master-submit.sh") + + +@mock.patch("subprocess.check_output") +def test_submit_master_writes_hpc_modules(check_output): + check_output.return_value = b"Submitted batch job 12345\n" + + args = args_helper( + [ + "--hpc-modules", + "apps/binapps/conda/miniforge3/25.9.1", + "--conda-env", + "codeentropy", + "--conda-exec", + "conda", + "--conda-path", + "/path/to/conda", + "--hpc-queue", + "standard", + ] + ) + manager = HPCDaskManager(args) + + cli = [ + "CodeEntropy", + "--top_traj_file", + "topology.tpr", + "trajectory.trr", + "--hpc", + "true", + "--submit", + "true", + ] + + try: + with mock.patch.object(sys, "argv", cli): + manager.submit_master() + + with open("CodeEntropy-master-submit.sh", encoding="utf-8") as file: + script = file.read() + + assert "module load apps/binapps/conda/miniforge3/25.9.1" in script + assert 'eval "$(/path/to/conda shell.bash hook)"' in script + assert "conda activate codeentropy" in script + + finally: + if os.path.exists("CodeEntropy-master-submit.sh"): + os.remove("CodeEntropy-master-submit.sh") + + +@mock.patch("CodeEntropy.core.dask_clusters.Client") +@mock.patch("CodeEntropy.core.dask_clusters.SLURMCluster") +@mock.patch.object(HPCDaskManager, "system_network_interface") +def test_configure_cluster_writes_job_script( + system_network_interface, + slurm_cluster, + client, +): + system_network_interface.return_value = "ib0" + + cluster_instance = mock.MagicMock() + cluster_instance.job_script.return_value = "#!/bin/bash\n# dask worker script\n" + slurm_cluster.return_value = cluster_instance + + client_instance = mock.MagicMock() + client.return_value = client_instance + + args = args_helper( + [ + "--conda-env", + "codeentropy", + "--conda-exec", + "conda", + "--conda-path", + "/path/to/conda", + "--hpc-account", + "c01-bio", + "--hpc-qos", + "standard", + "--hpc-queue", + "standard", + "--hpc-cores", + "8", + "--hpc-processes", + "1", + "--hpc-memory", + "16GB", + "--hpc-nodes", + "4", + "--hpc-walltime", + "02:00:00", + ] + ) + manager = HPCDaskManager(args) + + returned_client = manager.configure_cluster() + + assert returned_client is client_instance + + slurm_cluster.assert_called_once() + _, kwargs = slurm_cluster.call_args + assert kwargs["interface"] == "ib0" + assert "scheduler_options" not in kwargs + + cluster_instance.scale.assert_called_once_with(jobs=4) + client.assert_called_once_with(cluster_instance) + + with open("dask-cluster-submit.sh", encoding="utf-8") as file: + script = file.read() + + assert script == "#!/bin/bash\n# dask worker script\n" + + os.remove("dask-cluster-submit.sh") + + +@mock.patch("subprocess.check_output") +def test_submit_master_prints_called_process_error_output(check_output, capsys): + error_output = b"sbatch: error: invalid partition\n" + + check_output.side_effect = subprocess.CalledProcessError( + returncode=1, + cmd=["bash", "-c", "sbatch CodeEntropy-master-submit.sh"], + output=error_output, + ) + + args = args_helper( + [ + "--conda-env", + "codeentropy", + "--conda-exec", + "conda", + "--conda-path", + "/path/to/conda", + "--hpc-account", + "c01-bio", + "--hpc-qos", + "standard", + "--hpc-queue", + "standard", + "--hpc-walltime", + "24:00:00", + ] + ) + manager = HPCDaskManager(args) + + cli = [ + "CodeEntropy", + "--top_traj_file", + "topology.tpr", + "trajectory.trr", + "--hpc", + "true", + "--submit", + "true", + ] + + try: + with mock.patch.object(sys, "argv", cli): + manager.submit_master() + + captured = capsys.readouterr() + + assert "sbatch: error: invalid partition" in captured.out + check_output.assert_called_once_with( + ["bash", "-c", "sbatch CodeEntropy-master-submit.sh"] + ) + + finally: + if os.path.exists("CodeEntropy-master-submit.sh"): + os.remove("CodeEntropy-master-submit.sh") diff --git a/tests/unit/CodeEntropy/entropy/test_workflow.py b/tests/unit/CodeEntropy/entropy/test_workflow.py index 5f6ab900..04b9c900 100644 --- a/tests/unit/CodeEntropy/entropy/test_workflow.py +++ b/tests/unit/CodeEntropy/entropy/test_workflow.py @@ -1,4 +1,6 @@ import logging +import sys +import types from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -563,3 +565,306 @@ def test_finalize_molecule_results_skips_group_total_rows(): assert any( row[1] == "Group Total" and row[3] == 1.5 for row in wf._reporter.molecule_data ) + + +def test_configure_parallel_frame_execution_returns_when_disabled(): + args = SimpleNamespace( + parallel_frames=False, + use_dask=False, + output_file="out.json", + ) + wf = _make_wf(args) + shared_data = {} + + wf._configure_parallel_frame_execution(shared_data) + + assert shared_data == {} + + +def test_configure_parallel_frame_execution_reuses_existing_client(): + args = SimpleNamespace( + parallel_frames=True, + use_dask=False, + output_file="out.json", + ) + wf = _make_wf(args) + + client = MagicMock() + shared_data = {"dask_client": client} + + wf._configure_parallel_frame_execution(shared_data) + + assert shared_data["dask_client"] is client + assert shared_data["parallel_frames"] is True + + +def test_configure_parallel_frame_execution_creates_local_dask_client(): + args = SimpleNamespace( + parallel_frames=True, + use_dask=False, + hpc=False, + dask_workers=3, + dask_threads_per_worker=1, + output_file="out.json", + ) + wf = _make_wf(args) + + fake_client_instance = MagicMock() + fake_client_cls = MagicMock(return_value=fake_client_instance) + + fake_dask = types.ModuleType("dask") + fake_distributed = types.ModuleType("dask.distributed") + fake_distributed.Client = fake_client_cls + + shared_data = {} + + with patch.dict( + sys.modules, + { + "dask": fake_dask, + "dask.distributed": fake_distributed, + }, + ): + wf._configure_parallel_frame_execution(shared_data) + + fake_client_cls.assert_called_once_with( + processes=True, + n_workers=3, + threads_per_worker=1, + ) + assert shared_data["dask_client"] is fake_client_instance + assert shared_data["parallel_frames"] is True + + +def test_configure_parallel_frame_execution_raises_when_dask_missing(): + args = SimpleNamespace( + parallel_frames=True, + use_dask=False, + hpc=False, + dask_workers=2, + dask_threads_per_worker=1, + output_file="out.json", + ) + wf = _make_wf(args) + + real_import = __import__ + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): + if name == "dask.distributed": + raise ImportError("No module named dask.distributed") + return real_import(name, globals, locals, fromlist, level) + + with patch("builtins.__import__", side_effect=fake_import): + with pytest.raises( + RuntimeError, match="Parallel frame execution was requested" + ): + wf._configure_parallel_frame_execution({}) + + +def test_build_shared_data_contains_frame_source_and_frame_indices(): + args = SimpleNamespace( + selection_string="all", + water_entropy=False, + output_file="out.json", + ) + wf = _make_wf(args) + + reduced_universe = MagicMock() + levels = {0: ["united_atom"]} + groups = {2: [0, 1]} + frame_selection = FrameSelection.from_bounds(start=2, stop=8, step=2) + + with patch("CodeEntropy.entropy.workflow.FrameSource") as FrameSourceCls: + frame_source = MagicMock() + FrameSourceCls.return_value = frame_source + + shared_data = wf._build_shared_data( + reduced_universe=reduced_universe, + levels=levels, + groups=groups, + frame_selection=frame_selection, + ) + + FrameSourceCls.assert_called_once_with( + universe=reduced_universe, + selection=frame_selection, + ) + + assert shared_data["entropy_manager"] is wf + assert shared_data["run_manager"] is wf._run_manager + assert shared_data["reporter"] is wf._reporter + assert shared_data["args"] is args + assert shared_data["universe"] is wf._universe + assert shared_data["reduced_universe"] is reduced_universe + assert shared_data["levels"] is levels + assert shared_data["groups"] == groups + assert shared_data["start"] == frame_selection.source_start + assert shared_data["end"] == frame_selection.source_stop_exclusive + assert shared_data["step"] == frame_selection.infer_source_step() + assert shared_data["n_frames"] == frame_selection.n_frames + assert shared_data["frame_selection"] is frame_selection + assert shared_data["frame_source"] is frame_source + assert shared_data["frame_indices"] == [2, 4, 6] + assert shared_data["source_frame_indices"] == [2, 4, 6] + + +def test_run_level_dag_builds_and_executes_level_dag(): + args = SimpleNamespace(output_file="out.json") + wf = _make_wf(args) + shared_data = {"x": 1} + progress = MagicMock() + + with patch("CodeEntropy.entropy.workflow.LevelDAG") as LevelDAGCls: + level_dag = LevelDAGCls.return_value + built_dag = level_dag.build.return_value + + wf._run_level_dag(shared_data, progress=progress) + + LevelDAGCls.assert_called_once_with(wf._universe_operations) + level_dag.build.assert_called_once() + built_dag.execute.assert_called_once_with(shared_data, progress=progress) + + +def test_run_entropy_graph_executes_and_updates_shared_data(): + args = SimpleNamespace(output_file="out.json") + wf = _make_wf(args) + shared_data = {"existing": "value"} + progress = MagicMock() + + with patch("CodeEntropy.entropy.workflow.EntropyGraph") as GraphCls: + graph = GraphCls.return_value + built_graph = graph.build.return_value + built_graph.execute.return_value = {"entropy_results": {"ok": True}} + + wf._run_entropy_graph(shared_data, progress=progress) + + GraphCls.assert_called_once() + graph.build.assert_called_once() + built_graph.execute.assert_called_once_with(shared_data, progress=progress) + assert shared_data["existing"] == "value" + assert shared_data["entropy_results"] == {"ok": True} + + +def test_get_trajectory_bounds_none_values_use_defaults(): + args = SimpleNamespace( + start=None, + end=None, + step=None, + output_file="out.json", + ) + universe = SimpleNamespace(trajectory=list(range(7))) + + wf = EntropyWorkflow( + run_manager=MagicMock(), + args=args, + universe=universe, + reporter=MagicMock(), + group_molecules=MagicMock(), + dihedral_analysis=MagicMock(), + universe_operations=MagicMock(), + ) + + assert wf._get_trajectory_bounds() == (0, 7, 1) + + +def test_compute_water_entropy_appends_not_water_to_existing_selection(): + args = SimpleNamespace( + selection_string="protein", + water_entropy=True, + temperature=298.0, + output_file="out.json", + ) + wf = _make_wf(args) + + frame_selection = FrameSelection.from_bounds(start=0, stop=5, step=1) + + with patch("CodeEntropy.entropy.workflow.WaterEntropy") as WaterCls: + inst = WaterCls.return_value + inst.calculate_and_log = MagicMock() + + wf._compute_water_entropy(frame_selection, water_groups={9: [0]}) + + inst.calculate_and_log.assert_called_once_with( + universe=wf._universe, + start=0, + end=5, + step=1, + group_id=9, + ) + assert args.selection_string == "protein and not water" + + +def test_execute_closes_dask_client_in_finally(): + args = SimpleNamespace( + start=0, + end=-1, + step=1, + grouping="molecules", + water_entropy=False, + selection_string="all", + output_file="out.json", + ) + + universe = MagicMock() + universe.trajectory = list(range(5)) + + reporter = MagicMock() + reporter.molecule_data = [] + reporter.residue_data = [] + + progress_cm = MagicMock() + progress_cm.__enter__.return_value = MagicMock() + progress_cm.__exit__.return_value = False + reporter.progress.return_value = progress_cm + + wf = EntropyWorkflow( + run_manager=MagicMock(), + args=args, + universe=universe, + reporter=reporter, + group_molecules=MagicMock(), + dihedral_analysis=MagicMock(), + universe_operations=MagicMock(), + ) + + client = MagicMock() + + wf._build_reduced_universe = MagicMock(return_value=MagicMock()) + wf._detect_levels = MagicMock(return_value={0: ["united_atom"]}) + wf._split_water_groups = MagicMock(return_value=({0: [0]}, {})) + wf._build_shared_data = MagicMock(return_value={"dask_client": client}) + wf._configure_parallel_frame_execution = MagicMock() + wf._run_level_dag = MagicMock() + wf._run_entropy_graph = MagicMock() + wf._finalize_molecule_results = MagicMock() + wf._group_molecules.grouping_molecules.return_value = {0: [0]} + + wf.execute() + + client.close.assert_called_once() + + +def test_configure_parallel_frame_execution_uses_hpc_dask_manager(): + args = SimpleNamespace( + parallel_frames=False, + use_dask=False, + hpc=True, + dask_workers=None, + dask_threads_per_worker=1, + output_file="out.json", + ) + wf = _make_wf(args) + + shared_data = {} + client = MagicMock() + + with patch("CodeEntropy.entropy.workflow.HPCDaskManager") as HPCDaskManagerCls: + HPCDaskManagerCls.return_value.configure_cluster.return_value = client + + wf._configure_parallel_frame_execution(shared_data) + + HPCDaskManagerCls.assert_called_once_with(args) + HPCDaskManagerCls.return_value.configure_cluster.assert_called_once() + + assert shared_data["dask_client"] is client + assert shared_data["parallel_frames"] is True diff --git a/tests/unit/CodeEntropy/levels/test_level_dag.py b/tests/unit/CodeEntropy/levels/test_level_dag.py new file mode 100644 index 00000000..067f8dd3 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/test_level_dag.py @@ -0,0 +1,769 @@ +"""Unit tests for LevelDAG orchestration, reduction, and parallel frame execution.""" + +from __future__ import annotations + +import sys +import types +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from CodeEntropy.levels import level_dag as level_dag_module +from CodeEntropy.levels.level_dag import LevelDAG + + +def _empty_frame_out() -> dict: + """Return an empty frame-local covariance payload.""" + return { + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {}, "poly": {}}, + } + + +def _shared_force_torque() -> dict: + """Return minimal shared data for force/torque reduction tests.""" + return { + "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "frame_counts": { + "ua": {}, + "res": np.zeros(1, dtype=int), + "poly": np.zeros(1, dtype=int), + }, + "group_id_to_index": {7: 0, 9: 0}, + } + + +def _shared_forcetorque() -> dict: + """Return minimal shared data for combined force-torque reduction tests.""" + return { + "forcetorque_covariances": {"res": [None], "poly": [None]}, + "forcetorque_counts": { + "res": np.zeros(1, dtype=int), + "poly": np.zeros(1, dtype=int), + }, + "group_id_to_index": {7: 0, 9: 0}, + } + + +def test_incremental_mean_none_returns_copy_for_numpy(): + arr = np.array([1.0, 2.0]) + + out = LevelDAG._incremental_mean(None, arr, n=1) + + assert np.all(out == arr) + + arr[0] = 999.0 + assert out[0] != 999.0 + + +def test_incremental_mean_updates_mean_correctly(): + old = np.array([2.0, 2.0]) + new = np.array([4.0, 0.0]) + + out = LevelDAG._incremental_mean(old, new, n=2) + + np.testing.assert_allclose(out, np.array([3.0, 1.0])) + + +def test_incremental_mean_handles_non_copyable_values(): + out = LevelDAG._incremental_mean(old=None, new=3.0, n=1) + + assert out == 3.0 + + +def test_execute_sets_default_axes_manager_and_runs_stages(): + dag = LevelDAG() + + shared = { + "reduced_universe": MagicMock(), + "start": 0, + "end": 0, + "step": 1, + "n_frames": 1, + } + + dag._run_static_stage = MagicMock() + dag._run_frame_stage = MagicMock() + + out = dag.execute(shared) + + assert out is shared + assert "axes_manager" in shared + dag._run_static_stage.assert_called_once_with(shared, progress=None) + dag._run_frame_stage.assert_called_once_with(shared, progress=None) + + +def test_build_registers_static_nodes_and_builds_frame_dag(): + with ( + patch("CodeEntropy.levels.level_dag.DetectMoleculesNode"), + patch("CodeEntropy.levels.level_dag.DetectLevelsNode"), + patch("CodeEntropy.levels.level_dag.BuildBeadsNode"), + patch("CodeEntropy.levels.level_dag.InitCovarianceAccumulatorsNode"), + patch("CodeEntropy.levels.level_dag.ComputeConformationalStatesNode"), + patch("CodeEntropy.levels.level_dag.ComputeNeighborsNode"), + ): + dag = LevelDAG(universe_operations=MagicMock()) + dag._frame_dag.build = MagicMock() + + out = dag.build() + + assert out is dag + assert "detect_molecules" in dag._static_nodes + assert "detect_levels" in dag._static_nodes + assert "build_beads" in dag._static_nodes + assert "init_covariance_accumulators" in dag._static_nodes + assert "compute_conformational_states" in dag._static_nodes + assert "find_neighbors" in dag._static_nodes + dag._frame_dag.build.assert_called_once() + + +def test_add_static_adds_dependency_edges(): + dag = LevelDAG() + + dag._add_static("A", MagicMock()) + dag._add_static("B", MagicMock(), deps=["A"]) + + assert dag._static_nodes["A"] is not None + assert dag._static_nodes["B"] is not None + assert ("A", "B") in dag._static_graph.edges + + +def test_run_static_stage_calls_nodes_in_topological_sort_order(): + dag = LevelDAG() + dag._static_graph.add_node("a") + dag._static_graph.add_node("b") + + dag._static_nodes["a"] = MagicMock() + dag._static_nodes["b"] = MagicMock() + + with patch("networkx.topological_sort", return_value=["a", "b"]): + dag._run_static_stage({"X": 1}) + + dag._static_nodes["a"].run.assert_called_once() + dag._static_nodes["b"].run.assert_called_once() + + +def test_run_static_stage_forwards_progress_when_node_accepts_it(): + dag = LevelDAG() + dag._static_graph.add_node("a") + + node = MagicMock() + dag._static_nodes["a"] = node + + progress = MagicMock() + + with patch("networkx.topological_sort", return_value=["a"]): + dag._run_static_stage({"X": 1}, progress=progress) + + node.run.assert_called_once_with({"X": 1}, progress=progress) + + +def test_run_static_stage_falls_back_when_node_does_not_accept_progress(): + dag = LevelDAG() + dag._static_graph.add_node("a") + + node = MagicMock() + node.run.side_effect = [TypeError("no progress"), None] + dag._static_nodes["a"] = node + + progress = MagicMock() + + with patch("networkx.topological_sort", return_value=["a"]): + dag._run_static_stage({"X": 1}, progress=progress) + + assert node.run.call_count == 2 + node.run.assert_any_call({"X": 1}, progress=progress) + node.run.assert_any_call({"X": 1}) + + +def test_run_frame_stage_iterates_selected_frames_and_reduces_each(): + dag = LevelDAG() + + frame_source = MagicMock() + frame_source.iter_indices.return_value = [10, 11] + + shared = { + "frame_source": frame_source, + "n_frames": 2, + } + + frame_outputs = [_empty_frame_out(), _empty_frame_out()] + + dag._frame_dag = MagicMock() + dag._frame_dag.execute_frame.side_effect = frame_outputs + dag._reduce_one_frame = MagicMock() + + dag._run_frame_stage(shared) + + assert shared["n_frames"] == 2 + frame_source.iter_indices.assert_called_once() + + assert dag._frame_dag.execute_frame.call_count == 2 + dag._frame_dag.execute_frame.assert_any_call(shared, 10) + dag._frame_dag.execute_frame.assert_any_call(shared, 11) + + assert dag._reduce_one_frame.call_count == 2 + dag._reduce_one_frame.assert_any_call(shared, frame_outputs[0]) + dag._reduce_one_frame.assert_any_call(shared, frame_outputs[1]) + + +def test_run_frame_stage_progress_total_comes_from_frame_source_indices(): + dag = LevelDAG() + + frame_source = MagicMock() + frame_source.iter_indices.return_value = list(range(10)) + + shared = { + "frame_source": frame_source, + "n_frames": 0, + } + + frame_out = _empty_frame_out() + + dag._frame_dag = MagicMock() + dag._frame_dag.execute_frame.return_value = frame_out + dag._reduce_one_frame = MagicMock() + + progress = MagicMock() + progress.add_task.return_value = 123 + + dag._run_frame_stage(shared, progress=progress) + + progress.add_task.assert_called_once_with( + "[green]Frame processing", + total=10, + title="Initializing", + ) + + assert shared["n_frames"] == 10 + frame_source.iter_indices.assert_called_once() + assert dag._frame_dag.execute_frame.call_count == 10 + assert dag._reduce_one_frame.call_count == 10 + assert progress.advance.call_count == 10 + + +def test_run_frame_stage_with_progress_creates_task_and_updates_titles(): + dag = LevelDAG() + + frame_source = MagicMock() + frame_source.iter_indices.return_value = [10, 11] + + shared = { + "frame_source": frame_source, + "n_frames": 2, + } + + frame_out = _empty_frame_out() + + dag._frame_dag = MagicMock() + dag._frame_dag.execute_frame.return_value = frame_out + dag._reduce_one_frame = MagicMock() + + progress = MagicMock() + progress.add_task.return_value = 77 + + dag._run_frame_stage(shared, progress=progress) + + progress.add_task.assert_called_once_with( + "[green]Frame processing", + total=2, + title="Initializing", + ) + + assert progress.update.call_count == 2 + progress.update.assert_any_call(77, title="Frame 10") + progress.update.assert_any_call(77, title="Frame 11") + + assert progress.advance.call_count == 2 + progress.advance.assert_any_call(77) + + assert dag._frame_dag.execute_frame.call_count == 2 + dag._frame_dag.execute_frame.assert_any_call(shared, 10) + dag._frame_dag.execute_frame.assert_any_call(shared, 11) + + assert dag._reduce_one_frame.call_count == 2 + dag._reduce_one_frame.assert_any_call(shared, frame_out) + + +def test_run_frame_stage_falls_back_to_sequential_when_only_one_frame(): + dag = LevelDAG() + + frame_source = MagicMock() + frame_source.iter_indices.return_value = [0] + + client = MagicMock() + + shared_data = { + "frame_source": frame_source, + "dask_client": client, + "parallel_frames": True, + } + + frame_out = _empty_frame_out() + + dag._run_frame_stage_dask = MagicMock() + dag._frame_dag = MagicMock() + dag._frame_dag.execute_frame.return_value = frame_out + dag._reduce_one_frame = MagicMock() + + dag._run_frame_stage(shared_data) + + dag._run_frame_stage_dask.assert_not_called() + dag._frame_dag.execute_frame.assert_called_once_with(shared_data, 0) + dag._reduce_one_frame.assert_called_once_with(shared_data, frame_out) + assert shared_data["n_frames"] == 1 + + +def test_run_frame_stage_falls_back_to_sequential_without_client(): + dag = LevelDAG() + + frame_source = MagicMock() + frame_source.iter_indices.return_value = [0, 1] + + shared_data = { + "frame_source": frame_source, + "parallel_frames": True, + } + + frame_out = _empty_frame_out() + + dag._run_frame_stage_dask = MagicMock() + dag._frame_dag = MagicMock() + dag._frame_dag.execute_frame.return_value = frame_out + dag._reduce_one_frame = MagicMock() + + dag._run_frame_stage(shared_data) + + dag._run_frame_stage_dask.assert_not_called() + assert dag._frame_dag.execute_frame.call_count == 2 + assert dag._reduce_one_frame.call_count == 2 + assert shared_data["n_frames"] == 2 + + +def test_reduce_force_and_torque_handles_empty_frame_gracefully(): + dag = LevelDAG() + shared = _shared_force_torque() + + dag._reduce_force_and_torque(shared_data=shared, frame_out=_empty_frame_out()) + + assert shared["force_covariances"]["ua"] == {} + assert shared["torque_covariances"]["ua"] == {} + assert shared["frame_counts"]["res"][0] == 0 + assert shared["frame_counts"]["poly"][0] == 0 + + +def test_reduce_force_and_torque_updates_counts_and_means(): + dag = LevelDAG() + shared = _shared_force_torque() + + F1 = np.eye(3) + T1 = 2.0 * np.eye(3) + + frame_out = { + "force": {"ua": {(0, 0): F1}, "res": {9: F1}, "poly": {}}, + "torque": {"ua": {(0, 0): T1}, "res": {9: T1}, "poly": {}}, + } + + dag._reduce_force_and_torque(shared, frame_out) + + assert shared["frame_counts"]["ua"][(0, 0)] == 1 + np.testing.assert_allclose(shared["force_covariances"]["ua"][(0, 0)], F1) + np.testing.assert_allclose(shared["torque_covariances"]["ua"][(0, 0)], T1) + + assert shared["frame_counts"]["res"][0] == 1 + np.testing.assert_allclose(shared["force_covariances"]["res"][0], F1) + np.testing.assert_allclose(shared["torque_covariances"]["res"][0], T1) + + +def test_reduce_force_and_torque_exercises_count_branches(): + dag = LevelDAG() + shared = _shared_force_torque() + + frame_out = { + "force": { + "ua": {(9, 0): np.array([1.0])}, + "res": {7: np.array([2.0])}, + "poly": {7: np.array([3.0])}, + }, + "torque": { + "ua": {(9, 0): np.array([4.0])}, + "res": {7: np.array([5.0])}, + "poly": {7: np.array([6.0])}, + }, + } + + dag._reduce_force_and_torque(shared, frame_out) + + assert (9, 0) in shared["torque_covariances"]["ua"] + assert shared["frame_counts"]["res"][0] == 1 + assert shared["frame_counts"]["poly"][0] == 1 + np.testing.assert_allclose(shared["force_covariances"]["res"][0], np.array([2.0])) + np.testing.assert_allclose(shared["torque_covariances"]["res"][0], np.array([5.0])) + np.testing.assert_allclose(shared["force_covariances"]["poly"][0], np.array([3.0])) + np.testing.assert_allclose(shared["torque_covariances"]["poly"][0], np.array([6.0])) + + +def test_reduce_force_and_torque_res_torque_increments_when_res_count_is_zero(): + dag = LevelDAG() + shared = _shared_force_torque() + + frame_out = { + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {7: np.eye(3)}, "poly": {}}, + } + + dag._reduce_force_and_torque(shared, frame_out) + + assert shared["frame_counts"]["res"][0] == 1 + assert shared["torque_covariances"]["res"][0] is not None + + +def test_reduce_force_and_torque_poly_torque_increments_when_poly_count_is_zero(): + dag = LevelDAG() + shared = _shared_force_torque() + + frame_out = { + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {}, "poly": {7: np.eye(3)}}, + } + + dag._reduce_force_and_torque(shared, frame_out) + + assert shared["frame_counts"]["poly"][0] == 1 + assert shared["torque_covariances"]["poly"][0] is not None + + +def test_reduce_force_and_torque_increments_ua_frame_counts_for_force(): + dag = LevelDAG() + shared = _shared_force_torque() + + key = (9, 0) + F = np.eye(3) + + frame_out = { + "force": {"ua": {key: F}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {}, "poly": {}}, + } + + dag._reduce_force_and_torque(shared, frame_out) + + assert shared["frame_counts"]["ua"][key] == 1 + assert key in shared["force_covariances"]["ua"] + np.testing.assert_array_equal(shared["force_covariances"]["ua"][key], F) + + +def test_reduce_force_and_torque_ua_torque_increments_count_when_force_missing_key(): + dag = LevelDAG() + shared = _shared_force_torque() + + key = (9, 0) + T = np.eye(3) + + frame_out = { + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": {"ua": {key: T}, "res": {}, "poly": {}}, + } + + dag._reduce_force_and_torque(shared, frame_out) + + assert shared["frame_counts"]["ua"][key] == 1 + np.testing.assert_array_equal(shared["torque_covariances"]["ua"][key], T) + + +def test_reduce_one_frame_calls_force_torque_and_forcetorque_reducers(): + dag = LevelDAG() + shared = {} + frame_out = {} + + dag._reduce_force_and_torque = MagicMock() + dag._reduce_forcetorque = MagicMock() + + dag._reduce_one_frame(shared, frame_out) + + dag._reduce_force_and_torque.assert_called_once_with(shared, frame_out) + dag._reduce_forcetorque.assert_called_once_with(shared, frame_out) + + +def test_reduce_forcetorque_no_key_is_noop(): + dag = LevelDAG() + shared = _shared_forcetorque() + + dag._reduce_forcetorque(shared, frame_out={}) + + assert shared["forcetorque_counts"]["res"][0] == 0 + assert shared["forcetorque_covariances"]["res"][0] is None + + +def test_reduce_forcetorque_updates_res_and_poly(): + dag = LevelDAG() + shared = _shared_forcetorque() + + frame_out = { + "forcetorque": { + "res": {7: np.array([1.0, 1.0])}, + "poly": {7: np.array([2.0, 2.0])}, + } + } + + dag._reduce_forcetorque(shared, frame_out) + + assert shared["forcetorque_counts"]["res"][0] == 1 + assert shared["forcetorque_counts"]["poly"][0] == 1 + np.testing.assert_allclose( + shared["forcetorque_covariances"]["res"][0], + np.array([1.0, 1.0]), + ) + np.testing.assert_allclose( + shared["forcetorque_covariances"]["poly"][0], + np.array([2.0, 2.0]), + ) + + +def test_make_frame_worker_shared_data_excludes_parent_only_keys(): + shared_data = { + "force_covariances": "force accumulator", + "torque_covariances": "torque accumulator", + "forcetorque_covariances": "ft accumulator", + "frame_counts": "frame counts", + "forcetorque_counts": "ft counts", + "force_torque_stats": "legacy ft accumulator alias", + "force_torque_counts": "legacy ft counts alias", + "n_frames": 10, + "entropy_manager": "manager", + "run_manager": "run manager", + "reporter": "reporter", + "dask_client": "client", + "frame_source": "frame source", + "levels": "levels", + "groups": "groups", + "args": "args", + } + + worker_shared = LevelDAG._make_frame_worker_shared_data(shared_data) + + assert worker_shared == { + "frame_source": "frame source", + "levels": "levels", + "groups": "groups", + "args": "args", + } + + +def test_execute_frame_worker_builds_frame_graph_and_returns_frame_output(): + shared_data = {"x": 1} + universe_operations = MagicMock() + + with patch("CodeEntropy.levels.level_dag.FrameGraph") as FrameGraphCls: + graph = MagicMock() + graph.execute_frame.return_value = {"force": {}, "torque": {}} + FrameGraphCls.return_value.build.return_value = graph + + frame_index, frame_out = level_dag_module._execute_frame_worker( + shared_data, + frame_index="5", + universe_operations=universe_operations, + ) + + FrameGraphCls.assert_called_once_with(universe_operations=universe_operations) + FrameGraphCls.return_value.build.assert_called_once() + graph.execute_frame.assert_called_once_with(shared_data, 5) + + assert frame_index == 5 + assert frame_out == {"force": {}, "torque": {}} + + +def test_run_frame_stage_uses_dask_when_client_present(): + dag = LevelDAG() + + frame_source = MagicMock() + frame_source.iter_indices.return_value = [0, 1, 2] + + client = MagicMock() + + shared_data = { + "frame_source": frame_source, + "dask_client": client, + "parallel_frames": True, + } + + dag._run_frame_stage_dask = MagicMock() + dag._frame_dag = MagicMock() + dag._reduce_one_frame = MagicMock() + + dag._run_frame_stage(shared_data) + + dag._run_frame_stage_dask.assert_called_once_with( + shared_data, + frame_indices=[0, 1, 2], + client=client, + progress=None, + task=None, + ) + dag._frame_dag.execute_frame.assert_not_called() + dag._reduce_one_frame.assert_not_called() + assert shared_data["n_frames"] == 3 + + +def test_run_frame_stage_dask_submits_each_frame_and_reduces_completed_results(): + dag = LevelDAG() + + shared_data = { + "keep": "value", + "force_covariances": "exclude me", + "reporter": "exclude me too", + } + + client = MagicMock() + + frame_out0 = _empty_frame_out() + frame_out1 = _empty_frame_out() + + future0 = MagicMock() + future0.result.return_value = (0, frame_out0) + + future1 = MagicMock() + future1.result.return_value = (1, frame_out1) + + client.submit.side_effect = [future0, future1] + + fake_distributed = types.ModuleType("distributed") + fake_distributed.as_completed = MagicMock(return_value=[future0, future1]) + + dag._reduce_one_frame = MagicMock() + + with patch.dict(sys.modules, {"distributed": fake_distributed}): + dag._run_frame_stage_dask( + shared_data, + frame_indices=[0, 1], + client=client, + progress=None, + task=None, + ) + + assert client.submit.call_count == 2 + + for call in client.submit.call_args_list: + args, kwargs = call + assert args[0] is level_dag_module._execute_frame_worker + assert args[1] == {"keep": "value"} + assert kwargs == {"pure": False} + + assert dag._reduce_one_frame.call_count == 2 + dag._reduce_one_frame.assert_any_call(shared_data, frame_out0) + dag._reduce_one_frame.assert_any_call(shared_data, frame_out1) + client.cancel.assert_not_called() + + +def test_run_frame_stage_dask_updates_progress(): + dag = LevelDAG() + + shared_data = {"keep": "value"} + client = MagicMock() + + frame_out = _empty_frame_out() + future = MagicMock() + future.result.return_value = (7, frame_out) + client.submit.return_value = future + + fake_distributed = types.ModuleType("distributed") + fake_distributed.as_completed = MagicMock(return_value=[future]) + + progress = MagicMock() + dag._reduce_one_frame = MagicMock() + + with patch.dict(sys.modules, {"distributed": fake_distributed}): + dag._run_frame_stage_dask( + shared_data, + frame_indices=[7], + client=client, + progress=progress, + task="task-id", + ) + + progress.update.assert_called_once_with("task-id", title="Frame 7") + progress.advance.assert_called_once_with("task-id") + dag._reduce_one_frame.assert_called_once_with(shared_data, frame_out) + + +def test_run_frame_stage_dask_cancels_futures_and_reraises_on_result_error(): + dag = LevelDAG() + + shared_data = {"keep": "value"} + client = MagicMock() + + future = MagicMock() + future.result.side_effect = RuntimeError("worker failed") + client.submit.return_value = future + + fake_distributed = types.ModuleType("distributed") + fake_distributed.as_completed = MagicMock(return_value=[future]) + + with patch.dict(sys.modules, {"distributed": fake_distributed}): + with pytest.raises(RuntimeError, match="worker failed"): + dag._run_frame_stage_dask( + shared_data, + frame_indices=[0], + client=client, + progress=None, + task=None, + ) + + client.cancel.assert_called_once_with([future]) + + +def test_run_frame_stage_dask_raises_when_distributed_missing(): + dag = LevelDAG() + client = MagicMock() + + real_import = __import__ + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): + if name == "distributed": + raise ImportError("No module named distributed") + return real_import(name, globals, locals, fromlist, level) + + with patch("builtins.__import__", side_effect=fake_import): + with pytest.raises(RuntimeError, match="requires dask.distributed"): + dag._run_frame_stage_dask( + {"keep": "value"}, + frame_indices=[0], + client=client, + progress=None, + task=None, + ) + + +def test_run_frame_stage_dask_raises_if_completed_count_mismatch(): + dag = LevelDAG() + + shared_data = {"keep": "value"} + client = MagicMock() + + future0 = MagicMock() + future0.result.return_value = (0, _empty_frame_out()) + + future1 = MagicMock() + + client.submit.side_effect = [future0, future1] + + fake_distributed = types.ModuleType("distributed") + fake_distributed.as_completed = MagicMock(return_value=[future0]) + + dag._reduce_one_frame = MagicMock() + + with patch.dict(sys.modules, {"distributed": fake_distributed}): + with pytest.raises( + RuntimeError, + match="Parallel frame execution completed 1 frames, but expected 2", + ): + dag._run_frame_stage_dask( + shared_data, + frame_indices=[0, 1], + client=client, + progress=None, + task=None, + ) + + client.cancel.assert_called_once() diff --git a/tests/unit/CodeEntropy/levels/test_level_dag_orchestration.py b/tests/unit/CodeEntropy/levels/test_level_dag_orchestration.py deleted file mode 100644 index 1238a9eb..00000000 --- a/tests/unit/CodeEntropy/levels/test_level_dag_orchestration.py +++ /dev/null @@ -1,386 +0,0 @@ -from unittest.mock import MagicMock, patch - -import numpy as np - -from CodeEntropy.levels.level_dag import LevelDAG - - -def _shared(): - return { - "levels": [["united_atom"]], - "frame_counts": {}, - "force_covariances": {}, - "torque_covariances": {}, - "force_counts": {}, - "torque_counts": {}, - "reduced_force_covariances": {}, - "reduced_torque_covariances": {}, - "reduced_force_counts": {}, - "reduced_torque_counts": {}, - "group_id_to_index": {0: 0}, - } - - -def test_execute_sets_default_axes_manager_once(): - dag = LevelDAG() - - shared = { - "reduced_universe": MagicMock(), - "start": 0, - "end": 0, - "step": 1, - "n_frames": 1, - } - - dag._run_static_stage = MagicMock() - dag._run_frame_stage = MagicMock() - - dag.execute(shared) - - assert "axes_manager" in shared - dag._run_static_stage.assert_called_once() - dag._run_frame_stage.assert_called_once() - - -def test_run_static_stage_calls_nodes_in_topological_sort_order(): - dag = LevelDAG() - dag._static_graph.add_node("a") - dag._static_graph.add_node("b") - - dag._static_nodes["a"] = MagicMock() - dag._static_nodes["b"] = MagicMock() - - with patch("networkx.topological_sort", return_value=["a", "b"]): - dag._run_static_stage({"X": 1}) - - dag._static_nodes["a"].run.assert_called_once() - dag._static_nodes["b"].run.assert_called_once() - - -def test_run_frame_stage_iterates_selected_frames_and_reduces_each(): - dag = LevelDAG() - - frame_source = MagicMock() - frame_source.iter_indices.return_value = [10, 11] - - shared = { - "frame_source": frame_source, - "n_frames": 2, - } - - frame_outputs = [ - { - "force": {"ua": {}, "res": {}, "poly": {}}, - "torque": {"ua": {}, "res": {}, "poly": {}}, - }, - { - "force": {"ua": {}, "res": {}, "poly": {}}, - "torque": {"ua": {}, "res": {}, "poly": {}}, - }, - ] - - dag._frame_dag = MagicMock() - dag._frame_dag.execute_frame.side_effect = frame_outputs - dag._reduce_one_frame = MagicMock() - - dag._run_frame_stage(shared) - - assert shared["n_frames"] == 2 - frame_source.iter_indices.assert_called_once() - - assert dag._frame_dag.execute_frame.call_count == 2 - dag._frame_dag.execute_frame.assert_any_call(shared, 10) - dag._frame_dag.execute_frame.assert_any_call(shared, 11) - - assert dag._reduce_one_frame.call_count == 2 - dag._reduce_one_frame.assert_any_call(shared, frame_outputs[0]) - dag._reduce_one_frame.assert_any_call(shared, frame_outputs[1]) - - -def test_incremental_mean_handles_non_copyable_values(): - out = LevelDAG._incremental_mean(old=None, new=3.0, n=1) - assert out == 3.0 - - -def test_reduce_forcetorque_no_key_is_noop(): - dag = LevelDAG() - shared = { - "forcetorque_covariances": {"res": [None], "poly": [None]}, - "forcetorque_counts": { - "res": np.zeros(1, dtype=int), - "poly": np.zeros(1, dtype=int), - }, - "group_id_to_index": {9: 0}, - } - dag._reduce_forcetorque(shared, frame_out={}) - assert shared["forcetorque_counts"]["res"][0] == 0 - assert shared["forcetorque_covariances"]["res"][0] is None - - -def test_build_registers_static_nodes_and_builds_frame_dag(): - with ( - patch("CodeEntropy.levels.level_dag.DetectMoleculesNode") as _, - patch("CodeEntropy.levels.level_dag.DetectLevelsNode") as _, - patch("CodeEntropy.levels.level_dag.BuildBeadsNode") as _, - patch("CodeEntropy.levels.level_dag.InitCovarianceAccumulatorsNode") as _, - patch("CodeEntropy.levels.level_dag.ComputeConformationalStatesNode") as _, - ): - dag = LevelDAG(universe_operations=MagicMock()) - dag._frame_dag.build = MagicMock() - - dag.build() - - assert "detect_molecules" in dag._static_nodes - assert "detect_levels" in dag._static_nodes - assert "build_beads" in dag._static_nodes - assert "init_covariance_accumulators" in dag._static_nodes - assert "compute_conformational_states" in dag._static_nodes - dag._frame_dag.build.assert_called_once() - - -def test_add_static_adds_dependency_edges(): - dag = LevelDAG() - dag._add_static("A", MagicMock()) - dag._add_static("B", MagicMock(), deps=["A"]) - - assert ("A", "B") in dag._static_graph.edges - - -def test_reduce_force_and_torque_hits_zero_count_branches(): - dag = LevelDAG() - - shared = { - "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "frame_counts": {"ua": {}, "res": [0], "poly": [0]}, - "group_id_to_index": {7: 0}, - } - - frame_out = { - "force": { - "ua": {(7, 0): np.eye(1)}, - "res": {7: np.eye(2)}, - "poly": {7: np.eye(3)}, - }, - "torque": { - "ua": {(7, 0): np.eye(1)}, - "res": {7: np.eye(2)}, - "poly": {7: np.eye(3)}, - }, - } - - dag._reduce_force_and_torque(shared, frame_out) - - assert shared["frame_counts"]["ua"][(7, 0)] == 1 - assert (7, 0) in shared["force_covariances"]["ua"] - assert (7, 0) in shared["torque_covariances"]["ua"] - - assert shared["frame_counts"]["res"][0] == 1 - assert shared["frame_counts"]["poly"][0] == 1 - - -def test_reduce_force_and_torque_handles_empty_frame_gracefully(): - dag = LevelDAG() - - shared = { - "group_id_to_index": {0: 0}, - "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "frame_counts": {"ua": {}, "res": [0], "poly": [0]}, - } - - frame_out = { - "force": {"ua": {}, "res": {}, "poly": {}}, - "torque": {"ua": {}, "res": {}, "poly": {}}, - } - - dag._reduce_force_and_torque(shared_data=shared, frame_out=frame_out) - - assert shared["force_covariances"]["ua"] == {} - assert shared["torque_covariances"]["ua"] == {} - assert shared["frame_counts"]["res"][0] == 0 - assert shared["frame_counts"]["poly"][0] == 0 - - -def test_reduce_force_and_torque_increments_res_and_poly_counts_from_zero(): - dag = LevelDAG() - - shared = { - "group_id_to_index": {7: 0}, - "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "frame_counts": {"ua": {}, "res": [0], "poly": [0]}, - } - - F = np.eye(3) - T = np.eye(3) * 2 - - frame_out = { - "force": {"ua": {}, "res": {7: F}, "poly": {7: F}}, - "torque": {"ua": {}, "res": {7: T}, "poly": {7: T}}, - } - - dag._reduce_force_and_torque(shared_data=shared, frame_out=frame_out) - - assert shared["frame_counts"]["res"][0] == 1 - assert shared["frame_counts"]["poly"][0] == 1 - assert np.allclose(shared["torque_covariances"]["res"][0], T) - assert np.allclose(shared["torque_covariances"]["poly"][0], T) - - -def test_reduce_one_frame_skips_missing_force_and_torque_keys(): - dag = LevelDAG() - shared = _shared() - - bead_key = (0, "united_atom", 0) - frame_out = { - "beads": {bead_key: [1, 2, 3]}, - "counts": {bead_key: 1}, - "force": {"ua": {}, "res": {}, "poly": {}}, - "torque": {"ua": {}, "res": {}, "poly": {}}, - } - - dag._reduce_one_frame(shared_data=shared, frame_out=frame_out) - - assert shared["force_covariances"] == {} - assert shared["torque_covariances"] == {} - - -def test_reduce_force_and_torque_skips_when_counts_are_zero(): - dag = LevelDAG() - shared = _shared() - - k = (0, "united_atom", 0) - shared["force_covariances"][k] = np.eye(3) - shared["torque_covariances"][k] = np.eye(3) - shared["force_counts"][k] = 0 - shared["torque_counts"][k] = 0 - shared["frame_counts"][k] = 0 - - frame_out = { - "force": {"ua": {}, "res": {}, "poly": {}}, - "torque": {"ua": {}, "res": {}, "poly": {}}, - "beads": {}, - } - - dag._reduce_force_and_torque(shared_data=shared, frame_out=frame_out) - - assert shared["reduced_force_covariances"] == {} - assert shared["reduced_torque_covariances"] == {} - assert shared["reduced_force_counts"] == {} - assert shared["reduced_torque_counts"] == {} - - -def test_run_static_stage_forwards_progress_when_node_accepts_it(): - dag = LevelDAG() - dag._static_graph.add_node("a") - - node = MagicMock() - dag._static_nodes["a"] = node - - progress = MagicMock() - - with patch("networkx.topological_sort", return_value=["a"]): - dag._run_static_stage({"X": 1}, progress=progress) - - node.run.assert_called_once_with({"X": 1}, progress=progress) - - -def test_run_static_stage_falls_back_when_node_does_not_accept_progress(): - dag = LevelDAG() - dag._static_graph.add_node("a") - - class NoProgressNode: - def run(self, shared_data): - return None - - dag._static_nodes["a"] = NoProgressNode() - progress = MagicMock() - - with patch("networkx.topological_sort", return_value=["a"]): - dag._run_static_stage({"X": 1}, progress=progress) # should not raise - - -def test_run_frame_stage_with_progress_creates_task_and_updates_titles(): - dag = LevelDAG() - - frame_source = MagicMock() - frame_source.iter_indices.return_value = [10, 11] - - shared = { - "frame_source": frame_source, - "n_frames": 2, - } - - frame_out = { - "force": {"ua": {}, "res": {}, "poly": {}}, - "torque": {"ua": {}, "res": {}, "poly": {}}, - } - - dag._frame_dag = MagicMock() - dag._frame_dag.execute_frame.return_value = frame_out - dag._reduce_one_frame = MagicMock() - - progress = MagicMock() - progress.add_task.return_value = 77 - - dag._run_frame_stage(shared, progress=progress) - - progress.add_task.assert_called_once_with( - "[green]Frame processing", - total=2, - title="Initializing", - ) - - assert progress.update.call_count == 2 - progress.update.assert_any_call(77, title="Frame 10") - progress.update.assert_any_call(77, title="Frame 11") - - assert progress.advance.call_count == 2 - progress.advance.assert_any_call(77) - - assert dag._frame_dag.execute_frame.call_count == 2 - dag._frame_dag.execute_frame.assert_any_call(shared, 10) - dag._frame_dag.execute_frame.assert_any_call(shared, 11) - - assert dag._reduce_one_frame.call_count == 2 - dag._reduce_one_frame.assert_any_call(shared, frame_out) - - -def test_run_frame_stage_progress_total_comes_from_frame_source_indices(): - dag = LevelDAG() - - frame_source = MagicMock() - frame_source.iter_indices.return_value = list(range(10)) - - shared = { - "frame_source": frame_source, - "n_frames": 0, - } - - frame_out = { - "force": {"ua": {}, "res": {}, "poly": {}}, - "torque": {"ua": {}, "res": {}, "poly": {}}, - } - - dag._frame_dag = MagicMock() - dag._frame_dag.execute_frame.return_value = frame_out - dag._reduce_one_frame = MagicMock() - - progress = MagicMock() - progress.add_task.return_value = 123 - - dag._run_frame_stage(shared, progress=progress) - - progress.add_task.assert_called_once_with( - "[green]Frame processing", - total=10, - title="Initializing", - ) - - assert shared["n_frames"] == 10 - frame_source.iter_indices.assert_called_once() - - assert dag._frame_dag.execute_frame.call_count == 10 - assert dag._reduce_one_frame.call_count == 10 - assert progress.advance.call_count == 10 diff --git a/tests/unit/CodeEntropy/levels/test_level_dag_reduce.py b/tests/unit/CodeEntropy/levels/test_level_dag_reduce.py deleted file mode 100644 index cb14bc1a..00000000 --- a/tests/unit/CodeEntropy/levels/test_level_dag_reduce.py +++ /dev/null @@ -1,208 +0,0 @@ -import numpy as np - -from CodeEntropy.levels.level_dag import LevelDAG - - -def test_incremental_mean_first_sample_copies(): - x = np.array([1.0, 2.0]) - out = LevelDAG._incremental_mean(None, x, n=1) - assert np.allclose(out, x) - x[0] = 999.0 - assert out[0] != 999.0 - - -def test_reduce_force_and_torque_exercises_count_branches(): - dag = LevelDAG() - - shared = { - "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "frame_counts": {"ua": {}, "res": [0], "poly": [0]}, - "group_id_to_index": {7: 0}, - } - - frame_out = { - "force": { - "ua": {(9, 0): np.array([1.0])}, - "res": {7: np.array([2.0])}, - "poly": {7: np.array([3.0])}, - }, - "torque": { - "ua": {(9, 0): np.array([4.0])}, - "res": {7: np.array([5.0])}, - "poly": {7: np.array([6.0])}, - }, - } - - dag._reduce_force_and_torque(shared, frame_out) - - assert (9, 0) in shared["torque_covariances"]["ua"] - assert shared["frame_counts"]["res"][0] == 1 - assert shared["frame_counts"]["poly"][0] == 1 - - -def test_reduce_forcetorque_returns_when_missing_key(): - dag = LevelDAG() - shared = { - "forcetorque_covariances": {"res": [None], "poly": [None]}, - "forcetorque_counts": {"res": [0], "poly": [0]}, - "group_id_to_index": {7: 0}, - } - dag._reduce_forcetorque(shared, frame_out={}) - assert shared["forcetorque_counts"]["res"][0] == 0 - - -def test_reduce_forcetorque_updates_res_and_poly(): - dag = LevelDAG() - - shared = { - "forcetorque_covariances": {"res": [None], "poly": [None]}, - "forcetorque_counts": {"res": [0], "poly": [0]}, - "group_id_to_index": {7: 0}, - } - - frame_out = { - "forcetorque": { - "res": {7: np.array([1.0, 1.0])}, - "poly": {7: np.array([2.0, 2.0])}, - } - } - - dag._reduce_forcetorque(shared, frame_out) - - assert shared["forcetorque_counts"]["res"][0] == 1 - assert shared["forcetorque_counts"]["poly"][0] == 1 - assert shared["forcetorque_covariances"]["res"][0] is not None - assert shared["forcetorque_covariances"]["poly"][0] is not None - - -def test_reduce_force_and_torque_res_torque_increments_when_res_count_is_zero(): - dag = LevelDAG() - shared = { - "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "frame_counts": {"ua": {}, "res": [0], "poly": [0]}, - "group_id_to_index": {7: 0}, - } - - frame_out = { - "force": {"ua": {}, "res": {}, "poly": {}}, - "torque": {"ua": {}, "res": {7: np.eye(3)}, "poly": {}}, - } - - dag._reduce_force_and_torque(shared, frame_out) - - assert shared["frame_counts"]["res"][0] == 1 - assert shared["torque_covariances"]["res"][0] is not None - - -def test_reduce_force_and_torque_poly_torque_increments_when_poly_count_is_zero(): - dag = LevelDAG() - shared = { - "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "frame_counts": {"ua": {}, "res": [0], "poly": [0]}, - "group_id_to_index": {7: 0}, - } - - frame_out = { - "force": {"ua": {}, "res": {}, "poly": {}}, - "torque": {"ua": {}, "res": {}, "poly": {7: np.eye(3)}}, - } - - dag._reduce_force_and_torque(shared, frame_out) - - assert shared["frame_counts"]["poly"][0] == 1 - assert shared["torque_covariances"]["poly"][0] is not None - - -def test_reduce_force_and_torque_increments_ua_frame_counts_for_force(): - dag = LevelDAG() - - shared = { - "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "frame_counts": {"ua": {}, "res": [0], "poly": [0]}, - "group_id_to_index": {7: 0}, - } - - k = (9, 0) - frame_out = { - "force": {"ua": {k: np.eye(3)}, "res": {}, "poly": {}}, - "torque": {"ua": {}, "res": {}, "poly": {}}, - } - - dag._reduce_force_and_torque(shared, frame_out) - - assert shared["frame_counts"]["ua"][k] == 1 - assert k in shared["force_covariances"]["ua"] - - -def test_reduce_force_and_torque_increments_ua_counts_from_zero(): - dag = LevelDAG() - - key = (9, 0) - F = np.eye(3) - - shared = { - "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "frame_counts": {"ua": {}, "res": [0], "poly": [0]}, - "group_id_to_index": {7: 0}, - } - - frame_out = { - "force": {"ua": {key: F}, "res": {}, "poly": {}}, - "torque": {"ua": {}, "res": {}, "poly": {}}, - } - - dag._reduce_force_and_torque(shared, frame_out) - - assert shared["frame_counts"]["ua"][key] == 1 - - np.testing.assert_array_equal(shared["force_covariances"]["ua"][key], F) - - -def test_reduce_force_and_torque_hits_ua_force_count_increment_line(): - dag = LevelDAG() - key = (9, 0) - - shared = { - "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "frame_counts": {"ua": {}, "res": [0], "poly": [0]}, - "group_id_to_index": {7: 0}, - } - - frame_out = { - "force": {"ua": {key: np.eye(3)}, "res": {}, "poly": {}}, - "torque": {"ua": {}, "res": {}, "poly": {}}, - } - - dag._reduce_force_and_torque(shared, frame_out) - - assert shared["frame_counts"]["ua"][key] == 1 - - -def test_reduce_force_and_torque_ua_torque_increments_count_when_force_missing_key(): - dag = LevelDAG() - - key = (9, 0) - T = np.eye(3) - - shared = { - "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "frame_counts": {"ua": {}, "res": [0], "poly": [0]}, - "group_id_to_index": {7: 0}, - } - - frame_out = { - "force": {"ua": {}, "res": {}, "poly": {}}, - "torque": {"ua": {key: T}, "res": {}, "poly": {}}, - } - - dag._reduce_force_and_torque(shared, frame_out) - - assert shared["frame_counts"]["ua"][key] == 1 - np.testing.assert_array_equal(shared["torque_covariances"]["ua"][key], T) diff --git a/tests/unit/CodeEntropy/levels/test_level_dag_reduction.py b/tests/unit/CodeEntropy/levels/test_level_dag_reduction.py deleted file mode 100644 index 68e981e9..00000000 --- a/tests/unit/CodeEntropy/levels/test_level_dag_reduction.py +++ /dev/null @@ -1,106 +0,0 @@ -from unittest.mock import MagicMock - -import numpy as np - -from CodeEntropy.levels.level_dag import LevelDAG - - -def test_incremental_mean_none_returns_copy_for_numpy(): - arr = np.array([1.0, 2.0]) - out = LevelDAG._incremental_mean(None, arr, n=1) - assert np.all(out == arr) - arr[0] = 999.0 - assert out[0] != 999.0 - - -def test_incremental_mean_updates_mean_correctly(): - old = np.array([2.0, 2.0]) - new = np.array([4.0, 0.0]) - out = LevelDAG._incremental_mean(old, new, n=2) - np.testing.assert_allclose(out, np.array([3.0, 1.0])) - - -def test_reduce_force_and_torque_updates_counts_and_means(): - dag = LevelDAG() - - shared = { - "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "frame_counts": { - "ua": {}, - "res": np.zeros(1, dtype=int), - "poly": np.zeros(1, dtype=int), - }, - "group_id_to_index": {9: 0}, - } - - F1 = np.eye(3) - T1 = 2.0 * np.eye(3) - - frame_out = { - "force": {"ua": {(0, 0): F1}, "res": {9: F1}, "poly": {}}, - "torque": {"ua": {(0, 0): T1}, "res": {9: T1}, "poly": {}}, - } - - dag._reduce_force_and_torque(shared, frame_out) - - assert shared["frame_counts"]["ua"][(0, 0)] == 1 - np.testing.assert_allclose(shared["force_covariances"]["ua"][(0, 0)], F1) - np.testing.assert_allclose(shared["torque_covariances"]["ua"][(0, 0)], T1) - - assert shared["frame_counts"]["res"][0] == 1 - np.testing.assert_allclose(shared["force_covariances"]["res"][0], F1) - np.testing.assert_allclose(shared["torque_covariances"]["res"][0], T1) - - -def test_reduce_forcetorque_no_key_is_noop(): - dag = LevelDAG() - shared = { - "forcetorque_covariances": {"res": [None], "poly": [None]}, - "forcetorque_counts": { - "res": np.zeros(1, dtype=int), - "poly": np.zeros(1, dtype=int), - }, - "group_id_to_index": {9: 0}, - } - - dag._reduce_forcetorque(shared, frame_out={}) - assert shared["forcetorque_counts"]["res"][0] == 0 - assert shared["forcetorque_covariances"]["res"][0] is None - - -def test_run_frame_stage_calls_execute_frame_for_each_frame_index(): - dag = LevelDAG() - - frame_source = MagicMock() - frame_source.iter_indices.return_value = list(range(10)) - - shared = { - "frame_source": frame_source, - "n_frames": 10, - } - - frame_out = { - "force": {"ua": {}, "res": {}, "poly": {}}, - "torque": {"ua": {}, "res": {}, "poly": {}}, - } - - dag._frame_dag = MagicMock() - dag._frame_dag.execute_frame.side_effect = lambda shared_data, frame_index: ( - frame_out - ) - - dag._reduce_one_frame = MagicMock() - - dag._run_frame_stage(shared) - - frame_source.iter_indices.assert_called_once() - - assert dag._frame_dag.execute_frame.call_count == 10 - assert dag._reduce_one_frame.call_count == 10 - - for frame_index in range(10): - dag._frame_dag.execute_frame.assert_any_call(shared, frame_index) - - for call in dag._reduce_one_frame.call_args_list: - assert call.args == (shared, frame_out)