Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address retry bug for HighThroughputExecutor. #142

Merged
merged 6 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion colmena/proxy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Utilities for interacting with ProxyStore"""
import logging
import warnings
from typing import Dict

import proxystore
from proxystore.proxy import extract
Expand Down
20 changes: 11 additions & 9 deletions colmena/task_server/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,18 @@ def perform_callback(self, future: Future, result: Result, topic: str):
topic: Topic used to send back to the user
"""

# Get any exception thrown by the workflow engine
task_exc = future.exception()

# If it was, send back a modified copy of the input structure
if task_exc is not None:
# Mark it as unsuccessful and capture the exception information
result.success = False
result.failure_info = FailureInformation.from_exception(task_exc)
else:
# If not, the result object is the one we need
result = future.result()

# The task could have failed at the workflow engine level (task_exc)
# or application level (result.failure_info)
task_failed = (task_exc is not None) or (result.failure_info is not None)
gpauloski marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -145,15 +155,7 @@ def perform_callback(self, future: Future, result: Result, topic: str):
# Do not send the result back to the user
return

# If it was, send back a modified copy of the input structure
if task_exc is not None:
# Mark it as unsuccessful and capture the exception information
result.success = False
result.failure_info = FailureInformation.from_exception(task_exc)
else:
# If not, the result object is the one we need
result = future.result()

# Mark the task received timestamp
result.mark_task_received()

# Put them back in the pipe with the proper topic
Expand Down
31 changes: 23 additions & 8 deletions colmena/task_server/tests/test_retry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Tuple, Generator
from parsl import ThreadPoolExecutor
from parsl.executors import HighThroughputExecutor
from parsl.providers import LocalProvider
from parsl.config import Config
from colmena.queue.base import ColmenaQueues
from colmena.queue.python import PipeQueues
Expand All @@ -10,6 +11,7 @@
# Make global state for the retry task
RETRY_COUNT = 0


def retry_task(success_idx: int) -> bool:
"""Task that will succeed (return True) every `success_idx` times."""
global RETRY_COUNT
Expand All @@ -18,30 +20,44 @@ def retry_task(success_idx: int) -> bool:
if RETRY_COUNT < success_idx:
RETRY_COUNT += 1
raise ValueError('Retry')

# Reset the retry count
RETRY_COUNT = 0
return True


@fixture
def reset_retry_count():
"""Reset the retry count before each test."""
global RETRY_COUNT
RETRY_COUNT = 0


@fixture()
def config(tmpdir):
"""Make the Parsl configuration."""
# We must use a HighThroughputExecutor since the ThreadPoolExecutor
# does not serialize the results objects within the task server.
# This is necessary to test the retry policy, since otherwise the
# failure information stored in the result (needed to check for retries)
# is in shared memory and the perform_callback methods input result
# will store the previous failure information.
return Config(
executors=[
ThreadPoolExecutor(max_threads=2)
HighThroughputExecutor(
address='localhost',
label='htex',
max_workers_per_node=1,
provider=LocalProvider(init_blocks=1, max_blocks=1),
),
],
strategy=None,
run_dir=str(tmpdir / 'run'),
)


@fixture
def server_and_queue(config) -> Generator[Tuple[ParslTaskServer, ColmenaQueues], None, None]:
def server_and_queue(config) -> Generator[Tuple[ParslTaskServer, ColmenaQueues], None, None]:
queues = PipeQueues()
server = ParslTaskServer([retry_task], queues, config)
yield server, queues
Expand All @@ -62,7 +78,6 @@ def test_retry_policy_max_retries_zero(server_and_queue, reset_retry_count):
success_idx = 1

for i in range(4):
# The task will fail every other time (setting success_idx=1)
queue.send_inputs(success_idx, method='retry_task', max_retries=0)
result = queue.get_result()
assert result.success == (i % 2 == 1)
Expand All @@ -73,11 +88,12 @@ def test_retry_policy_max_retries_zero(server_and_queue, reset_retry_count):
assert not result.success
assert 'Retry' in str(result.failure_info.exception)


@mark.timeout(10)
@mark.parametrize(('success_idx', 'max_retries'), [(0, 0), (1, 1), (4, 10)])
def test_retry_policy_max_retries(server_and_queue, reset_retry_count, success_idx: int, max_retries: int):
"""Test the retry policy.

This test checks the following cases:
- A task that always succeeds (success_idx=0, max_retries=0)
- A task that succeeds after one retry (success_idx=1, max_retries=1)
Expand All @@ -88,8 +104,7 @@ def test_retry_policy_max_retries(server_and_queue, reset_retry_count, success_i
server, queue = server_and_queue
server.start()

# The task will fail every other time (setting success_idx=1)
# However, we set max_retries=1, so it should succeed after the first try
# Send the task to the queue
queue.send_inputs(success_idx, method='retry_task', max_retries=max_retries)
result = queue.get_result()
assert result is not None
Expand Down