diff --git a/easybuild/tools/run.py b/easybuild/tools/run.py index d49258a48a..5d02e31ef2 100644 --- a/easybuild/tools/run.py +++ b/easybuild/tools/run.py @@ -48,6 +48,7 @@ import sys import tempfile import time +import threading from collections import namedtuple from datetime import datetime @@ -327,11 +328,62 @@ def _answer_question(stdout, proc, qa_patterns, qa_wait_patterns): return match_found +def _read_pipe(pipe, size, output): + """Helper function to read from a pipe and store output in a list. + :param pipe: pipe to read from + :param size: number of bytes to read + :param output: list to store output in + """ + data = pipe.read(size) + output.append(data) + + +def read_pipe(pipe, size, timeout=None): + """Read from a pipe using a separate thread to avoid blocking and implement a timeout. + :param pipe: pipe to read from + :param size: number of bytes to read + :param timeout: timeout in seconds (default: None = no timeout) + + :return: data read from pipe + + :raises TimeoutError: when reading from pipe takes longer than specified timeout + """ + + output = [] + t = threading.Thread(target=_read_pipe, args=(pipe, size, output)) + t.start() + t.join(timeout) + if t.is_alive(): + raise TimeoutError() + return output[0] + + +def terminate_process(proc, timeout=20): + """ + Terminate specified process (subprocess.Popen instance). + Attempt to terminate the process using proc.terminate(), and if that fails, use proc.kill(). + + :param proc: process to terminate + :param timeout: timeout in seconds to wait for process to terminate + """ + proc.terminate() + try: + proc.wait(timeout=timeout) + except subprocess.TimeoutExpired: + _log.warning(f"Process did not terminate after {timeout} seconds, sending SIGKILL") + + proc.kill() + try: + proc.wait(timeout=timeout) + except subprocess.TimeoutExpired: + raise EasyBuildError(f"Process `{proc.args}` did not terminate after {timeout} seconds, giving up") + + @run_shell_cmd_cache def run_shell_cmd(cmd, fail_on_error=True, split_stderr=False, stdin=None, env=None, hidden=False, in_dry_run=False, verbose_dry_run=False, work_dir=None, use_bash=True, output_file=True, stream_output=None, asynchronous=False, task_id=None, with_hooks=True, - qa_patterns=None, qa_wait_patterns=None, qa_timeout=100): + timeout=None, qa_patterns=None, qa_wait_patterns=None, qa_timeout=100): """ Run specified (interactive) shell command, and capture output + exit code. @@ -349,6 +401,7 @@ def run_shell_cmd(cmd, fail_on_error=True, split_stderr=False, stdin=None, env=N :param asynchronous: indicate that command is being run asynchronously :param task_id: task ID for specified shell command (included in return value) :param with_hooks: trigger pre/post run_shell_cmd hooks (if defined) + :param timeout: timeout in seconds for command execution :param qa_patterns: list of 2-tuples with patterns for questions + corresponding answers :param qa_wait_patterns: list of strings with patterns for non-questions :param qa_timeout: amount of seconds to wait until more output is produced when there is no matching question @@ -503,8 +556,13 @@ def to_cmd_str(cmd): prev_stdout = '' # collect output piece-wise, while checking for questions to answer (if qa_patterns is provided) + start = time.time() while exit_code is None: - + if timeout and time.time() - start > timeout: + error_msg = f"Timeout during `{cmd}` after {timeout} seconds!" + _log.warning(error_msg) + terminate_process(proc) + raise EasyBuildError(error_msg) # use small read size (128 bytes) when streaming output, to make it stream more fluently # -1 means reading until EOF read_size = 128 if exit_code is None else -1 @@ -514,7 +572,11 @@ def to_cmd_str(cmd): # since that will always wait until EOF more_stdout = True while more_stdout: - more_stdout = proc.stdout.read(read_size) or b'' + try: + t = timeout - (time.time() - start) if timeout else None + more_stdout = read_pipe(proc.stdout, read_size, timeout=t) or b'' + except TimeoutError: + break _log.debug(f"Obtained more stdout: {more_stdout}") stdout += more_stdout @@ -522,7 +584,11 @@ def to_cmd_str(cmd): if split_stderr: more_stderr = True while more_stderr: - more_stderr = proc.stderr.read(read_size) or b'' + try: + t = timeout - (time.time() - start) if timeout else None + more_stderr = read_pipe(proc.stderr, read_size, timeout=t) or b'' + except TimeoutError: + break stderr += more_stderr if qa_patterns: @@ -554,7 +620,13 @@ def to_cmd_str(cmd): if split_stderr: stderr += proc.stderr.read() else: - (stdout, stderr) = proc.communicate(input=stdin) + try: + (stdout, stderr) = proc.communicate(input=stdin, timeout=timeout) + except subprocess.TimeoutExpired: + error_msg = f"Timeout during `{cmd}` after {timeout} seconds" + _log.warning(error_msg) + terminate_process(proc) + raise EasyBuildError(error_msg) # return output as a regular string rather than a byte sequence (and non-UTF-8 characters get stripped out) # getpreferredencoding normally gives 'utf-8' but can be ASCII (ANSI_X3.4-1968) diff --git a/test/framework/run.py b/test/framework/run.py index 2bbfd4f774..319c87073e 100644 --- a/test/framework/run.py +++ b/test/framework/run.py @@ -1659,6 +1659,49 @@ def test_run_shell_cmd_eof_stdin(self): self.assertEqual(res.exit_code, 0, "Non-streaming output: Command timed out") self.assertEqual(res.output, inp) + def test_run_shell_cmd_timeout(self): + """Test use of run_shell_cmd with a timeout.""" + cmd = 'sleep 1; echo hello' + # Failure on process timeout + with self.mocked_stdout_stderr(): + self.assertErrorRegex( + EasyBuildError, "Timeout during `.*` after .* seconds", + run_shell_cmd, cmd, timeout=.5 + ) + + # Success + with self.mocked_stdout_stderr(): + res = run_shell_cmd(cmd, timeout=3) + self.assertEqual(res.exit_code, 0) + self.assertEqual(res.output, "hello\n") + + def test_run_shell_cmd_timeout_stream(self): + """Test use of run_shell_cmd with a timeout.""" + data = '0'*128 + # Failure on process timeout + cmd = f'for i in {{1..20}}; do echo {data} && sleep 0.1; done' + with self.mocked_stdout_stderr(): + self.assertErrorRegex( + EasyBuildError, "Timeout during `.*` after .* seconds", + run_shell_cmd, cmd, timeout=.5, stream_output=True + ) + + # Failure on stdout read timeout + cmd = 'timeout 1 cat -' + with self.mocked_stdout_stderr(): + self.assertErrorRegex( + EasyBuildError, "Timeout during `.*` after .* seconds", + run_shell_cmd, cmd, timeout=.5, stream_output=True + ) + + # Success + cmd = 'sleep .5 && echo hello' + with self.mocked_stdout_stderr(): + res = run_shell_cmd(cmd, timeout=1.5, stream_output=True) + + self.assertEqual(res.exit_code, 0) + self.assertEqual(res.output, "hello\n") + def test_run_cmd_async(self): """Test asynchronously running of a shell command via run_cmd + complete_cmd."""