Skip to content

Commit

Permalink
Move run test from base to task model
Browse files Browse the repository at this point in the history
The wrapper has moved
  • Loading branch information
WardLT committed Mar 7, 2024
1 parent 3902ba2 commit 049eac0
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 58 deletions.
57 changes: 0 additions & 57 deletions colmena/task_server/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
from typing import Any, Dict, Tuple, List, Optional
from pathlib import Path

from proxystore.connectors.file import FileConnector
from proxystore.store import Store
from proxystore.store import register_store
from proxystore.store import unregister_store
from pytest import fixture

from colmena.models import Result, SerializationMethod
from colmena.models.tasks import ExecutableTask
from colmena.task_server.base import run_and_record_timing


# TODO (wardlt): Figure how to import this from test_models
class EchoTask(ExecutableTask):
def __init__(self):
super().__init__(executable=['echo'])
Expand All @@ -35,51 +26,3 @@ def preprocess(self, run_dir: Path, args: Tuple[Any], kwargs: Dict[str, Any]) ->

def postprocess(self, run_dir: Path) -> Any:
return (run_dir / 'colmena.stdout').read_text()


def test_run_with_executable():
result = Result(inputs=((1,), {}))
func = EchoTask()
run_and_record_timing(func, result)
result.deserialize()
assert result.value == '1\n'


@fixture
def store(tmpdir):
with Store('store', FileConnector(tmpdir), metrics=True) as store:
register_store(store)
yield store
unregister_store(store)


def test_run_function(store):
"""Make sure the run function behaves as expected:
- Records runtimes
- Tracks proxy statistics
"""

# Make the result and configure it to use the store
result = Result(inputs=(('a' * 1024,), {}))
result.proxystore_name = store.name
result.proxystore_threshold = 128
result.proxystore_config = store.config()

# Serialize it
result.serialization_method = SerializationMethod.PICKLE
result.serialize()

# Run the function
run_and_record_timing(lambda x: x.upper(), result)

# Make sure the timings are all set
assert result.time.running > 0
assert result.time.async_resolve_proxies > 0
assert result.time.deserialize_inputs > 0
assert result.time.serialize_results > 0
assert result.timestamp.compute_ended > result.timestamp.compute_started

# Make sure we have stats for both proxies
assert len(result.time.proxy) == 2
assert all('store.proxy' in v['times'] for v in result.time.proxy.values())
47 changes: 46 additions & 1 deletion colmena/tests/test_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
from math import isnan

from pytest import fixture
from proxystore.connectors.file import FileConnector
from proxystore.store import Store
from proxystore.store import register_store
from proxystore.store import unregister_store

from colmena.models.tasks import ExecutableTask, PythonTask, PythonGeneratorTask
from colmena.models import ResourceRequirements, Result
from colmena.models import ResourceRequirements, Result, SerializationMethod


class EchoTask(ExecutableTask):
Expand All @@ -28,6 +32,14 @@ def result() -> Result:
return result


@fixture
def store(tmpdir):
with Store('store', FileConnector(tmpdir), metrics=True) as store:
register_store(store)
yield store
unregister_store(store)


def echo(x: Any) -> Any:
return x

Expand Down Expand Up @@ -100,3 +112,36 @@ def test_executable_task(result):
result = task(result)
result.deserialize()
assert result.value == '-N 6 -n 3 --cc depth echo 1\n'


def test_run_function(store):
"""Make sure the run function behaves as expected:
- Records runtimes
- Tracks proxy statistics
"""

# Make the result and configure it to use the store
result = Result(inputs=(('a' * 1024,), {}))
result.proxystore_name = store.name
result.proxystore_threshold = 128
result.proxystore_config = store.config()

# Serialize it
result.serialization_method = SerializationMethod.PICKLE
result.serialize()

# Run the function
task = PythonTask(lambda x: x.upper(), name='upper')
result = task(result)

# Make sure the timings are all set
assert result.time.running > 0
assert result.time.async_resolve_proxies > 0
assert result.time.deserialize_inputs > 0
assert result.time.serialize_results > 0
assert result.timestamp.compute_ended > result.timestamp.compute_started

# Make sure we have stats for both proxies
assert len(result.time.proxy) == 2
assert all('store.proxy' in v['times'] for v in result.time.proxy.values())

0 comments on commit 049eac0

Please sign in to comment.