diff --git a/lib/codecs/array_to_bytes.ml b/lib/codecs/array_to_bytes.ml index 72c9502..b320b16 100644 --- a/lib/codecs/array_to_bytes.ml +++ b/lib/codecs/array_to_bytes.ml @@ -84,39 +84,35 @@ module Any = Owl.Dense.Ndarray.Any module rec ArrayToBytes : sig val parse : - array_tobytes -> + arraytobytes -> ('a, 'b) array_repr -> (unit, [> error]) result - val compute_encoded_size : int -> array_tobytes -> int - val default : array_tobytes + val compute_encoded_size : int -> fixed_arraytobytes -> int val encode : - array_tobytes -> + arraytobytes -> ('a, 'b) Ndarray.t -> (string, [> error]) result val decode : - array_tobytes -> + arraytobytes -> ('a, 'b) array_repr -> string -> (('a, 'b) Ndarray.t, [> `Store_read of string | error]) result - val of_yojson : Yojson.Safe.t -> (array_tobytes, string) result - val to_yojson : array_tobytes -> Yojson.Safe.t + val of_yojson : Yojson.Safe.t -> (arraytobytes, string) result + val to_yojson : arraytobytes -> Yojson.Safe.t end = struct - let default = `Bytes LE - let parse t decoded_repr = match t with | `Bytes _ -> Ok () | `ShardingIndexed c -> ShardingIndexedCodec.parse c decoded_repr - let compute_encoded_size input_size = function - | `Bytes _ -> - BytesCodec.compute_encoded_size input_size - | `ShardingIndexed s -> - ShardingIndexedCodec.compute_encoded_size input_size s.index_codecs + let compute_encoded_size : + int -> fixed_arraytobytes -> int + = fun input_size -> function + | `Bytes _ -> BytesCodec.compute_encoded_size input_size let encode : - array_tobytes -> + arraytobytes -> ('a, 'b) Ndarray.t -> (string, [> error]) result = fun t x -> match t with @@ -124,7 +120,7 @@ end = struct | `ShardingIndexed c -> ShardingIndexedCodec.encode c x let decode : - array_tobytes -> + arraytobytes -> ('a, 'b) array_repr -> string -> (('a, 'b) Ndarray.t, [> error]) result @@ -148,7 +144,6 @@ end and ShardingIndexedCodec : sig type t = internal_shard_config val parse : t -> ('a, 'b) array_repr -> (unit, [> error]) result - val compute_encoded_size : int -> fixed_bytestobytes internal_chain -> int val encode : t -> ('a, 'b) Ndarray.t -> (string, [> error]) result val partial_encode : t -> @@ -177,7 +172,11 @@ and ShardingIndexedCodec : sig end = struct type t = internal_shard_config - let parse_chain repr chain = + let parse_chain : + ('a, 'b) array_repr -> + (arraytobytes, bytestobytes) internal_chain -> + (unit, [> error ]) result + = fun repr chain -> List.fold_left (fun acc c -> acc >>= fun r -> @@ -194,9 +193,7 @@ end = struct (match Array.(length r.shape = length t.chunk_shape) with | true -> Ok () | false -> - let msg = - "sharding chunk_shape length must equal the dimensionality of - the decoded representaton of a shard." in + let msg = "chunk_shape size must equal the dimensionality of its shard." in Result.error @@ `Sharding (t.chunk_shape, r.shape, msg)) >>= fun () -> (match @@ -204,13 +201,13 @@ end = struct with | true -> Ok () | false -> - let msg = - "sharding chunk_shape must evenly divide the size of the shard shape." - in + let msg = "chunk_shape must evenly divide the size of the shard shape." in Result.error @@ `Sharding (t.chunk_shape, r.shape, msg)) >>= fun () -> parse_chain r t.codecs >>= fun () -> - parse_chain {r with shape = Array.append r.shape [|2|]} t.index_codecs + parse_chain + {r with shape = Array.append r.shape [|2|]} + (t.index_codecs :> (arraytobytes, bytestobytes) internal_chain) let compute_encoded_size input_size chain = List.fold_left BytesToBytes.compute_encoded_size @@ -232,17 +229,14 @@ end = struct (ArrayToBytes.encode t.a2b y) t.b2b let encode_index_chain : - fixed_bytestobytes internal_chain -> + (fixed_arraytobytes, fixed_bytestobytes) internal_chain -> Stdint.uint64 Any.arr -> (string, [> error]) result = fun t x -> let open Stdint in (match t.a2a with | [] -> Ok x - | `Transpose o :: _ -> - try Result.ok @@ Any.transpose ~axis:o x with - | Assert_failure _ -> - Error (`Transpose_order (o, "Invalid transpose order."))) + | `Transpose o :: _ -> Result.ok @@ Any.transpose ~axis:o x) >>= fun y -> let buf = Bytes.create @@ 8 * Any.numel y in let z = @@ -253,9 +247,6 @@ end = struct | `Bytes BE -> Any.iteri (fun i v -> Uint64.to_bytes_big_endian v buf (i*8)) y; Ok (Bytes.to_string buf) - | `ShardingIndexed _ -> - Result.error @@ - `CodecChain "Sharding codec is not allowed for shard index." in List.fold_left (fun acc c -> acc >>= BytesToBytes.encode c) @@ -320,7 +311,7 @@ end = struct t.a2a (ArrayToBytes.decode t.a2b repr' y) let decode_index_chain : - fixed_bytestobytes internal_chain -> + (fixed_arraytobytes, fixed_bytestobytes) internal_chain -> (int64, Bigarray.int64_elt) array_repr -> string -> (Stdint.uint64 Any.arr, [> error]) result @@ -344,22 +335,18 @@ end = struct | `Bytes BE -> Any.init repr'.shape @@ fun i -> Uint64.of_bytes_big_endian buf (i*8) - | `ShardingIndexed _ -> - failwith "Sharding codec is not allowed for shard index." in match t.a2a with | [] -> Ok arr | `Transpose o :: _ -> - try - let inv_order = Array.(make (length o) 0) in - Array.iteri (fun i x -> inv_order.(x) <- i) o; - Result.ok @@ Any.transpose ~axis:inv_order arr - with - | Assert_failure _ -> - Error (`Transpose_order (o, "Invalid transpose order.")) + let inv_order = Array.(make (length o) 0) in + Array.iteri (fun i x -> inv_order.(x) <- i) o; + Result.ok @@ Any.transpose ~axis:inv_order arr let index_size : - fixed_bytestobytes internal_chain -> int array -> int + (fixed_arraytobytes, fixed_bytestobytes) internal_chain -> + int array -> + int = fun index_chain cps -> compute_encoded_size (16 * Util.prod cps) index_chain @@ -527,8 +514,8 @@ end = struct let to_yojson t = let codecs = chain_to_yojson t.codecs in let index_codecs = - chain_to_yojson (t.index_codecs :> bytestobytes internal_chain) - in + chain_to_yojson + (t.index_codecs :> (arraytobytes, bytestobytes) internal_chain) in let index_location = match t.index_location with | End -> `String "end" @@ -549,7 +536,8 @@ end = struct ("codecs", codecs)])] let chain_of_yojson : - Yojson.Safe.t list -> (bytestobytes internal_chain, string) result + Yojson.Safe.t list -> + ((arraytobytes, bytestobytes) internal_chain, string) result = fun codecs -> let filter_partition f encoded = List.fold_right (fun c (l, r) -> @@ -624,11 +612,12 @@ end = struct (fun c acc -> acc >>= fun l -> match c with - | `Crc32c -> - Ok (`Crc32c :: l) - | `Gzip _ -> - Error "index_codecs must not contain variable-sized codecs.") - ic.b2b (Ok []) - >>| fun b2b -> - {index_codecs = {ic with b2b}; index_location; codecs; chunk_shape} + | `Crc32c -> Ok (`Crc32c :: l) + | `Gzip _ -> Error msg) ic.b2b (Ok []) + >>= fun b2b -> + (match ic.a2b with + | `Bytes e -> Ok (`Bytes e) + | `ShardingIndexed _ -> Error msg) + >>| fun a2b -> + {index_codecs = {ic with a2b; b2b}; index_location; codecs; chunk_shape} end diff --git a/lib/codecs/array_to_bytes.mli b/lib/codecs/array_to_bytes.mli index 8bfaa3d..60ac5bd 100644 --- a/lib/codecs/array_to_bytes.mli +++ b/lib/codecs/array_to_bytes.mli @@ -2,23 +2,22 @@ open Codecs_intf module ArrayToBytes : sig val parse : - array_tobytes -> + arraytobytes -> ('a, 'b) array_repr -> (unit, [> error]) result - val compute_encoded_size : int -> array_tobytes -> int - val default : array_tobytes + val compute_encoded_size : int -> fixed_arraytobytes -> int val encode : - array_tobytes -> + arraytobytes -> ('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t -> (string, [> error]) result val decode : - array_tobytes -> + arraytobytes -> ('a, 'b) array_repr -> string -> (('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t ,[> `Store_read of string | error]) result - val of_yojson : Yojson.Safe.t -> (array_tobytes, string) result - val to_yojson : array_tobytes -> Yojson.Safe.t + val of_yojson : Yojson.Safe.t -> (arraytobytes, string) result + val to_yojson : arraytobytes -> Yojson.Safe.t end module ShardingIndexedCodec : sig diff --git a/lib/codecs/codecs.ml b/lib/codecs/codecs.ml index 5ab8ae9..82f76fd 100644 --- a/lib/codecs/codecs.ml +++ b/lib/codecs/codecs.ml @@ -5,29 +5,31 @@ open Util.Result_syntax include Codecs_intf -type arraytobytes = - [ `Bytes of endianness - | `ShardingIndexed of shard_config ] +type fixed_arraytobytes = + [ `Bytes of endianness ] + +type variable_array_tobytes = + [ `ShardingIndexed of shard_config ] and shard_config = {chunk_shape : int array ;codecs : [ arraytoarray - | `Bytes of endianness + | fixed_arraytobytes | `ShardingIndexed of shard_config | bytestobytes ] list ;index_codecs : - [ arraytoarray - | `Bytes of endianness - | `ShardingIndexed of shard_config - | fixed_bytestobytes ] list + [ arraytoarray | fixed_arraytobytes | fixed_bytestobytes ] list ;index_location : loc} +type array_tobytes = + [ fixed_arraytobytes | variable_array_tobytes ] + type codec_chain = - [ arraytoarray | arraytobytes | bytestobytes ] list + [ arraytoarray | array_tobytes | bytestobytes ] list module Chain = struct - type t = bytestobytes internal_chain + type t = (arraytobytes, bytestobytes) internal_chain let rec create : type a b. (a, b) array_repr -> codec_chain -> (t, [> error ]) result @@ -36,7 +38,7 @@ module Chain = struct List.partition_map (function | #arraytoarray as c -> Either.left c - | #arraytobytes as c -> Either.right c + | #array_tobytes as c -> Either.right c | #bytestobytes as c -> Either.right c) cc in List.fold_right @@ -44,13 +46,13 @@ module Chain = struct acc >>= fun (l, r) -> match c with | #bytestobytes as c -> Ok (l, c :: r) - | `Bytes e -> Ok (`Bytes e :: l, r) + | #fixed_arraytobytes as c -> Ok (c :: l, r) | `ShardingIndexed cfg -> create repr cfg.codecs >>= fun codecs -> create {repr with shape = Array.append repr.shape [|2|]} (cfg.index_codecs :> codec_chain) >>= fun index_codecs -> - (* coerse to a fixed_bytestobytes internal_chain list type *) + (* coerse to a fixed codec internal_chain list type *) let b2b = fst @@ List.partition_map @@ -58,11 +60,19 @@ module Chain = struct | #fixed_bytestobytes as c -> Either.left c | c -> Either.right c) index_codecs.b2b in + let a2b = + List.hd @@ + fst @@ + List.partition_map + (function + | #fixed_arraytobytes as c -> Either.left c + | c -> Either.right c) [index_codecs.a2b] + in let cfg' : internal_shard_config = {codecs ;chunk_shape = cfg.chunk_shape ;index_location = cfg.index_location - ;index_codecs = {index_codecs with b2b}} + ;index_codecs = {index_codecs with a2b; b2b}} in Ok (`ShardingIndexed cfg' :: l, r)) rest (Ok ([], [])) >>= fun result -> (match result with diff --git a/lib/codecs/codecs_intf.ml b/lib/codecs/codecs_intf.ml index b00d7d2..3aea986 100644 --- a/lib/codecs/codecs_intf.ml +++ b/lib/codecs/codecs_intf.ml @@ -17,20 +17,27 @@ type endianness = LE | BE type loc = Start | End -type array_tobytes = - [ `Bytes of endianness - | `ShardingIndexed of internal_shard_config ] +type fixed_arraytobytes = + [ `Bytes of endianness ] + +type variable_arraytobytes = + [ `ShardingIndexed of internal_shard_config ] and internal_shard_config = {chunk_shape : int array - ;codecs : bytestobytes internal_chain - ;index_codecs : fixed_bytestobytes internal_chain + ;codecs : + ([fixed_arraytobytes | `ShardingIndexed of internal_shard_config ] + ,bytestobytes) internal_chain + ;index_codecs : (fixed_arraytobytes, fixed_bytestobytes) internal_chain + ;index_location : loc} -and 'a internal_chain = +and ('a, 'b) internal_chain = {a2a : arraytoarray list - ;a2b : array_tobytes - ;b2b : 'a list} + ;a2b : 'a + ;b2b : 'b list} + +type arraytobytes = [ fixed_arraytobytes | variable_arraytobytes ] type error = [ `Extension of string @@ -77,28 +84,31 @@ module type Interface = sig type loc = Start | End (** The type of [array -> bytes] codecs. *) - type arraytobytes = - [ `Bytes of endianness - | `ShardingIndexed of shard_config ] + type fixed_arraytobytes = + [ `Bytes of endianness ] + + type variable_array_tobytes = + [ `ShardingIndexed of shard_config ] (** A type representing the Sharding indexed codec's configuration parameters. *) and shard_config = {chunk_shape : int array ;codecs : [ arraytoarray - | `Bytes of endianness + | fixed_arraytobytes | `ShardingIndexed of shard_config | bytestobytes ] list ;index_codecs : - [ arraytoarray - | `Bytes of endianness - | `ShardingIndexed of shard_config - | fixed_bytestobytes ] list + [ arraytoarray | fixed_arraytobytes | fixed_bytestobytes ] list ;index_location : loc} + (** The type of [array -> bytes] codecs. *) + type array_tobytes = + [ fixed_arraytobytes | variable_array_tobytes ] + (** A type used to build a user-defined chain of codecs when creating a Zarr array. *) type codec_chain = - [ arraytoarray | arraytobytes | bytestobytes ] list + [ arraytoarray | array_tobytes | bytestobytes ] list (** The type of errors returned upon failure when an calling a function on a {!Chain} type. *) diff --git a/test/test_codecs.ml b/test/test_codecs.ml index ad164ac..e565e9a 100644 --- a/test/test_codecs.ml +++ b/test/test_codecs.ml @@ -382,8 +382,30 @@ let tests = [ in let r = Chain.of_yojson @@ Yojson.Safe.from_string str in assert_bool - "Encoding this nested sharding chain should not fail" @@ - Result.is_ok r) + "Encoding this nested sharding chain should not fail" @@ Result.is_ok r; + (* test if decoding of indexed_codec with sharding for array->bytes fails.*) + let str = + {|[ + {"name": "sharding_indexed", + "configuration": + {"index_location": "start", + "chunk_shape": [5, 5, 5], + "codecs": + [{"name": "bytes", "configuration": {"endian": "big"}}], + "index_codecs": + [{"name": "sharding_indexed", + "configuration": + {"index_location": "end", + "chunk_shape": [5, 5, 5, 1], + "index_codecs": + [{"name": "bytes", "configuration": {"endian": "big"}}], + "codecs": + [{"name": "bytes", "configuration": {"endian": "big"}}]}}]}}]|} + in + let r = Chain.of_yojson @@ Yojson.Safe.from_string str in + assert_bool + "Decoding of index_codec chain with sharding should fail" @@ + Result.is_error r) ;