Skip to content

Commit ed8655d

Browse files
authored
Add a common rope impl (#107)
This rope impl supports the OG rope and the scaling used in Llama 3.1
1 parent a22bf60 commit ed8655d

File tree

3 files changed

+182
-0
lines changed

3 files changed

+182
-0
lines changed

torchprime/rope/__init__.py

Whitespace-only changes.

torchprime/rope/rope.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""
2+
Rotary Positional Embeddings (RoPE) implementation.
3+
Reference: https://github.com/adalkiran/llama-nuts-and-bolts/blob/main/docs/10-ROPE-ROTARY-POSITIONAL-EMBEDDINGS.md
4+
"""
5+
6+
import math
7+
from dataclasses import dataclass
8+
9+
import torch
10+
11+
12+
@dataclass
13+
class RopeScaling:
14+
"""
15+
RoPE scaling parameters. The defaults are what was selected in Llama 3.1.
16+
"""
17+
18+
factor: float = 8.0
19+
low_freq_factor: float = 1.0
20+
high_freq_factor: float = 4.0
21+
original_context_len: int = 8192
22+
23+
24+
def default_rope_frequencies(
25+
head_dim: int,
26+
theta: float = 10000.0,
27+
) -> torch.Tensor:
28+
"""
29+
Computes the original RoPE frequencies in e.g. Llama 2.
30+
Args:
31+
head_dim: the size of a single attention head.
32+
theta: a hyperparameter controlling how fast the embeddings rotate.
33+
Returns:
34+
The frequencies for the RoPE embeddings.
35+
"""
36+
return 1.0 / (
37+
theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).float() / head_dim)
38+
)
39+
40+
41+
def llama3_rope_frequencies(
42+
head_dim: int,
43+
theta: float = 10000.0,
44+
scaling: RopeScaling | None = None,
45+
) -> torch.Tensor:
46+
"""
47+
Computes Llama 3 and 3.1 RoPE frequencies. In Llama 3.1, RoPE frequencies
48+
may be scaled and interpolated as we move beyond the original context length.
49+
"""
50+
freqs = default_rope_frequencies(head_dim=head_dim, theta=theta)
51+
if scaling is None:
52+
return freqs
53+
54+
low_freq_wavelen = scaling.original_context_len / scaling.low_freq_factor
55+
high_freq_wavelen = scaling.original_context_len / scaling.high_freq_factor
56+
57+
assert low_freq_wavelen > high_freq_wavelen, (
58+
f"low_freq_wavelen {low_freq_wavelen} must be greater than "
59+
f"high_freq_wavelen {high_freq_wavelen}"
60+
)
61+
62+
wavelen = 2 * math.pi / freqs
63+
# wavelen < high_freq_wavelen: do nothing
64+
# wavelen > low_freq_wavelen: divide by factor
65+
freqs = torch.where(wavelen > low_freq_wavelen, freqs / scaling.factor, freqs)
66+
# otherwise: interpolate between the two, using a smooth factor
67+
smooth_factor = (scaling.original_context_len / wavelen - scaling.low_freq_factor) / (
68+
scaling.high_freq_factor - scaling.low_freq_factor
69+
)
70+
smoothed_freqs = (1 - smooth_factor) * freqs / scaling.factor + smooth_factor * freqs
71+
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
72+
freqs = torch.where(is_medium_freq, smoothed_freqs, freqs)
73+
74+
return freqs

torchprime/tests/test_rope.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import math
2+
3+
import pytest
4+
import torch
5+
from transformers import PretrainedConfig
6+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
7+
8+
from torchprime.rope import rope
9+
10+
LLAMA3_SCALING = rope.RopeScaling(
11+
factor=8,
12+
low_freq_factor=1,
13+
high_freq_factor=4,
14+
original_context_len=8192,
15+
)
16+
17+
18+
@pytest.mark.parametrize(
19+
"hidden_size, num_attention_heads, theta",
20+
[(4096, 32, 500000.0), (16384, 128, 500000.0), (65536, 128, 500000.0)],
21+
)
22+
class TestRope:
23+
def test_default_rope(self, hidden_size, num_attention_heads, theta):
24+
head_dim = hidden_size // num_attention_heads
25+
ours = rope.default_rope_frequencies(head_dim=head_dim, theta=theta)
26+
27+
hf_rope_fn = ROPE_INIT_FUNCTIONS["default"]
28+
hf, scale = hf_rope_fn(
29+
PretrainedConfig.from_dict(
30+
{
31+
"hidden_size": hidden_size,
32+
"num_attention_heads": num_attention_heads,
33+
"rope_theta": theta,
34+
}
35+
)
36+
)
37+
38+
assert scale == 1
39+
torch.testing.assert_close(ours, hf)
40+
41+
def test_llama3_rope_against_hf(self, hidden_size, num_attention_heads, theta):
42+
head_dim = hidden_size // num_attention_heads
43+
ours = rope.llama3_rope_frequencies(
44+
head_dim=head_dim,
45+
theta=theta,
46+
scaling=LLAMA3_SCALING,
47+
)
48+
49+
hf_rope_fn = ROPE_INIT_FUNCTIONS["llama3"]
50+
hf, scale = hf_rope_fn(
51+
PretrainedConfig.from_dict(
52+
{
53+
"hidden_size": hidden_size,
54+
"num_attention_heads": num_attention_heads,
55+
"rope_theta": theta,
56+
"rope_scaling": {
57+
"factor": 8,
58+
"low_freq_factor": 1,
59+
"high_freq_factor": 4,
60+
"original_max_position_embeddings": 8192,
61+
},
62+
}
63+
),
64+
device="cpu",
65+
)
66+
67+
assert scale == 1
68+
torch.testing.assert_close(ours, hf)
69+
70+
def test_llama3_rope_against_reference(self, hidden_size, num_attention_heads, theta):
71+
head_dim = hidden_size // num_attention_heads
72+
ours = rope.llama3_rope_frequencies(
73+
head_dim=head_dim,
74+
theta=theta,
75+
scaling=LLAMA3_SCALING,
76+
)
77+
reference = _llama3_reference_apply_scaling(
78+
rope.default_rope_frequencies(head_dim=head_dim, theta=theta)
79+
)
80+
torch.testing.assert_close(ours, reference)
81+
82+
83+
def _llama3_reference_apply_scaling(freqs: torch.Tensor):
84+
"""
85+
Reference from https://github.com/karpathy/llm.c/blob/7ecd8906afe6ed7a2b2cdb731c042f26d525b820/train_llama3.py#L80
86+
"""
87+
# Values obtained from grid search
88+
scale_factor = 8
89+
low_freq_factor = 1
90+
high_freq_factor = 4
91+
old_context_len = 8192 # original llama3 length
92+
93+
low_freq_wavelen = old_context_len / low_freq_factor
94+
high_freq_wavelen = old_context_len / high_freq_factor
95+
new_freqs = []
96+
for freq in freqs:
97+
wavelen = 2 * math.pi / freq
98+
if wavelen < high_freq_wavelen:
99+
new_freqs.append(freq)
100+
elif wavelen > low_freq_wavelen:
101+
new_freqs.append(freq / scale_factor)
102+
else:
103+
assert low_freq_wavelen != high_freq_wavelen
104+
smooth = (old_context_len / wavelen - low_freq_factor) / (
105+
high_freq_factor - low_freq_factor
106+
)
107+
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
108+
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)

0 commit comments

Comments
 (0)