| 
1 |  | -# Spotify Recommender  | 
 | 1 | +# 🎶 trecs: Transformer-Based Playlist Recommendation  | 
2 | 2 | 
 
  | 
3 |  | -This repository implements a chat-based playlist generator.  | 
 | 3 | +This repository implements a transformer architecture to complete Spotify playlists given a set of seed tracks. The model is trained on the [Million Playlist Dataset](https://www.aicrowd.com/challenges/spotify-million-playlist-dataset-challenge) (MPD).  | 
4 | 4 | 
 
  | 
5 |  | -## Architecture  | 
 | 5 | +It can easily be extended to integrate artist, album, and audio features, as well as implicit user taste and expressed user preferences for a particular session.  | 
6 | 6 | 
 
  | 
7 |  | -The recommender comprises three parts:  | 
 | 7 | +## 🌟 Highlights  | 
8 | 8 | 
 
  | 
9 |  | -1. An orchestrator agent that chats with the user and can use a `recommend(description: str, constraints: ...)` tool once it has gathered enough information to generate recommendations. The description is a textual summary of the requirements the user has specified. This can be any commercially available or open weights model.  | 
10 |  | -2. Once the `recommend` tool is called, a pretrained encoder model maps the textual summary to an embedding vector.  | 
11 |  | -3. A conditional decoder model with cross-attention onto the embedding from step 2. decodes track tokens to generate a new playlist. Any hard `constraints` are applied by modifying the logits of the track predictions, e.g., "only songs from Taylor Swift."  | 
12 | 9 | 
 
  | 
13 |  | -This architecture is motivated by:  | 
14 | 10 | 
 
  | 
15 |  | -* Commercial models are great at chatting, and there is no need to reinvent the wheel—nor could we. We need this chat interface to meet user expectations. The tool call is our way to translate between the conversation and a semantic representation of the user preferences.  | 
16 |  | -* Pretrained encoders already have a good understanding of general cultural context and generate meaningful semantic representations. Using a tool call to generate the description decouples the user interaction from the textual representation that will serve as input to condition a decoder on. This means we can easily swap out different components of the system without re-training. Even swapping out the chat agent or modifying its prompt is acceptable provided the textual input to the encoder has the same semantics.  | 
17 |  | -* The decoder can be pretrained on a large playlist dataset without needing `(description_embedding, playlist)` pairs. We can then fine-tune the cross-attention layer on a smaller dataset with those pairs. This means swapping out the encoder requires re-training the decoder.  | 
 | 11 | +## 🏗️ Architecture  | 
18 | 12 | 
 
  | 
19 |  | -### Next steps  | 
 | 13 | +The model is a 6-layer causal transformer with 8 heads, 128 embedding dimensions, pre-layer norm, and 0.1 dropout during training. The feed-forward network of each transformer block is a two-layer dense network with ReLU activations and hidden dimension of 256. The model uses weight-tying, i.e., the output head is identical to the token embeddings. We use token-based positional encodings (as opposed to Fourier-style positional encodings) with a maximum context length of 50 tracks.  | 
20 | 14 | 
 
  | 
21 |  | -* Training the encoder and decoder together as in the seminal "Attention is All You Need" paper.  | 
 | 15 | +Tracks are treated as individual tokens because limited semantic information is available in the MPD. Semantic embeddings or semantic side information could be more easily integrated for content such as movies (e.g., using reviews and synopses) or podcasts (e.g., using concise vectors obtained through pre-trained encoders like DistilBERT applied to transcripts). The tokenizer is thus a lookup from track uri to continuous integers, each representing a 128-dimensional embedding. The embedding dimension is small because the training set comprises approximately two million tracks, corresponding to about one Gigabyte of embeddings. This is not large for modern standards, but large enough to make my laptop sweat.  | 
22 | 16 | 
 
  | 
23 |  | -## Data  | 
 | 17 | +## 🏋️ Training  | 
24 | 18 | 
 
  | 
25 |  | -The million playlist dataset from https://www.aicrowd.com/challenges/spotify-million-playlist-dataset-challenge. The data are stuck together into a single sqlite database (see [`schema.sql`](./src/spotify_recommender/schema.sql) and [`build_db.py`](./src/spotify_recommender/scripts/build_db.py) for details). This database also contains tables to indicate the train-test split so we don't accidentally have leakage and reproducible training. Building the database takes 15 to 20 minutes, but it's well worth it for the ease of querying data.  | 
 | 19 | +This section outlines how to train the transformer model from scratch—taking only about 48 hours on a 2020 laptop.  | 
26 | 20 | 
 
  | 
27 |  | -## Training  | 
 | 21 | +### ▶️ Runtime Environment  | 
28 | 22 | 
 
  | 
29 |  | -### Self-supervised decoder pre-training  | 
 | 23 | +The runtime environment is managed through [`uv`](https://docs.astral.sh/uv/). Make sure `uv` is installed and run `uv sync` to set up the environment. You can validate the installation by running `uv run pytest`.  | 
30 | 24 | 
 
  | 
31 |  | -We treat playlists as documents and pretrain a GPT-style model on the training set until the validation loss starts increasing. The loss is softmax-cross entropy for the next track.  | 
 | 25 | +### 📊 Data  | 
32 | 26 | 
 
  | 
33 |  | -### Supervised cross-attention training  | 
 | 27 | +Download the [Million Playlist Dataset](https://www.aicrowd.com/challenges/spotify-million-playlist-dataset-challenge) from AIcrowd and place the archive at `data/spotify_million_playlist_dataset.zip`. Then run `uv run make data/mpd.db` which transforms the data into a sqlite database and takes about 15 minutes. This may seem cumbersome but facilitates efficient queries to generate training data on the fly and persists train-validation-test splits, ensuring there is no accidental data leakage that might affect evaluation. For details, see [`schema.sql`](./src/trecs/schema.sql) and the [`build_db.py`](./src/trecs/scripts/build_db.py) script.  | 
34 | 28 | 
 
  | 
35 |  | -For a subset of the training set, we generate synthetic textual summaries that serve as the ground truth for training the cross-attention layer. These could be generated by human experts to describe the playlist. But I'm neither an expert, nor do I have the time. So we'll be feeding a text representation of the playlist, including title, song titles, album titles, artist names, and some of the audio features into a commercial model. That model will give us a `description` which is what we'd expect to receive from the tool call of the orchestrator agent. We then train *only* the cross-attention layer by conditioning on the embedding.  | 
 | 29 | +The dataset is split into train-validation-test sets with 80-10-10 ratio. The model must be able to handle `<unk>` tokens which often appear in samples from the validation set because there may be tracks in the validation set that are not in the training set. This is important for two reasons. First, we would not be able to evaluate the validation loss without accounting for unknown tokens. Second, the model would not be able to predict tracks beyond the first `<unk>` token it encounters in playlists from the validation set. Injecting `<unk>` tokens during training addresses both challenges because the model learns to ignore `<unk>` tokens and assigns a low but non-negligible probability to `<unk>`. This may seem like leakage from the validation set, but we can easily estimate the expected fraction of `<unk>` tokens in a production environment. If a playlist is shorter than the 50-track context size, we pad with `<eop>` tokens, representing the end of a playlist.  | 
36 | 30 | 
 
  | 
37 |  | -### Fine-tuning  | 
 | 31 | +### ⚙️ Model Implementation and Training Details  | 
38 | 32 | 
 
  | 
39 |  | -Once the cross-attention layers have been trained, we may also want to fine tune the whole thing together. This is an empirical question, because we don't want to overfit the decoder part of the model onto the subset of the training data for which we have generated `description`s.  | 
 | 33 | +The model is implemented in Flax NNX on top of JAX, using Google's Orbax for checkpointing and Grain for data loading with a custom data source implementation to fetch data from the sqlite database. The Orbax-Grain combination facilitates fully reproducible training because model, optimizer, and data loader states can be saved and restored.  | 
40 | 34 | 
 
  | 
41 |  | -### Inference (what should be called prediction)  | 
 | 35 | +The model is trained using batch size 16 for one epoch, corresponding to 50000 iterations. We use the AdamW optimizer with constant learning rate 0.0005 and weight decay 0.01. Playlist samples are drawn uniformly at random without replacement from the training set for each epoch. The validation loss is evaluated on a single batch every 100 iterations, and the model is checkpointed every 1000 iterations. Training progress is monitored using TensorBoard.  | 
42 | 36 | 
 
  | 
43 |  | -We find the maximum-dot-product items in the database and filter them down based on hard constraints. Once we've identified the top-k items that meet the constraints, we sample from the softmax. This is equivalent to doing top-k sampling while also being efficient using the nearest neighbor search.  | 
 | 37 | +Evaluating the softmax cross-entropy for next-token prediction as the loss function is computationally prohibitive for large vocabularies, and we use a sampled softmax-cross entropy evaluation. Sampling is applied at the embedding level such that the logits for tokens not included in the contrastive sample do not need to be evaluated. We use a sample size of 20, reducing the computational burden of evaluating the loss by five orders of magnitude.  | 
44 | 38 | 
 
  | 
45 |  | -### Next steps  | 
 | 39 | +  | 
46 | 40 | 
 
  | 
47 |  | -* Include artist and album embeddings. This can be achieved by having a dense network to generate the initial context vectors passed to the transformer. We want to make sure that there is an additive component to this network like a resnet because that naturally embeds the tracks, albums, and artists into the same space (or at least linear projections into the same space if they have different embedding dimensions).  | 
 | 41 | +As shown in the figure above (and as we might expect based on a single training epoch), the training and validation losses decrease consistently and reach a loss around 3—because the loss is sampled and biased (see below), this is not equivalent to the log perplexity. Nevertheless, we expect the model to have learned *something*. Running a second epoch and introducing a learning rate schedule could further improve the results.  | 
 | 42 | + | 
 | 43 | +## 🪄 Inference  | 
 | 44 | + | 
 | 45 | +We directly sample from the next-token distribution with top-$k$ sampling with unit temperature to preserve diversity while preventing sampling of low-probability tokens that, together, make up a non-negligible probability mass. The top few most likely token often have significant posterior mass such that top-$p$ sampling would likely be close to greedy decoding. Greedy decoding did not generate desirable playlists in experiments. Top-$k$ sampling has two additional benefits. First, top-$k$ sampling can be performed efficiently even for *very* large vocabularies because it is equivalent to [$k$-nearest neighbors search (with dot-product norm)](https://en.wikipedia.org/wiki/Maximum_inner-product_search) in the output head of the transformer, using the embedding of the last token as the query.  | 
 | 46 | + | 
 | 47 | +`<unk>` tokens are assigned zero probability because we can't predict unknown songs (although predicting out-of-vocab would become feasible with artist, album, and audio context). We also mask out `<eop>` tokens to generate playlists that fill the entire context window. Because `<eop>` tokens appear in a large number of playlists (except the ones that exceed the context window size), they are likely to be predicted, often ending playlist generation early. Alternatively, we could artificially down-weight the probability of the `<eop>` token to control typical playlist length. We also assign zero probability to tracks that already occur in the playlist to prevent repetition.  | 
 | 48 | + | 
 | 49 | +## 🚀 Next Steps  | 
 | 50 | + | 
 | 51 | +* We should prepend a special `<start>` token to all playlists because the first token is never used as a label.  | 
 | 52 | +* The sampled softmax cross-entropy is biased due to the non-linearity in the denominator. We should be able to apply a low-order bias correction to get a better estimate, although it remains to be established if the bias affects the gradients. The bias leads to an *optimistic* estimate of the perplexity (see appendix for details). First-order bias correction is definitely feasible for the loss.  | 
 | 53 | +* The model can be readily extended to include album, artist, and stylistic coherence as well as conditioning on a representation of user taste or expressed user preference for a particular session—future work.  | 
 | 54 | + | 
 | 55 | +## 📒 Appendix  | 
 | 56 | + | 
 | 57 | +### Bias of Sampled Softmax Cross-Entropy  | 
 | 58 | + | 
 | 59 | +The (naive) sampled softmax cross-entropy is $-\log \xi_i + \log\left(\frac{n}{\left\vert S\right\vert}\sum_{j\in S} \exp \xi_j\right)$, where $\xi$ are token logits and $n$ is the vocabulary size and $S$ is a random sample of negatives, i.e., not the target token $i$. While the sum is an unbiased estimator of the desired quantity, the $\log$ non-linearity leads to a bias. For any random variable $x$, $\mathbb{E}\left[\log x\right]=\mathbb{E}\left[\log\left(\bar x+\delta\right)\right]\approx \log\bar x-\frac{\sigma^2}{2\bar{x}^2}$ to second order in a perturbation $\delta$ about the mean $\bar x$. The first-order term vanishes due to $\mathbb{E}\left[\delta\right]=0$ by assumption, and $\mathrm{var}\,\delta=\sigma^2$. In other words, our sampled softmax is somewhat optimistic.  | 
0 commit comments