diff --git a/colmena/models.py b/colmena/models.py index c39bc1e..3cd55aa 100644 --- a/colmena/models.py +++ b/colmena/models.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum +from functools import partial from io import StringIO from pathlib import Path from subprocess import run @@ -68,6 +69,62 @@ def deserialize(method: 'SerializationMethod', message: str) -> Any: raise NotImplementedError(f'Method {method} not yet implemented') +def _serialized_str_to_bytes_shim( + s: str, + method: Union[str, SerializationMethod], +) -> bytes: + """Shim between Colmena serialized objects and bytes. + + Colmena's serialization mechanisms produce strings but ProxyStore + serializes to bytes, so this shim takes a an object serialized by Colmena + and converts it to bytes. + + Args: + s: Serialized string object + method: Serialization method used to produce s + + Returns: + bytes representation of s + """ + if method == "json": + return s.encode('utf-8') + elif method == "pickle": + # In this case the conversion goes from obj > bytes > str > bytes + # which results in an unecessary conversion to a string but this is + # an unavoidable side effect of converting between the Colmena + # and ProxyStore serialization formats. + return bytes.fromhex(s) + else: + raise NotImplementedError(f'Method {method} not yet implemented') + + +def _serialized_bytes_to_obj_wrapper( + b: str, + method: Union[str, SerializationMethod], +) -> Any: + """Wrapper which converts bytes to strings before deserializing. + + Args: + b: Byte string of serialized object + method: Serialization method used to produce b + + Returns: + Deserialized object + """ + if method == "json": + s = b.decode('utf-8') + elif method == "pickle": + # In this case the conversion goes from bytes > str > bytes > obj + # which results in an unecessary conversion to a string but this is + # an unavoidable side effect of converting between the Colmena + # and ProxyStore serialization formats. + s = b.hex() + else: + raise NotImplementedError(f'Method {method} not yet implemented') + + return SerializationMethod.serialize(method, s) + + class FailureInformation(BaseModel): """Stores information about a task failure""" @@ -306,7 +363,25 @@ def _serialize_and_proxy(value, evict=False) -> Tuple[str, int]: not isinstance(value, Proxy) and value_size >= self.proxystore_threshold ): - value_proxy = store.proxy(value, evict=evict) + # Override ProxyStore's default serialization with these shims + # to Colmena's serialization mechanisms. This avoids value + # being serialized twice: once to get the size of the + # serialized object and once by proxy(). + deserializer = partial( + _serialized_bytes_to_obj_wrapper, + method=self.serialization_method, + ) + serializer = partial( + _serialized_str_to_bytes_shim, + method=self.serialization_method, + ) + + value_proxy = store.proxy( + value_str, + evict=evict, + deserializer=deserializer, + serializer=serializer, + ) logger.debug(f'Proxied object of type {type(value)} with id={id(value)}') proxies.append(value_proxy)