File tree Expand file tree Collapse file tree 5 files changed +46
-5
lines changed Expand file tree Collapse file tree 5 files changed +46
-5
lines changed Original file line number Diff line number Diff line change 1+ name : trecs
2+
3+ on :
4+ push :
5+ branches : ["main"]
6+ pull_request :
7+ branches : ["main"]
8+
9+ jobs :
10+ build :
11+ name : Build
12+ runs-on : " ubuntu-latest"
13+ steps :
14+ - uses : " actions/checkout@v4"
15+ - name : Install uv
16+ uses : astral-sh/setup-uv@v5
17+ with :
18+ enable-cache : true
19+ cache-dependency-glob : uv.lock
20+ - name : Set up Python
21+ uses : actions/setup-python@v5
22+ with :
23+ python-version-file : .python-version
24+ - name : Install project
25+ run : uv sync --all-groups
26+ - name : Lint the code
27+ run : uv run black --check .
28+ - name : Run the tests
29+ run : uv run pytest -v tests
Original file line number Diff line number Diff line change 1- # 🎶 trecs: Transformer-Based Playlist Recommendation
1+ # 🎶 trecs: Transformer-Based Playlist Recommendation [ ![ trecs ] ( https://github.com/tillahoffmann/trecs/actions/workflows/main.yaml/badge.svg )] ( https://github.com/tillahoffmann/trecs/actions/workflows/main.yaml )
22
33This repository implements a transformer architecture to complete Spotify playlists given a set of seed tracks. The model is trained on the [ Million Playlist Dataset] ( https://www.aicrowd.com/challenges/spotify-million-playlist-dataset-challenge ) (MPD).
44
Original file line number Diff line number Diff line change @@ -90,6 +90,19 @@ def _get_conn(self) -> sqlite3.Connection:
9090 self ._local .conn = sqlite3 .connect (self .conn )
9191 return self ._local .conn
9292
93+ def _close_conn (self ) -> None :
94+ # Close the connection if there is one.
95+ if isinstance (self .conn , sqlite3 .Connection ):
96+ self .conn .close ()
97+ elif isinstance (self ._local .conn , sqlite3 .Connection ):
98+ self ._local .conn .close ()
99+
100+ def __enter__ (self ) -> Self :
101+ return self
102+
103+ def __exit__ (self , * _ ) -> None :
104+ self ._close_conn ()
105+
93106 @property
94107 def idx (self ) -> list [Any ]:
95108 if self ._idx is None :
Original file line number Diff line number Diff line change 11import argparse
2-
3-
4-
Original file line number Diff line number Diff line change @@ -49,7 +49,9 @@ def restore(
4949 rngs : nnx .Rngs | None = None ,
5050) -> int :
5151 flax_nodes = {"model" : model , "optimizer" : optimizer , "rngs" : rngs }
52- flax_states = {key : nnx .state (value ) for key , value in flax_nodes .items () if value is not None }
52+ flax_states = {
53+ key : nnx .state (value ) for key , value in flax_nodes .items () if value is not None
54+ }
5355 data_iterators = data_iterators or {}
5456 restored = checkpoint_manager .restore (
5557 step ,
You can’t perform that action at this time.
0 commit comments