Skip to content

Commit 7a8f48a

Browse files
Merge pull request #2 from tillahoffmann/action
Add GitHub action.
2 parents e0e1fa5 + 2599452 commit 7a8f48a

File tree

5 files changed

+46
-5
lines changed

5 files changed

+46
-5
lines changed

.github/workflows/main.yaml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
name: trecs
2+
3+
on:
4+
push:
5+
branches: ["main"]
6+
pull_request:
7+
branches: ["main"]
8+
9+
jobs:
10+
build:
11+
name: Build
12+
runs-on: "ubuntu-latest"
13+
steps:
14+
- uses: "actions/checkout@v4"
15+
- name: Install uv
16+
uses: astral-sh/setup-uv@v5
17+
with:
18+
enable-cache: true
19+
cache-dependency-glob: uv.lock
20+
- name: Set up Python
21+
uses: actions/setup-python@v5
22+
with:
23+
python-version-file: .python-version
24+
- name: Install project
25+
run: uv sync --all-groups
26+
- name: Lint the code
27+
run: uv run black --check .
28+
- name: Run the tests
29+
run: uv run pytest -v tests

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# 🎶 trecs: Transformer-Based Playlist Recommendation
1+
# 🎶 trecs: Transformer-Based Playlist Recommendation [![trecs](https://github.com/tillahoffmann/trecs/actions/workflows/main.yaml/badge.svg)](https://github.com/tillahoffmann/trecs/actions/workflows/main.yaml)
22

33
This repository implements a transformer architecture to complete Spotify playlists given a set of seed tracks. The model is trained on the [Million Playlist Dataset](https://www.aicrowd.com/challenges/spotify-million-playlist-dataset-challenge) (MPD).
44

src/trecs/data/util.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,19 @@ def _get_conn(self) -> sqlite3.Connection:
9090
self._local.conn = sqlite3.connect(self.conn)
9191
return self._local.conn
9292

93+
def _close_conn(self) -> None:
94+
# Close the connection if there is one.
95+
if isinstance(self.conn, sqlite3.Connection):
96+
self.conn.close()
97+
elif isinstance(self._local.conn, sqlite3.Connection):
98+
self._local.conn.close()
99+
100+
def __enter__(self) -> Self:
101+
return self
102+
103+
def __exit__(self, *_) -> None:
104+
self._close_conn()
105+
93106
@property
94107
def idx(self) -> list[Any]:
95108
if self._idx is None:

src/trecs/scripts/predict.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1 @@
11
import argparse
2-
3-
4-

src/trecs/scripts/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def restore(
4949
rngs: nnx.Rngs | None = None,
5050
) -> int:
5151
flax_nodes = {"model": model, "optimizer": optimizer, "rngs": rngs}
52-
flax_states = {key: nnx.state(value) for key, value in flax_nodes.items() if value is not None}
52+
flax_states = {
53+
key: nnx.state(value) for key, value in flax_nodes.items() if value is not None
54+
}
5355
data_iterators = data_iterators or {}
5456
restored = checkpoint_manager.restore(
5557
step,

0 commit comments

Comments
 (0)