Skip to content

Commit

Permalink
Forcing empty list/empty tuple behavior.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Aug 5, 2024
1 parent 77dfa5c commit 2db5741
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 6 deletions.
28 changes: 22 additions & 6 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ use pyo3::exceptions::{PyException, PyFileNotFoundError};
use pyo3::prelude::*;
use pyo3::sync::GILOnceCell;
use pyo3::types::IntoPyDict;
use pyo3::types::PySlice;
use pyo3::types::{PyByteArray, PyBytes, PyDict, PyList};
use pyo3::types::{PyByteArray, PyBytes, PyDict, PyList, PySlice};
use pyo3::Bound as PyBound;
use pyo3::{intern, PyErr};
use safetensors::slice::TensorIndexer;
Expand Down Expand Up @@ -768,13 +767,13 @@ struct PySafeSlice {
storage: Arc<Storage>,
}

#[derive(FromPyObject)]
#[derive(Debug, FromPyObject)]
enum SliceIndex<'a> {
Slice(PyBound<'a, PySlice>),
Index(i32),
}

#[derive(FromPyObject)]
#[derive(Debug, FromPyObject)]
enum Slice<'a> {
Slice(SliceIndex<'a>),
Slices(Vec<SliceIndex<'a>>),
Expand Down Expand Up @@ -842,10 +841,27 @@ impl PySafeSlice {
pub fn __getitem__(&self, slices: &PyBound<'_, PyAny>) -> PyResult<PyObject> {
match &self.storage.as_ref() {
Storage::Mmap(mmap) => {
let slices: Slice = slices.extract()?;
let pyslices = slices;
let slices: Slice = pyslices.extract()?;
let is_list = pyslices.is_instance_of::<PyList>();
let slices: Vec<SliceIndex> = match slices {
Slice::Slice(slice) => vec![slice],
Slice::Slices(slices) => slices,
Slice::Slices(slices) => {
if slices.is_empty() && is_list {
vec![SliceIndex::Slice(PySlice::new_bound(
pyslices.py(),
0,
0,
0,
))]
} else if is_list {
return Err(SafetensorError::new_err(
"Non empty lists are not implemented",
));
} else {
slices
}
}
};
let data = &mmap[self.info.data_offsets.0 + self.offset
..self.info.data_offsets.1 + self.offset];
Expand Down
20 changes: 20 additions & 0 deletions bindings/python/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ def test_torch_slice(self):
self.assertEqual(list(tensor.shape), [10, 5])
torch.testing.assert_close(tensor, A)

tensor = slice_[tuple()]
self.assertEqual(list(tensor.shape), [10, 5])
torch.testing.assert_close(tensor, A)

tensor = slice_[:2]
self.assertEqual(list(tensor.shape), [2, 5])
torch.testing.assert_close(tensor, A[:2])
Expand All @@ -270,6 +274,10 @@ def test_torch_slice(self):
self.assertEqual(list(tensor.shape), [8])
torch.testing.assert_close(tensor, A[2:, -1])

tensor = slice_[list()]
self.assertEqual(list(tensor.shape), [0, 5])
torch.testing.assert_close(tensor, A[list()])

def test_numpy_slice(self):
A = np.random.rand(10, 5)
tensors = {
Expand All @@ -284,6 +292,10 @@ def test_numpy_slice(self):
self.assertEqual(list(tensor.shape), [10, 5])
self.assertTrue(np.allclose(tensor, A))

tensor = slice_[tuple()]
self.assertEqual(list(tensor.shape), [10, 5])
self.assertTrue(np.allclose(tensor, A))

tensor = slice_[:2]
self.assertEqual(list(tensor.shape), [2, 5])
self.assertTrue(np.allclose(tensor, A[:2]))
Expand Down Expand Up @@ -312,10 +324,18 @@ def test_numpy_slice(self):
self.assertEqual(list(tensor.shape), [8])
self.assertTrue(np.allclose(tensor, A[2:, -5]))

tensor = slice_[list()]
self.assertEqual(list(tensor.shape), [0, 5])
self.assertTrue(np.allclose(tensor, A[list()]))

with self.assertRaises(SafetensorError) as cm:
tensor = slice_[2:, -6]
self.assertEqual(str(cm.exception), "Invalid index -6 for dimension 1 of size 5")

with self.assertRaises(SafetensorError) as cm:
tensor = slice_[[0, 1]]
self.assertEqual(str(cm.exception), "Non empty lists are not implemented")

with self.assertRaises(SafetensorError) as cm:
tensor = slice_[2:, 20]
self.assertEqual(
Expand Down

0 comments on commit 2db5741

Please sign in to comment.