Skip to content

Commit

Permalink
Implementations for u8->u64 and i8->i64
Browse files Browse the repository at this point in the history
  • Loading branch information
timwedde committed Dec 4, 2023
1 parent beee7a1 commit 64891fa
Showing 1 changed file with 251 additions and 0 deletions.
251 changes: 251 additions & 0 deletions dfdx-core/src/tensor/numpy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,166 @@ impl NumpyDtype for f64 {
}
}

impl NumpyDtype for u8 {
const NUMPY_DTYPE_STR: &'static str = "u1";
fn read_endian<R: Read>(r: &mut R, endian: Endian) -> io::Result<Self> {
let mut bytes = [0; 1];
r.read_exact(&mut bytes)?;
Ok(match endian {
Endian::Big => Self::from_be_bytes(bytes),
Endian::Little => Self::from_le_bytes(bytes),
Endian::Native => Self::from_ne_bytes(bytes),
})
}
fn write_endian<W: Write>(&self, w: &mut W, endian: Endian) -> io::Result<()> {
match endian {
Endian::Big => w.write_all(&self.to_be_bytes()),
Endian::Little => w.write_all(&self.to_le_bytes()),
Endian::Native => w.write_all(&self.to_ne_bytes()),
}
}
}

impl NumpyDtype for u16 {
const NUMPY_DTYPE_STR: &'static str = "u2";
fn read_endian<R: Read>(r: &mut R, endian: Endian) -> io::Result<Self> {
let mut bytes = [0; 2];
r.read_exact(&mut bytes)?;
Ok(match endian {
Endian::Big => Self::from_be_bytes(bytes),
Endian::Little => Self::from_le_bytes(bytes),
Endian::Native => Self::from_ne_bytes(bytes),
})
}
fn write_endian<W: Write>(&self, w: &mut W, endian: Endian) -> io::Result<()> {
match endian {
Endian::Big => w.write_all(&self.to_be_bytes()),
Endian::Little => w.write_all(&self.to_le_bytes()),
Endian::Native => w.write_all(&self.to_ne_bytes()),
}
}
}

impl NumpyDtype for u32 {
const NUMPY_DTYPE_STR: &'static str = "u4";
fn read_endian<R: Read>(r: &mut R, endian: Endian) -> io::Result<Self> {
let mut bytes = [0; 4];
r.read_exact(&mut bytes)?;
Ok(match endian {
Endian::Big => Self::from_be_bytes(bytes),
Endian::Little => Self::from_le_bytes(bytes),
Endian::Native => Self::from_ne_bytes(bytes),
})
}
fn write_endian<W: Write>(&self, w: &mut W, endian: Endian) -> io::Result<()> {
match endian {
Endian::Big => w.write_all(&self.to_be_bytes()),
Endian::Little => w.write_all(&self.to_le_bytes()),
Endian::Native => w.write_all(&self.to_ne_bytes()),
}
}
}

impl NumpyDtype for u64 {
const NUMPY_DTYPE_STR: &'static str = "u8";
fn read_endian<R: Read>(r: &mut R, endian: Endian) -> io::Result<Self> {
let mut bytes = [0; 8];
r.read_exact(&mut bytes)?;
Ok(match endian {
Endian::Big => Self::from_be_bytes(bytes),
Endian::Little => Self::from_le_bytes(bytes),
Endian::Native => Self::from_ne_bytes(bytes),
})
}
fn write_endian<W: Write>(&self, w: &mut W, endian: Endian) -> io::Result<()> {
match endian {
Endian::Big => w.write_all(&self.to_be_bytes()),
Endian::Little => w.write_all(&self.to_le_bytes()),
Endian::Native => w.write_all(&self.to_ne_bytes()),
}
}
}

impl NumpyDtype for i8 {
const NUMPY_DTYPE_STR: &'static str = "i1";
fn read_endian<R: Read>(r: &mut R, endian: Endian) -> io::Result<Self> {
let mut bytes = [0; 1];
r.read_exact(&mut bytes)?;
Ok(match endian {
Endian::Big => Self::from_be_bytes(bytes),
Endian::Little => Self::from_le_bytes(bytes),
Endian::Native => Self::from_ne_bytes(bytes),
})
}
fn write_endian<W: Write>(&self, w: &mut W, endian: Endian) -> io::Result<()> {
match endian {
Endian::Big => w.write_all(&self.to_be_bytes()),
Endian::Little => w.write_all(&self.to_le_bytes()),
Endian::Native => w.write_all(&self.to_ne_bytes()),
}
}
}

impl NumpyDtype for i16 {
const NUMPY_DTYPE_STR: &'static str = "i2";
fn read_endian<R: Read>(r: &mut R, endian: Endian) -> io::Result<Self> {
let mut bytes = [0; 2];
r.read_exact(&mut bytes)?;
Ok(match endian {
Endian::Big => Self::from_be_bytes(bytes),
Endian::Little => Self::from_le_bytes(bytes),
Endian::Native => Self::from_ne_bytes(bytes),
})
}
fn write_endian<W: Write>(&self, w: &mut W, endian: Endian) -> io::Result<()> {
match endian {
Endian::Big => w.write_all(&self.to_be_bytes()),
Endian::Little => w.write_all(&self.to_le_bytes()),
Endian::Native => w.write_all(&self.to_ne_bytes()),
}
}
}

impl NumpyDtype for i32 {
const NUMPY_DTYPE_STR: &'static str = "i4";
fn read_endian<R: Read>(r: &mut R, endian: Endian) -> io::Result<Self> {
let mut bytes = [0; 4];
r.read_exact(&mut bytes)?;
Ok(match endian {
Endian::Big => Self::from_be_bytes(bytes),
Endian::Little => Self::from_le_bytes(bytes),
Endian::Native => Self::from_ne_bytes(bytes),
})
}
fn write_endian<W: Write>(&self, w: &mut W, endian: Endian) -> io::Result<()> {
match endian {
Endian::Big => w.write_all(&self.to_be_bytes()),
Endian::Little => w.write_all(&self.to_le_bytes()),
Endian::Native => w.write_all(&self.to_ne_bytes()),
}
}
}

impl NumpyDtype for i64 {
const NUMPY_DTYPE_STR: &'static str = "i8";
fn read_endian<R: Read>(r: &mut R, endian: Endian) -> io::Result<Self> {
let mut bytes = [0; 8];
r.read_exact(&mut bytes)?;
Ok(match endian {
Endian::Big => Self::from_be_bytes(bytes),
Endian::Little => Self::from_le_bytes(bytes),
Endian::Native => Self::from_ne_bytes(bytes),
})
}
fn write_endian<W: Write>(&self, w: &mut W, endian: Endian) -> io::Result<()> {
match endian {
Endian::Big => w.write_all(&self.to_be_bytes()),
Endian::Little => w.write_all(&self.to_le_bytes()),
Endian::Native => w.write_all(&self.to_ne_bytes()),
}
}
}

#[derive(Debug)]
pub enum NpyError {
/// Magic number did not match the expected value.
Expand Down Expand Up @@ -560,4 +720,95 @@ mod tests {
.load_from_npy(file.path())
.expect_err("");
}

#[test]
fn test_0d_u8_save() {
let dev: TestDevice = Default::default();

let x = dev.tensor(0u8);

let file = NamedTempFile::new().expect("failed to create tempfile");

x.save_to_npy(file.path()).expect("Saving failed");

let mut f = File::open(file.path()).expect("No file found");

let mut found = Vec::new();
f.read_to_end(&mut found).expect("Reading failed");

assert_eq!(
&found,
&[
147, 78, 85, 77, 80, 89, 1, 0, 64, 0, 123, 39, 100, 101, 115, 99, 114, 39, 58, 32,
39, 60, 117, 49, 39, 44, 32, 39, 102, 111, 114, 116, 114, 97, 110, 95, 111, 114,
100, 101, 114, 39, 58, 32, 70, 97, 108, 115, 101, 44, 32, 39, 115, 104, 97, 112,
101, 39, 58, 32, 40, 41, 44, 32, 125, 32, 32, 32, 32, 32, 32, 32, 32, 10, 0
]
);
}

#[test]
fn test_0d_u8_load() {
let dev: TestDevice = Default::default();
let x = dev.tensor(2u8);

let file = NamedTempFile::new().expect("failed to create tempfile");

x.save_to_npy(file.path()).expect("Saving failed");

let mut v = dev.tensor(0u8);
v.load_from_npy(file.path()).expect("Loading failed");
assert_eq!(v.array(), x.array());

dev.tensor(0u16).load_from_npy(file.path()).expect_err("");
dev.tensor([0u8; 1])
.load_from_npy(file.path())
.expect_err("");
}

#[test]
fn test_0d_i8_save() {
let dev: TestDevice = Default::default();

let x = dev.tensor(0i8);

let file = NamedTempFile::new().expect("failed to create tempfile");

x.save_to_npy(file.path()).expect("Saving failed");
x.save_to_npy("out.npy").expect("Saving failed");

let mut f = File::open(file.path()).expect("No file found");

let mut found = Vec::new();
f.read_to_end(&mut found).expect("Reading failed");

assert_eq!(
&found,
&[
147, 78, 85, 77, 80, 89, 1, 0, 64, 0, 123, 39, 100, 101, 115, 99, 114, 39, 58, 32,
39, 60, 105, 49, 39, 44, 32, 39, 102, 111, 114, 116, 114, 97, 110, 95, 111, 114,
100, 101, 114, 39, 58, 32, 70, 97, 108, 115, 101, 44, 32, 39, 115, 104, 97, 112,
101, 39, 58, 32, 40, 41, 44, 32, 125, 32, 32, 32, 32, 32, 32, 32, 32, 10, 0
]
);
}

#[test]
fn test_0d_i8_load() {
let dev: TestDevice = Default::default();
let x = dev.tensor(2i8);

let file = NamedTempFile::new().expect("failed to create tempfile");

x.save_to_npy(file.path()).expect("Saving failed");

let mut v = dev.tensor(0i8);
v.load_from_npy(file.path()).expect("Loading failed");
assert_eq!(v.array(), x.array());

dev.tensor(0i16).load_from_npy(file.path()).expect_err("");
dev.tensor([0i8; 1])
.load_from_npy(file.path())
.expect_err("");
}
}

0 comments on commit 64891fa

Please sign in to comment.