Skip to content

Commit

Permalink
Add int8 per channel weight-only quantized matmul (#7201)
Browse files Browse the repository at this point in the history
Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
  • Loading branch information
lsy323 and Siyuan Liu committed Jun 7, 2024
1 parent 4a30ea7 commit 56ddd5d
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 0 deletions.
114 changes: 114 additions & 0 deletions docs/quantized_ops.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
Quantized Operations for XLA device (Experimental feature)
--------------------------

This document outlines how to utilize quantized operations to enable quantization on XLA devices.

XLA Quantized ops offer a high-level abstraction for quantized operations (e.g., blockwise int4 quantized matrix multiplication). These ops are analogous to quantized CUDA kernels ([example](https://github.com/vllm-project/vllm/blob/main/csrc/quantization/gptq/q_gemm.cu)) in the CUDA ecosystem, providing similar functionality and performance benefits within the XLA framework.

**NOTE:** Currently this is classified as experimental feature. It's API specifics
will change in the next (2.5) release.


## How to use:

XLA quantized operations can be used as `torch op`, or a `torch.nn.Module` that wraps the `torch.op`. These 2 options give model developers the flexibility to choose the best way to integrate XLA quantized ops into their solution.

Both `torch op` and `nn.Module` are compatible with `torch.compile( backend='openxla')`.

### Call XLA quantized op in model code

Users can call XLA quantized ops in the same way as calling other regular PyTorch ops. This provides maximum flexibility in integrating XLA quantized ops into their applications. The quantized ops work in both eager mode and Dynamo, with regular PyTorch CPU tensor and XLA tensor.

**Note** Please check the docstring of the quantized ops for the layout of the quantized weights.

```Python
import torch
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_quantized_matmul

N_INPUT_FEATURES=10
N_OUTPUT_FEATURES=20
x = torch.randn((3, N_INPUT_FEATURES), dtype=torch.bfloat16)
w_int = torch.randint(-128, 127, (N_OUTPUT_FEATURES, N_INPUT_FEATURES), dtype=torch.int8)
scaler = torch.randn((N_OUTPUT_FEATURES,), dtype=torch.bfloat16)

# Call with torch CPU tensor (For debugging purpose)
matmul_output = torch.ops.xla.quantized_matmul(x, w_int, scaler)

device = xm.xla_device()
x_xla = x.to(device)
w_int_xla = w_int.to(device)
scaler_xla = scaler.to(device)

# Call with XLA Tensor to run on XLA device
matmul_output_xla = torch.ops.xla.quantized_matmul(x_xla, w_int_xla, scaler_xla)

# Use with torch.compile(backend='openxla')
def f(x, w, s):
return torch.ops.xla.quantized_matmul(x, w, s)

f_dynamo = torch.compile(f, backend="openxla")
dynamo_out_xla = f_dynamo(x_xla, w_int_xla, scaler_xla)
```

It's common to wrap the quantized op into a custom `nn.Module` in model developers model code:

```Python
class MyQLinearForXLABackend(torch.nn.Module):
def __init__(self):
self.weight = ...
self.scaler = ...

def load_weight(self, w, scaler):
# Load quantized Linear weights
# Customized way to preprocess the weights
...
self.weight = processed_w
self.scaler = processed_scaler


def forward(self, x):
# Do some random stuff with x
...
matmul_output = torch.ops.xla.quantized_matmul(x, self.weight, self.scaler)
# Do some random stuff with matmul_output
...
```

### Module Swap

Alternatively, users can also use the `nn.Module` that wraps the XLA quantized ops and do module swap in the model code:

```Python
orig_model = MyModel()
# Quantize the model and get quantized weights
q_weights = quantize(orig_model)
# Process the quantized weight to the format that XLA quantized op expects.
q_weights_for_xla = process_for_xla(q_weights)

# Do module swap
q_linear = XlaQuantizedLinear(self.linear.in_features,
self.linear.out_features)
q_linear.load_quantized_weight(q_weights_for_xla)
orig_model.linear = q_linear
```

## Supported Quantized Operations:

### Matrix Multiply

| Weight Quantization Type | Activation Quantization Type | Dtype | Supported |
|---|---|---|---|
| per-channel | N/A | W8A16 | Yes |
| per-channel | N/A | W4A16 | No |
| per-channel | per-token | W8A8 | No |
| per-channel | per-token | W4A8 | No |
| blockwise | N/A | W8A16 | No |
| blockwise | N/A | W4A16 | No |
| blockwise | per-token | W8A8 | No |
| blockwise | per-token | W4A8 | No |

**Note** `W[X]A[Y]` refers to Weight in `X`-bit, Activation in `Y`-bit. If `X/Y` is 4 or 8, it refers to `int4/8`. 16 for `bfloat16` format.

### Embedding
To be added
105 changes: 105 additions & 0 deletions test/quantized_ops/test_quantized_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import re
import unittest

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_quantized_matmul
from torch_xla.experimental.xla_quantized_matmul import XlaQuantizedLinear
from torch.ao.quantization.utils import determine_qparams

torch.manual_seed(123456)

device = xm.xla_device()


class M(torch.nn.Module):

def __init__(self, input_dim, output_dim):
super(M, self).__init__()
# Define a linear layer
self.linear = torch.nn.Linear(input_dim, output_dim, bias=False)

def weight_quantization_rtn(self,
linear,
quant_method=torch.per_channel_symmetric):
'''
Quantize linear weight using Round-To-Nearest(RTN) algorithm.
'''
assert isinstance(self.linear, torch.nn.Linear)
w_fp = linear.weight.data
min_val, max_val = torch.aminmax(w_fp, dim=1) # min_val, max_val [out_dim]
n_bits = 8
int_min = -2**(n_bits - 1)
int_max = 2**(n_bits - 1) - 1
scaler, zero_point = determine_qparams(
min_val,
max_val,
int_min,
int_max,
dtype=torch.int8,
eps=torch.Tensor([1e-5]),
has_customized_qrange=False,
qscheme=quant_method)
w_int = torch.ops.quantized_decomposed.quantize_per_channel(
w_fp, scaler, zero_point, 0, int_min, int_max, torch.int8)
return w_int, scaler.to(w_fp.dtype), zero_point

def replace_with_xla_quantized_matmul(self):
assert isinstance(self.linear, torch.nn.Linear)
w_int, scaler, _ = self.weight_quantization_rtn(self.linear)
q_linear = XlaQuantizedLinear(self.linear.in_features,
self.linear.out_features)
q_linear.load_quantized_weight(w_int, scaler)
self.linear = q_linear

def forward(self, x):
# Forward pass through the linear layer
return self.linear(x)


class QuantizedTest(unittest.TestCase):

def test_q_linear_module_per_channel(self):

with torch.no_grad():
m = M(5, 8)
x = torch.randn(3, 5)
out_fp = m(x)
m.replace_with_xla_quantized_matmul()
out_quant = m(x)

m = m.to(device)
x = x.to(device)
out_quant_xla = m(x)
self.assertTrue(torch.allclose(out_fp, out_quant, atol=0.01))
self.assertTrue(torch.allclose(out_quant_xla.cpu(), out_quant))

def test_q_linear_module_dynamo(self):

with torch.no_grad():
m = M(5, 8)
x = torch.randn(3, 5)
out_fp = m(x)
m.replace_with_xla_quantized_matmul()
out_quant = m(x)
m = m.to(device)
m_dynamo = torch.compile(m, backend="openxla")
out_quant_dynamo = m_dynamo(x.to(device))
self.assertTrue(torch.allclose(out_fp, out_quant, atol=0.01))
self.assertTrue(torch.allclose(out_quant_dynamo.cpu(), out_quant))

def test_q_linear_hlo(self):
with torch.no_grad():
x = torch.randn((3, 5), dtype=torch.bfloat16).to(device)
w_int = torch.randint(-128, 127, (8, 5), dtype=torch.int8).to(device)
scaler = torch.randn((8,), dtype=torch.bfloat16).to(device)

output = torch.ops.xla.quantized_matmul(x, w_int, scaler)
hlo = torch_xla._XLAC._get_xla_tensors_hlo([output])
print(hlo)
self.assertTrue(re.search(r'bf16.*dot.*bf16.*s8', hlo) is not None)


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ function run_xla_op_tests3 {
run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_inference.py"
run_test "$CDIR/stablehlo/test_stablehlo_compile.py"
run_test "$CDIR/stablehlo/test_unbounded_dynamism.py"
run_test "$CDIR/quantized_ops/test_quantized_matmul.py"
run_test "$CDIR/spmd/test_xla_sharding.py"
run_test "$CDIR/spmd/test_xla_sharding_hlo.py"
run_test "$CDIR/spmd/test_xla_virtual_device.py"
Expand Down
89 changes: 89 additions & 0 deletions torch_xla/experimental/xla_quantized_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import torch
import torch.nn.functional as F
import torch_xla
from torch.library import impl
from torch_xla.core.xla_model import XLA_LIB

XLA_LIB.define(
"quantized_matmul(Tensor x, Tensor w, Tensor scale, int? blocksize=-1, bool? quantize_activation=False) -> Tensor"
)


def _check_per_channel_quant_weight_dtype_shapes(input_dim, output_dim, w,
w_scaler):
assert w.dtype == torch.int8, f"Weight dtype is expected to be torch.int8, got {w.dtype}."
assert w.dim(
) == 2, f"Weight tensor is expected to be 2D, got {w.dim()}D Tensor."
assert output_dim == w.shape[0] and input_dim == w.shape[
1], f"Weight shape is expected to be [output_dim, input_dim], output_dim: {output_dim}, input_dim: {input_dim}, but got {w.shape}."
assert w_scaler.dim() == 1 and w_scaler.shape[0] == w.shape[
0], f"weight scaler shape is expect to be [out_channel,], got {w_scaler.shape}, weight shape {w.shape}."


@impl(XLA_LIB, "quantized_matmul", "XLA")
def quantized_matmul_xla(x: torch.Tensor,
w: torch.Tensor,
scaler: torch.Tensor,
blocksize: int = -1):
"""Quantized Matrix Multiply op on XLA devices.
Args:
x: torch.Tensor - Activation of Matmul [..., in_channel].
w: torch.Tensor - Weight Tensor.
per-channel quant: torch.int8 x [out_channel, in_channel].
scaler: torch.Tensor - Weight scaler.
per-channel quant: [out_channel,].
blocksize: blocksize for blockwise quantization, -1 for per-channel quantization.
"""
assert blocksize == -1, "blockwise quantization is not supported yet."
# Per-channel quant.
_check_per_channel_quant_weight_dtype_shapes(x.shape[-1], scaler.shape[0], w,
scaler)
return F.linear(x, w) * scaler


@impl(XLA_LIB, "quantized_matmul", "CompositeExplicitAutograd")
def quantized_matmul(x: torch.Tensor,
w: torch.Tensor,
scaler: torch.Tensor,
blocksize: int = -1):
assert blocksize == -1, "blockwise quantization is not supported yet."
# Per-channel quant.
_check_per_channel_quant_weight_dtype_shapes(x.shape[-1], scaler.shape[0], w,
scaler)
w = w.to(x.dtype)
return torch.mul(F.linear(x, w), scaler)


class XlaQuantizedLinear(torch.nn.Module):

def __init__(self, input_dim, output_dim, blocksize=-1):
super().__init__()
assert blocksize == -1, "Only per-channel quantization is supported."
self.input_dim = input_dim
self.output_dim = output_dim
self.blocksize = blocksize
self.register_buffer('weight',
torch.zeros(output_dim, input_dim).to(torch.int8))
self.register_buffer('weight_scaler', torch.zeros(output_dim))

def load_quantized_weight(self, weight, weight_scaler):
'''
Weight shape: [output_channel, input_channel]
Weight scaler shape: [output_channel]
'''
if self.blocksize == -1:
# Per-channel quant.
_check_per_channel_quant_weight_dtype_shapes(self.input_dim,
self.output_dim, weight,
weight_scaler)
self.weight = weight
self.weight_scaler = weight_scaler
else:
assert False, "Only per-channel quantization is supported."

def forward(self, x):
if self.blocksize == -1:
return torch.ops.xla.quantized_matmul(x, self.weight, self.weight_scaler)
else:
assert False, "Only per-channel quantization is supported."

0 comments on commit 56ddd5d

Please sign in to comment.