Skip to content

Commit

Permalink
Cleanup codecs interface.
Browse files Browse the repository at this point in the history
This attempts to make it more ergonomic to specify a codec
chain while also retaining the ability to preserve codec
chain invariants using the type system.
  • Loading branch information
zoj613 committed Jul 13, 2024
1 parent 8f21a53 commit ba58aa9
Show file tree
Hide file tree
Showing 11 changed files with 192 additions and 252 deletions.
11 changes: 2 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,8 @@ FilesystemStore.create_group store group_node;;
let array_node =
Result.get_ok @@ ArrayNode.(group_node / "name");;
let codec_chain =
{a2a = [`Transpose [|2; 0; 1|]]
;a2b = `Bytes Big
;b2b = [`Gzip L2]};;
FilesystemStore.create_array
~codecs:codec_chain
~codecs:[`Transpose [|2; 0; 1|]; `Bytes Big; `Gzip L2]
~shape:[|100; 100; 50|]
~chunks:[|10; 15; 20|]
Bigarray.Float32
Expand Down Expand Up @@ -89,13 +84,11 @@ let config =
;a2b = `Bytes Big
;b2b = [`Crc32c]}
;index_location = Start};;
let codec_chain =
{a2a = []; a2b = `ShardingIndexed config; b2b = []};;
let shard_node = Result.get_ok @@ ArrayNode.(group_node / "another");;
FilesystemStore.create_array
~codecs:codec_chain
~codecs:[`ShardingIndexed config]
~shape:[|100; 100; 50|]
~chunks:[|10; 15; 20|]
Bigarray.Complex32
Expand Down
28 changes: 13 additions & 15 deletions lib/codecs/array_to_array.ml
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
module Ndarray = Owl.Dense.Ndarray.Generic

type dimension_order = int array

type array_to_array =
| Transpose of dimension_order
type arraytoarray =
[ `Transpose of int array ]

type error =
[ `Transpose_order of dimension_order * string ]
[ `Transpose_order of int array * string ]

(* https://zarr-specs.readthedocs.io/en/latest/v3/codecs/transpose/v1.0.html *)
module TransposeCodec = struct
Expand Down Expand Up @@ -49,12 +47,12 @@ module TransposeCodec = struct
or negative values." in
Result.error @@ `Transpose_order (o, msg)
else
Result.ok @@ Transpose o
Result.ok @@ `Transpose o

let parse
: type a b.
(a, b) Util.array_repr ->
dimension_order ->
int array ->
(unit, [> error]) result
= fun repr o ->
ignore @@ parse_order o;
Expand Down Expand Up @@ -108,37 +106,37 @@ module TransposeCodec = struct
"transpose order must only
contain positive integers and unique values."
in Error msg) o (Ok [])
>>| fun o' -> Transpose (Array.of_list o')
>>| fun o' -> `Transpose (Array.of_list o')
| _ -> Error "Invalid transpose configuration."
end

module ArrayToArray = struct
let parse decoded_repr = function
| Transpose o -> TransposeCodec.parse decoded_repr o
| `Transpose o -> TransposeCodec.parse decoded_repr o

let compute_encoded_size input_size = function
| Transpose _ -> TransposeCodec.compute_encoded_size input_size
| `Transpose _ -> TransposeCodec.compute_encoded_size input_size

let compute_encoded_representation
: type a b.
array_to_array ->
arraytoarray ->
(a, b) Util.array_repr ->
((a, b) Util.array_repr, [> error]) result
= fun t repr ->
match t with
| Transpose o ->
| `Transpose o ->
TransposeCodec.compute_encoded_representation o repr

let encode t x =
match t with
| Transpose order -> TransposeCodec.encode order x
| `Transpose order -> TransposeCodec.encode order x

let decode t x =
match t with
| Transpose order -> TransposeCodec.decode order x
| `Transpose order -> TransposeCodec.decode order x

let to_yojson = function
| Transpose order -> TransposeCodec.to_yojson order
| `Transpose order -> TransposeCodec.to_yojson order

let of_yojson x =
match Util.get_name x with
Expand Down
22 changes: 10 additions & 12 deletions lib/codecs/array_to_array.mli
Original file line number Diff line number Diff line change
@@ -1,31 +1,29 @@
module Ndarray = Owl.Dense.Ndarray.Generic

type dimension_order = int array

type array_to_array =
| Transpose of dimension_order
type arraytoarray =
[ `Transpose of int array ]

type error =
[ `Transpose_order of dimension_order * string ]
[ `Transpose_order of int array * string ]

module ArrayToArray : sig
val parse
: ('a, 'b) Util.array_repr ->
array_to_array ->
arraytoarray ->
(unit, [> error]) result
val compute_encoded_size : int -> array_to_array -> int
val compute_encoded_size : int -> arraytoarray -> int
val compute_encoded_representation
: array_to_array ->
: arraytoarray ->
('a, 'b) Util.array_repr ->
(('a, 'b) Util.array_repr, [> error]) result
val encode
: array_to_array ->
: arraytoarray ->
('a, 'b) Ndarray.t ->
(('a, 'b) Ndarray.t, [> error]) result
val decode
: array_to_array ->
: arraytoarray ->
('a, 'b) Ndarray.t ->
(('a, 'b) Ndarray.t, [> error]) result
val of_yojson : Yojson.Safe.t -> (array_to_array, string) result
val to_yojson : array_to_array -> Yojson.Safe.t
val of_yojson : Yojson.Safe.t -> (arraytoarray, string) result
val to_yojson : arraytoarray -> Yojson.Safe.t
end
98 changes: 45 additions & 53 deletions lib/codecs/array_to_bytes.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,24 @@ open Util.Result_syntax

module Ndarray = Owl.Dense.Ndarray.Generic

type endianness =
| Little
| Big
type endianness = Little | Big

type loc =
| Start
| End
type loc = Start | End

type array_to_bytes =
| Bytes of endianness
| ShardingIndexed of shard_config
type arraytobytes =
[ `Bytes of endianness
| `ShardingIndexed of shard_config ]

and shard_config =
{chunk_shape : int array
;codecs : any_bytes_to_bytes sharding_chain
;index_codecs : fixed bytes_to_bytes sharding_chain
;codecs : bytestobytes shard_chain
;index_codecs : fixed_bytestobytes shard_chain
;index_location : loc}

and 'a sharding_chain = {
a2a: array_to_array list;
a2b: array_to_bytes;
b2b: 'a list}
and 'a shard_chain =
{a2a: arraytoarray list
;a2b: arraytobytes
;b2b: 'a list}

type error =
[ Extensions.error
Expand Down Expand Up @@ -115,65 +111,65 @@ end
module rec ArrayToBytes : sig
val parse
: ('a, 'b) Util.array_repr ->
array_to_bytes ->
arraytobytes ->
(unit, [> error]) result
val compute_encoded_size : int -> array_to_bytes -> int
val default : array_to_bytes
val compute_encoded_size : int -> arraytobytes -> int
val default : arraytobytes
val encode
: ('a, 'b) Ndarray.t ->
array_to_bytes ->
arraytobytes ->
(string, [> error]) result
val decode
: string ->
('a, 'b) Util.array_repr ->
array_to_bytes ->
arraytobytes ->
(('a, 'b) Ndarray.t, [> error]) result
val of_yojson : Yojson.Safe.t -> (array_to_bytes, string) result
val to_yojson : array_to_bytes -> Yojson.Safe.t
val of_yojson : Yojson.Safe.t -> (arraytobytes, string) result
val to_yojson : arraytobytes -> Yojson.Safe.t
end = struct

let default = Bytes Little
let default = `Bytes Little

let parse decoded_repr = function
| Bytes _ -> Ok ()
| ShardingIndexed c ->
| `Bytes _ -> Ok ()
| `ShardingIndexed c ->
ShardingIndexedCodec.parse decoded_repr c

let compute_encoded_size input_size = function
| Bytes _ ->
| `Bytes _ ->
BytesCodec.compute_encoded_size input_size
| ShardingIndexed s ->
| `ShardingIndexed s ->
ShardingIndexedCodec.compute_encoded_size input_size s

let encode
: type a b.
(a, b) Ndarray.t ->
array_to_bytes ->
arraytobytes ->
(string, [> error]) result
= fun x -> function
| Bytes endian -> BytesCodec.encode x endian
| ShardingIndexed c -> ShardingIndexedCodec.encode x c
| `Bytes endian -> BytesCodec.encode x endian
| `ShardingIndexed c -> ShardingIndexedCodec.encode x c

let decode
: type a b.
string ->
(a, b) Util.array_repr ->
array_to_bytes ->
arraytobytes ->
((a, b) Ndarray.t, [> error]) result
= fun b repr -> function
| Bytes endian -> BytesCodec.decode b repr endian
| ShardingIndexed c -> ShardingIndexedCodec.decode b repr c
| `Bytes endian -> BytesCodec.decode b repr endian
| `ShardingIndexed c -> ShardingIndexedCodec.decode b repr c

let to_yojson = function
| Bytes endian -> BytesCodec.to_yojson endian
| ShardingIndexed c -> ShardingIndexedCodec.to_yojson c
| `Bytes endian -> BytesCodec.to_yojson endian
| `ShardingIndexed c -> ShardingIndexedCodec.to_yojson c

let of_yojson x =
match Util.get_name x with
| "bytes" ->
BytesCodec.of_yojson x >>| fun e -> Bytes e
BytesCodec.of_yojson x >>| fun e -> `Bytes e
| "sharding_indexed" ->
ShardingIndexedCodec.of_yojson x >>| fun c -> ShardingIndexed c
ShardingIndexedCodec.of_yojson x >>| fun c -> `ShardingIndexed c
| _ -> Error ("array->bytes codec not supported: ")
end

Expand Down Expand Up @@ -248,7 +244,7 @@ end = struct

let rec encode_chain
: type a b.
any_bytes_to_bytes sharding_chain ->
bytestobytes shard_chain ->
(a, b) Ndarray.t ->
(string, [> error]) result
= fun t x ->
Expand Down Expand Up @@ -321,9 +317,7 @@ end = struct
offset := Int64.add !offset nbytes) cindices (Ok ())
>>= fun () ->
(* convert t.index_codecs to a generic bytes-to-bytes chain. *)
encode_chain
{t.index_codecs with b2b =
List.map (fun v -> Any v) t.index_codecs.b2b} shard_idx
encode_chain (t.index_codecs :> bytestobytes shard_chain) shard_idx
>>| fun b' ->
match t.index_location with
| Start ->
Expand All @@ -337,7 +331,7 @@ end = struct

let rec decode_chain
: type a b.
any_bytes_to_bytes sharding_chain ->
bytestobytes shard_chain ->
string ->
(a, b) Util.array_repr ->
((a, b) Ndarray.t, [> error]) result
Expand Down Expand Up @@ -378,9 +372,8 @@ end = struct
;shape = Array.append cps [|2|]}
in
decode_chain
{t.index_codecs with b2b =
List.map (fun v -> Any v) t.index_codecs.b2b} b' repr
>>| fun decoded ->
(t.index_codecs : fixed_bytestobytes shard_chain :> bytestobytes shard_chain)
b' repr >>| fun decoded ->
(decoded, rest)

and index_size t cps =
Expand Down Expand Up @@ -444,9 +437,7 @@ end = struct
and to_yojson t =
let codecs = chain_to_yojson t.codecs in
let index_codecs =
chain_to_yojson
{t.index_codecs with b2b =
List.map (fun v -> Any v) t.index_codecs.b2b}
chain_to_yojson (t.index_codecs :> bytestobytes shard_chain)
in
let index_location =
match t.index_location with
Expand All @@ -468,8 +459,7 @@ end = struct
("codecs", codecs)])]

let rec chain_of_yojson :
Yojson.Safe.t list ->
(any_bytes_to_bytes sharding_chain, string) result
Yojson.Safe.t list -> (bytestobytes shard_chain, string) result
= fun codecs ->
let filter_partition f encoded =
List.fold_right (fun c (l, r) ->
Expand Down Expand Up @@ -542,9 +532,11 @@ end = struct
(fun c acc ->
acc >>= fun l ->
match c with
| Any Crc32c -> Ok (Crc32c :: l)
| Any (Gzip _) ->
| `Crc32c ->
Ok (`Crc32c :: l)
| `Gzip _ ->
Error "index_codecs must not contain variable-sized codecs.")
ic.b2b (Ok []) >>| fun b2b ->
ic.b2b (Ok [])
>>| fun b2b ->
{index_codecs = {ic with b2b}; index_location; codecs; chunk_shape}
end
Loading

0 comments on commit ba58aa9

Please sign in to comment.