Skip to content

Commit b685bc7

Browse files
Merge pull request #4 from tillahoffmann/sampling
Use uniform sampled softmax and add `<START>` token.
2 parents 6b1676d + d94836c commit b685bc7

File tree

13 files changed

+230
-886
lines changed

13 files changed

+230
-886
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ __pycache__/
66
.vscode
77
*.egg-info
88
*.ipynb
9-
*.ipynb-checkpoint
9+
*.ipynb_checkpoints
1010
/data/
1111
~*
1212
htmlcov/
1313
playground/
14+
workspace/

Makefile

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.PHONY : tests
1+
.PHONY : experiments tests
22

33
tests :
44
pytest tests --cov=trecs --cov-report=term-missing -v
@@ -21,3 +21,15 @@ data/spotify_million_playlist_dataset/md5sums.check : data/spotify_million_playl
2121
data/mpd.db : data/spotify_million_playlist_dataset/md5sums.check
2222
# Build the database.
2323
python -m trecs.scripts.build_db data/mpd.db data/spotify_million_playlist_dataset/data/mpd.slice.*.json
24+
25+
# Training.
26+
27+
WORKDIR ?= workspace
28+
MPD_PATH ?= data/mpd.db
29+
EXPERIMENT_SETUPS = $(filter-out $(wildcard src/trecs/experiments/*/_*.py),$(wildcard src/trecs/experiments/*/*.py))
30+
EXPERIMENT_OUTPUTS = $(addprefix ${WORKDIR}/,${EXPERIMENT_SETUPS:src/trecs/experiments/%.py=%})
31+
32+
experiments : ${EXPERIMENT_OUTPUTS}
33+
34+
${EXPERIMENT_OUTPUTS} : ${WORKDIR}/% : src/trecs/experiments/%.py data/mpd.db
35+
MPD=${MPD_PATH} python -m trecs.scripts.train $@ $<

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ We directly sample from the next-token distribution with top-$k$ sampling with u
4848

4949
## 🚀 Next Steps
5050

51-
* We should prepend a special `<start>` token to all playlists because the first token is never used as a label.
5251
* The sampled softmax cross-entropy is biased due to the non-linearity in the denominator. We should be able to apply a low-order bias correction to get a better estimate, although it remains to be established if the bias affects the gradients. The bias leads to an *optimistic* estimate of the perplexity (see appendix for details). First-order bias correction is definitely feasible for the loss.
5352
* The model can be readily extended to include album, artist, and stylistic coherence as well as conditioning on a representation of user taste or expressed user preference for a particular session—future work.
5453

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ dependencies = [
1717
"orbax-checkpoint>=0.11.20",
1818
"pandas>=2.3.1",
1919
"pydantic>=2.11.7",
20-
"pydantic-ai>=0.4.11",
2120
"python-dotenv>=1.1.1",
2221
"tensorboard>=2.20.0",
2322
"tensorboardx>=2.6.4",
@@ -29,6 +28,8 @@ dependencies = [
2928
dev = [
3029
"black>=25.1.0",
3130
"jupyter>=1.1.1",
31+
"jupytext>=1.17.2",
32+
"localscope>=0.2.5",
3233
"matplotlib>=3.10.5",
3334
"pyright>=1.1.403",
3435
"pytest>=8.4.1",
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from ._base import DecoderOnlyExperiment
2+
3+
4+
__all__ = [
5+
"DecoderOnlyExperiment",
6+
]

src/trecs/experiments/decoder_only.py renamed to src/trecs/experiments/decoder_only/_base.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from jax import numpy as jnp
1111
import pydantic
1212
import sqlite3
13-
from typing import cast
14-
from .util import Experiment
15-
from ..models import PlaylistDecoder
16-
from ..data import (
13+
from typing import cast, Literal
14+
from ..util import Experiment
15+
from ...models import PlaylistDecoder
16+
from ...data import (
1717
Sqlite3Dataset,
1818
SELECT_DISTINCT_TRACK_IDS_BY_SPLIT,
1919
SELECT_PLAYLISTS_BY_SPLIT,
@@ -24,7 +24,11 @@
2424
BatchTransform,
2525
Encoder,
2626
)
27-
from ..util import sampled_dot_cross_entropy_with_integer_labels, evaluate_eop_loss_mask
27+
from ...util import (
28+
sampled_dot_cross_entropy_with_integer_labels_and_label_in_denominator,
29+
sampled_dot_cross_entropy_with_integer_labels_uniform,
30+
evaluate_eop_loss_mask,
31+
)
2832

2933

3034
class DecoderOnlyExperiment(Experiment):
@@ -33,11 +37,18 @@ class DecoderOnlyExperiment(Experiment):
3337
num_heads: int
3438
num_features: int
3539
num_hidden: int
40+
loss_function: Literal[
41+
"label_in_denominator",
42+
"uniform",
43+
]
3644
dropout: float = pydantic.Field(ge=0, le=1)
3745
num_tracks: int | None
3846
unk_proba: float = pydantic.Field(ge=0, le=1)
3947
weight_decay: float = pydantic.Field(ge=0)
48+
49+
start_token: int | None = None
4050
eop_token: int | None = None
51+
unk_token: int | None = None
4152
track_encoder: Encoder | None = None
4253

4354
# Because the `Encoder` is not a standard class.
@@ -78,7 +89,7 @@ def create_data_source(self, split: str) -> RandomAccessDataSource:
7889
ON ptm.track_id = tracks.id
7990
WHERE ptm.playlist_id = :id
8091
ORDER BY ptm.pos
81-
LIMIT :context_length + 1
92+
LIMIT :context_length
8293
""",
8394
{"split": split},
8495
{"context_length": self.context_length},
@@ -90,7 +101,14 @@ def create_data_loader(
90101
) -> DataLoader:
91102
assert self.track_encoder, "Create track encoder first."
92103
operations = [
93-
# {START}: {"pos": [0, 1, ...], "track_id": [43, 7, ...]}
104+
# {INPUT}: {"pos": [0, 1, ...], "track_id": [43, 7, ...]}
105+
# Inject a start token.
106+
LambdaMap[dict, dict](
107+
lambda x: {
108+
"track_id": ["<START>", *x["track_id"]],
109+
"pos": list(range(len(x["pos"]) + 1)),
110+
}
111+
),
94112
# Encode tracks and truncate to the maximum context length:
95113
# {"pos": [0, 1, ...], "track_id": [0, 1, ...]}
96114
LambdaMap[dict, dict](
@@ -108,7 +126,11 @@ def create_data_loader(
108126
# Batch records: [{"track_id": [0, 1], ...}, {"track_id": [4, 5], ...}, ...]
109127
BatchTransform(self.batch_size, on_short="drop"),
110128
# Pad values to the same length.
111-
LambdaMap(pad_batch, fill_value={"track_id": self.eop_token, "pos": 0}),
129+
LambdaMap(
130+
pad_batch,
131+
fill_value={"track_id": self.eop_token, "pos": 0},
132+
length=self.context_length + 1,
133+
),
112134
# Transpose to get a dictionary keyed by `track_id`, `pos`, etc. Then
113135
# convert to jax arrays.
114136
LambdaMap[dict, dict](
@@ -166,7 +188,16 @@ def evaluate_loss(
166188
flat_labels = labels.reshape((batch_size * num_tokens,))
167189

168190
# Evaluate the loss.
169-
sampled_loss = sampled_dot_cross_entropy_with_integer_labels(
191+
if self.loss_function == "label_in_denominator":
192+
func = (
193+
sampled_dot_cross_entropy_with_integer_labels_and_label_in_denominator
194+
)
195+
elif self.loss_function == "uniform":
196+
func = sampled_dot_cross_entropy_with_integer_labels_uniform
197+
else:
198+
raise ValueError(self.loss_function)
199+
200+
sampled_loss = func(
170201
prng_key,
171202
flat_embeddings,
172203
model.track_embedding.embedding.value,
@@ -214,14 +245,23 @@ def setup_output(self, output: Path) -> None:
214245
SELECT_DISTINCT_TRACK_IDS_BY_SPLIT, {"split": "train"}
215246
)
216247
self.track_encoder = Encoder(
217-
["<UNK>", "<EOP>", *(track_id for (track_id,) in cursor)],
248+
[
249+
"<START>",
250+
"<EOP>",
251+
"<UNK>",
252+
*(track_id for (track_id,) in cursor),
253+
],
218254
on_unknown="default",
219255
default="<UNK>",
220256
)
221257
self.track_encoder.to_pickle(encoder_path)
222258
print(f"Built new track encoder with {len(self.track_encoder):,} tokens.")
223259

260+
# Get named special tokens.
261+
self.start_token = self.track_encoder("<START>")
224262
self.eop_token = self.track_encoder("<EOP>")
263+
self.unk_token = self.track_encoder("<UNK>")
264+
225265
num_tracks = len(self.track_encoder)
226266
if self.num_tracks is None:
227267
self.num_tracks = num_tracks
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from trecs.experiments.decoder_only import DecoderOnlyExperiment
2+
3+
4+
def setup() -> DecoderOnlyExperiment:
5+
train_size = 800_000
6+
batch_size = 16
7+
num_steps = train_size // batch_size
8+
return DecoderOnlyExperiment(
9+
seed=42,
10+
num_steps=num_steps,
11+
batch_size=batch_size,
12+
learning_rate=0.0005,
13+
weight_decay=0.01,
14+
context_length=50,
15+
num_layers=6,
16+
num_features=128,
17+
num_hidden=256,
18+
num_heads=8,
19+
dropout=0.1,
20+
eval_every=100,
21+
checkpoint_every=1000,
22+
loss_function="uniform",
23+
unk_proba=0.01,
24+
# This will be determined by the encoder.
25+
num_tracks=None,
26+
)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from trecs.experiments.decoder_only import DecoderOnlyExperiment
2+
3+
4+
def setup() -> DecoderOnlyExperiment:
5+
train_size = 800_000
6+
batch_size = 16
7+
num_steps = train_size // batch_size
8+
return DecoderOnlyExperiment(
9+
seed=42,
10+
num_steps=num_steps,
11+
batch_size=batch_size,
12+
learning_rate=0.0005,
13+
weight_decay=0.01,
14+
context_length=50,
15+
num_layers=6,
16+
num_features=128,
17+
num_hidden=256,
18+
num_heads=8,
19+
dropout=0.1,
20+
eval_every=100,
21+
checkpoint_every=1000,
22+
loss_function="label_in_denominator",
23+
unk_proba=0.01,
24+
# This will be determined by the encoder.
25+
num_tracks=None,
26+
)

src/trecs/util.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import contextlib
22
from jax import numpy as jnp
33
from jax import random
4+
from jax.scipy.special import logsumexp
45
from pathlib import Path
56
from typing import Generator, IO
67
import importlib.util
@@ -45,7 +46,48 @@ def safe_write(
4546
tmp_path.rename(path)
4647

4748

48-
def sampled_dot_cross_entropy_with_integer_labels(
49+
def sampled_dot_cross_entropy_with_integer_labels_uniform(
50+
key: jnp.ndarray,
51+
query: jnp.ndarray,
52+
embedding: jnp.ndarray,
53+
labels: jnp.ndarray,
54+
num_samples: int = 20,
55+
):
56+
"""Evaluate the sampled cross entropy based on logits obtained through a dot
57+
product `query @ embedding.T`. This function never evaluates the full dot product
58+
but only considers a sampled subset of the embedding matrix.
59+
60+
Args:
61+
key: Random number generator key.
62+
query: Context to contract with the output embedding with shape
63+
`(batch_size, num_features)`.
64+
embedding: Output embedding with shape `(num_classes, num_features)`.
65+
labels: Target labels with shape `(batch_size,)`.
66+
num_samples: Number of samples for the sampled softmax cross-entropy.
67+
68+
Returns:
69+
Sampled cross-entropy with shape `(batch_size,)`.
70+
"""
71+
batch_size, query_num_features = query.shape
72+
num_classes, embedding_num_features = embedding.shape
73+
assert query_num_features == embedding_num_features
74+
75+
# Logits for the labels we are after.
76+
label_logits = jnp.vecdot(query, embedding[labels])
77+
78+
# Sample indices uniformly at random with replacement. This introduces extra
79+
# variance because we can double-sample certain indices, but this effect is small
80+
# when num_samples << num_classes.
81+
idx = random.randint(key, (num_samples, batch_size), 0, num_classes)
82+
sampled_logits = jnp.vecdot(query, embedding[idx]).T
83+
return (
84+
-label_logits
85+
+ logsumexp(sampled_logits, axis=1)
86+
+ jnp.log(num_classes / num_samples)
87+
)
88+
89+
90+
def sampled_dot_cross_entropy_with_integer_labels_and_label_in_denominator(
4991
key: jnp.ndarray,
5092
query: jnp.ndarray,
5193
embedding: jnp.ndarray,
@@ -62,6 +104,7 @@ def sampled_dot_cross_entropy_with_integer_labels(
62104
`(batch_size, num_features)`.
63105
embedding: Output embedding with shape `(num_classes, num_features)`.
64106
labels: Target labels with shape `(batch_size,)`.
107+
num_samples: Number of samples for the sampled softmax cross-entropy.
65108
66109
Returns:
67110
Sampled cross-entropy with shape `(batch_size,)`.

tests/assets/decoder_only_mini_experiment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ def setup():
1818
unk_proba=0.05,
1919
weight_decay=0.01,
2020
num_tracks=None,
21+
loss_function="uniform",
2122
)

0 commit comments

Comments
 (0)