diff --git a/benchmarks/benchmark_model.py b/benchmarks/benchmark_model.py index 16e003d6693..07fa4e2c158 100644 --- a/benchmarks/benchmark_model.py +++ b/benchmarks/benchmark_model.py @@ -157,8 +157,12 @@ def prepare_for_experiment(self, dynamo_compilation_opts): self.model_iter_fn = torch.compile(self.model_iter_fn, **compilation_opts) if keep_model_data_on_cuda: - assert self.example_inputs[0].device.type.lower( - ) == 'cuda', 'When keep_model_data_on_cuda is set, the input data should remain on the CUDA device.' + + def assert_func(t): + assert t.device.type.lower( + ) == 'cuda', 'When keep_model_data_on_cuda is set, the input data should remain on the CUDA device.' + + pytree.tree_map_only(torch.Tensor, assert_func, self.example_inputs) def pick_grad(self): if self.benchmark_experiment.test == "eval":