Skip to content

Commit

Permalink
Fix benchmark not to assume the input is a list of tensor. (#7280)
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed Jun 15, 2024
1 parent 28f9887 commit afd31c3
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions benchmarks/benchmark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,12 @@ def prepare_for_experiment(self, dynamo_compilation_opts):
self.model_iter_fn = torch.compile(self.model_iter_fn, **compilation_opts)

if keep_model_data_on_cuda:
assert self.example_inputs[0].device.type.lower(
) == 'cuda', 'When keep_model_data_on_cuda is set, the input data should remain on the CUDA device.'

def assert_func(t):
assert t.device.type.lower(
) == 'cuda', 'When keep_model_data_on_cuda is set, the input data should remain on the CUDA device.'

pytree.tree_map_only(torch.Tensor, assert_func, self.example_inputs)

def pick_grad(self):
if self.benchmark_experiment.test == "eval":
Expand Down

0 comments on commit afd31c3

Please sign in to comment.