Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor messages to separate "time" and "timestamp" better #121

Merged
merged 7 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 61 additions & 47 deletions colmena/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import logging
import pickle as pkl
import shlex
import sys
from math import nan
import pickle as pkl
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
Expand Down Expand Up @@ -55,7 +56,7 @@ def deserialize(method: 'SerializationMethod', message: str) -> Any:
"""Deserialize an object

Args:
method: Method used to serialize the message
method: Method used to serialize
message: Message to deserialize
Returns:
Result object
Expand All @@ -70,13 +71,13 @@ def deserialize(method: 'SerializationMethod', message: str) -> Any:


def _serialized_str_to_bytes_shim(
s: str,
method: Union[str, SerializationMethod],
s: str,
method: Union[str, SerializationMethod],
) -> bytes:
"""Shim between Colmena serialized objects and bytes.

Colmena's serialization mechanisms produce strings but ProxyStore
serializes to bytes, so this shim takes a an object serialized by Colmena
serializes to bytes, so this shim takes an object serialized by Colmena
and converts it to bytes.

Args:
Expand All @@ -90,7 +91,7 @@ def _serialized_str_to_bytes_shim(
return s.encode('utf-8')
elif method == "pickle":
# In this case the conversion goes from obj > bytes > str > bytes
# which results in an unecessary conversion to a string but this is
# which results in an unnecessary conversion to a string but this is
# an unavoidable side effect of converting between the Colmena
# and ProxyStore serialization formats.
return bytes.fromhex(s)
Expand All @@ -99,8 +100,8 @@ def _serialized_str_to_bytes_shim(


def _serialized_bytes_to_obj_wrapper(
b: str,
method: Union[str, SerializationMethod],
b: str,
method: Union[str, SerializationMethod],
) -> Any:
"""Wrapper which converts bytes to strings before deserializing.

Expand Down Expand Up @@ -161,6 +162,42 @@ def total_ranks(self) -> int:
return self.node_count * self.cpu_processes


class Timestamps(BaseModel):
"""A class which records the system times at which key events in a task occurred

All should be in UTC.
"""

created: float = Field(description="Time this value object was created",
default_factory=lambda: datetime.now().timestamp())
input_received: float = Field(nan, description="Time the inputs was received by the task server")
compute_started: float = Field(nan, description="Time workflow process began executing a task")
compute_ended: float = Field(nan, description="Time workflow process finished executing a task")
result_sent: float = Field(nan, description="Time message was sent from the server")
result_received: float = Field(nan, description="Time value was received by client")
start_task_submission: float = Field(nan, description="Time marking the start of the task submission to workflow engine")
task_received: float = Field(nan, description="Time task result received from workflow engine")


class TimeSpans(BaseModel):
"""Amount of time elapsed between major events

All are recorded in seconds
"""

running: float = Field(nan, description="Runtime of the method, if available")
serialize_inputs: float = Field(nan, description="Time required to serialize inputs on client")
deserialize_inputs: float = Field(nan, description="Time required to deserialize inputs on worker")
serialize_results: float = Field(nan, description="Time required to serialize results on worker")
deserialize_results: float = Field(nan, description="Time required to deserialize results on client")
async_resolve_proxies: float = Field(nan, description="Time required to start async resolves of proxies")
proxy: Dict[str, Dict[str, dict]] = Field(default_factory=dict,
description='Timings related to resolving ProxyStore proxies on the compute worker')

additional: Dict[str, float] = Field(default_factory=dict,
description="Additional timings reported by a task server")


class Result(BaseModel):
"""A class which describes the inputs and results of the calculations evaluated by the MethodServer

Expand All @@ -187,31 +224,12 @@ class Result(BaseModel):
resources: ResourceRequirements = Field(default_factory=ResourceRequirements, help='List of the resources required for a task, if desired')
failure_info: Optional[FailureInformation] = Field(None, description="Messages about task failure. Provided by Task Server")
worker_info: Optional[WorkerInformation] = Field(None, description="Information about the worker which executed a task. Provided by Task Server")

# Performance tracking
time_created: float = Field(None, description="Time this value object was created")
time_input_received: float = Field(None, description="Time the inputs was received by the task server")
time_compute_started: float = Field(None, description="Time workflow process began executing a task")
time_compute_ended: float = Field(None, description="Time workflow process finished executing a task")
time_result_sent: float = Field(None, description="Time message was sent from the server")
time_result_received: float = Field(None, description="Time value was received by client")
time_start_task_submission: float = Field(None, description="Time marking the start of the task submission to workflow engine")
time_task_received: float = Field(None, description="Time task result received from workflow engine")

time_running: float = Field(None, description="Runtime of the method, if available")
time_serialize_inputs: float = Field(None, description="Time required to serialize inputs on client")
time_deserialize_inputs: float = Field(None, description="Time required to deserialize inputs on worker")
time_serialize_results: float = Field(None, description="Time required to serialize results on worker")
time_deserialize_results: float = Field(None, description="Time required to deserialize results on client")
time_async_resolve_proxies: float = Field(None,
description="Time required to scan function inputs and start async resolves of proxies")

additional_timing: dict = Field(default_factory=dict,
description="Timings recorded by a TaskServer that are not defined by above")
proxy_timing: Dict[str, Dict[str, dict]] = Field(default_factory=dict,
description='Timings related to resolving ProxyStore proxies on the compute worker')
message_sizes: Dict[str, int] = Field(default_factory=dict, description='Sizes of the inputs and results in bytes')

# Timings
timestamp: Timestamps = Field(default_factory=Timestamps, help='Times at which major events occurred')
time: TimeSpans = Field(default_factory=TimeSpans, help='Elapsed time between major events')

# Serialization options
serialization_method: SerializationMethod = Field(SerializationMethod.JSON,
description="Method used to serialize input data")
Expand All @@ -229,10 +247,6 @@ def __init__(self, inputs: Tuple[Tuple[Any], Dict[str, Any]], **kwargs):
"""
super().__init__(inputs=inputs, **kwargs)

# Mark "created" only if the value is not already set
if 'time_created' not in kwargs:
self.time_created = datetime.now().timestamp()

@property
def args(self) -> Tuple[Any]:
return tuple(self.inputs[0])
Expand Down Expand Up @@ -277,33 +291,33 @@ def json(self, **kwargs: Dict[str, Any]) -> str:

def mark_result_received(self):
"""Mark that a completed computation was received by a client"""
self.time_result_received = datetime.now().timestamp()
self.timestamp.result_received = datetime.now().timestamp()

def mark_input_received(self):
"""Mark that a task server has received a value"""
self.time_input_received = datetime.now().timestamp()
self.timestamp.input_received = datetime.now().timestamp()

def mark_compute_started(self):
"""Mark that the compute for a method has started"""
self.time_compute_started = datetime.now().timestamp()
self.timestamp.compute_started = datetime.now().timestamp()

def mark_result_sent(self):
"""Mark when a result is sent from the task server"""
self.time_result_sent = datetime.now().timestamp()
self.timestamp.result_sent = datetime.now().timestamp()

def mark_start_task_submission(self):
"""Mark when the Task Server submits a task to the engine"""
self.time_start_task_submission = datetime.now().timestamp()
self.timestamp.start_task_submission = datetime.now().timestamp()

def mark_task_received(self):
"""Mark when the Task Server receives the task from the engine"""
self.time_task_received = datetime.now().timestamp()
self.timestamp.task_received = datetime.now().timestamp()

def mark_compute_ended(self):
"""Mark when the task finished executing"""
self.time_compute_ended = datetime.now().timestamp()
self.timestamp.compute_ended = datetime.now().timestamp()

def set_result(self, result: Any, runtime: float = None):
def set_result(self, result: Any, runtime: float = nan):
"""Set the value of this computation

Automatically sets the "time_result_completed" field and, if known, defines the runtime.
Expand All @@ -318,7 +332,7 @@ def set_result(self, result: Any, runtime: float = None):
self.value = result
if not self.keep_inputs:
self.inputs = ((), {})
self.time_running = runtime
self.time.running = runtime
self.success = True

def serialize(self) -> Tuple[float, List[Proxy]]:
Expand Down Expand Up @@ -386,7 +400,7 @@ def _serialize_and_proxy(value, evict=False) -> Tuple[str, int]:
proxies.append(value_proxy)

# Update the statistics
store_proxy_stats(value_proxy, self.proxy_timing)
store_proxy_stats(value_proxy, self.time.proxy)

# Serialize the proxy with Colmena's utilities. This is
# efficient since the proxy is just a reference and metadata
Expand Down Expand Up @@ -418,9 +432,9 @@ def _serialize_and_proxy(value, evict=False) -> Tuple[str, int]:
# so the value is evicted from the value server once it is resolved
# by the thinker.
if _value is not None:
self.value, value_size = _serialize_and_proxy(_value, evict=True)
self.value, size = _serialize_and_proxy(_value, evict=True)
if 'value' not in self.message_sizes:
self.message_sizes['value'] = value_size
self.message_sizes['value'] = size

return perf_counter() - start_time, proxies
except Exception as e:
Expand Down
3 changes: 1 addition & 2 deletions colmena/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ def store_proxy_stats(proxy: Proxy, proxy_timing: dict):
# Get the key associated with this proxy
key = get_key(proxy)

# ProxyStore keys are NamedTuples so we cast to a string
# so we can use the key as a JSON key.
# ProxyStore keys are NamedTuples, so we cast to a string to use as a JSON key.
key = str(key)

# Get the store associated with this proxy
Expand Down
4 changes: 2 additions & 2 deletions colmena/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def get_result(self, topic: str = 'default', timeout: Optional[float] = None) ->

# Parse the value and mark it as complete
result_obj = Result.parse_raw(message)
result_obj.time_deserialize_results = result_obj.deserialize()
result_obj.time.deserialize_results = result_obj.deserialize()
result_obj.mark_result_received()

# Some logging
Expand Down Expand Up @@ -238,7 +238,7 @@ def send_inputs(self,
)

# Push the serialized value to the task server
result.time_serialize_inputs, proxies = result.serialize()
result.time.serialize_inputs, proxies = result.serialize()
self._send_request(result.json(exclude_none=True), topic)
logger.info(f'Client sent a {method} task with topic {topic}. Created {len(proxies)} proxies for input values')

Expand Down
8 changes: 4 additions & 4 deletions colmena/task_server/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def run_and_record_timing(func: Callable, result: Result) -> Result:
result.mark_compute_started()

# Unpack the inputs
result.time_deserialize_inputs = result.deserialize()
result.time.deserialize_inputs = result.deserialize()

# Start resolving any proxies in the input asynchronously
start_time = perf_counter()
Expand All @@ -186,7 +186,7 @@ def run_and_record_timing(func: Callable, result: Result) -> Result:
input_proxies.extend(resolve_proxies_async(arg))
for value in result.kwargs.values():
input_proxies.extend(resolve_proxies_async(value))
result.time_async_resolve_proxies = perf_counter() - start_time
result.time.async_resolve_proxies = perf_counter() - start_time

# Execute the function
start_time = perf_counter()
Expand Down Expand Up @@ -222,10 +222,10 @@ def run_and_record_timing(func: Callable, result: Result) -> Result:
result.mark_compute_ended()

# Re-pack the results. Will store the proxy statistics
result.time_serialize_results, _ = result.serialize()
result.time.serialize_results, _ = result.serialize()

# Get the statistics for the proxy resolution
for proxy in input_proxies:
store_proxy_stats(proxy, result.proxy_timing)
store_proxy_stats(proxy, result.time.proxy)

return result
2 changes: 1 addition & 1 deletion colmena/task_server/globus.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class GlobusComputeTaskServer(FutureBasedTaskServer):
`registers <https://funcx.readthedocs.io/en/latest/sdk.html#registering-functions>`_
the wrapped function with Globus Compute.
You must also provide a Globus Compute :class:`~globus_compute_sdk.client.Client`
that the task server will use to authenticate with the web service.
that the task server will use to authenticate with the web service.

The task server works using Globus Compute's :class:`~globus_compute_sdk.executor.Executor`
to communicate to the web service over a web socket.
Expand Down
30 changes: 11 additions & 19 deletions colmena/task_server/parsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ def _execute_preprocess(task: ExecutableTask, result: Result) -> Tuple[Result, P
result.mark_compute_started()

# Unpack the inputs
result.time_deserialize_inputs = result.deserialize()
result.time.deserialize_inputs = result.deserialize()

# Start resolving any proxies in the input asynchronously
start_time = perf_counter()
resolve_proxies_async(result.args)
resolve_proxies_async(result.kwargs)
result.time_async_resolve_proxies = perf_counter() - start_time
result.time.async_resolve_proxies = perf_counter() - start_time

# Create a temporary directory
# TODO (wardlt): Figure out how to allow users to define a path for temporary directories
Expand All @@ -76,7 +76,7 @@ def _execute_preprocess(task: ExecutableTask, result: Result) -> Tuple[Result, P
end_time = perf_counter()

# Record the time required to perform the pre-processing
result.additional_timing['exec_preprocess'] = end_time - start_time
result.time.additional['exec_preprocess'] = end_time - start_time

# Remove the inputs. We don't need to send them back to the manager (the manager already knows what it sent out)
result.inputs = ((), {})
Expand Down Expand Up @@ -108,23 +108,23 @@ def _execute_postprocess(task: ExecutableTask, exit_code: int, result: Result, t
result.failure_info = FailureInformation.from_exception(e)
finally:
end_time = perf_counter()
result.additional_timing['exec_postprocess'] = end_time - start_time
result.time.additional['exec_postprocess'] = end_time - start_time

# Store the results
if result.success:
result.set_result(output, datetime.now().timestamp() - result.time_compute_started)
result.set_result(output, datetime.now().timestamp() - result.timestamp.compute_started)

# Store the run time in the result object
result.additional_timing['exec_execution'] = (result.time_running -
result.additional_timing['exec_postprocess'] -
result.additional_timing['exec_preprocess'])
result.time.additional['exec_execution'] = (result.time.running -
result.time.additional['exec_postprocess'] -
result.time.additional['exec_preprocess'])

# Add the worker information into the tasks, if available
worker_info = {'hostname': platform.node()}
result.worker_info = worker_info

# Re-pack the results (will use proxystore, if able)
result.time_serialize_results, _ = result.serialize()
result.time.serialize_results, _ = result.serialize()

# Put the serialized inputs back, if desired
if result.keep_inputs:
Expand Down Expand Up @@ -222,7 +222,7 @@ def _preprocess_callback(
result.inputs = serialized_inputs

# Store the time it took to run the preprocessing
result.time_running = result.additional_timing.get('exec_preprocess', 0)
result.time.running = result.time.additional.get('exec_preprocess', 0)
return task_server.queues.send_result(result, topic)

# If successful, submit the execute step and pass its result to Parsl
Expand Down Expand Up @@ -383,20 +383,12 @@ def __init__(self, methods: List[Union[Callable, Tuple[Callable, Dict]]],

logger.info(f'Defined {len(self.methods_)} methods: {", ".join(self.methods_.keys())}')

# If only one method, store a default method
self.default_method_ = list(self.methods_.keys())[0] if len(self.methods_) == 1 else None
if self.default_method_ is not None:
logger.info(f'There is only one method, so we are using {self.default_method_} as a default')

# Initialize the base class
super().__init__(queues, self.methods_.keys(), timeout)

def _submit(self, task: Result, topic: str) -> Optional[Future]:
# Determine which method to run
if self.default_method_ and task.method is None:
method = self.default_method_
else:
method = task.method
method = task.method

# Submit the application
task.mark_start_task_submission()
Expand Down
14 changes: 7 additions & 7 deletions colmena/task_server/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ def test_run_function(store):
run_and_record_timing(lambda x: x.upper(), result)

# Make sure the timings are all set
assert result.time_running > 0
assert result.time_async_resolve_proxies > 0
assert result.time_deserialize_inputs > 0
assert result.time_serialize_results > 0
assert result.time_compute_ended > result.time_compute_started
assert result.time.running > 0
assert result.time.async_resolve_proxies > 0
assert result.time.deserialize_inputs > 0
assert result.time.serialize_results > 0
assert result.timestamp.compute_ended > result.timestamp.compute_started

# Make sure we have stats for both proxies
assert len(result.proxy_timing) == 2
assert all('store.proxy' in v['times'] for v in result.proxy_timing.values())
assert len(result.time.proxy) == 2
assert all('store.proxy' in v['times'] for v in result.time.proxy.values())
Loading