Skip to content

Commit

Permalink
Add basic sdxl inference example.
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Sep 25, 2024
1 parent bdbfaa0 commit c4862ef
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 7 deletions.
12 changes: 10 additions & 2 deletions experimental/reference_models/README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
This directory will contain a list of reference models that
we have optimized and runs well with torch_xla or torch_xla2.
we have optimized and runs well on TPU.

Contents of this directory is organized in the following way:

* Every subdirectory is a self-contained model, as a seperate pip package.

* Each subdirectory must has a README indicating:
** is this training or inference
** on what devices it has been tested / developed
** instructions on running.

* Every subdirectory contains it's own set of shell scripts do with all the flags
set for the best performance that we turned, be it training or inference.

* Each subdirectory can specify their own dependencies, and can depend on models / layers
defined in well-known OSS libraries, such as HuggingFace transformers.
defined in well-known OSS libraries, such as HuggingFace transformers. But should ideally not depend on each other.

* (Optional) Each model can also have a GPU "original" version that illustrates and attributes where this model code came from, if any. This also helps to show case what changes we have done to make it performant on TPU.

23 changes: 21 additions & 2 deletions experimental/reference_models/tpu_sdxl/sdxl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import time
import functools
import jax
import torch
import torch_xla2
from torch_xla2 import interop
from torch_xla2.interop import JittableModule

from transformers.modeling_outputs import BaseModelOutputWithPooling

from jax.tree_util import register_pytree_node
import jax

def base_model_output_with_pooling_flatten(v):
return (v.last_hidden_state, v.pooler_output, v.hidden_states, v.attentions), None
Expand All @@ -21,13 +25,14 @@ def base_model_output_with_pooling_unflatten(aux_data, children):


from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base")

prompt = "a photograph of an astronaut riding a horse"
# image = pipe(prompt).images[0]


env = torch_xla2.default_env()
jax.config.update('jax_enable_x64', False)

def move_scheduler(scheduler):
for k, v in scheduler.__dict__.items():
Expand All @@ -37,11 +42,25 @@ def move_scheduler(scheduler):

with env:
pipe.to('jax:1')
#import pdb; pdb.set_trace()
move_scheduler(pipe.scheduler)
pipe.unet = JittableModule(pipe.unet, extra_jit_args={'static_argnames': ('return_dict',)})
pipe.text_encoder = JittableModule(pipe.text_encoder)

BS = 4
prompt = [prompt] * BS
pipe.vae = JittableModule(
pipe.vae,
extra_jit_args={'static_argnames': ('return_dict', )})
pipe.vae.make_jitted('decode')

image = pipe(prompt).images[0]

jax.profiler.start_trace('/tmp/sdxl')
start = time.perf_counter()
image = pipe(prompt, num_inference_steps=20).images[0]
end = time.perf_counter()
jax.profiler.stop_trace()
print('Total time is ', end - start, 'bs = ', BS)
image.save(f"astronaut_rides_horse.png")


Expand Down
3 changes: 2 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,8 @@ def setUpClass(cls):
print('op_db size: ', len(op_db), 'testing: ', len(ops_to_test))

def setUp(self):
self.env = tensor.Environment()
self.env = torch_xla2.default_env()
torch_xla2.enable_accuracy_mode()
#self.env.config.debug_accuracy_for_each_op = True
torch.manual_seed(0)

Expand Down
44 changes: 42 additions & 2 deletions experimental/torch_xla2/torch_xla2/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ def set_one(module, prefix):
set_one(m, '')


class JittableModule:
class JittableModule(torch.nn.Module):

# TODO: add statedict loading hook

def __init__(self, m: torch.nn.Module, extra_jit_args={}):
super().__init__()
self.params, self.buffers = extract_all_buffers(m)
self._model = m
self._jitted = {}
Expand Down Expand Up @@ -83,14 +86,51 @@ def jitted_forward(*args, **kwargs):
return self._jitted['forward'](*args, **kwargs)

def __getattr__(self, key):
if key == '_model':
return super().__getattr__(key)
if key in self._jitted:
return self._jitted[key]
return getattr(self._model, key)

def make_jitted(self, key):
jitted = jax_jit(
functools.partial(self.functional_call, key),
kwargs_for_jax_jit=self._extra_jit_args)
def call(*args, **kwargs):
return jitted(self.params, self.buffers, *args, **kwargs)
self._jitted[key] = call





class CompileMixin:

def functional_call(
self, method, params, buffers, *args, **kwargs):
kwargs = kwargs or {}
params_copy = copy.copy(params)
params_copy.update(buffers)
with torch_stateless._reparametrize_module(self, params_copy):
res = method(*args, **kwargs)
return res

def jit(self, method):
jitted = jax_jit(functools.partial(self.functional_call, method_name))
def call(*args, **kwargs):
return jitted(self.named_paramters(), self.named_buffers(), *args, **kwargs)
return call


def compile_nn_module(m: torch.nn.Module, methods=None):
if methods is None:
methods = ['forward']


new_parent = type(
m.__class__.__name__ + '_with_CompileMixin',
(CompileMixin, m.__class__),
)
m.__class__ = NewParent


def _torch_view(t: JaxValue) -> TorchValue:
Expand Down

0 comments on commit c4862ef

Please sign in to comment.