Skip to content

Commit

Permalink
Training example for llama3; misc changes to make the training work. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Jun 7, 2024
1 parent 56ddd5d commit f4a612c
Show file tree
Hide file tree
Showing 14 changed files with 1,748 additions and 9 deletions.
194 changes: 194 additions & 0 deletions experimental/torch_xla2/docs/torch_xla2_dynamo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Dynamo backend for torchxla2

## Goal

Have a dynamo backend backend by torch_xla2.

The users should be able to do the following:

```python
m = model ...
m_compiled = torch.compile(m, backend='torch_xla2_compile') # backend name TBD
result = m_compiled(*inputs)
```

The above should run on TPU will low overhead.

## Challenge

Usually the challenge of a dynamo backend is the compiler that
transforms a fx graph with torch (or Aten) ops to the compiled executable.
However, in our case, that piece is solved.

For every `call_function` node; we lookup the corresponding implementation of
said ATen op in a dictionary for it's corresponding implementation in Jax,
and we just call it.

This is illustrated here: https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/torch_xla2/export.py#L23

Now, the challenge is for dynamo to be able to 1. produce the graph; and 2. n
not incur any data copies in this process.


Consider this following pseudocode:

```python
class XLATensor2:
_data: jax.Array
def __torch_dispatch__(...):
# do stuff with _data, get new data
return XLATensor2(new_data)

def dynamo_backend(fx, sample):
compiled = compile fx into graph that manipulate jax.Array.
def returned_callable(inputs):
datas = [i._data for i in inputs]
res = compiled(*datas)
return TensorSubclass(res)
return returned_callable

model = torch.compile(model, backend = dynamo_backend)
inputs = a list of TensorSubclass or a list of torch.Tensor?
model(*inputs)
```

What would be the type of inputs?
If inputs are of type `TensorSubclass`, then dynamo
will attempt to trace through the `__torch_dispatch__` method,
and throws error because it doesn't know what is `_data` and the
operations on it.

If `inputs` is of type `torch.Tensor`, then it works: dynamo
calls the backend, the backend can produce correct result.
But, `inputs` need to be converted to `TensorSubclass` first inside of
the backend; which usually means a data copy. This happens everytime
the compiled backend is executed, therefore not desirable.

## The Desired behavior

When *tracing* dynamo treats TensorSubclass as if it is a regular tensor
without dispatch override; and when executing the compiled callable,
TensorSubclass is passed in as-is. We know that dynamo can do this with
some tensor subclass, namely `FakeTensor`.


Let's list out the possible ways we could accomplish this behavior.


# Option 1. Have the jax.Array object hold in C++

Roughly we would have a `Tensor` subclass in C++, this is very
similar to the `LazyTensor` subclass that is the current `XLATensor`.
This tensor can hold it's own states in C++. In our case, that would
be a `PyObject*` that happens to point to either `jnp.ndarray` or
jax's `Traced<ShapedArray>` during jax.jit. We might further result the
`XLA` dispatch key to route the operators to the jax implementation,
emulating what `__torch_dispatch__` does.

This way, eager mode will continue to work, and dynamo would work
because the Python class is still `torch.Tensor` (not a subclass), and
there are no Python logic in dispatching so dynamo cannot trace through.

## Pros:
* Very clear that this will work.
* Recommended by ezyang

## Cons:
Now need to deal with C++ builds. In particular, `torch` becomes a source
dependency instead of a pip dependency; meaning, again we need to start
building torch first then build torch_xla2. This might be mitigated if
that subclass can be upstreamed.


# Option 2. Modify dynamo to do the desired behavior

We have one instance where a `torch.Tensor` dispatch subclass
just works with dynamo, without dynamo make a fuss when it traces
`__torch_dispatch__`. This is `FakeTensor`. (https://github.com/pytorch/pytorch/pull/100017/files)

The idea is to make dynamo trace as-if the inputs are `FakeTensor` and
not `XLATensor`. and only after the creation of fx graph and backend, dynamo
calls the compiled callable with `XLATensor`.

Pros:
* Likely pure python changes.

Cons:
* We also need to design a mechanism to represent tensor subclasses that
is desirable for dynamo to trace through, and those is not.
* Likely significant amount of work.


# Option 3. Register All the ops as custom_ops

So currently dynamo traces `__torch_dispatch__`, and we don't like that
because it will find the operations on Jax arrays, and doesn't understand those.

What if we make dynamo **able** to understand what is inside?
The [Black box python functions](https://docs.google.com/document/d/1ZuCVyMfibExwvtzhd9cfMWk5zXT3Dhy1b3kuvAIkBoU/edit#heading=h.56tggsazyrkh) doc
points the possibility of registering things that we don't want dynamo
to go into as a custom op. So we could, theoretically do the following:

1. Register the jax impl of an Aten op as a custom op.
i.e. register `jaten.add` for `aten.add`.
2. For meta kernels, just call the meta kernel of `aten.add`.
3. In `__torch_dispatch__`, we forward the call from `aten.add` to `jaten.add`.

When dynamo attempts to go inside of `__torch_dispatch__`, it will find
`jaten.add`. Then it will record that in the `fx.Graph`.

Our backend will see the same ops but in a different namespace (`jaten`).
That is fine as long as we know how to look up its implementation.

Note: we probably also need to hook up gradients of custom ops via. `autograph.Function`.


Pros / Cons:
Haven't tried, don't know if it gonna work or not.






# Appendix, Failed attempts:

## Attempt 1: move dispatch to a mode (i.e. subclass have no dispatch override)

```python
class Subclass(torch.Tensor):

@staticmethod
def __new__(cls, elem):
dtype = tensor.j2t_dtype(elem.dtype)
shape = list(elem.shape)
for i, s in enumerate(shape):
if not isinstance(s, int):
shape[i] = 1
if dtype is None:
dtype = torch.float32

self = torch.Tensor._make_wrapper_subclass(
cls,
shape,
dtype=dtype,
device='meta',
requires_grad=False,
)
self._meta = torch.empty(
shape, dtype=dtype, device='meta', requires_grad=False
)
self._elem = elem
return self

def __init__(self, elem: jax.Array):
super().__init__()
self._elem = elem

def __str__(self):
return "Subclass({} {})".format(str(type(self._elem)), str(self._elem))

```

This fails with an error saying that exhausted subclasses and all the `__torch_dispatch__` returned `NotImplemented`.

Empty file.
73 changes: 73 additions & 0 deletions experimental/torch_xla2/examples/_grad_of_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import jax.numpy as jnp
import jax
from jax.experimental.pallas.ops.tpu import flash_attention

import torch_xla2
from jax.experimental import mesh_utils
from torch_xla2.ops.jtorch import _tpu_flash_attention

env = torch_xla2.default_env()
jax.config.update('jax_enable_x64', False)
env._mesh = jax.sharding.Mesh(
mesh_utils.create_device_mesh((4, )),
axis_names=("fsdp", ),
)
env.use_flash_attention = True


from torch.nn import functional as F


def attn(q, k, v):
q, k, v = env.j2t_iso((q, k, v))
with env:
x = F.scaled_dot_product_attention(q, k, v, is_causal=True)
x = env.t2j_iso(x)
return jnp.sum(x)


import torch

class M(torch.nn.Module):

def __init__(self):
super().__init__()
self.a = torch.nn.Linear(10, 10)

def forward(self, x):
return self.a(x)

m = M()
from torch_xla2.interop import JittableModule

mjit = JittableModule(m)

from torch.nn.utils import stateless

def f(weights, x):
res = mjit.functional_call('forward', weights, {}, (x, ))
return torch.sum(res)


def crossent(x, y):
x, y = env.j2t_iso((x, y))
res = torch.func.functional_call(m, x, (y, ))
return env.t2j_iso(res)

graded = jax.value_and_grad(attn)

shape = (4, 32, 128, 32)
q = jnp.ones(shape, dtype='bfloat16')
v = jnp.ones(shape, dtype='bfloat16')
k = jnp.ones(shape, dtype='bfloat16')


env = torch_xla2.default_env()
weights = env.t2j_iso(env.to_xla(mjit.params))

from torch_xla2.interop import jax_view

#print(jax.jit(graded).lower(q, v, k).as_text())
print(jax.jit(jax.grad(jax_view(f))).lower(
weights, jax.ShapeDtypeStruct((10, ), 'float32')
).as_text())
Loading

0 comments on commit f4a612c

Please sign in to comment.