diff --git a/networks/configuration_extrapolation.py b/networks/configuration_extrapolation.py index 5a664e9..3f662b7 100644 --- a/networks/configuration_extrapolation.py +++ b/networks/configuration_extrapolation.py @@ -2,14 +2,19 @@ # create: @time: 8/7/23 18:42 def set_config_for_extrapolation(config): + # RoPE config config.use_rope_attention_bias = True + config.rope_type = "dynamic" # "dynamic" of "linear" or "mixed_base" + # when scale_factor=1.0, RoPE is NTKScaleRoPE, when scale_factor > 1, RoPE becomes DynamicallyNTKScaleRope config.rope_scaling_factor = 1.0 - config.rope_type = "dynamic" + config.fix_base = False # please refer to https://normxu.github.io/Rethinking-Rotary-Position-Embedding-2/ + config.b = 0.75 # please refer to https://normxu.github.io/Rethinking-Rotary-Position-Embedding-2/ # alibi for encoder https://github.com/lucidrains/x-transformers/pull/88 config.use_alibi = False config.learnable_alibi = False + # attention scale config.use_entropy_scale = True # https://openreview.net/forum?id=qc9O2EtrMI- # Others diff --git a/networks/modeling_erine_layout_extrapolation.py b/networks/modeling_erine_layout_extrapolation.py index b89fd5a..9ff1463 100644 --- a/networks/modeling_erine_layout_extrapolation.py +++ b/networks/modeling_erine_layout_extrapolation.py @@ -200,20 +200,53 @@ def forward(self, x, seq_len=None): self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ) +class MixedNTKScalingRotaryEmbedding(RotaryEmbedding): + """ + copied from LLamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla + 'fix_based' is inspired from https://normxu.github.io/Rethinking-Rotary-Position-Embedding-2/ + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, b=0.75, device=None): + self.b = b + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + if seq_len > self.max_position_embeddings: + k = seq_len / self.max_position_embeddings + a = np.log(k) / (self.dim // 2 ** self.b) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq /= (a * torch.arange(1, self.dim // 2 + 1).float().to(device) ** self.b).exp() + self.register_buffer("inv_freq", inv_freq) + + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): - """copied from LLamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + """ + copied from LLamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla + 'fix_based' is inspired from https://normxu.github.io/Rethinking-Rotary-Position-Embedding-2/ + """ - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, fix_base=False): self.scaling_factor = scaling_factor + self.fix_base = fix_base super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) + lamda_factor = ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - ( + self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + base = self.base * lamda_factor inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + if self.fix_base: + inv_freq = inv_freq * 1.0 / lamda_factor self.register_buffer("inv_freq", inv_freq) t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) @@ -513,7 +546,13 @@ def _init_rope(self, config, max_position_embeddings): elif scaling_type == "dynamic": rotary_emb = DynamicNTKScalingRotaryEmbedding( self.attention_head_size, max_position_embeddings=max_position_embeddings, - scaling_factor=scaling_factor + scaling_factor=scaling_factor, + fix_base=config.fix_base + ) + elif scaling_type == "mixed_base": + rotary_emb = MixedNTKScalingRotaryEmbedding( + self.attention_head_size, max_position_embeddings=max_position_embeddings, + b=config.b ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}")