Skip to content

Commit

Permalink
[quantization] llm quant (#222)
Browse files Browse the repository at this point in the history
* [quantization] init commit for llm quant

* [tests] delete dequantize

* [misc] refactor

* [llm_quant] add readme and fix typo

* [llm_quant] add readme and fix typo
  • Loading branch information
zk1998 committed May 31, 2023
1 parent edf2852 commit 841294e
Show file tree
Hide file tree
Showing 8 changed files with 652 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tinynn/llm_quant/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# LLM QUANT

## 安装依赖

- PyTorch: tested on PyTorch 1.13 & CUDA 11.6
- transformers: tested on v4.28.1
- easyquant: 需要到[Releases](https://github.com/alibaba/TinyNeuralNetwork/releases)手动下载安装包进行安装, 提供权重动态解压和动态量化的cuda加速kernel

## 量化模式

- 8bit仅权重量化: 权重压缩为8-bit,显存需求降低,计算时还原为FP16进行计算,相比于FP16的模型推理存在额外开销。模型精度几乎没有影响。
- 4bit仅权重量化: 权重压缩为4-bit,显存需求大幅度降低, 计算时还原为FP16进行计算,相比于FP16的模型推理存在额外开销。模型精度下降较严重。
- token-wise动态量化: 权重压缩为8-bit, 激活值运行时动态量化为8-bit, 结合easyquant库的int8 GEMM可以有效提升推理性能。在Llama-family模型中精度小幅度下降,基本没有影响。

## Llama 量化
我们对llama模型进行了详细的量化分析和测试,推荐使用8-bit的动态量化,其可以有效提升推理速度并降低显存需求,同时精度几乎不受影响。

| 量化模式 | wikitext2(ppl⬇️) | 推理性能(ms/token) <br/>GPU:2080Ti | 推理性能(ms/token)<br/> GPU:T4 | 模型占用显存(GB) |
|-------------------------|------------------|--------------------------------|----------------------------|------------|
| llama-7b fp16 | 5.68 | - | 61.5882 | 12.90 |
| llama-7b weight8 | 5.68 | 68.6845 | 151.1209 | 7.10 |
| llama-7b token-wise动态量化 | 5.82(+0.14) | 43.0228 | 47.1649 | 7.09 |
| llama-7b weight4 | 6.5657(+0.89) | 63.7035 | 141.1330 | 3.99 |

> 除了模型占用显存外,在模型推理过程中还存在激活值和上下文的显存占用,需要预留1~2GB的额外显存。
## 未来工作

- 4-bit量化精度恢复及加速推理
- 8-bit静态量化
Empty file added tinynn/llm_quant/__init__.py
Empty file.
61 changes: 61 additions & 0 deletions tinynn/llm_quant/examples/chatglm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# This script is based on https://github.com/THUDM/ChatGLM-6B
import signal
import os
import torch
from transformers import AutoModel, AutoTokenizer

from tinynn.llm_quant.modules import quant_fc


def basic_usage(model_path='THUDM/chatglm-6b', quant_mod='dynamic'):
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
device = torch.device('cuda')

# Do quantization.
if quant_mod != 'fp16':
quant_fc(model, quant_mod=quant_mod)
model.to(device)

clear_command = 'clear'
stop_stream = False

def build_prompt(history):
prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
for query, response in history:
prompt += f"\n\n用户:{query}"
prompt += f"\n\nChatGLM-6B:{response}"
return prompt

def signal_handler(signal, frame):
global stop_stream
stop_stream = True

history = []
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
while True:
query = input("\n用户:")
if query.strip() == "stop":
break
if query.strip() == "clear":
history = []
os.system(clear_command)
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
continue
count = 0
for response, history in model.stream_chat(tokenizer, query, history=history):
if stop_stream:
stop_stream = False
break
else:
count += 1
if count % 8 == 0:
os.system(clear_command)
print(build_prompt(history), flush=True)
signal.signal(signal.SIGINT, signal_handler)
os.system(clear_command)
print(build_prompt(history), flush=True)


if __name__ == '__main__':
basic_usage()
42 changes: 42 additions & 0 deletions tinynn/llm_quant/examples/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer

from tinynn.llm_quant.modules import quant_fc


def basic_usage(model_path='huggyllama/llama-7b', quant_mod='dynamic'):
device = torch.device('cuda')

# load LLM model from huggingface or local path
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)

# Do quantization.
if quant_mod != 'fp16':
# If your LLM model is Llama-family, you can set fuse_qkv to fuse qkv linear and scaled-dot-product-attention.
quant_fc(model, quant_mod=quant_mod, fuse_qkv=True)
model.to(device)

prompt = "Building a website can be done in 10 simple steps:\n"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(device)

generated_ids = model.generate(
input_ids,
max_new_tokens=1024,
do_sample=True,
top_k=1,
top_p=0.95,
temperature=0.8,
repetition_penalty=1.2,
use_cache=True,
)

outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
for output in outputs:
print(output)


if __name__ == '__main__':
basic_usage()
120 changes: 120 additions & 0 deletions tinynn/llm_quant/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import math
from typing import Optional, Tuple
from distutils.version import LooseVersion

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from transformers.modeling_utils import set_module_tensor_to_device


class LlamaAttentionFused(nn.Module):
def __init__(self, origin_attention):
super().__init__()
self.config = origin_attention.config
self.hidden_size = origin_attention.hidden_size
self.num_heads = origin_attention.num_heads
self.head_dim = origin_attention.head_dim
self.max_position_embeddings = origin_attention.max_position_embeddings

self.qkv_proj = nn.Linear(
origin_attention.hidden_size, origin_attention.num_heads * origin_attention.head_dim * 3, bias=False
)
fused_weight = torch.cat(
[
fc_node.weight.data
for fc_node in [origin_attention.q_proj, origin_attention.k_proj, origin_attention.v_proj]
],
dim=0,
)
set_module_tensor_to_device(
self.qkv_proj, 'weight', fused_weight.device, value=fused_weight, dtype=fused_weight.dtype
)
self.o_proj = origin_attention.o_proj
self.rotary_emb = origin_attention.rotary_emb

origin_attention.q_proj = None
origin_attention.k_proj = None
origin_attention.v_proj = None

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
# use fused fc output to get qkv states
qkv_states = self.qkv_proj(hidden_states).view(bsz, q_len, self.num_heads * 3, self.head_dim).transpose(1, 2)
(query_states, key_states, value_states) = torch.chunk(qkv_states, 3, 1)

is_causal = past_key_value is None

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None
if LooseVersion(torch.__version__) == LooseVersion('1.13.0'):
with torch.backends.cuda.sdp_kernel(enable_math=False):
attn_output, attn_weights = F._scaled_dot_product_attention(
query_states, key_states, value_states, is_causal=is_causal
)
elif LooseVersion(torch.__version__) >= LooseVersion('2.0.0'):
with torch.backends.cuda.sdp_kernel(enable_math=False):
attn_output, attn_weights = F.scaled_dot_product_attention(
query_states, key_states, value_states, is_causal=is_causal
)
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is"
f" {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
del query_states, key_states, value_states

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value
Loading

0 comments on commit 841294e

Please sign in to comment.