diff --git a/experimental/reference_models/README.md b/experimental/reference_models/README.md index c715042bd13..49d4a2030e8 100644 --- a/experimental/reference_models/README.md +++ b/experimental/reference_models/README.md @@ -1,12 +1,20 @@ This directory will contain a list of reference models that -we have optimized and runs well with torch_xla or torch_xla2. +we have optimized and runs well on TPU. Contents of this directory is organized in the following way: * Every subdirectory is a self-contained model, as a seperate pip package. + +* Each subdirectory must has a README indicating: +** is this training or inference +** on what devices it has been tested / developed +** instructions on running. + * Every subdirectory contains it's own set of shell scripts do with all the flags set for the best performance that we turned, be it training or inference. + * Each subdirectory can specify their own dependencies, and can depend on models / layers - defined in well-known OSS libraries, such as HuggingFace transformers. + defined in well-known OSS libraries, such as HuggingFace transformers. But should ideally not depend on each other. +* (Optional) Each model can also have a GPU "original" version that illustrates and attributes where this model code came from, if any. This also helps to show case what changes we have done to make it performant on TPU. diff --git a/experimental/reference_models/tpu_sdxl/sdxl.py b/experimental/reference_models/tpu_sdxl/sdxl.py deleted file mode 100644 index 20e8c30d08e..00000000000 --- a/experimental/reference_models/tpu_sdxl/sdxl.py +++ /dev/null @@ -1,48 +0,0 @@ -import jax -import torch -import torch_xla2 -from torch_xla2.interop import JittableModule - -from transformers.modeling_outputs import BaseModelOutputWithPooling - -from jax.tree_util import register_pytree_node - -def base_model_output_with_pooling_flatten(v): - return (v.last_hidden_state, v.pooler_output, v.hidden_states, v.attentions), None - -def base_model_output_with_pooling_unflatten(aux_data, children): - return BaseModelOutputWithPooling(*children) - -register_pytree_node( - BaseModelOutputWithPooling, - base_model_output_with_pooling_flatten, - base_model_output_with_pooling_unflatten -) - - -from diffusers import StableDiffusionPipeline -pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") - -prompt = "a photograph of an astronaut riding a horse" -# image = pipe(prompt).images[0] - - -env = torch_xla2.default_env() - -def move_scheduler(scheduler): - for k, v in scheduler.__dict__.items(): - if isinstance(v, torch.Tensor): - setattr(scheduler, k, v.to('jax')) - - -with env: - pipe.to('jax:1') - #import pdb; pdb.set_trace() - move_scheduler(pipe.scheduler) - pipe.unet = JittableModule(pipe.unet, extra_jit_args={'static_argnames': ('return_dict',)}) - pipe.text_encoder = JittableModule(pipe.text_encoder) - image = pipe(prompt).images[0] - image.save(f"astronaut_rides_horse.png") - - - diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index e6f4bb652c3..f449983abf4 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -339,7 +339,8 @@ def setUpClass(cls): print('op_db size: ', len(op_db), 'testing: ', len(ops_to_test)) def setUp(self): - self.env = tensor.Environment() + self.env = torch_xla2.default_env() + torch_xla2.enable_accuracy_mode() #self.env.config.debug_accuracy_for_each_op = True torch.manual_seed(0) diff --git a/experimental/torch_xla2/torch_xla2/interop.py b/experimental/torch_xla2/torch_xla2/interop.py index 9e32e8fc5f0..604ce8b7184 100644 --- a/experimental/torch_xla2/torch_xla2/interop.py +++ b/experimental/torch_xla2/torch_xla2/interop.py @@ -47,9 +47,12 @@ def set_one(module, prefix): set_one(m, '') -class JittableModule: +class JittableModule(torch.nn.Module): + + # TODO: add statedict loading hook def __init__(self, m: torch.nn.Module, extra_jit_args={}): + super().__init__() self.params, self.buffers = extract_all_buffers(m) self._model = m self._jitted = {} @@ -83,14 +86,51 @@ def jitted_forward(*args, **kwargs): return self._jitted['forward'](*args, **kwargs) def __getattr__(self, key): + if key == '_model': + return super().__getattr__(key) + if key in self._jitted: + return self._jitted[key] return getattr(self._model, key) + def make_jitted(self, key): + jitted = jax_jit( + functools.partial(self.functional_call, key), + kwargs_for_jax_jit=self._extra_jit_args) + def call(*args, **kwargs): + return jitted(self.params, self.buffers, *args, **kwargs) + self._jitted[key] = call + + + + + +class CompileMixin: + def functional_call( + self, method, params, buffers, *args, **kwargs): + kwargs = kwargs or {} + params_copy = copy.copy(params) + params_copy.update(buffers) + with torch_stateless._reparametrize_module(self, params_copy): + res = method(*args, **kwargs) + return res + def jit(self, method): + jitted = jax_jit(functools.partial(self.functional_call, method_name)) + def call(*args, **kwargs): + return jitted(self.named_paramters(), self.named_buffers(), *args, **kwargs) + return call +def compile_nn_module(m: torch.nn.Module, methods=None): + if methods is None: + methods = ['forward'] - + new_parent = type( + m.__class__.__name__ + '_with_CompileMixin', + (CompileMixin, m.__class__), + ) + m.__class__ = NewParent def _torch_view(t: JaxValue) -> TorchValue: