From 049eac083eec2f03bc10febe6c14ef82a964dbf4 Mon Sep 17 00:00:00 2001 From: lward Date: Thu, 7 Mar 2024 15:11:02 -0500 Subject: [PATCH] Move run test from base to task model The wrapper has moved --- colmena/task_server/tests/test_base.py | 57 -------------------------- colmena/tests/test_task_model.py | 47 ++++++++++++++++++++- 2 files changed, 46 insertions(+), 58 deletions(-) diff --git a/colmena/task_server/tests/test_base.py b/colmena/task_server/tests/test_base.py index 16bd900..64e2d0b 100644 --- a/colmena/task_server/tests/test_base.py +++ b/colmena/task_server/tests/test_base.py @@ -1,18 +1,9 @@ from typing import Any, Dict, Tuple, List, Optional from pathlib import Path -from proxystore.connectors.file import FileConnector -from proxystore.store import Store -from proxystore.store import register_store -from proxystore.store import unregister_store -from pytest import fixture - -from colmena.models import Result, SerializationMethod from colmena.models.tasks import ExecutableTask -from colmena.task_server.base import run_and_record_timing -# TODO (wardlt): Figure how to import this from test_models class EchoTask(ExecutableTask): def __init__(self): super().__init__(executable=['echo']) @@ -35,51 +26,3 @@ def preprocess(self, run_dir: Path, args: Tuple[Any], kwargs: Dict[str, Any]) -> def postprocess(self, run_dir: Path) -> Any: return (run_dir / 'colmena.stdout').read_text() - - -def test_run_with_executable(): - result = Result(inputs=((1,), {})) - func = EchoTask() - run_and_record_timing(func, result) - result.deserialize() - assert result.value == '1\n' - - -@fixture -def store(tmpdir): - with Store('store', FileConnector(tmpdir), metrics=True) as store: - register_store(store) - yield store - unregister_store(store) - - -def test_run_function(store): - """Make sure the run function behaves as expected: - - - Records runtimes - - Tracks proxy statistics - """ - - # Make the result and configure it to use the store - result = Result(inputs=(('a' * 1024,), {})) - result.proxystore_name = store.name - result.proxystore_threshold = 128 - result.proxystore_config = store.config() - - # Serialize it - result.serialization_method = SerializationMethod.PICKLE - result.serialize() - - # Run the function - run_and_record_timing(lambda x: x.upper(), result) - - # Make sure the timings are all set - assert result.time.running > 0 - assert result.time.async_resolve_proxies > 0 - assert result.time.deserialize_inputs > 0 - assert result.time.serialize_results > 0 - assert result.timestamp.compute_ended > result.timestamp.compute_started - - # Make sure we have stats for both proxies - assert len(result.time.proxy) == 2 - assert all('store.proxy' in v['times'] for v in result.time.proxy.values()) diff --git a/colmena/tests/test_task_model.py b/colmena/tests/test_task_model.py index 8017ed5..d1292f1 100644 --- a/colmena/tests/test_task_model.py +++ b/colmena/tests/test_task_model.py @@ -4,9 +4,13 @@ from math import isnan from pytest import fixture +from proxystore.connectors.file import FileConnector +from proxystore.store import Store +from proxystore.store import register_store +from proxystore.store import unregister_store from colmena.models.tasks import ExecutableTask, PythonTask, PythonGeneratorTask -from colmena.models import ResourceRequirements, Result +from colmena.models import ResourceRequirements, Result, SerializationMethod class EchoTask(ExecutableTask): @@ -28,6 +32,14 @@ def result() -> Result: return result +@fixture +def store(tmpdir): + with Store('store', FileConnector(tmpdir), metrics=True) as store: + register_store(store) + yield store + unregister_store(store) + + def echo(x: Any) -> Any: return x @@ -100,3 +112,36 @@ def test_executable_task(result): result = task(result) result.deserialize() assert result.value == '-N 6 -n 3 --cc depth echo 1\n' + + +def test_run_function(store): + """Make sure the run function behaves as expected: + + - Records runtimes + - Tracks proxy statistics + """ + + # Make the result and configure it to use the store + result = Result(inputs=(('a' * 1024,), {})) + result.proxystore_name = store.name + result.proxystore_threshold = 128 + result.proxystore_config = store.config() + + # Serialize it + result.serialization_method = SerializationMethod.PICKLE + result.serialize() + + # Run the function + task = PythonTask(lambda x: x.upper(), name='upper') + result = task(result) + + # Make sure the timings are all set + assert result.time.running > 0 + assert result.time.async_resolve_proxies > 0 + assert result.time.deserialize_inputs > 0 + assert result.time.serialize_results > 0 + assert result.timestamp.compute_ended > result.timestamp.compute_started + + # Make sure we have stats for both proxies + assert len(result.time.proxy) == 2 + assert all('store.proxy' in v['times'] for v in result.time.proxy.values())