diff --git a/examples-proposed/020-simple-ensemble/driver.py b/examples-proposed/020-simple-ensemble/driver.py index 0f1f5a8a..36028dc6 100644 --- a/examples-proposed/020-simple-ensemble/driver.py +++ b/examples-proposed/020-simple-ensemble/driver.py @@ -55,11 +55,16 @@ def step(self, timestamp=0.0): # given set of variables. E.g., the instance corresponding to # {'A' : 2, 'B' : 5.82, 'C' : 'baz'} is probably found in the # `INSTANCE_1` subdirectory. + + # We also demonstrate that stdout and stderr output per instance can + # be captured in files by specifying logfile and errfile, respectively. mapping = self.services.run_ensemble(template, variables, run_dir=Path('.').absolute(), name='INSTANCE_', num_nodes=1, - cores_per_instance=1) + cores_per_instance=1, + logfile='stdout.txt', + errfile='stderr.txt') # Print each mapping of instance name to what variable values were used. for instance in mapping: - self.services.info(f'{instance!s}') \ No newline at end of file + self.services.info(f'{instance!s}') diff --git a/examples-proposed/020-simple-ensemble/ensemble.conf b/examples-proposed/020-simple-ensemble/ensemble.conf index 19ca57b5..706af55f 100644 --- a/examples-proposed/020-simple-ensemble/ensemble.conf +++ b/examples-proposed/020-simple-ensemble/ensemble.conf @@ -1,7 +1,7 @@ SIM_NAME = simpleensemble SIM_ROOT = $PWD/ENSEMBLES LOG_FILE = log -LOG_LEVEL = INFO +LOG_LEVEL = DEBUG SIMULATION_MODE = NORMAL [PORTS] diff --git a/examples-proposed/020-simple-ensemble/instance_component.py b/examples-proposed/020-simple-ensemble/instance_component.py index 3effca67..c69feca9 100644 --- a/examples-proposed/020-simple-ensemble/instance_component.py +++ b/examples-proposed/020-simple-ensemble/instance_component.py @@ -16,6 +16,8 @@ class InstanceComponent(Component): def step(self, timestamp: float = 0.0, **keywords): start = time() + + # ENSEMBLE_INSTANCE is a special IPS variable that contains the # string uniquely identifying this instance. Each instance will have # the `run_ensemble()` `name` argument prepended to a unique number @@ -24,6 +26,8 @@ def step(self, timestamp: float = 0.0, **keywords): self.services.info(f'{instance_id}: Start of step of instance ' f'component.') + print(f'start of instance component for {instance_id}') + # Echo the parameters we're expecting, A, B, and C self.services.info(f'{instance_id}: instance component parameters: ' f'A={self.A}, B={self.B}, C={self.C}') @@ -38,6 +42,7 @@ def step(self, timestamp: float = 0.0, **keywords): writer.writerow([instance_id, sys.argv[0], run_env['hostname'], run_env['pid'], run_env['core_id'], start, time()]) + print(f'Wrote stats.csv for {instance_id}') self.services.info(f'{instance_id}: End of step of instance ' f'component.') diff --git a/examples-proposed/020-simple-ensemble/template.conf b/examples-proposed/020-simple-ensemble/template.conf index 71a4b3d8..9e64e7d7 100644 --- a/examples-proposed/020-simple-ensemble/template.conf +++ b/examples-proposed/020-simple-ensemble/template.conf @@ -1,7 +1,7 @@ SIM_NAME = simpleensembleinstance SIM_ROOT = $PWD LOG_FILE = log -LOG_LEVEL = INFO +LOG_LEVEL = DEBUG SIMULATION_MODE = NORMAL [PORTS] diff --git a/ipsframework/services.py b/ipsframework/services.py index d9e33573..d7ac7f9c 100644 --- a/ipsframework/services.py +++ b/ipsframework/services.py @@ -11,6 +11,7 @@ import logging.handlers from datetime import datetime import os +import platform import queue import shutil import signal @@ -74,7 +75,7 @@ def launch(executable: Any, * `logfile` - where the task output is written; if not specified, STDOUT used * `errfile` - where the task error output is written; if not specified, - STDOUT used + STDERR used * `task_env` - A dictionary of environment variables to set * `timeout` - The timeout in seconds for the task to complete. * `cpus_per_proc` - The number of cpus per process to use for the task. @@ -94,14 +95,17 @@ def launch(executable: Any, import logging from dask.distributed import get_worker # pylint: disable=import-outside-toplevel - # Later, we use Client.forward_logging() to handle these log messages. - log = logging.getLogger('launch') - worker = get_worker() task_key = worker.get_current_task() + # Later, we use Client.forward_logging() to handle these log messages. We + # access the root logger for forward_logging() to work. + log = logging.getLogger() + log.info(f'Launching task {task_name} with id {task_key!s} and ' f'worker {worker.name!s} in {working_dir}') + print(f'Launching task {task_name} with id {task_key!s} and ' + f'worker {worker.name!s} in {working_dir}') start_time = time.time() working_dir_path = Path(working_dir) @@ -112,12 +116,14 @@ def launch(executable: Any, # via a subprocess.Popen() # Do we write the Popen stdout to sys.stdout or to a file? - subprocess_stdout = sys.stdout + subprocess_stdout = subprocess.PIPE close_stdout = False # is true if we need to later close the file + log_path = None try: log_filename = kwargs['logfile'] except KeyError: log.info('No logfile specified, using stdout for task output') + print('No logfile specified, using stdout for task output') else: log_path = Path(log_filename) if not log_path.is_absolute(): @@ -125,27 +131,36 @@ def launch(executable: Any, subprocess_stdout = open(log_path, 'w') close_stdout = True # Welp, gotta close it now log.info(f'Task output log file: {log_path}') + print(f'Task output log file: {log_path}') # Repeat the same for stderr - subprocess_errfile = subprocess.STDOUT + subprocess_stderr = subprocess.STDOUT close_stderr = False try: - subprocess_errfile = kwargs['errfile'] + err_filename = kwargs['errfile'] except KeyError: log.info('No errfile specified, using STDOUT for task errors') + print('No errfile specified, using STDOUT for task errors') else: - err_path = Path(subprocess_errfile) + err_path = Path(err_filename) if not err_path.is_absolute(): err_path = working_dir_path / err_path - try: - subprocess_errfile = open(err_path, 'w') - except OSError: - log.info(f'Could not open errfile {err_path}, ' - f'using STDOUT for task errors') - subprocess_errfile = subprocess.STDOUT + if log_path is not None and err_path.resolve(strict=False) == log_path.resolve(strict=False): + log.info(f'Task error log file matches output log file: {log_path}') + print(f'Task error log file matches output log file: {log_path}') else: - close_stderr = True - log.info(f'Task error log file: {err_path}') + try: + subprocess_stderr = open(err_path, 'w') + except OSError: + log.info(f'Could not open errfile {err_path}, ' + f'using STDOUT for task errors') + print(f'Could not open errfile {err_path}, ' + f'using STDOUT for task errors') + subprocess_stderr = subprocess.STDOUT + else: + close_stderr = True + log.info(f'Task error log file: {err_path}') + print(f'Task error log file: {err_path}') task_env = kwargs.get('task_env', {}) new_env = os.environ.copy() @@ -160,9 +175,11 @@ def launch(executable: Any, dvm_uri_file = Path(worker.dvm_uri_file) if not dvm_uri_file.exists(): log.error(f'DVM URI file {dvm_uri_file} does not exist') + print(f'DVM URI file {dvm_uri_file} does not exist') # print(f'DVM URI file {dvm_uri_file} does not exist', flush=True) else: log.debug(f'Using DVM URI file: {dvm_uri_file}') + print(f'Using DVM URI file: {dvm_uri_file}') # print(f'Using DVM URI file: {dvm_uri_file}', flush=True) # PMIX_SERVER_URI41 is used by prun to figure out how to talk to the DVM @@ -174,12 +191,18 @@ def launch(executable: Any, log.debug(f"DVM environment variable PMIX_SERVER_URI41 " f"set in task_env to " f"{task_env['PMIX_SERVER_URI41']}") + print(f"DVM environment variable PMIX_SERVER_URI41 " + f"set in task_env to " + f"{task_env['PMIX_SERVER_URI41']}") # print(f'DVM environment variable PMIX_SERVER_URI41 set in task_' # f'env to {task_env["PMIX_SERVER_URI41"]}', flush=True) if 'PMIX_SERVER_URI41' in os.environ: log.debug(f"DVM environment variable PMIX_SERVER_URI41 set " f"in os.environ to " f"{os.environ['PMIX_SERVER_URI41']}") + print(f"DVM environment variable PMIX_SERVER_URI41 set " + f"in os.environ to " + f"{os.environ['PMIX_SERVER_URI41']}") # print(f'DVM environment variable PMIX_SERVER_URI41 set in os.environ ' # f'to {os.environ["PMIX_SERVER_URI41"]}', flush=True) @@ -188,6 +211,7 @@ def launch(executable: Any, cmd = f'{executable} {" ".join(map(str, args))}' log.debug(f'Launching task {task_name} with command: {cmd}') + print(f'Launching task {task_name} with command: {cmd}') worker.log_event('ips', { @@ -205,8 +229,9 @@ def launch(executable: Any, try: process = subprocess.Popen(cmd_lst, stdout=subprocess_stdout, - stderr=subprocess_errfile, + stderr=subprocess_stderr, cwd=working_dir_path, + text=True, preexec_fn=os.setsid, env=new_env) # noqa: PLW1509 (TODO: look into this to potentially avoid deadlocks) except Exception as e: worker.log_event('ips', @@ -221,6 +246,8 @@ def launch(executable: Any, }) log.error(f'Failed to launch task {task_name} with ' f'command {cmd}: {e}') + print(f'Failed to launch task {task_name} with ' + f'command {cmd}: {e}') raise try: @@ -253,6 +280,8 @@ def launch(executable: Any, process.kill() log.error(f'Task {task_name} with command {cmd} timed out ' f'after {timeout}s') + print(f'Task {task_name} with command {cmd} timed out ' + f'after {timeout}s') ret_val = -1 except Exception as e: worker.log_event('ips', @@ -264,12 +293,19 @@ def launch(executable: Any, f'Exception when calling ' f'{executable!s}: {e!s}'}) log.error(f'Task {task_name} with command {cmd} failed with {e!s}') + print(f'Task {task_name} with command {cmd} failed with {e!s}') finally: + if 'logfile' not in kwargs: + print(process.stdout.read() if process and process.stdout else '') + if 'errfile' not in kwargs: + print(process.stderr.read() if process and process.stderr else '') + if close_stdout: subprocess_stdout.close() if close_stderr: - subprocess_errfile.close() + subprocess_stderr.close() + elif isinstance(executable, Callable): # binary not a string, but is a python callable, so we call it directly # with the given *args @@ -316,6 +352,8 @@ def launch(executable: Any, f'{executable!s}: {e!s}'}) log.error(f'Task {task_name} with callable {executable!s} failed ' f'with {e!s}') + print(f'Task {task_name} with callable {executable!s} failed ' + f'with {e!s}') finally: os.chdir(str(original_dir)) else: @@ -323,10 +361,44 @@ def launch(executable: Any, f'callable, cannot launch task {task_name}') log.info(f'Task {task_name} finished with return value: {ret_val}') + print(f'Task {task_name} finished with return value: {ret_val}') return task_name, ret_val +def launch_mapped_task( + executable: Any, + task_name: str, + working_dir: Union[str, os.PathLike], + task_args: Iterable[Any], + task_keywords: dict[str, Any], + cpus_per_proc: int, + worker_event_logfile: Optional[str]): + """ Adapt task-specific launch arguments for :meth:`Client.map`. + + This is a wrapper for `launch()` because we need to ensure `cpus_per_proc` + and `worker_event_logfile` get stuffed into the expected `task_args` and + `task_keywords` that `launch()` expects. + + This is invoked in :meth:`TaskPool.submit_dask_tasks()`. + + TODO remove `worker_event_logfile` since that's no longer needed. + + :param executable: to be invoked + :param task_name: name of the task + :param working_dir: working directory + :param task_args: list of arguments + :param task_keywords: keyword arguments + :param cpus_per_proc: number of cpus + :param worker_event_logfile: event logfile + """ + task_keywords = dict(task_keywords) + task_keywords['cpus_per_proc'] = cpus_per_proc + task_keywords['worker_event_logfile'] = worker_event_logfile + + return launch(executable, task_name, working_dir, *task_args, **task_keywords) + + class ServicesProxy: """The *ServicesProxy* object is responsible for marshalling invocations of framework services to the framework process using a @@ -2290,6 +2362,8 @@ def submit_tasks( dask_worker_per_gpu=False, oversubscribe=False, hwthreads=False, + logfile=None, + errfile=None, ): """ Launch all unfinished tasks in task pool *task_pool_name*. If *block* is ``True``, @@ -2316,6 +2390,8 @@ def submit_tasks( :param hwthreads: if True, use hardware threads as the basis for resource allocation; if False, use physical cores as the basis for resource allocation + :param logfile: optional default file name for redirected task stdout + :param errfile: optional default file name for redirected task stderr :returns: task return value """ start_time = time.time() @@ -2325,7 +2401,8 @@ def submit_tasks( retval = task_pool.submit_tasks( block, use_dask, dask_nodes, dask_ppw, launch_interval, use_shifter, shifter_args, dask_worker_plugin, - dask_worker_per_gpu, oversubscribe, hwthreads + dask_worker_per_gpu, oversubscribe, hwthreads, + logfile, errfile ) elapsed_time = time.time() - start_time self._send_monitor_event('IPS_TASK_POOL_END', @@ -2735,7 +2812,7 @@ def send_ensemble_instance_to_portal(ensemble_name: str, data_path: Path) -> Non num_nodes = int(os.environ['SLURM_JOB_NUM_NODES']) elif num_nodes is None: num_nodes = 1 - self.debug(f'run_ensemble() num_nodes = {num_nodes}') + self.info(f'run_ensemble() num_nodes = {num_nodes}') # Ensure that we create a unique task pool name for this using the @@ -2835,6 +2912,8 @@ def send_ensemble_instance_to_portal(ensemble_name: str, data_path: Path) -> Non dask_ppw=cores_per_instance, oversubscribe=oversubscribe, hwthreads=hwthreads, + logfile=logfile, + errfile=errfile, # launch_interval=0.0, # use_shifter=False, # shifter_args=None, @@ -2846,9 +2925,9 @@ def send_ensemble_instance_to_portal(ensemble_name: str, data_path: Path) -> Non self.critical(f'Got an exception running ensemble: {e!s}') traceback.print_exc() finally: - self.debug('Getting finished tasks') - exit_status = self.get_finished_tasks(task_pool_name) - self.info(f'Finished tasks: {exit_status!s}') + # self.debug('Getting finished tasks') + # exit_status = self.get_finished_tasks(task_pool_name) + # self.info(f'Finished tasks: {exit_status!s}') self.remove_task_pool(task_pool_name) @@ -2896,8 +2975,8 @@ def setup(self, worker: Worker): self.logger.info('Launching DVM') self.worker.dvm_uri_file = f'/tmp/dvm.uri.{os.getpid()}' command = [#'srun', '--mpi=pmix_v4', '-N', os.environ['SLURM_NNODES'], '--ntasks-per-node=1', - 'prte', #'--no-daemonize', - '--report-uri', self.worker.dvm_uri_file] + 'prte', #'--no-daemonize', + '--report-uri', self.worker.dvm_uri_file] mapping_policy = 'core' # by default bind to cores if self.hwthreads: @@ -2942,11 +3021,28 @@ def setup(self, worker: Worker): def teardown(self, worker: Worker): self.logger.info(f'Shutting down DVM at {self.worker.dvm_uri}') - command = ['pterm', '--dvm-uri', self.worker.dvm_uri] - subprocess.call(command) + + # On some systems we use `pterm` to shut down the DVM, and on others we + # use `prte-term`, so check for both. + pterm_cmd = shutil.which('pterm') + if pterm_cmd is None: + # On MacOS homebrew, pterm -> prte-term + pterm_cmd = shutil.which('prte-term') + if pterm_cmd is None: + # if it's *still* none, then there is a serious + # configuration problem. + self.logger.critical('Neither pterm nor prte-term command found') + else: + self.logger.debug(f'DVMPluggin.teardown(), pterm: {pterm_cmd!s}') + command = [pterm_cmd, '--dvm-uri', self.worker.dvm_uri] + subprocess.call(command) + # Regardless if we have `pterm` or `prte-term`, we can still just + # kill the process directly. self.worker.dvm_proc.terminate() self.worker.dvm_proc.kill() + self.logger.info('DVM shutdown') + class TaskPool: """ @@ -3059,6 +3155,16 @@ def add_task(self, task_name: str, nproc: int, working_dir: str, binary: str, *a self.serial_pool = self.serial_pool and (nproc == 1) self.queued_tasks[task_name] = Task(task_name, nproc, working_dir, binary_fullpath, *args, **keywords['keywords']) + @staticmethod + def _launch_keywords_with_defaults(task_keywords, logfile=None, errfile=None): + """Return launch keywords with submission-level log defaults applied.""" + keywords = dict(task_keywords) + if logfile: + keywords.setdefault('logfile', logfile) + if errfile: + keywords.setdefault('errfile', errfile) + return keywords + def _process_dask_event(self, event): """ This will create an IPS monitor event from a Dask event @@ -3095,7 +3201,9 @@ def submit_dask_tasks( dask_worker_plugin=None, dask_worker_per_gpu=False, oversubscribe=False, - hwthreads=False + hwthreads=False, + logfile=None, + errfile=None, ): """Launch tasks in *queued_tasks* using dask. @@ -3127,6 +3235,10 @@ def submit_dask_tasks( :type oversubscribe: bool :param hwthreads: Whether to use hardware threads :type hwthreads: bool + :param logfile: Optional default file name for redirected task stdout + :type logfile: str + :param errfile: Optional default file name for redirected task stderr + :type errfile: str FIXME consider having n processes instead of n threads given that we're likely running in a HPC context. @@ -3333,6 +3445,20 @@ def _make_worker_args(num_workers: int, num_threads: int, use_shifter: bool, shi hwthreads=hwthreads)) self.services.debug('Registered DVMPlugin') + # Wait for so many workers to be online before proceeding to + # more evenly spread the load instead of biasing the tasks by the + # first set of workers to spin up. Note that we don't wait for + # 100% of the workers since it's possible that a few will have problems + # (e.g., due to node failures). + if dask_nodes > 1: + num_to_wait_for = max(1, int(dask_nodes * 0.8)) + self.services.debug(f'Waiting for {num_to_wait_for} Dask workers') + self.dask_client.wait_for_workers(num_to_wait_for) + self.services.debug(f'Have {num_to_wait_for} Dask workers available ... proceeding') + else: + self.services.info('Only a single Dask worker needed, proceeding') + + try: # FIXME this is deprecated, but be mindful of blithely deleting file_id = str(self.services._portal_runid) if self.services._portal_runid > 0 else self.services._fallback_portal_runid @@ -3342,26 +3468,44 @@ def _make_worker_args(num_workers: int, num_threads: int, use_shifter: bool, shi # USE_PORTAL == False self.worker_event_logfile = None + # accumulate arguments for different tasks in lists suitable for + # invoking map(). launch.__module__ = '__main__' - self.futures = [] + launch_mapped_task.__module__ = '__main__' + task_names = [] + binaries = [] + working_dirs = [] + task_args = [] + task_keywords = [] + cpus_per_procs = [] + worker_event_logfiles = [] for task_name, task in self.queued_tasks.items(): self.services.debug(f'Submitting task {task_name} to dask client with {dask_ppw} cores per worker') self.services.debug(f'Task {task_name} working dir: {task.working_dir}') self.services.debug(f'Task args: {task.args} keywords: {task.keywords}') - self.futures.append( - self.dask_client.submit( - launch, - task.binary, - task_name, - task.working_dir, - *task.args, - **task.keywords, - pure=False, - key=task_name, - cpus_per_proc=dask_ppw, - worker_event_logfile=self.worker_event_logfile, - ) - ) + keywords = self._launch_keywords_with_defaults(task.keywords, + logfile, + errfile) + task_names.append(task_name) + binaries.append(task.binary) + working_dirs.append(task.working_dir) + task_args.append(task.args) + task_keywords.append(keywords) + cpus_per_procs.append(dask_ppw) + worker_event_logfiles.append(self.worker_event_logfile) + + self.futures = self.dask_client.map( + launch_mapped_task, + binaries, + task_names, + working_dirs, + task_args, + task_keywords, + cpus_per_procs, + worker_event_logfiles, + pure=False, + key=task_names, + ) self.active_tasks = self.queued_tasks self.queued_tasks = {} @@ -3378,6 +3522,9 @@ def _make_worker_args(num_workers: int, num_threads: int, use_shifter: bool, shi # Set this to empty list so that get_dask_finished_tasks_status # doesn't try to gather() needlessly again. self.futures = [] + + # Since we're done with Dask, let's shut it down + self._shutdown_dask() else: self.services.debug(f'submit_dask_tasks: not blocking tasks') @@ -3395,7 +3542,9 @@ def submit_tasks( dask_worker_plugin=None, dask_worker_per_gpu=False, oversubscribe=False, - hwthreads=False + hwthreads=False, + logfile=None, + errfile=None, ): """Launch tasks in *queued_tasks*. Finished tasks are handled before launching new ones. If *block* is ``True``, the number of @@ -3434,6 +3583,10 @@ def submit_tasks( :param hwthreads: If True then use hardware threads when launching tasks. Default is False. :type hwthreads: bool + :param logfile: Optional default file name for redirected task stdout + :type logfile: str + :param errfile: Optional default file name for redirected task stderr + :type errfile: str :returns: """ if use_dask: @@ -3447,7 +3600,8 @@ def submit_tasks( return self.submit_dask_tasks( block, dask_nodes, dask_ppw, use_shifter, shifter_args, dask_worker_plugin, - dask_worker_per_gpu, oversubscribe, hwthreads + dask_worker_per_gpu, oversubscribe, hwthreads, + logfile, errfile ) elif not TaskPool.dask or not TaskPool.distributed: raise RuntimeError( @@ -3484,18 +3638,23 @@ def _shutdown_dask(self): Shut down the dask client, scheduler, and workers. Side effect is setting self.dask_sched_pid and self.dask_client - to None. + to None as well as other internal state. :returns: None """ + self.services.debug('Shutting down Dask client, scheduler, and workers') + # Gently release any pending futures for f in self.futures: f.release() + self.services.debug('Released pending futures') if self.dask_client is not None: # Shutdown handles ending client, scheduler, and workers self.dask_client.unsubscribe_topic('ips') # unregister handler + self.services.debug('Unsubscribed from Dask client ips topic') self.dask_client.shutdown() + self.services.debug('Shutdown Dask client') # TODO a more gentle way to shutdown: # 1. self.dask_client.close() @@ -3520,6 +3679,22 @@ def _shutdown_dask(self): # # time.sleep(1) # Give time for the scheduler to shut down + self.finished_tasks = {} + self.active_tasks = {} + self.services.wait_task(self.dask_workers_tid) + self.dask_scheduler_file = None + self.dask_workers_tid = None + self.dask_sched_pid: Optional[int] = None + self.dask_sched_popen = None + self.dask_pool = False + + # Presumably the default state for TaskPool is serial task execution, so + # we revert to that after the Dask system is shutdown. + self.serial_pool = True + + self.services.debug('Shutdown Dask system') + + def get_dask_finished_tasks_status(self): """Return a dictionary of exit status values for all dask tasks that have finished since the last time finished tasks were polled. @@ -3597,17 +3772,7 @@ def get_dask_finished_tasks_status(self): self._shutdown_dask() self.services.debug(f'get_dask_finished_tasks_status: after _shutdown_dask()') - # TODO These probably should be migrated to _shutdown_dask() since - # these are part of that housekeeping. - self.finished_tasks = {} - self.active_tasks = {} - self.services.wait_task(self.dask_workers_tid) - self.dask_scheduler_file = None - self.dask_workers_tid = None - self.dask_sched_pid: Optional[int] = None - self.dask_sched_popen = None - self.dask_pool = False - self.serial_pool = True + if result is not None: self.services.debug('get_dask_finished_tasks_status: have result') diff --git a/ipsframework/taskManager.py b/ipsframework/taskManager.py index 8311cd75..92c55aef 100644 --- a/ipsframework/taskManager.py +++ b/ipsframework/taskManager.py @@ -379,7 +379,12 @@ def build_launch_cmd( ppn_flag = '-npernode' host_select = '-H' if smp_node or mpi_binary == 'prun': - cmd = ' '.join([mpicmd, nproc_flag, str(nproc)]) + # --display MAP-DEVEL is added to show the DVM state when + # invoking this prun. We do this so that we can verify the + # resources managed by DVM for this task as displayed in + # detailed messages sent to stdout prior to running the + # desired IPS task. + cmd = ' '.join([mpicmd, '--display', 'ALLOCATION,MAP-DEVEL,BINDINGS', nproc_flag, str(nproc)]) else: cmd = ' '.join([mpicmd, nproc_flag, str(nproc), ppn_flag, str(ppn)]) cmd = f'{cmd} -x PYTHONPATH' # Propagate PYTHONPATH to compute nodes diff --git a/tests/new/test_run_ensemble.py b/tests/new/test_run_ensemble.py new file mode 100644 index 00000000..1197f51e --- /dev/null +++ b/tests/new/test_run_ensemble.py @@ -0,0 +1,221 @@ +import logging +import os + +from ipsframework import ServicesProxy, TaskPool +from ipsframework import services as services_module + + +class DummyFramework: + logger = logging.getLogger(__name__) + + +class DummyDaskWorker: + name = 'worker_0' + + def get_current_task(self): + return 'task-key' + + def log_event(self, topic, event): + pass + + +def write_stdout_stderr_script(tmpdir): + script = tmpdir.join('write_stdout_stderr.sh') + script.write('#!/bin/sh\necho stdout-line\necho stderr-line >&2\n') + script.chmod(448) # 700 + return script + + +def test_run_ensemble_passes_logfile_and_errfile_to_add_task(tmpdir, monkeypatch): + template = tmpdir.join('template.config') + template.write('[comp]\nA = ?\n') + run_dir = tmpdir.mkdir('runs') + + services = ServicesProxy(None, None, None, {'USE_PORTAL': 'False'}, None) + services.fwk = DummyFramework() + services.logger = logging.getLogger(__name__) + + submitted_kwargs = [] + submit_tasks_kwargs = [] + + def record_add_task(task_pool_name, task_name, nproc, working_dir, binary, *args, **kwargs): + submitted_kwargs.append(kwargs) + + def record_submit_tasks(*args, **kwargs): + submit_tasks_kwargs.append(kwargs) + return 1 + + monkeypatch.setattr(services, 'create_task_pool', lambda name: None) + monkeypatch.setattr(services, 'add_task', record_add_task) + monkeypatch.setattr(services, 'submit_tasks', record_submit_tasks) + monkeypatch.setattr(services, 'get_finished_tasks', lambda task_pool_name: {}) + monkeypatch.setattr(services, 'remove_task_pool', lambda task_pool_name: None) + + services.run_ensemble( + template, + {'comp': {'A': ['1', '2']}}, + run_dir, + 'ensemble', + num_nodes=1, + logfile='instance.out', + errfile='instance.err', + ) + + assert submitted_kwargs == [ + {'logfile': 'instance.out', 'errfile': 'instance.err'}, + {'logfile': 'instance.out', 'errfile': 'instance.err'}, + ] + assert submit_tasks_kwargs == [ + { + 'block': True, + 'use_dask': True, + 'dask_nodes': 1, + 'dask_ppw': None, + 'oversubscribe': False, + 'hwthreads': False, + 'logfile': 'instance.out', + 'errfile': 'instance.err', + } + ] + + +def test_services_submit_tasks_passes_logfile_and_errfile_to_task_pool(): + services = ServicesProxy(None, None, None, {'USE_PORTAL': 'False'}, None) + services.logger = logging.getLogger(__name__) + + class DummyTaskPool: + def __init__(self): + self.submit_args = None + + def submit_tasks(self, *args): + self.submit_args = args + return 1 + + task_pool = DummyTaskPool() + services.task_pools['pool'] = task_pool + services._send_monitor_event = lambda *args, **kwargs: None + + assert services.submit_tasks('pool', logfile='instance.out', errfile='instance.err') == 1 + assert task_pool.submit_args[-2:] == ('instance.out', 'instance.err') + + +def test_task_pool_submit_tasks_passes_logfile_and_errfile_to_dask(monkeypatch): + services = ServicesProxy(None, None, None, {'USE_PORTAL': 'False'}, None) + task_pool = TaskPool('pool', services) + task_pool.serial_pool = True + submitted_args = [] + + def record_submit_dask_tasks(*args): + submitted_args.append(args) + return 1 + + monkeypatch.setattr(TaskPool, 'dask', object()) + monkeypatch.setattr(TaskPool, 'distributed', object()) + monkeypatch.setattr(task_pool, 'submit_dask_tasks', record_submit_dask_tasks) + + assert task_pool.submit_tasks(use_dask=True, logfile='instance.out', errfile='instance.err') == 1 + assert submitted_args[0][-2:] == ('instance.out', 'instance.err') + + +def test_task_pool_launch_keywords_use_logfile_and_errfile_defaults(): + keywords = TaskPool._launch_keywords_with_defaults( + {'block': False}, + logfile='instance.out', + errfile='instance.err', + ) + + assert keywords == { + 'block': False, + 'logfile': 'instance.out', + 'errfile': 'instance.err', + } + + +def test_task_pool_launch_keywords_preserve_task_logfile_and_errfile(): + keywords = TaskPool._launch_keywords_with_defaults( + { + 'block': False, + 'logfile': 'task.out', + 'errfile': 'task.err', + }, + logfile='instance.out', + errfile='instance.err', + ) + + assert keywords == { + 'block': False, + 'logfile': 'task.out', + 'errfile': 'task.err', + } + + +def test_launch_mapped_task_passes_logfile_and_errfile_to_launch(monkeypatch): + launch_calls = [] + + def record_launch(executable, task_name, working_dir, *args, **kwargs): + launch_calls.append((args, kwargs)) + + monkeypatch.setattr(services_module, 'launch', record_launch) + + services_module.launch_mapped_task( + '/bin/echo', + 'task_0', + '/tmp', + ['hello'], + {'logfile': 'instance.out', 'errfile': 'instance.err'}, + 1, + None, + ) + + assert launch_calls == [ + ( + ('hello',), + { + 'logfile': 'instance.out', + 'errfile': 'instance.err', + 'cpus_per_proc': 1, + 'worker_event_logfile': None, + }, + ) + ] + + +def test_launch_writes_stderr_to_logfile_when_errfile_is_omitted(tmpdir, monkeypatch): + script = write_stdout_stderr_script(tmpdir) + + import dask.distributed + + monkeypatch.setattr(dask.distributed, 'get_worker', lambda: DummyDaskWorker()) + + assert services_module.launch( + str(script), + 'task_0', + str(tmpdir), + logfile='task.log', + ) == ('task_0', 0) + + assert tmpdir.join('task.log').readlines() == [ + 'stdout-line\n', + 'stderr-line\n', + ] + + +def test_launch_writes_stderr_to_logfile_when_errfile_matches_logfile(tmpdir, monkeypatch): + script = write_stdout_stderr_script(tmpdir) + + import dask.distributed + + monkeypatch.setattr(dask.distributed, 'get_worker', lambda: DummyDaskWorker()) + + assert services_module.launch( + str(script), + 'task_0', + str(tmpdir), + logfile='task.log', + errfile=os.path.join(str(tmpdir), 'task.log'), + ) == ('task_0', 0) + + assert tmpdir.join('task.log').readlines() == [ + 'stdout-line\n', + 'stderr-line\n', + ]