From f64633e59c9f060a160657a74e21c0229fd1fa9b Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Thu, 18 Jul 2024 10:59:39 -0700 Subject: [PATCH] Update jaten.py --- experimental/torch_xla2/torch_xla2/ops/jaten.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 87e39e41d4f..7d40eb48ee7 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -284,6 +284,8 @@ def _aten_div(x, y, rounding_mode=""): if rounding_mode == "floor": res = jnp.floor_divide(x, y) + if _is_int(x) and _is_int(y): + res_dtype = jnp.dtype('int64') else: res = x / y if rounding_mode == "trunc":