diff --git a/colmena/task_server/base.py b/colmena/task_server/base.py index 4d42505..94507b3 100644 --- a/colmena/task_server/base.py +++ b/colmena/task_server/base.py @@ -124,14 +124,20 @@ def perform_callback(self, future: Future, result: Result, topic: str): topic: Topic used to send back to the user """ + # Get any exception thrown by the workflow engine task_exc = future.exception() - # The task could have failed at the workflow engine level (task_exc) - # or application level (result.failure_info) - task_failed = (task_exc is not None) or (result.failure_info is not None) + # If it was, send back a modified copy of the input structure + if task_exc is not None: + # Mark it as unsuccessful and capture the exception information + result.success = False + result.failure_info = FailureInformation.from_exception(task_exc) + else: + # If not, the result object is the one we need + result = future.result() # If the task failed and we have retries left, try again - if task_failed and result.retries < result.max_retries: + if not result.success and result.retries < result.max_retries: # Increment the retry count and clear the failure information result.retries += 1 result.failure_info, result.success = None, None @@ -145,15 +151,7 @@ def perform_callback(self, future: Future, result: Result, topic: str): # Do not send the result back to the user return - # If it was, send back a modified copy of the input structure - if task_exc is not None: - # Mark it as unsuccessful and capture the exception information - result.success = False - result.failure_info = FailureInformation.from_exception(task_exc) - else: - # If not, the result object is the one we need - result = future.result() - + # Mark the task received timestamp result.mark_task_received() # Put them back in the pipe with the proper topic diff --git a/colmena/task_server/tests/test_retry.py b/colmena/task_server/tests/test_retry.py index 8e49504..b1eed3a 100644 --- a/colmena/task_server/tests/test_retry.py +++ b/colmena/task_server/tests/test_retry.py @@ -1,5 +1,6 @@ from typing import Tuple, Generator -from parsl import ThreadPoolExecutor +from parsl.executors import HighThroughputExecutor +from parsl.providers import LocalProvider from parsl.config import Config from colmena.queue.base import ColmenaQueues from colmena.queue.python import PipeQueues @@ -35,9 +36,20 @@ def reset_retry_count(): @fixture() def config(tmpdir): """Make the Parsl configuration.""" + # We must use a HighThroughputExecutor since the ThreadPoolExecutor + # does not serialize the results objects within the task server. + # This is necessary to test the retry policy, since otherwise the + # failure information stored in the result (needed to check for retries) + # is in shared memory and the perform_callback methods input result + # will store the previous failure information. return Config( executors=[ - ThreadPoolExecutor(max_threads=2) + HighThroughputExecutor( + address='localhost', + label='htex', + max_workers_per_node=1, + provider=LocalProvider(init_blocks=1, max_blocks=1), + ), ], strategy=None, run_dir=str(tmpdir / 'run'), @@ -66,7 +78,6 @@ def test_retry_policy_max_retries_zero(server_and_queue, reset_retry_count): success_idx = 1 for i in range(4): - # The task will fail every other time (setting success_idx=1) queue.send_inputs(success_idx, method='retry_task', max_retries=0) result = queue.get_result() assert result.success == (i % 2 == 1) @@ -93,8 +104,7 @@ def test_retry_policy_max_retries(server_and_queue, reset_retry_count, success_i server, queue = server_and_queue server.start() - # The task will fail every other time (setting success_idx=1) - # However, we set max_retries=1, so it should succeed after the first try + # Send the task to the queue queue.send_inputs(success_idx, method='retry_task', max_retries=max_retries) result = queue.get_result() assert result is not None