Skip to content

Commit

Permalink
Add basic sdxl inference example.
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Sep 25, 2024
1 parent bdbfaa0 commit 7dd0663
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 55 deletions.
12 changes: 10 additions & 2 deletions experimental/reference_models/README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
This directory will contain a list of reference models that
we have optimized and runs well with torch_xla or torch_xla2.
we have optimized and runs well on TPU.

Contents of this directory is organized in the following way:

* Every subdirectory is a self-contained model, as a seperate pip package.

* Each subdirectory must has a README indicating:
** is this training or inference
** on what devices it has been tested / developed
** instructions on running.

* Every subdirectory contains it's own set of shell scripts do with all the flags
set for the best performance that we turned, be it training or inference.

* Each subdirectory can specify their own dependencies, and can depend on models / layers
defined in well-known OSS libraries, such as HuggingFace transformers.
defined in well-known OSS libraries, such as HuggingFace transformers. But should ideally not depend on each other.

* (Optional) Each model can also have a GPU "original" version that illustrates and attributes where this model code came from, if any. This also helps to show case what changes we have done to make it performant on TPU.

48 changes: 0 additions & 48 deletions experimental/reference_models/tpu_sdxl/sdxl.py

This file was deleted.

3 changes: 2 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,8 @@ def setUpClass(cls):
print('op_db size: ', len(op_db), 'testing: ', len(ops_to_test))

def setUp(self):
self.env = tensor.Environment()
self.env = torch_xla2.default_env()
torch_xla2.enable_accuracy_mode()
#self.env.config.debug_accuracy_for_each_op = True
torch.manual_seed(0)

Expand Down
44 changes: 42 additions & 2 deletions experimental/torch_xla2/torch_xla2/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ def set_one(module, prefix):
set_one(m, '')


class JittableModule:
class JittableModule(torch.nn.Module):

# TODO: add statedict loading hook

def __init__(self, m: torch.nn.Module, extra_jit_args={}):
super().__init__()
self.params, self.buffers = extract_all_buffers(m)
self._model = m
self._jitted = {}
Expand Down Expand Up @@ -83,14 +86,51 @@ def jitted_forward(*args, **kwargs):
return self._jitted['forward'](*args, **kwargs)

def __getattr__(self, key):
if key == '_model':
return super().__getattr__(key)
if key in self._jitted:
return self._jitted[key]
return getattr(self._model, key)

def make_jitted(self, key):
jitted = jax_jit(
functools.partial(self.functional_call, key),
kwargs_for_jax_jit=self._extra_jit_args)
def call(*args, **kwargs):
return jitted(self.params, self.buffers, *args, **kwargs)
self._jitted[key] = call





class CompileMixin:

def functional_call(
self, method, params, buffers, *args, **kwargs):
kwargs = kwargs or {}
params_copy = copy.copy(params)
params_copy.update(buffers)
with torch_stateless._reparametrize_module(self, params_copy):
res = method(*args, **kwargs)
return res

def jit(self, method):
jitted = jax_jit(functools.partial(self.functional_call, method_name))
def call(*args, **kwargs):
return jitted(self.named_paramters(), self.named_buffers(), *args, **kwargs)
return call


def compile_nn_module(m: torch.nn.Module, methods=None):
if methods is None:
methods = ['forward']


new_parent = type(
m.__class__.__name__ + '_with_CompileMixin',
(CompileMixin, m.__class__),
)
m.__class__ = NewParent


def _torch_view(t: JaxValue) -> TorchValue:
Expand Down
2 changes: 0 additions & 2 deletions experimental/torch_xla2/torch_xla2/ops/jtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,6 @@ def getitem(self, indexes):
indexes = (indexes, )
elif isinstance(indexes, list):
indexes = tuple(indexes)
if not isinstance(self, jax.Array):
breakpoint()
return self[indexes]

@register_function(torch.corrcoef)
Expand Down

0 comments on commit 7dd0663

Please sign in to comment.