Skip to content

Commit

Permalink
Add support for the bool data type.
Browse files Browse the repository at this point in the history
  • Loading branch information
zoj613 committed Sep 17, 2024
1 parent a016aa9 commit d3cb229
Show file tree
Hide file tree
Showing 11 changed files with 40 additions and 5 deletions.
12 changes: 7 additions & 5 deletions zarr/src/codecs/array_to_bytes.ml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ module BytesCodec = struct
let open (val endian_module t) in
let buf = Bytes.create @@ Ndarray.byte_size x in
match Ndarray.data_type x with
| Char-> Ndarray.iteri (set_char buf) x; Bytes.unsafe_to_string buf
| Uint8-> Ndarray.iteri (set_uint8 buf) x; Bytes.unsafe_to_string buf
| Int8-> Ndarray.iteri (set_int8 buf) x; Bytes.unsafe_to_string buf
| Int16-> Ndarray.iteri (set_int16 buf) x; Bytes.unsafe_to_string buf
| Uint16-> Ndarray.iteri (set_uint16 buf) x; Bytes.unsafe_to_string buf
| Char -> Ndarray.iteri (set_char buf) x; Bytes.unsafe_to_string buf
| Bool -> Ndarray.iteri (set_bool buf) x; Bytes.unsafe_to_string buf
| Uint8 -> Ndarray.iteri (set_uint8 buf) x; Bytes.unsafe_to_string buf
| Int8 -> Ndarray.iteri (set_int8 buf) x; Bytes.unsafe_to_string buf
| Int16 -> Ndarray.iteri (set_int16 buf) x; Bytes.unsafe_to_string buf
| Uint16 -> Ndarray.iteri (set_uint16 buf) x; Bytes.unsafe_to_string buf
| Int32 -> Ndarray.iteri (set_int32 buf) x; Bytes.unsafe_to_string buf
| Int64 -> Ndarray.iteri (set_int64 buf) x; Bytes.unsafe_to_string buf
| Uint64 -> Ndarray.iteri (set_uint64 buf) x; Bytes.unsafe_to_string buf
Expand All @@ -41,6 +42,7 @@ module BytesCodec = struct
let buf = Bytes.unsafe_of_string str in
match k, Ndarray.dtype_size k with
| Char, _ -> Ndarray.init k shp @@ get_char buf
| Bool, _ -> Ndarray.init k shp @@ get_bool buf
| Uint8, _ -> Ndarray.init k shp @@ get_int8 buf
| Int8, _ -> Ndarray.init k shp @@ get_uint8 buf
| Int16, s -> Ndarray.init k shp @@ fun i -> get_int16 buf (i*s)
Expand Down
6 changes: 6 additions & 0 deletions zarr/src/codecs/ebuffer.ml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module type S = sig
val set_char : bytes -> int -> char -> unit
val set_bool : bytes -> int -> bool -> unit
val set_int8 : bytes -> int -> int -> unit
val set_uint8 : bytes -> int -> int -> unit
val set_int16 : bytes -> int -> int -> unit
Expand All @@ -15,6 +16,7 @@ module type S = sig
val set_nativeint : bytes -> int -> nativeint -> unit

val get_char : bytes -> int -> char
val get_bool : bytes -> int -> bool
val get_int8 : bytes -> int -> int
val get_uint8 : bytes -> int -> int
val get_int16 : bytes -> int -> int
Expand All @@ -34,6 +36,7 @@ module Little = struct
let set_int8 = Bytes.set_int8
let set_uint8 = Bytes.set_uint8
let set_char buf i v = Char.code v |> set_uint8 buf i
let set_bool buf i v = Bool.to_int v |> set_uint8 buf i
let set_int16 buf i v = Bytes.set_int16_le buf (2*i) v
let set_uint16 buf i v = Bytes.set_uint16_le buf (2*i) v
let set_int32 buf i v = Bytes.set_int32_le buf (4*i) v
Expand All @@ -53,6 +56,7 @@ module Little = struct
let get_int8 = Bytes.get_int8
let get_uint8 = Bytes.get_uint8
let get_char buf i = get_uint8 buf i |> Char.chr
let get_bool buf i = match get_uint8 buf i with | 0 -> false | _ -> true
let get_int16 = Bytes.get_int16_le
let get_uint16 = Bytes.get_uint16_le
let get_int32 = Bytes.get_int32_le
Expand All @@ -74,6 +78,7 @@ module Big = struct
let set_int8 = Bytes.set_int8
let set_uint8 = Bytes.set_uint8
let set_char buf i v = Char.code v |> set_uint8 buf i
let set_bool buf i v = Bool.to_int v |> set_uint8 buf i
let set_int16 buf i v = Bytes.set_int16_be buf (i * 2) v
let set_uint16 buf i v = Bytes.set_uint16_be buf (i * 2) v
let set_int32 buf i v = Bytes.set_int32_be buf (i * 4) v
Expand All @@ -93,6 +98,7 @@ module Big = struct
let get_int8 = Bytes.get_int8
let get_uint8 = Bytes.get_uint8
let get_char buf i = get_uint8 buf i |> Char.chr
let get_bool buf i = match get_uint8 buf i with | 0 -> false | _ -> true
let get_int16 = Bytes.get_int16_be
let get_uint16 = Bytes.get_uint16_be
let get_int32 = Bytes.get_int32_be
Expand Down
2 changes: 2 additions & 0 deletions zarr/src/codecs/ebuffer.mli
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module type S = sig
val set_char : bytes -> int -> char -> unit
val set_bool : bytes -> int -> bool -> unit
val set_int8 : bytes -> int -> int -> unit
val set_uint8 : bytes -> int -> int -> unit
val set_int16 : bytes -> int -> int -> unit
Expand All @@ -15,6 +16,7 @@ module type S = sig
val set_nativeint : bytes -> int -> nativeint -> unit

val get_char : bytes -> int -> char
val get_bool : bytes -> int -> bool
val get_int8 : bytes -> int -> int
val get_uint8 : bytes -> int -> int
val get_int16 : bytes -> int -> int
Expand Down
4 changes: 4 additions & 0 deletions zarr/src/extensions.ml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ end
module Datatype = struct
type t =
| Char
| Bool
| Int8
| Uint8
| Int16
Expand All @@ -123,6 +124,7 @@ module Datatype = struct

let of_kind : type a. a Ndarray.dtype -> t = function
| Ndarray.Char -> Char
| Ndarray.Bool -> Bool
| Ndarray.Int8 -> Int8
| Ndarray.Uint8 -> Uint8
| Ndarray.Int16 -> Int16
Expand All @@ -139,6 +141,7 @@ module Datatype = struct

let to_yojson = function
| Char -> `String "char"
| Bool -> `String "bool"
| Int8 -> `String "int8"
| Uint8 -> `String "uint8"
| Int16 -> `String "int16"
Expand All @@ -155,6 +158,7 @@ module Datatype = struct

let of_yojson = function
| `String "char" -> Ok Char
| `String "bool" -> Ok Bool
| `String "int8" -> Ok Int8
| `String "uint8" -> Ok Uint8
| `String "int16" -> Ok Int16
Expand Down
1 change: 1 addition & 0 deletions zarr/src/extensions.mli
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ module Datatype : sig

type t =
| Char
| Bool
| Int8
| Uint8
| Int16
Expand Down
3 changes: 3 additions & 0 deletions zarr/src/metadata.ml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ module FillValue = struct
= fun kind a ->
match kind with
| Ndarray.Char -> Char a
| Ndarray.Bool -> Bool a
| Ndarray.Int8 -> Int (Stdint.Uint64.of_int a)
| Ndarray.Uint8 -> Int (Stdint.Uint64.of_int a)
| Ndarray.Int16 -> Int (Stdint.Uint64.of_int a)
Expand Down Expand Up @@ -299,6 +300,7 @@ module Array = struct
= fun t kind ->
match kind, t.data_type with
| Ndarray.Char, Datatype.Char
| Ndarray.Bool, Datatype.Bool
| Ndarray.Int8, Datatype.Int8
| Ndarray.Uint8, Datatype.Uint8
| Ndarray.Int16, Datatype.Int16
Expand All @@ -319,6 +321,7 @@ module Array = struct
= fun t kind ->
match kind, t.fill_value with
| Ndarray.Char, FillValue.Char c -> c
| Ndarray.Bool, FillValue.Bool b -> b
| Ndarray.Int8, FillValue.Int i -> Stdint.Uint64.to_int i
| Ndarray.Uint8, FillValue.Int i -> Stdint.Uint64.to_int i
| Ndarray.Int16, FillValue.Int i -> Stdint.Uint64.to_int i
Expand Down
2 changes: 2 additions & 0 deletions zarr/src/ndarray.ml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
type _ dtype =
| Char : char dtype
| Bool : bool dtype
| Int8 : int dtype
| Uint8 : int dtype
| Int16 : int dtype
Expand All @@ -22,6 +23,7 @@ type 'a t =

let dtype_size : type a. a dtype -> int = function
| Char -> 1
| Bool -> 1
| Int8 -> 1
| Uint8 -> 1
| Int16 -> 2
Expand Down
1 change: 1 addition & 0 deletions zarr/src/ndarray.mli
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
(** Supported data types for a Zarr array. *)
type _ dtype =
| Char : char dtype
| Bool : bool dtype
| Int8 : int dtype
| Uint8 : int dtype
| Int16 : int dtype
Expand Down
4 changes: 4 additions & 0 deletions zarr/test/test_codecs.ml
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,10 @@ let tests = [
(* test encoding/decoding of Char *)
bytes_encode_decode {shape; kind = Ndarray.Char} '?';

(* test encoding/decoding of Bool *)
bytes_encode_decode {shape; kind = Ndarray.Bool} false;
bytes_encode_decode {shape; kind = Ndarray.Bool} true;

(* test encoding/decoding of int8 *)
bytes_encode_decode {shape; kind = Ndarray.Int8} 0;

Expand Down
8 changes: 8 additions & 0 deletions zarr/test/test_metadata.ml
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,14 @@ let array = [
let chunks = [|5; 2; 6|] in
let dimension_names = [Some "x"; None; Some "z"] in

(* tests using bool data type. *)
test_array_metadata
~shape
~chunks
Ndarray.Bool
Ndarray.Float32
false;

(* tests using char data type. *)
test_array_metadata
~shape
Expand Down
2 changes: 2 additions & 0 deletions zarr/test/test_ndarray.ml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ let tests = [
let shape = [|2; 5; 3|] in
run_test {shape; kind = M.Char} '?' 1;

run_test {shape; kind = M.Bool} false 1;

run_test {shape; kind = M.Int8} 0 1;

run_test {shape; kind = M.Uint8} 0 1;
Expand Down

0 comments on commit d3cb229

Please sign in to comment.