Skip to content

Commit

Permalink
[Update] support fixedNTKScaleRoPE & MixedNTKScaleRoPE
Browse files Browse the repository at this point in the history
  • Loading branch information
NormXU committed Aug 18, 2023
1 parent c0fc9e1 commit d7a2304
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
7 changes: 6 additions & 1 deletion networks/configuration_extrapolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@
# create: @time: 8/7/23 18:42

def set_config_for_extrapolation(config):
# RoPE config
config.use_rope_attention_bias = True
config.rope_type = "dynamic" # "dynamic" of "linear" or "mixed_base"
# when scale_factor=1.0, RoPE is NTKScaleRoPE, when scale_factor > 1, RoPE becomes DynamicallyNTKScaleRope
config.rope_scaling_factor = 1.0
config.rope_type = "dynamic"
config.fix_base = False # please refer to https://normxu.github.io/Rethinking-Rotary-Position-Embedding-2/
config.b = 0.75 # please refer to https://normxu.github.io/Rethinking-Rotary-Position-Embedding-2/

# alibi for encoder https://github.com/lucidrains/x-transformers/pull/88
config.use_alibi = False
config.learnable_alibi = False

# attention scale
config.use_entropy_scale = True # https://openreview.net/forum?id=qc9O2EtrMI-

# Others
Expand Down
51 changes: 45 additions & 6 deletions networks/modeling_erine_layout_extrapolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,20 +200,53 @@ def forward(self, x, seq_len=None):
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)

class MixedNTKScalingRotaryEmbedding(RotaryEmbedding):
"""
copied from LLamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
'fix_based' is inspired from https://normxu.github.io/Rethinking-Rotary-Position-Embedding-2/
"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, b=0.75, device=None):
self.b = b
super().__init__(dim, max_position_embeddings, base, device)

def _set_cos_sin_cache(self, seq_len, device, dtype):
if seq_len > self.max_position_embeddings:
k = seq_len / self.max_position_embeddings
a = np.log(k) / (self.dim // 2 ** self.b)
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
inv_freq /= (a * torch.arange(1, self.dim // 2 + 1).float().to(device) ** self.b).exp()
self.register_buffer("inv_freq", inv_freq)

t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
"""copied from LLamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
"""
copied from LLamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
'fix_based' is inspired from https://normxu.github.io/Rethinking-Rotary-Position-Embedding-2/
"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, fix_base=False):
self.scaling_factor = scaling_factor
self.fix_base = fix_base
super().__init__(dim, max_position_embeddings, base, device)

def _set_cos_sin_cache(self, seq_len, device, dtype):
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
lamda_factor = (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (
self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
base = self.base * lamda_factor
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
if self.fix_base:
inv_freq = inv_freq * 1.0 / lamda_factor
self.register_buffer("inv_freq", inv_freq)

t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
Expand Down Expand Up @@ -513,7 +546,13 @@ def _init_rope(self, config, max_position_embeddings):
elif scaling_type == "dynamic":
rotary_emb = DynamicNTKScalingRotaryEmbedding(
self.attention_head_size, max_position_embeddings=max_position_embeddings,
scaling_factor=scaling_factor
scaling_factor=scaling_factor,
fix_base=config.fix_base
)
elif scaling_type == "mixed_base":
rotary_emb = MixedNTKScalingRotaryEmbedding(
self.attention_head_size, max_position_embeddings=max_position_embeddings,
b=config.b
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
Expand Down

0 comments on commit d7a2304

Please sign in to comment.