diff --git a/parsl/executors/high_throughput/executor.py b/parsl/executors/high_throughput/executor.py index 2e20f41795..92a1965bb1 100644 --- a/parsl/executors/high_throughput/executor.py +++ b/parsl/executors/high_throughput/executor.py @@ -1,13 +1,13 @@ import logging import math import pickle +import subprocess import threading import typing import warnings from collections import defaultdict from concurrent.futures import Future from dataclasses import dataclass -from multiprocessing import Process from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import typeguard @@ -18,7 +18,7 @@ from parsl.app.errors import RemoteExceptionWrapper from parsl.data_provider.staging import Staging from parsl.executors.errors import BadMessage, ScalingFailed -from parsl.executors.high_throughput import interchange, zmq_pipes +from parsl.executors.high_throughput import zmq_pipes from parsl.executors.high_throughput.errors import CommandClientTimeoutError from parsl.executors.high_throughput.mpi_prefix_composer import ( VALID_LAUNCHERS, @@ -26,7 +26,6 @@ ) from parsl.executors.status_handling import BlockProviderExecutor from parsl.jobs.states import TERMINAL_STATES, JobState, JobStatus -from parsl.multiprocessing import ForkProcess from parsl.process_loggers import wrap_with_logs from parsl.providers import LocalProvider from parsl.providers.base import ExecutionProvider @@ -305,7 +304,7 @@ def __init__(self, self._task_counter = 0 self.worker_ports = worker_ports self.worker_port_range = worker_port_range - self.interchange_proc: Optional[Process] = None + self.interchange_proc: Optional[subprocess.Popen] = None self.interchange_port_range = interchange_port_range self.heartbeat_threshold = heartbeat_threshold self.heartbeat_period = heartbeat_period @@ -520,38 +519,45 @@ def _queue_management_worker(self): logger.info("Queue management worker finished") - def _start_local_interchange_process(self): + def _start_local_interchange_process(self) -> None: """ Starts the interchange process locally - Starts the interchange process locally and uses an internal command queue to + Starts the interchange process locally and uses the command queue to get the worker task and result ports that the interchange has bound to. """ - self.interchange_proc = ForkProcess(target=interchange.starter, - kwargs={"client_address": "127.0.0.1", - "client_ports": (self.outgoing_q.port, - self.incoming_q.port, - self.command_client.port), - "interchange_address": self.address, - "worker_ports": self.worker_ports, - "worker_port_range": self.worker_port_range, - "hub_address": self.hub_address, - "hub_zmq_port": self.hub_zmq_port, - "logdir": self.logdir, - "heartbeat_threshold": self.heartbeat_threshold, - "poll_period": self.poll_period, - "logging_level": logging.DEBUG if self.worker_debug else logging.INFO, - "cert_dir": self.cert_dir, - }, - daemon=True, - name="HTEX-Interchange" - ) - self.interchange_proc.start() + interchange_config = {"client_address": "127.0.0.1", + "client_ports": (self.outgoing_q.port, + self.incoming_q.port, + self.command_client.port), + "interchange_address": self.address, + "worker_ports": self.worker_ports, + "worker_port_range": self.worker_port_range, + "hub_address": self.hub_address, + "hub_zmq_port": self.hub_zmq_port, + "logdir": self.logdir, + "heartbeat_threshold": self.heartbeat_threshold, + "poll_period": self.poll_period, + "logging_level": logging.DEBUG if self.worker_debug else logging.INFO, + "cert_dir": self.cert_dir, + } + + config_pickle = pickle.dumps(interchange_config) + + self.interchange_proc = subprocess.Popen(b"interchange.py", stdin=subprocess.PIPE) + stdin = self.interchange_proc.stdin + assert stdin is not None, "Popen should have created an IO object (vs default None) because of PIPE mode" + + logger.debug("Popened interchange process. Writing config object") + stdin.write(config_pickle) + stdin.flush() + logger.debug("Sent config object. Requesting worker ports") try: (self.worker_task_port, self.worker_result_port) = self.command_client.run("WORKER_PORTS", timeout_s=120) except CommandClientTimeoutError: - logger.error("Interchange has not completed initialization in 120s. Aborting") + logger.error("Interchange has not completed initialization. Aborting") raise Exception("Interchange failed to start") + logger.debug("Got worker ports") def _start_queue_management_thread(self): """Method to start the management thread as a daemon. @@ -810,13 +816,12 @@ def shutdown(self, timeout: float = 10.0): logger.info("Attempting HighThroughputExecutor shutdown") self.interchange_proc.terminate() - self.interchange_proc.join(timeout=timeout) - if self.interchange_proc.is_alive(): + try: + self.interchange_proc.wait(timeout=timeout) + except subprocess.TimeoutExpired: logger.info("Unable to terminate Interchange process; sending SIGKILL") self.interchange_proc.kill() - self.interchange_proc.close() - logger.info("Finished HighThroughputExecutor shutdown attempt") def get_usage_information(self): diff --git a/parsl/executors/high_throughput/interchange.py b/parsl/executors/high_throughput/interchange.py index 764c9805a0..9fe94dbabd 100644 --- a/parsl/executors/high_throughput/interchange.py +++ b/parsl/executors/high_throughput/interchange.py @@ -672,13 +672,10 @@ def start_file_logger(filename: str, level: int = logging.DEBUG, format_string: logger.addHandler(handler) -@wrap_with_logs(target="interchange") -def starter(*args: Any, **kwargs: Any) -> None: - """Start the interchange process - - The executor is expected to call this function. The args, kwargs match that of the Interchange.__init__ - """ +if __name__ == "__main__": setproctitle("parsl: HTEX interchange") - # logger = multiprocessing.get_logger() - ic = Interchange(*args, **kwargs) + + config = pickle.load(sys.stdin.buffer) + + ic = Interchange(**config) ic.start() diff --git a/parsl/tests/test_htex/test_htex.py b/parsl/tests/test_htex/test_htex.py index ca95773e1b..2227529f82 100644 --- a/parsl/tests/test_htex/test_htex.py +++ b/parsl/tests/test_htex/test_htex.py @@ -1,11 +1,11 @@ import pathlib import warnings +from subprocess import Popen, TimeoutExpired from unittest import mock import pytest from parsl import HighThroughputExecutor, curvezmq -from parsl.multiprocessing import ForkProcess _MOCK_BASE = "parsl.executors.high_throughput.executor" @@ -78,16 +78,33 @@ def test_htex_shutdown( timeout_expires: bool, htex: HighThroughputExecutor, ): - mock_ix_proc = mock.Mock(spec=ForkProcess) + mock_ix_proc = mock.Mock(spec=Popen) if started: htex.interchange_proc = mock_ix_proc - mock_ix_proc.is_alive.return_value = True + + # This will, in the absence of any exit trigger, block forever if + # no timeout is given and if the interchange does not terminate. + # Raise an exception to report that, rather than actually block, + # and hope that nothing is catching that exception. + + # this function implements the behaviour if the interchange has + # not received a termination call + def proc_wait_alive(timeout): + if timeout: + raise TimeoutExpired(cmd="mock-interchange", timeout=timeout) + else: + raise RuntimeError("This wait call would hang forever") + + def proc_wait_terminated(timeout): + return 0 + + mock_ix_proc.wait.side_effect = proc_wait_alive if not timeout_expires: # Simulate termination of the Interchange process def kill_interchange(*args, **kwargs): - mock_ix_proc.is_alive.return_value = False + mock_ix_proc.wait.side_effect = proc_wait_terminated mock_ix_proc.terminate.side_effect = kill_interchange @@ -96,8 +113,8 @@ def kill_interchange(*args, **kwargs): mock_logs = mock_logger.info.call_args_list if started: assert mock_ix_proc.terminate.called - assert mock_ix_proc.join.called - assert {"timeout": 10} == mock_ix_proc.join.call_args[1] + assert mock_ix_proc.wait.called + assert {"timeout": 10} == mock_ix_proc.wait.call_args[1] if timeout_expires: assert "Unable to terminate Interchange" in mock_logs[1][0][0] assert mock_ix_proc.kill.called @@ -105,7 +122,7 @@ def kill_interchange(*args, **kwargs): assert "Finished" in mock_logs[-1][0][0] else: assert not mock_ix_proc.terminate.called - assert not mock_ix_proc.join.called + assert not mock_ix_proc.wait.called assert "has not started" in mock_logs[0][0][0] diff --git a/setup.py b/setup.py index dae3e64ca4..85e014dc18 100755 --- a/setup.py +++ b/setup.py @@ -56,6 +56,7 @@ python_requires=">=3.8.0", install_requires=install_requires, scripts = ['parsl/executors/high_throughput/process_worker_pool.py', + 'parsl/executors/high_throughput/interchange.py', 'parsl/executors/workqueue/exec_parsl_function.py', 'parsl/executors/workqueue/parsl_coprocess.py', ],