Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support streaming the results of generators #127

Merged
merged 8 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Building one using Parsl requires only that your computations are expressed as P

```python
from parsl.configs.htex_local import config # Configuration to run locally
from colmena.task_server import ParslTaskServer
from colmena.task_server.parsl import ParslTaskServer

# Define your function
def simulate(x: float) -> float:
Expand Down
79 changes: 56 additions & 23 deletions colmena/models/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,11 @@ def function(self, *args, **kwargs) -> Any:
"""Function provided by the Colmena user"""
raise NotImplementedError()

def __call__(self, result: Result, queues: Optional[ColmenaQueues] = None) -> Result:
def __call__(self, result: Result) -> Result:
"""Invoke a Colmena task request

Args:
result: Request, which inclues the arguments and will hold the result
queues: Queues used to send intermediate results back [Not Yet Used]
Returns:
The input result object, populated with the results
"""
Expand All @@ -61,16 +60,28 @@ def __call__(self, result: Result, queues: Optional[ColmenaQueues] = None) -> Re
input_proxies.extend(resolve_proxies_async(value))
result.time.async_resolve_proxies = perf_counter() - start_time

# Add the worker information into the tasks, if available
worker_info = {}
# TODO (wardlt): Move this information into a separate, parsl-specific wrapper
for tag in ['PARSL_WORKER_RANK', 'PARSL_WORKER_POOL_ID']:
if tag in os.environ:
worker_info[tag] = os.environ[tag]
worker_info['hostname'] = platform.node()
result.worker_info = worker_info

# Determine additional kwargs to provide to the function
additional_kwargs = {}
for k, v in [('_resources', result.resources), ('_result', result)]:
if k in result.kwargs:
logger.warning(f'`{k}` provided as a kwargs. Unexpected things are about to happen')
if k in signature(self.function).parameters:
additional_kwargs[k] = v

# Execute the function
start_time = perf_counter()
success = True
try:
if '_resources' in result.kwargs:
logger.warning('`_resources` provided as a kwargs. Unexpected things are about to happen')
if '_resources' in signature(self.function).parameters:
output = self.function(*result.args, **result.kwargs, _resources=result.resources)
else:
output = self.function(*result.args, **result.kwargs)
output = self.function(*result.args, **result.kwargs, **additional_kwargs)
except BaseException as e:
output = None
success = False
Expand All @@ -82,16 +93,6 @@ def __call__(self, result: Result, queues: Optional[ColmenaQueues] = None) -> Re
result.set_result(output, end_time - start_time)
if not success:
result.success = False

# Add the worker information into the tasks, if available
worker_info = {}
# TODO (wardlt): Move this information into a separate, parsl-specific wrapper
for tag in ['PARSL_WORKER_RANK', 'PARSL_WORKER_POOL_ID']:
if tag in os.environ:
worker_info[tag] = os.environ[tag]
worker_info['hostname'] = platform.node()
result.worker_info = worker_info

result.mark_compute_ended()

# Re-pack the results. Will store the proxy statistics
Expand Down Expand Up @@ -124,6 +125,9 @@ def __init__(self, function: Callable, name: Optional[str] = None) -> None:
class PythonGeneratorMethod(ColmenaMethod):
"""Python function which runs on a single worker and generates results iteratively

Generator functions support streaming each iteration of the generator
to the Thinker when a `streaming_queue` is provided.

Args:
function: Generator function to be executed
name: Name of the function. Defaults to `function.__name__`
Expand All @@ -134,24 +138,53 @@ class PythonGeneratorMethod(ColmenaMethod):
def __init__(self,
function: Callable[..., Generator],
name: Optional[str] = None,
store_return_value: bool = False) -> None:
store_return_value: bool = False,
streaming_queue: Optional[ColmenaQueues] = None) -> None:
if not isgeneratorfunction(function):
raise ValueError('Function is not a generator function. Use `PythonTask` instead.')
self._function = function
self.name = name or function.__name__
self.store_return_value = store_return_value
self.streaming_queue = streaming_queue

def function(self, *args, **kwargs) -> Any:
def stream_result(self, y: Any, result: Result, start_time: float):
"""Send an intermediate result using the task queue

Args:
y: Intermediate result
result: Result package carrying task metadata
start_time: Start time of the algorithm, used to report
"""

# Store the intermediate result in a copy of the input object
result = result.copy(deep=True)
result.set_result(
y, perf_counter() - start_time, intermediate=True,
)
result.time.serialize_results, _ = result.serialize()

# Send it back to the queue
self.streaming_queue.send_result(result)

def function(self, *args, _result: Result, **kwargs) -> Any:
"""Run the Colmena task and collect intermediate results to provide as a list"""

# TODO (wardlt): Have the function push intemediate results back to a function queue
# TODO (wardlt): Make push to task queue asynchronous
gen = self._function(*args, **kwargs)
iter_results = []
start_time = perf_counter()
while True:
try:
iter_results.append(next(gen))
y = next(gen)
if self.streaming_queue is None:
iter_results.append(y)
else:
self.stream_result(y, _result, start_time)
except StopIteration as e:
if self.store_return_value:
if self.streaming_queue is not None:
if self.store_return_value:
return e.value
elif self.store_return_value:
return iter_results, e.value
else:
return iter_results
Expand Down
10 changes: 8 additions & 2 deletions colmena/models/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ class Result(BaseModel):
value: Any = Field(None, description="Output of a function")
method: Optional[str] = Field(None, description="Name of the method to run.")
success: Optional[bool] = Field(None, description="Whether the task completed successfully")
complete: Optional[bool] = Field(None, description="Whether this result is the last for a task instead of an intermediate result")

# Store task information
task_info: Optional[Dict[str, Any]] = Field(default_factory=dict,
Expand All @@ -234,6 +235,9 @@ class Result(BaseModel):
proxystore_threshold: Optional[int] = Field(None,
description="Proxy all input/output objects larger than this threshold in bytes")

# Task routing information
topic: Optional[str] = Field(None, description='Label used to group results in queue between Thinker and Task Server')

def __init__(self, inputs: Tuple[Tuple[Any], Dict[str, Any]], **kwargs):
"""
Args:
Expand Down Expand Up @@ -325,7 +329,7 @@ def mark_compute_ended(self):
"""Mark when the task finished executing"""
self.timestamp.compute_ended = datetime.now().timestamp()

def set_result(self, result: Any, runtime: float = nan):
def set_result(self, result: Any, runtime: float = nan, intermediate: bool = False):
"""Set the value of this computation

Automatically sets the "time_result_completed" field and, if known, defines the runtime.
Expand All @@ -335,13 +339,15 @@ def set_result(self, result: Any, runtime: float = nan):

Args:
result: Result to be stored
runtime (float): Runtime for the function
runtime: Runtime for the function
intermediate: If this result is not the final one in a workflow
"""
self.value = result
if not self.keep_inputs:
self.inputs = ((), {})
self.time.running = runtime
self.success = True
self.complete = not intermediate

def serialize(self) -> Tuple[float, List[Proxy]]:
"""Stores the input and value fields as a pickled objects
Expand Down
10 changes: 6 additions & 4 deletions colmena/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ def send_inputs(self,
self._check_role(QueueRole.CLIENT, 'send_inputs')

# Make sure the queue topic exists
assert topic in self.topics, f'Unknown topic: {topic}. Known are: {", ".join(self.topics)}'
if topic not in self.topics:
raise ValueError(f'Unknown topic: {topic}. Known are: {", ".join(self.topics)}')

# Make fake kwargs, if needed
if input_kwargs is None:
Expand Down Expand Up @@ -233,7 +234,8 @@ def send_inputs(self,
keep_inputs=_keep_inputs,
serialization_method=self.serialization_method,
task_info=task_info,
resources=resources or ResourceRequirements(), # Takes either the user specified or a default
resources=resources or ResourceRequirements(), # Takes either the user specified or a default,
topic=topic,
**ps_kwargs
)

Expand Down Expand Up @@ -285,7 +287,7 @@ def send_kill_signal(self):
self._check_role(QueueRole.CLIENT, 'send_kill_signal')
self._send_request("null", topic='default')

def send_result(self, result: Result, topic: str):
def send_result(self, result: Result):
"""Send a value to a client

Args:
Expand All @@ -294,7 +296,7 @@ def send_result(self, result: Result, topic: str):
"""
self._check_role(QueueRole.SERVER, 'send_result')
result.mark_result_sent()
self._send_result(result.json(), topic=topic)
self._send_result(result.json(), topic=result.topic)

@abstractmethod
def _get_request(self, timeout: int = None) -> Tuple[str, str]:
Expand Down
8 changes: 4 additions & 4 deletions colmena/queue/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_results(queue):

# Set a result value and send it back
request.serialize()
queue.send_result(request, topic)
queue.send_result(request)
result = queue.get_result(topic=topic)
assert result.value == 1

Expand All @@ -91,7 +91,7 @@ def test_serialization(queue):
x.x = 1
task.set_result(x)
task.serialize()
queue.send_result(task, topic)
queue.send_result(task)
result = queue.get_result(topic=topic)
assert result.args[0].x is None
assert result.value.x == 1
Expand Down Expand Up @@ -138,7 +138,7 @@ def test_task_info(queue):

# Send it back
result.serialize()
queue.send_result(result, topic)
queue.send_result(result)
result = queue.get_result()
assert result.task_info == {'id': 'test'}

Expand Down Expand Up @@ -169,7 +169,7 @@ def test_event_count(queue):
task.set_result(1)
print(queue._active_tasks)
print(task)
queue.send_result(task, topic)
queue.send_result(task)
queue.get_result()
assert queue.active_count == 0
assert queue.wait_until_done(timeout=1)
Expand Down
4 changes: 2 additions & 2 deletions colmena/queue/tests/test_queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_flush(queue):
# Test that it will flush a result
queue.send_inputs(1, method='method')
topic, result = queue.get_task()
queue.send_result(result, topic)
queue.send_result(result)
queue.flush()

with raises(TimeoutException):
Expand All @@ -66,7 +66,7 @@ def test_basic(queue, topic):
result.deserialize()
result.set_result(1, 1)
result.serialize()
pool.apply(queue.send_result, (result, topic))
pool.apply(queue.send_result, (result,))

# Make sure it does not appear in b
with raises(TimeoutException):
Expand Down
4 changes: 2 additions & 2 deletions colmena/task_server/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def listen_and_launch(self):
task.failure_info = FailureInformation.from_exception(
ValueError(f'Method name "{task.method}" not recognized. Options: {", ".join(self.method_names)}')
)
self.queues.send_result(task, topic)
self.queues.send_result(task)

except KillSignalException:
logger.info('Kill signal received')
Expand Down Expand Up @@ -138,7 +138,7 @@ def perform_callback(self, future: Future, result: Result, topic: str):
result.mark_task_received()

# Put them back in the pipe with the proper topic
self.queues.send_result(result, topic)
self.queues.send_result(result)

@abstractmethod
def _submit(self, task: Result, topic: str) -> Optional[Future]:
Expand Down
2 changes: 1 addition & 1 deletion colmena/task_server/parsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _preprocess_callback(

# Store the time it took to run the preprocessing
result.time.running = result.time.additional.get('exec_preprocess', 0)
return task_server.queues.send_result(result, topic)
return task_server.queues.send_result(result)

# If successful, submit the execute step and pass its result to Parsl
logger.info(f'Preprocessing was successful for {result.method} task. Submitting to execute')
Expand Down
25 changes: 24 additions & 1 deletion colmena/tests/test_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from colmena.models.methods import ExecutableMethod, PythonMethod, PythonGeneratorMethod
from colmena.models import ResourceRequirements, Result, SerializationMethod
from colmena.queue import PipeQueues


class EchoTask(ExecutableMethod):
Expand Down Expand Up @@ -89,6 +90,28 @@ def test_generator_with_return(result):
assert result.value == [[0,], 'done']


def test_generator_streaming(result):
"""Trigger streaming by adding a queue to the task definition"""

queue = PipeQueues()
task = PythonGeneratorMethod(function=generator, name='stream', store_return_value=True, streaming_queue=queue)

result.topic = 'default'
result = task(result)
assert result.success, result.failure_info.traceback
result.deserialize()
assert result.value == 'done'
assert result.complete

intermediate = queue.get_result(timeout=1)
assert intermediate.success
intermediate.deserialize()
assert not intermediate.complete
assert intermediate.value == 0

assert result.time.running >= intermediate.time.running


def test_executable_task(result):
# Run a basic task
task = EchoTask()
Expand Down Expand Up @@ -139,7 +162,7 @@ def test_run_function(store):
assert result.time.async_resolve_proxies > 0
assert result.time.deserialize_inputs > 0
assert result.time.serialize_results > 0
assert result.timestamp.compute_ended > result.timestamp.compute_started
assert result.timestamp.compute_ended >= result.timestamp.compute_started

# Make sure we have stats for both proxies
assert len(result.time.proxy) == 2
Expand Down
7 changes: 3 additions & 4 deletions colmena/thinker/tests/test_thinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,11 @@ def process(self, _):
thinker.start()

# Spoof a result completing

queues.send_inputs(1, method='test')
topic, result = queues.get_task()
result.set_result(1, 1)
with caplog.at_level(logging.INFO):
queues.send_result(result, topic)
queues.send_result(result)

# Wait then check the logs
sleep(0.5)
Expand Down Expand Up @@ -194,7 +193,7 @@ def test_run(queues):
queues.send_inputs(1)
topic, task = queues.get_task()
task.set_result(4)
queues.send_result(task, topic=topic)
queues.send_result(task)
sleep(0.1)
assert th.last_value == 4

Expand Down Expand Up @@ -230,7 +229,7 @@ def test_run(queues):
# The system should not exit until all results are back
topic, task = queues.get_task()
task.set_result(4)
queues.send_result(task, topic)
queues.send_result(task)
assert th.queues.wait_until_done(timeout=2)
sleep(0.1)
assert not th.is_alive()
Expand Down
Loading
Loading