Skip to content

Commit

Permalink
Add torch_xla2 export_program_to_stablehlo API with unbounded dynam…
Browse files Browse the repository at this point in the history
…ism support (#7093)
  • Loading branch information
GleasonK authored May 22, 2024
1 parent 6023855 commit baf08ae
Show file tree
Hide file tree
Showing 7 changed files with 1,007 additions and 31 deletions.
8 changes: 8 additions & 0 deletions experimental/torch_xla2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ the instructions below from scratch (fresh venv / conda environment.)

### 1. Installing `torch_xla2`

The following instructions assume you are in the `torch_xla2` directory:

```
$ git clone https://github.com/pytorch/xla.git
$ cd xla/experimental/torch_xla2
```


#### 1.0 (recommended) Make a virtualenv / conda env

If you are using VSCode, then [you can create a new environment from
Expand Down
2 changes: 1 addition & 1 deletion experimental/torch_xla2/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ name = "torch_xla2"
dependencies = [
"absl-py",
"immutabledict",
"jax>=0.4.24",
"jax[cpu]>=0.4.24",
"pytest",
"tensorflow-cpu",
# Developers should install `dev-requirements.txt` first
Expand Down
100 changes: 88 additions & 12 deletions experimental/torch_xla2/test/test_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,33 +34,109 @@ def setUp(self):

def test_interpolate(self):

# Check Accuracy
arg = (torch.randn(3, 3, 200, 200),)
model = Interpolate()

ans = model(*arg)

with torch.no_grad():
exported = torch.export.export(model, arg)
weights, func = torch_xla2.export.exported_program_to_jax(exported)
argj = tensor.t2j(arg[0])
ans2 = jax.jit(func)(weights, (argj,))[0]
ans2 = tensor.j2t(ans2)
self.assertTrue(torch.allclose(ans, ans2, atol=1e-3))
weights, func = torch_xla2.export.exported_program_to_jax(exported)
argj = tensor.t2j(arg[0])
ans2 = jax.jit(func)(weights, (argj,))[0]
ans2 = tensor.j2t(ans2)
self.assertTrue(torch.allclose(ans, ans2, atol=1e-3))

# Convert to StableHLO
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
module_str = str(stablehlo.mlir_module())
self.assertIn("func.func public @main", module_str)
self.assertIn("func.func private @clip(%arg0: tensor<500xf32>", module_str)
self.assertIn("stablehlo.minimum", module_str)

def test_constant(self):

# Check Accuracy
arg = (torch.randn(10, 10),)
model = TensorConstant()

ans = model(*arg)

with torch.no_grad():
exported = torch.export.export(model, arg)
weights, func = torch_xla2.export.exported_program_to_jax(exported)
argj = tensor.t2j(arg[0])
ans2 = jax.jit(func)(weights, (argj,))[0]
ans2 = tensor.j2t(ans2)
self.assertTrue(torch.allclose(ans, ans2, atol=1e-5))

weights, func = torch_xla2.export.exported_program_to_jax(exported)
argj = tensor.t2j(arg[0])
ans2 = jax.jit(func)(weights, (argj,))[0]
ans2 = tensor.j2t(ans2)
self.assertTrue(torch.allclose(ans, ans2, atol=1e-5))

# Convert to StableHLO
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
module_str = str(stablehlo.mlir_module())
self.assertIn("func.func public @main", module_str)
self.assertIn("stablehlo.divide", module_str)

def test_interpolate_dynamic(self):
# Export with dynamic dimension constraints on both min and max
arg = (torch.randn(3, 3, 200, 200),)
model = Interpolate()
ans = model(*arg)
dynamic_shapes = ({0: torch.export.Dim("b", min=3, max=10)},)

with torch.no_grad():
exported = torch.export.export(model, arg, dynamic_shapes=dynamic_shapes)
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
module_str = str(stablehlo.mlir_module())

# Look for dynamic shape artifacts
self.assertIn("func.func public @main(%arg0: tensor<?x3x200x200xf32>", module_str)
self.assertIn("stablehlo.dynamic_broadcast_in_dim", module_str)
self.assertIn("stablehlo.dynamic_gather", module_str)

def test_export_dtypes(self):
DTYPE_TO_MLIR_STR = {
# NO_MAPPING : jnp.float0 (signless scalar int)
torch.bool : "i1",
# NO_MAPPING : "i4"
torch.int8 : "i8",
torch.int16 : "i16",
torch.int32 : "i32",
torch.int64 : "i64",
torch.long : "i64",
# NO_MAPPING : "ui4"
torch.uint8 : "ui8",
torch.uint16 : "ui16",
torch.uint32 : "ui32",
torch.uint64 : "ui64",
# NO_MAPPING : "f8E4M3B11FNUZ"
torch.float8_e4m3fn : "f8E4M3FN",
# NO_MAPPING : f8E4M3FNUZ
torch.float8_e5m2 : "f8E5M2",
# NO_MAPPING : f8E5M2FNUZ
torch.bfloat16 : "bf16",
torch.half : "f16",
torch.float16 : "f16",
torch.float32 : "f32",
torch.float64 : "f64",
torch.double : "f64",
torch.complex64 : "complex<f32>",
torch.complex128 : "complex<f64>",
None : None,
}

model = TensorConstant()
for torch_dtype in torch_xla2.tensor.TORCH_DTYPE_TO_JAX.keys():
if torch_dtype == None:
## TODO: Figure out what the None mapping should be, seems like:
## torch.tensor(dtype=None) maps to f32
## jnp.tensor(dtype=None) maps to f64
continue
arg = (torch.randn(10).to(torch_dtype),)
with torch.no_grad():
exported = torch.export.export(model, arg)
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
module_str = str(stablehlo.mlir_module())
self.assertIn(DTYPE_TO_MLIR_STR[torch_dtype], module_str)


if __name__ == '__main__':
Expand Down
92 changes: 92 additions & 0 deletions experimental/torch_xla2/test/test_symbolic_shapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import unittest
import torch
import jax
import torch_xla2

class AddOne(torch.nn.Module):

def __init__(self):
super().__init__()

def forward(self, a):
return a + 1

class ConcatAddModel(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b):
a = torch.concat([a, a], dim=0)
return a + b

class SymbolicShapeTest(unittest.TestCase):
"""Test possible symbolic shape computations that upstream torch export can
emit. Seems to be currently limited to a few binary math operations where one
operand is a symbolic variable/expr and the other is a constant integer.
"""

def setUp(self):
torch.manual_seed(0)

def test_constraints_min_max(self):
"""Test a model with basic min/max dimension restrictions
"""

# Arg shapes are a=s0{<=10}, b=s0*2
model = AddOne()
args = (torch.rand(5),)
sym_a = torch.export.Dim("a", min=3, max=10)
dynamic_shapes = ({0: sym_a},)

with torch.no_grad():
exported = torch.export.export(model, args=args, dynamic_shapes=dynamic_shapes)
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
module_str = str(stablehlo.mlir_module())

self.assertRegex(module_str, r"stablehlo.constant.*3")
self.assertRegex(module_str, r"shape_assertion.*s[0-9]+ >= 3")
self.assertRegex(module_str, r"stablehlo.constant.*10")
self.assertRegex(module_str, r"shape_assertion.*s[0-9]+ <= 10")

def test_constraints_multiply(self):
"""Test a model with a slightly more complex constraint, where the input
shapes are determined by an equation of the other, in this case input shapes
are s0{<=10} and s0*2.
"""
# Arg shapes are a=s0{<=10}, b=s0*2
model = ConcatAddModel()
args = (torch.rand(2),torch.rand(4))
sym_a = torch.export.Dim("a", max=10)
sym_b = sym_a*2
dynamic_shapes = ({0: sym_a}, {0: sym_b})

with torch.no_grad():
exported = torch.export.export(model, args=args, dynamic_shapes=dynamic_shapes)
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
module_str = str(stablehlo.mlir_module())

self.assertRegex(module_str, r"stablehlo.constant.*10")
self.assertRegex(module_str, r"shape_assertion.*s[0-9]+ <= 10")
self.assertRegex(module_str, r"stablehlo.constant.*2")
self.assertRegex(module_str, r"shape_assertion.*2\*s[0-9]+")

def test_constraint_indirection(self):
"""Test a model where none of the shapes are directly symbolic variables
but all are expressions of symints that don't appear directly in the model.
"""

# Arg shapes are b=s0{<=10}*2
args = (torch.randn(10, 10),)
model = AddOne()
sym_a = torch.export.Dim("a", max=10)
sym_b = sym_a*2
dynamic_shapes = ({0: sym_b},)

with torch.no_grad():
exported = torch.export.export(model, args=args, dynamic_shapes=dynamic_shapes)
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
module_str = str(stablehlo.mlir_module())

self.assertRegex(module_str, r"shape_assertion.*s[0-9]+ <= 10")
self.assertRegex(module_str, r"shape_assertion.*2\*s[0-9]+")

Loading

0 comments on commit baf08ae

Please sign in to comment.