Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement {to/of}_yojson functions for Metadata types. #25

Merged
merged 1 commit into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions lib/codecs/array_to_bytes.ml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type error =
[ `Bytes_encode_error of string
| `Bytes_decode_error of string
| `Sharding_shape_mismatch of int array * int array * string
| Extensions.error
| Array_to_array.error
| Bytes_to_bytes.error ]

Expand Down Expand Up @@ -259,21 +260,21 @@ end = struct
(string, [> error]) result
= fun x t ->
let open Util in
let open Extensions in
let open Util.Result_syntax in
let shard_shape = Ndarray.shape x in
let cps = Array.map2 (/) shard_shape t.chunk_shape in
let idx_shp = Array.append cps [|2|] in
let shard_idx =
Ndarray.create Bigarray.Int64 idx_shp Int64.max_int in
let sg =
Extensions.RegularGrid.create t.chunk_shape in
let shard_idx = Ndarray.create Bigarray.Int64 idx_shp Int64.max_int in
RegularGrid.create ~array_shape:shard_shape t.chunk_shape
>>= fun grid ->
let slice =
Array.make
(Ndarray.num_dims x) (Owl_types.R []) in
let coords = Indexing.coords_of_slice slice shard_shape in
let tbl = Arraytbl.create @@ Array.length coords in
Ndarray.iteri (fun i y ->
let k, c = Extensions.RegularGrid.index_coord_pair sg coords.(i) in
let k, c = RegularGrid.index_coord_pair grid coords.(i) in
Arraytbl.add tbl k (c, y)) x;
let fill_value =
Arraytbl.to_seq_values tbl
Expand Down Expand Up @@ -378,7 +379,8 @@ end = struct
if Ndarray.for_all (Int64.equal Int64.max_int) shard_idx then
Ok (Ndarray.create repr.kind repr.shape repr.fill_value)
else
let sg = RegularGrid.create t.chunk_shape in
RegularGrid.create ~array_shape:repr.shape t.chunk_shape
>>= fun sg ->
let slice =
Array.make
(Array.length repr.shape)
Expand Down Expand Up @@ -416,10 +418,10 @@ end = struct
inner.kind (Array.of_list res) repr.shape

let rec chain_to_yojson chain =
[%to_yojson: Yojson.Safe.t list] @@
List.map ArrayToArray.to_yojson chain.a2a @
(ArrayToBytes.to_yojson chain.a2b) ::
List.map BytesToBytes.to_yojson chain.b2b
`List
(List.map ArrayToArray.to_yojson chain.a2a @
(ArrayToBytes.to_yojson chain.a2b) ::
List.map BytesToBytes.to_yojson chain.b2b)

and to_yojson t =
let codecs = chain_to_yojson t.codecs
Expand Down
9 changes: 5 additions & 4 deletions lib/codecs/array_to_bytes.mli
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ val pp_chain : Format.formatter -> chain -> unit
val show_chain : chain -> string

type error =
[ `Bytes_encode_error of string
| `Bytes_decode_error of string
| `Sharding_shape_mismatch of int array * int array * string
[ Extensions.error
| Array_to_array.error
| Bytes_to_bytes.error ]
| Bytes_to_bytes.error
| `Bytes_encode_error of string
| `Bytes_decode_error of string
| `Sharding_shape_mismatch of int array * int array * string ]

module ArrayToBytes : sig
val parse
Expand Down
10 changes: 5 additions & 5 deletions lib/codecs/codecs.ml
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ module Chain = struct
(fun c acc -> acc >>= ArrayToArray.decode c)
t.a2a (ArrayToBytes.decode y repr' t.a2b)

let equal x y =
let ( = ) x y =
x.a2a = y.a2a && x.a2b = y.a2b && x.b2b = y.b2b

let to_yojson t =
[%to_yojson: Yojson.Safe.t list] @@
List.map ArrayToArray.to_yojson t.a2a @
(ArrayToBytes.to_yojson t.a2b) ::
List.map BytesToBytes.to_yojson t.b2b
`List
(List.map ArrayToArray.to_yojson t.a2a @
(ArrayToBytes.to_yojson t.a2b) ::
List.map BytesToBytes.to_yojson t.b2b)

let of_yojson x =
let filter_partition f encoded =
Expand Down
2 changes: 1 addition & 1 deletion lib/codecs/codecs.mli
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ module Chain : sig
string ->
(('a, 'b) Ndarray.t, [> error]) result

val equal : t -> t -> bool
val ( = ) : t -> t -> bool

val of_yojson : Yojson.Safe.t -> (t, string) result

Expand Down
5 changes: 1 addition & 4 deletions lib/dune
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
(public_name zarr)
(libraries
yojson
ppx_deriving_yojson.runtime
ezgzip
owl
stdint
Expand All @@ -12,9 +11,7 @@
(:standard -O3))
(preprocess
(pps
ppx_deriving.eq
ppx_deriving.show
ppx_deriving_yojson))
ppx_deriving.show))
(instrumentation
(backend bisect_ppx)))

Expand Down
26 changes: 21 additions & 5 deletions lib/extensions.ml
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
type grid_info =
{msg : string
;chunk_shape : int array
;array_shape : int array}

type error =
[ `Grid of grid_info ]

module RegularGrid = struct
type t = int array

let chunk_shape t = t

let create chunk_shape = chunk_shape
let create ~array_shape chunk_shape =
match chunk_shape, array_shape with
| c, a when Array.(length c <> length a) ->
let msg = "grid chunk and array shape must have the same the length." in
Result.error @@ `Grid {msg; array_shape; chunk_shape}
| c, a when Util.(max c > max a) ->
let msg = "grid chunk dimension size must not be larger than array's." in
Result.error @@ `Grid {msg; array_shape; chunk_shape}
| c, _ -> Ok c

let ceildiv x y =
Float.(to_int @@ ceil (of_int x /. of_int y))
Expand All @@ -27,7 +43,7 @@ module RegularGrid = struct
|> Util.Indexing.cartesian_prod
|> List.map Array.of_list

let equal x y = x = y
let ( = ) x y = x = y

let to_yojson t =
let chunk_shape =
Expand All @@ -54,7 +70,7 @@ module RegularGrid = struct
"Regular grid chunk_shape must only contain positive integers."
in
Error msg) xs (Ok [])
>>| fun l' -> Array.of_list l'
>>| Array.of_list
| _ -> Error "Invalid Chunk grid name or configuration."
end

Expand Down Expand Up @@ -83,7 +99,7 @@ module ChunkKeyEncoding = struct
String.concat sep @@
Array.fold_right f index []

let equal x y =
let ( = ) x y =
x.name = y.name && x.sep = y.sep

let to_yojson {name; sep} =
Expand Down Expand Up @@ -128,7 +144,7 @@ module Datatype = struct
| Int
| Nativeint

let equal : t -> t -> bool = fun x y -> x = y
let ( = ) : t -> t -> bool = fun x y -> x = y

let of_kind : type a b. (a, b) Bigarray.kind -> t = function
| Bigarray.Char -> Char
Expand Down
16 changes: 12 additions & 4 deletions lib/extensions.mli
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
type grid_info =
{msg : string
;chunk_shape : int array
;array_shape : int array}

type error =
[ `Grid of grid_info ]

module RegularGrid : sig
type t
val create : int array -> t
val create : array_shape:int array -> int array -> (t, [> error]) result
val chunk_shape : t -> int array
val grid_shape : t -> int array -> int array
val indices : t -> int array -> int array list
val index_coord_pair : t -> int array -> int array * int array
val equal : t -> t -> bool
val ( = ) : t -> t -> bool
val of_yojson : Yojson.Safe.t -> (t, string) result
val to_yojson : t -> Yojson.Safe.t
end
Expand All @@ -14,7 +22,7 @@ module ChunkKeyEncoding : sig
type t
val create : [< `Slash | `Dot > `Slash ] -> t
val encode : t -> int array -> string
val equal : t -> t -> bool
val ( = ) : t -> t -> bool
val of_yojson : Yojson.Safe.t -> (t, string) result
val to_yojson : t -> Yojson.Safe.t
end
Expand All @@ -38,7 +46,7 @@ module Datatype : sig
| Nativeint
(** A type for the supported data types of a Zarr array. *)

val equal : t -> t -> bool
val ( = ) : t -> t -> bool
val of_kind : ('a, 'b) Bigarray.kind -> t
val of_yojson : Yojson.Safe.t -> (t, string) result
val to_yojson : t -> Yojson.Safe.t
Expand Down
Loading
Loading