Skip to content

Commit

Permalink
Fixing zero-width tensor for in memory loading. (#510)
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil authored Jul 31, 2024
1 parent c00471e commit 8d21261
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
7 changes: 6 additions & 1 deletion bindings/python/py_src/safetensors/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,12 @@ def _view2torch(safeview) -> Dict[str, torch.Tensor]:
result = {}
for k, v in safeview:
dtype = _getdtype(v["dtype"])
arr = torch.frombuffer(v["data"], dtype=dtype).reshape(v["shape"])
if len(v["data"]) == 0:
# Workaround because frombuffer doesn't accept zero-size tensors
assert any(x == 0 for x in v["shape"])
arr = torch.empty(v["shape"], dtype=dtype)
else:
arr = torch.frombuffer(v["data"], dtype=dtype).reshape(v["shape"])
if sys.byteorder == "big":
arr = torch.from_numpy(arr.numpy().byteswap(inplace=False))
result[k] = arr
Expand Down
2 changes: 2 additions & 0 deletions bindings/python/tests/test_pt_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def test_zero_sized(self):
save_file(data, local)
reloaded = load_file(local)
self.assertTrue(torch.equal(data["test"], reloaded["test"]))
reloaded = load(open(local, "rb").read())
self.assertTrue(torch.equal(data["test"], reloaded["test"]))

def test_multiple_zero_sized(self):
data = {
Expand Down

0 comments on commit 8d21261

Please sign in to comment.