Skip to content

Commit

Permalink
Split array->bytes codecs into fixed and variable.
Browse files Browse the repository at this point in the history
This is to help the type checker rule out ShardingIndexed codec
when pattern matching on array->bytes codecs used for the index_codec
chain.
  • Loading branch information
zoj613 committed Jul 25, 2024
1 parent 3a3b548 commit 9ec68ba
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 94 deletions.
97 changes: 43 additions & 54 deletions lib/codecs/array_to_bytes.ml
Original file line number Diff line number Diff line change
Expand Up @@ -84,47 +84,43 @@ module Any = Owl.Dense.Ndarray.Any

module rec ArrayToBytes : sig
val parse :
array_tobytes ->
arraytobytes ->
('a, 'b) array_repr ->
(unit, [> error]) result
val compute_encoded_size : int -> array_tobytes -> int
val default : array_tobytes
val compute_encoded_size : int -> fixed_arraytobytes -> int
val encode :
array_tobytes ->
arraytobytes ->
('a, 'b) Ndarray.t ->
(string, [> error]) result
val decode :
array_tobytes ->
arraytobytes ->
('a, 'b) array_repr ->
string ->
(('a, 'b) Ndarray.t, [> `Store_read of string | error]) result
val of_yojson : Yojson.Safe.t -> (array_tobytes, string) result
val to_yojson : array_tobytes -> Yojson.Safe.t
val of_yojson : Yojson.Safe.t -> (arraytobytes, string) result
val to_yojson : arraytobytes -> Yojson.Safe.t
end = struct

let default = `Bytes LE

let parse t decoded_repr =
match t with
| `Bytes _ -> Ok ()
| `ShardingIndexed c -> ShardingIndexedCodec.parse c decoded_repr

let compute_encoded_size input_size = function
| `Bytes _ ->
BytesCodec.compute_encoded_size input_size
| `ShardingIndexed s ->
ShardingIndexedCodec.compute_encoded_size input_size s.index_codecs
let compute_encoded_size :
int -> fixed_arraytobytes -> int
= fun input_size -> function
| `Bytes _ -> BytesCodec.compute_encoded_size input_size

let encode :
array_tobytes ->
arraytobytes ->
('a, 'b) Ndarray.t ->
(string, [> error]) result
= fun t x -> match t with
| `Bytes endian -> BytesCodec.encode x endian
| `ShardingIndexed c -> ShardingIndexedCodec.encode c x

let decode :
array_tobytes ->
arraytobytes ->
('a, 'b) array_repr ->
string ->
(('a, 'b) Ndarray.t, [> error]) result
Expand All @@ -148,7 +144,6 @@ end
and ShardingIndexedCodec : sig
type t = internal_shard_config
val parse : t -> ('a, 'b) array_repr -> (unit, [> error]) result
val compute_encoded_size : int -> fixed_bytestobytes internal_chain -> int
val encode : t -> ('a, 'b) Ndarray.t -> (string, [> error]) result
val partial_encode :
t ->
Expand Down Expand Up @@ -177,7 +172,11 @@ and ShardingIndexedCodec : sig
end = struct
type t = internal_shard_config

let parse_chain repr chain =
let parse_chain :
('a, 'b) array_repr ->
(arraytobytes, bytestobytes) internal_chain ->
(unit, [> error ]) result
= fun repr chain ->
List.fold_left
(fun acc c ->
acc >>= fun r ->
Expand All @@ -194,23 +193,21 @@ end = struct
(match Array.(length r.shape = length t.chunk_shape) with
| true -> Ok ()
| false ->
let msg =
"sharding chunk_shape length must equal the dimensionality of
the decoded representaton of a shard." in
let msg = "chunk_shape size must equal the dimensionality of its shard." in
Result.error @@ `Sharding (t.chunk_shape, r.shape, msg))
>>= fun () ->
(match
Array.for_all2 (fun x y -> (x mod y) = 0) r.shape t.chunk_shape
with
| true -> Ok ()
| false ->
let msg =
"sharding chunk_shape must evenly divide the size of the shard shape."
in
let msg = "chunk_shape must evenly divide the size of the shard shape." in
Result.error @@ `Sharding (t.chunk_shape, r.shape, msg))
>>= fun () ->
parse_chain r t.codecs >>= fun () ->
parse_chain {r with shape = Array.append r.shape [|2|]} t.index_codecs
parse_chain
{r with shape = Array.append r.shape [|2|]}
(t.index_codecs :> (arraytobytes, bytestobytes) internal_chain)

let compute_encoded_size input_size chain =
List.fold_left BytesToBytes.compute_encoded_size
Expand All @@ -232,17 +229,14 @@ end = struct
(ArrayToBytes.encode t.a2b y) t.b2b

let encode_index_chain :
fixed_bytestobytes internal_chain ->
(fixed_arraytobytes, fixed_bytestobytes) internal_chain ->
Stdint.uint64 Any.arr ->
(string, [> error]) result
= fun t x ->
let open Stdint in
(match t.a2a with
| [] -> Ok x
| `Transpose o :: _ ->
try Result.ok @@ Any.transpose ~axis:o x with
| Assert_failure _ ->
Error (`Transpose_order (o, "Invalid transpose order.")))
| `Transpose o :: _ -> Result.ok @@ Any.transpose ~axis:o x)
>>= fun y ->
let buf = Bytes.create @@ 8 * Any.numel y in
let z =
Expand All @@ -253,9 +247,6 @@ end = struct
| `Bytes BE ->
Any.iteri (fun i v -> Uint64.to_bytes_big_endian v buf (i*8)) y;
Ok (Bytes.to_string buf)
| `ShardingIndexed _ ->
Result.error @@
`CodecChain "Sharding codec is not allowed for shard index."
in
List.fold_left
(fun acc c -> acc >>= BytesToBytes.encode c)
Expand Down Expand Up @@ -320,7 +311,7 @@ end = struct
t.a2a (ArrayToBytes.decode t.a2b repr' y)

let decode_index_chain :
fixed_bytestobytes internal_chain ->
(fixed_arraytobytes, fixed_bytestobytes) internal_chain ->
(int64, Bigarray.int64_elt) array_repr ->
string ->
(Stdint.uint64 Any.arr, [> error]) result
Expand All @@ -344,22 +335,18 @@ end = struct
| `Bytes BE ->
Any.init repr'.shape @@ fun i ->
Uint64.of_bytes_big_endian buf (i*8)
| `ShardingIndexed _ ->
failwith "Sharding codec is not allowed for shard index."
in
match t.a2a with
| [] -> Ok arr
| `Transpose o :: _ ->
try
let inv_order = Array.(make (length o) 0) in
Array.iteri (fun i x -> inv_order.(x) <- i) o;
Result.ok @@ Any.transpose ~axis:inv_order arr
with
| Assert_failure _ ->
Error (`Transpose_order (o, "Invalid transpose order."))
let inv_order = Array.(make (length o) 0) in
Array.iteri (fun i x -> inv_order.(x) <- i) o;
Result.ok @@ Any.transpose ~axis:inv_order arr

let index_size :
fixed_bytestobytes internal_chain -> int array -> int
(fixed_arraytobytes, fixed_bytestobytes) internal_chain ->
int array ->
int
= fun index_chain cps ->
compute_encoded_size (16 * Util.prod cps) index_chain

Expand Down Expand Up @@ -527,8 +514,8 @@ end = struct
let to_yojson t =
let codecs = chain_to_yojson t.codecs in
let index_codecs =
chain_to_yojson (t.index_codecs :> bytestobytes internal_chain)
in
chain_to_yojson
(t.index_codecs :> (arraytobytes, bytestobytes) internal_chain) in
let index_location =
match t.index_location with
| End -> `String "end"
Expand All @@ -549,7 +536,8 @@ end = struct
("codecs", codecs)])]

let chain_of_yojson :
Yojson.Safe.t list -> (bytestobytes internal_chain, string) result
Yojson.Safe.t list ->
((arraytobytes, bytestobytes) internal_chain, string) result
= fun codecs ->
let filter_partition f encoded =
List.fold_right (fun c (l, r) ->
Expand Down Expand Up @@ -624,11 +612,12 @@ end = struct
(fun c acc ->
acc >>= fun l ->
match c with
| `Crc32c ->
Ok (`Crc32c :: l)
| `Gzip _ ->
Error "index_codecs must not contain variable-sized codecs.")
ic.b2b (Ok [])
>>| fun b2b ->
{index_codecs = {ic with b2b}; index_location; codecs; chunk_shape}
| `Crc32c -> Ok (`Crc32c :: l)
| `Gzip _ -> Error msg) ic.b2b (Ok [])
>>= fun b2b ->
(match ic.a2b with
| `Bytes e -> Ok (`Bytes e)
| `ShardingIndexed _ -> Error msg)
>>| fun a2b ->
{index_codecs = {ic with a2b; b2b}; index_location; codecs; chunk_shape}
end
13 changes: 6 additions & 7 deletions lib/codecs/array_to_bytes.mli
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,22 @@ open Codecs_intf

module ArrayToBytes : sig
val parse :
array_tobytes ->
arraytobytes ->
('a, 'b) array_repr ->
(unit, [> error]) result
val compute_encoded_size : int -> array_tobytes -> int
val default : array_tobytes
val compute_encoded_size : int -> fixed_arraytobytes -> int
val encode :
array_tobytes ->
arraytobytes ->
('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t ->
(string, [> error]) result
val decode :
array_tobytes ->
arraytobytes ->
('a, 'b) array_repr ->
string ->
(('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t
,[> `Store_read of string | error]) result
val of_yojson : Yojson.Safe.t -> (array_tobytes, string) result
val to_yojson : array_tobytes -> Yojson.Safe.t
val of_yojson : Yojson.Safe.t -> (arraytobytes, string) result
val to_yojson : arraytobytes -> Yojson.Safe.t
end

module ShardingIndexedCodec : sig
Expand Down
38 changes: 24 additions & 14 deletions lib/codecs/codecs.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,31 @@ open Util.Result_syntax

include Codecs_intf

type arraytobytes =
[ `Bytes of endianness
| `ShardingIndexed of shard_config ]
type fixed_arraytobytes =
[ `Bytes of endianness ]

type variable_array_tobytes =
[ `ShardingIndexed of shard_config ]

and shard_config =
{chunk_shape : int array
;codecs :
[ arraytoarray
| `Bytes of endianness
| fixed_arraytobytes
| `ShardingIndexed of shard_config
| bytestobytes ] list
;index_codecs :
[ arraytoarray
| `Bytes of endianness
| `ShardingIndexed of shard_config
| fixed_bytestobytes ] list
[ arraytoarray | fixed_arraytobytes | fixed_bytestobytes ] list
;index_location : loc}

type array_tobytes =
[ fixed_arraytobytes | variable_array_tobytes ]

type codec_chain =
[ arraytoarray | arraytobytes | bytestobytes ] list
[ arraytoarray | array_tobytes | bytestobytes ] list

module Chain = struct
type t = bytestobytes internal_chain
type t = (arraytobytes, bytestobytes) internal_chain

let rec create :
type a b. (a, b) array_repr -> codec_chain -> (t, [> error ]) result
Expand All @@ -36,33 +38,41 @@ module Chain = struct
List.partition_map
(function
| #arraytoarray as c -> Either.left c
| #arraytobytes as c -> Either.right c
| #array_tobytes as c -> Either.right c
| #bytestobytes as c -> Either.right c) cc
in
List.fold_right
(fun c acc ->
acc >>= fun (l, r) ->
match c with
| #bytestobytes as c -> Ok (l, c :: r)
| `Bytes e -> Ok (`Bytes e :: l, r)
| #fixed_arraytobytes as c -> Ok (c :: l, r)
| `ShardingIndexed cfg ->
create repr cfg.codecs >>= fun codecs ->
create
{repr with shape = Array.append repr.shape [|2|]}
(cfg.index_codecs :> codec_chain) >>= fun index_codecs ->
(* coerse to a fixed_bytestobytes internal_chain list type *)
(* coerse to a fixed codec internal_chain list type *)
let b2b =
fst @@
List.partition_map
(function
| #fixed_bytestobytes as c -> Either.left c
| c -> Either.right c) index_codecs.b2b
in
let a2b =
List.hd @@
fst @@
List.partition_map
(function
| #fixed_arraytobytes as c -> Either.left c
| c -> Either.right c) [index_codecs.a2b]

Check warning on line 69 in lib/codecs/codecs.ml

View check run for this annotation

Codecov / codecov/patch

lib/codecs/codecs.ml#L69

Added line #L69 was not covered by tests
in
let cfg' : internal_shard_config =
{codecs
;chunk_shape = cfg.chunk_shape
;index_location = cfg.index_location
;index_codecs = {index_codecs with b2b}}
;index_codecs = {index_codecs with a2b; b2b}}
in Ok (`ShardingIndexed cfg' :: l, r)) rest (Ok ([], []))
>>= fun result ->
(match result with
Expand Down
Loading

0 comments on commit 9ec68ba

Please sign in to comment.