From f088810ac87a98bfbda4678af8fb069d9d9a3f1d Mon Sep 17 00:00:00 2001 From: qihqi Date: Thu, 26 Sep 2024 12:20:02 -0700 Subject: [PATCH] Add stablediffusion inference reference model (#8027) --- .github/workflows/build_and_test.yml | 4 +- .github/workflows/build_upstream_image.yml | 2 +- experimental/reference_models/README.md | 20 +++++ experimental/torch_xla2/test/test_exports.py | 1 + .../torch_xla2/test/test_functions.py | 1 + experimental/torch_xla2/test/test_ops.py | 4 +- .../torch_xla2/torch_xla2/__init__.py | 15 +++- experimental/torch_xla2/torch_xla2/config.py | 1 + experimental/torch_xla2/torch_xla2/interop.py | 75 +++++++++++++++++-- experimental/torch_xla2/torch_xla2/tensor.py | 4 +- 10 files changed, 114 insertions(+), 13 deletions(-) create mode 100644 experimental/reference_models/README.md diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index af95dc955b6..8576c908e0a 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -5,13 +5,13 @@ on: - master - r[0-9]+.[0-9]+ paths-ignore: - - 'experimental/torch_xla2/**' + - 'experimental/**' push: branches: - master - r[0-9]+.[0-9]+ paths-ignore: - - 'experimental/torch_xla2/**' + - 'experimental/**' workflow_dispatch: concurrency: diff --git a/.github/workflows/build_upstream_image.yml b/.github/workflows/build_upstream_image.yml index bb8ce87f01c..37992bc20f8 100644 --- a/.github/workflows/build_upstream_image.yml +++ b/.github/workflows/build_upstream_image.yml @@ -5,7 +5,7 @@ on: - master - r[0-9]+.[0-9]+ paths-ignore: - - 'experimental/torch_xla2/**' + - 'experimental/**' workflow_dispatch: jobs: build: diff --git a/experimental/reference_models/README.md b/experimental/reference_models/README.md new file mode 100644 index 00000000000..49d4a2030e8 --- /dev/null +++ b/experimental/reference_models/README.md @@ -0,0 +1,20 @@ +This directory will contain a list of reference models that +we have optimized and runs well on TPU. + +Contents of this directory is organized in the following way: + +* Every subdirectory is a self-contained model, as a seperate pip package. + +* Each subdirectory must has a README indicating: +** is this training or inference +** on what devices it has been tested / developed +** instructions on running. + +* Every subdirectory contains it's own set of shell scripts do with all the flags + set for the best performance that we turned, be it training or inference. + +* Each subdirectory can specify their own dependencies, and can depend on models / layers + defined in well-known OSS libraries, such as HuggingFace transformers. But should ideally not depend on each other. + +* (Optional) Each model can also have a GPU "original" version that illustrates and attributes where this model code came from, if any. This also helps to show case what changes we have done to make it performant on TPU. + diff --git a/experimental/torch_xla2/test/test_exports.py b/experimental/torch_xla2/test/test_exports.py index 60dcbeb856b..9cf296df3cf 100644 --- a/experimental/torch_xla2/test/test_exports.py +++ b/experimental/torch_xla2/test/test_exports.py @@ -34,6 +34,7 @@ class ExportTest(unittest.TestCase): def setUp(self): torch.manual_seed(0) + torch_xla2.enable_accuracy_mode() def test_interpolate(self): diff --git a/experimental/torch_xla2/test/test_functions.py b/experimental/torch_xla2/test/test_functions.py index 092f38a7e84..9e291dc802a 100644 --- a/experimental/torch_xla2/test/test_functions.py +++ b/experimental/torch_xla2/test/test_functions.py @@ -10,6 +10,7 @@ class TestTorchFunctions(parameterized.TestCase): def setUp(self): self.env = torch_xla2.tensor.Environment() + torch_xla2.enable_accuracy_mode() @parameterized.named_parameters( ('tensor_2d', lambda: torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])), diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index b9c7e232214..2a96b010ab7 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -7,6 +7,7 @@ instantiate_device_type_tests, ops) from torch.utils import _pytree as pytree from torch_xla2 import tensor +import torch_xla2 skiplist = { @@ -259,7 +260,8 @@ def setUpClass(cls): print('op_db size: ', len(op_db), 'testing: ', len(ops_to_test)) def setUp(self): - self.env = tensor.Environment() + self.env = torch_xla2.default_env() + torch_xla2.enable_accuracy_mode() #self.env.config.debug_accuracy_for_each_op = True torch.manual_seed(0) diff --git a/experimental/torch_xla2/torch_xla2/__init__.py b/experimental/torch_xla2/torch_xla2/__init__.py index 54af0eccab4..f7dbde71263 100644 --- a/experimental/torch_xla2/torch_xla2/__init__.py +++ b/experimental/torch_xla2/torch_xla2/__init__.py @@ -16,7 +16,6 @@ from jax._src import xla_bridge os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') -jax.config.update('jax_enable_x64', True) # torch_xla2:oss-begin old_pjrt_options = jax.config.jax_pjrt_client_create_options @@ -80,4 +79,16 @@ def disable_globally(): unsupported_dtype=unsupported_dtype) import jax -torch._register_device_module('jax', jax) \ No newline at end of file +torch._register_device_module('jax', jax) + + +def enable_accuracy_mode(): + jax.config.update('jax_enable_x64', True) + jax.config.update('jax_default_matmul_precision', 'highest') + default_env().config.internal_respect_torch_return_dtypes = True + + +def enable_performance_mode(): + jax.config.update('jax_enable_x64', False) + jax.config.update('jax_default_matmul_precision', 'default') + default_env().config.internal_respect_torch_return_dtypes = False \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/config.py b/experimental/torch_xla2/torch_xla2/config.py index 119f3b44d7e..8a0870996a2 100644 --- a/experimental/torch_xla2/torch_xla2/config.py +++ b/experimental/torch_xla2/torch_xla2/config.py @@ -15,3 +15,4 @@ class Configuration: # device treat_cuda_as_jax_device: bool = True use_torch_native_for_cpu_tensor: bool = False + internal_respect_torch_return_dtypes: bool = False diff --git a/experimental/torch_xla2/torch_xla2/interop.py b/experimental/torch_xla2/torch_xla2/interop.py index e4357ef5a51..604ce8b7184 100644 --- a/experimental/torch_xla2/torch_xla2/interop.py +++ b/experimental/torch_xla2/torch_xla2/interop.py @@ -47,19 +47,25 @@ def set_one(module, prefix): set_one(m, '') -class JittableModule: +class JittableModule(torch.nn.Module): - def __init__(self, m: torch.nn.Module): + # TODO: add statedict loading hook + + def __init__(self, m: torch.nn.Module, extra_jit_args={}): + super().__init__() self.params, self.buffers = extract_all_buffers(m) self._model = m + self._jitted = {} + + self._extra_jit_args = extra_jit_args def __call__(self, *args, **kwargs): - res = self._model(*args, **kwargs) - return res + return self.forward(*args, **kwargs) + def functional_call( - self, method_name, params, buffers, args, kwargs=None): + self, method_name, params, buffers, *args, **kwargs): kwargs = kwargs or {} params_copy = copy.copy(params) params_copy.update(buffers) @@ -68,6 +74,65 @@ def functional_call( return res + def forward(self, *args, **kwargs): + if 'forward' not in self._jitted: + jitted = jax_jit( + functools.partial(self.functional_call, 'forward'), + kwargs_for_jax_jit=self._extra_jit_args, + ) + def jitted_forward(*args, **kwargs): + return jitted(self.params, self.buffers, *args, **kwargs) + self._jitted['forward'] = jitted_forward + return self._jitted['forward'](*args, **kwargs) + + def __getattr__(self, key): + if key == '_model': + return super().__getattr__(key) + if key in self._jitted: + return self._jitted[key] + return getattr(self._model, key) + + def make_jitted(self, key): + jitted = jax_jit( + functools.partial(self.functional_call, key), + kwargs_for_jax_jit=self._extra_jit_args) + def call(*args, **kwargs): + return jitted(self.params, self.buffers, *args, **kwargs) + self._jitted[key] = call + + + + + +class CompileMixin: + + def functional_call( + self, method, params, buffers, *args, **kwargs): + kwargs = kwargs or {} + params_copy = copy.copy(params) + params_copy.update(buffers) + with torch_stateless._reparametrize_module(self, params_copy): + res = method(*args, **kwargs) + return res + + def jit(self, method): + jitted = jax_jit(functools.partial(self.functional_call, method_name)) + def call(*args, **kwargs): + return jitted(self.named_paramters(), self.named_buffers(), *args, **kwargs) + return call + + +def compile_nn_module(m: torch.nn.Module, methods=None): + if methods is None: + methods = ['forward'] + + new_parent = type( + m.__class__.__name__ + '_with_CompileMixin', + (CompileMixin, m.__class__), + ) + m.__class__ = NewParent + + def _torch_view(t: JaxValue) -> TorchValue: # t is an object from jax land # view it as-if it's a torch land object diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index 861dd8aaf89..6e446fb3874 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -331,8 +331,8 @@ def _to_copy(self, the_tensor, new_dtype, new_device): the_tensor = the_tensor.to(new_dtype) jax_device = self.get_as_jax_device(new_device) if jax_device: - with jax.default_device(jax_device): - arr = t2j(the_tensor) + arr = t2j(the_tensor) + arr = jax.device_put(arr, jax_device) else: with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): return torch_tensor.to(new_device)