Skip to content

Commit

Permalink
[BugFix] Fix parsing integer batch size in AOT
Browse files Browse the repository at this point in the history
ghstack-source-id: 73e7dd429770e1c383b3b2a1c28dbbf661d65f07
Pull Request resolved: #1004
  • Loading branch information
vmoens committed Sep 20, 2024
1 parent 85b6b81 commit 9f6b899
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
10 changes: 5 additions & 5 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -2061,7 +2061,7 @@ def _parse_batch_size(
source: T | dict | None,
batch_size: Sequence[int] | torch.Size | int | None = None,
) -> torch.Size:
ERR = "batch size was not specified when creating the TensorDict instance and it could not be retrieved from source."
ERR = "batch size {} was not specified when creating the TensorDict instance and it could not be retrieved from source."

if is_dynamo_compiling():
if isinstance(batch_size, torch.Size):
Expand All @@ -2072,22 +2072,22 @@ def _parse_batch_size(
return torch.Size(tuple(batch_size))
if batch_size is None:
return torch.Size([])
elif isinstance(batch_size, Number):
elif isinstance(batch_size, (Number, torch.SymInt)):
return torch.Size([batch_size])
elif isinstance(source, TensorDictBase):
return source.batch_size
raise ValueError()
raise ValueError(ERR.format(batch_size))

try:
return torch.Size(batch_size)
except Exception:
if batch_size is None:
return torch.Size([])
elif isinstance(batch_size, Number):
elif isinstance(batch_size, (Number, torch.SymInt)):
return torch.Size([batch_size])
elif isinstance(source, TensorDictBase):
return source.batch_size
raise ValueError(ERR)
raise ValueError(ERR.format(batch_size))

@property
def batch_dims(self) -> int:
Expand Down
37 changes: 33 additions & 4 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,26 +774,55 @@ def call(x, td):


@pytest.mark.skipif(not _v2_5, reason="Requires PT>=2.5")
@pytest.mark.parametrize("strict", [False, True])
class TestExport:
def test_export_module(self):
def test_export_module(self, strict):
torch._dynamo.reset_code_caches()
tdm = Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"])
x = torch.randn(3)
y = torch.randn(3)
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}, strict=strict)
assert (out.module()(x=x, y=y) == tdm(x=x, y=y)).all()

def test_export_seq(self):
def test_export_seq(self, strict):
torch._dynamo.reset_code_caches()
tdm = Seq(
Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]),
Mod(lambda z, x: z + x, in_keys=["z", "x"], out_keys=["out"]),
)
x = torch.randn(3)
y = torch.randn(3)
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}, strict=strict)
torch.testing.assert_close(out.module()(x=x, y=y), tdm(x=x, y=y))

def test_td_output(self, strict):
class Test(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
return TensorDict(
{
"x": x,
"y": y,
},
batch_size=x.shape[0],
)

test = Test()
x, y = torch.zeros(2, 100), torch.zeros(2, 100)
result = torch.export.export(
test,
args=(x, y),
strict=False,
dynamic_shapes={
"x": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")},
"y": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")},
},
)
export_mod = result.module()
x_new, y_new = torch.zeros(5, 100), torch.zeros(5, 100)
export_test = export_mod(x_new, y_new)
eager_test = test(x_new, y_new)
assert (export_test == eager_test).all()


@pytest.mark.skipif(not _has_onnx, reason="ONNX is not available")
class TestONNXExport:
Expand Down

0 comments on commit 9f6b899

Please sign in to comment.