Skip to content

Commit cb79647

Browse files
committed
Add examples
1 parent c0418c2 commit cb79647

File tree

5 files changed

+460
-0
lines changed

5 files changed

+460
-0
lines changed

examples/1-ubm/main.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from pathlib import Path
2+
from typing import Tuple
3+
4+
import optax
5+
from flax import nnx
6+
from flax.training.early_stopping import EarlyStopping
7+
from torch.utils.data import DataLoader
8+
9+
from clax import UserBrowsingModel
10+
from clax.datasets import YandexDataset
11+
from clax.trainer import Trainer
12+
13+
14+
def get_yandex_loader(
15+
dataset_dir: Path,
16+
session_range: Tuple[int, int],
17+
):
18+
dataset = YandexDataset(
19+
dataset_dir=dataset_dir,
20+
session_range=session_range,
21+
)
22+
23+
return DataLoader(
24+
dataset,
25+
batch_size=4_096,
26+
collate_fn=dataset.collate_fn,
27+
num_workers=4,
28+
persistent_workers=True,
29+
)
30+
31+
32+
def main():
33+
# Load a few sessions from the Yandex WSCD-2012 dataset:
34+
dataset_dir = Path("../../clax-datasets/yandex")
35+
train_loader = get_yandex_loader(dataset_dir, session_range=(0, 1_000_000))
36+
val_loader = get_yandex_loader(dataset_dir, session_range=(1_000_000, 1_500_000))
37+
test_loader = get_yandex_loader(dataset_dir, session_range=(1_500_000, 2_000_000))
38+
39+
# Instantiate a UBM:
40+
rngs = nnx.Rngs(42)
41+
model = UserBrowsingModel(
42+
query_doc_pairs=10_000_000,
43+
positions=10,
44+
rngs=rngs,
45+
)
46+
47+
# Train and evaluate a UBM:
48+
trainer = Trainer(
49+
optax.adamw(0.0003),
50+
epochs=10,
51+
early_stopping=EarlyStopping(patience=0),
52+
)
53+
train_df = trainer.train(model, train_loader, val_loader)
54+
test_df = trainer.test_clicks(model, test_loader)
55+
56+
# Use the trained UBM:
57+
batch = next(iter(test_loader))
58+
59+
print("Predict unconditional click probabilities:")
60+
print(model.predict_clicks(batch))
61+
62+
print("Predict conditional click probabilities:")
63+
print(model.predict_conditional_clicks(batch))
64+
65+
print("Predict query-doc relevance for ranking:")
66+
print(model.predict_relevance(batch))
67+
68+
print("Sample clicks:")
69+
print(model.sample(batch, rngs=rngs))
70+
71+
72+
if __name__ == "__main__":
73+
main()
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from functools import partial
2+
from pathlib import Path
3+
from typing import Tuple
4+
5+
import optax
6+
from flax import nnx
7+
from flax.training.early_stopping import EarlyStopping
8+
from torch.utils.data import DataLoader
9+
10+
from clax import DynamicBayesianNetwork, ClickChainModel
11+
from clax.datasets import YandexDataset
12+
from clax.parameters import EmbeddingParameterConfig, QREmbedding
13+
from clax.parameters.embeddings.compositional import Combination
14+
from clax.trainer import Trainer
15+
16+
17+
def get_yandex_loader(
18+
dataset_dir: Path,
19+
session_range: Tuple[int, int],
20+
):
21+
dataset = YandexDataset(
22+
dataset_dir=dataset_dir,
23+
session_range=session_range,
24+
)
25+
26+
return DataLoader(
27+
dataset,
28+
batch_size=4_096,
29+
collate_fn=dataset.collate_fn,
30+
num_workers=4,
31+
persistent_workers=True,
32+
)
33+
34+
35+
def main():
36+
# Scale to the entire Yandex WSCD-2012 dataset with 346_711_929 query-doc pairs
37+
dataset_dir = Path("../../clax-datasets/yandex")
38+
query_doc_pairs = 346_711_929
39+
40+
train_loader = get_yandex_loader(
41+
dataset_dir,
42+
session_range=(0, 100_000_000),
43+
)
44+
val_loader = get_yandex_loader(
45+
dataset_dir,
46+
session_range=(100_000_000, 120_000_000),
47+
)
48+
test_loader = get_yandex_loader(
49+
dataset_dir,
50+
session_range=(120_000_000, 145_000_000),
51+
)
52+
53+
# Instantiate a CCM with Quotient-Remainder compression to reduce the number
54+
# of allocated embeddings by a factor of 1000x and multiplicative combination:
55+
rngs = nnx.Rngs(42)
56+
57+
model = ClickChainModel(
58+
attraction=EmbeddingParameterConfig(
59+
use_feature="query_doc_ids",
60+
embedding_fn=partial(
61+
QREmbedding, # Use HashEmbedding for hashing-trick compression
62+
compression_ratio=1000,
63+
),
64+
parameters=query_doc_pairs,
65+
add_baseline=True,
66+
),
67+
rngs=rngs,
68+
)
69+
70+
# Train and evaluate a UBM:
71+
trainer = Trainer(
72+
optax.adamw(0.0003),
73+
epochs=10,
74+
early_stopping=EarlyStopping(patience=0),
75+
)
76+
train_df = trainer.train(model, train_loader, val_loader)
77+
test_df = trainer.test_clicks(model, test_loader)
78+
79+
80+
if __name__ == "__main__":
81+
main()

examples/3-two-tower-model/main.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from pathlib import Path
2+
from typing import Tuple
3+
4+
import optax
5+
from flax import nnx
6+
from flax.training.early_stopping import EarlyStopping
7+
from torch.utils.data import DataLoader
8+
9+
from clax import PositionBasedModel
10+
from clax.datasets import (
11+
BaiduUltrFeatureClickDataset,
12+
BaiduUltrFeatureAnnotationDataset,
13+
)
14+
from clax.parameters import DeepCrossParameterConfig
15+
from clax.trainer import Trainer
16+
17+
18+
def get_baidu_click_loader(
19+
dataset_dir: Path,
20+
session_range: Tuple[int, int],
21+
):
22+
dataset = BaiduUltrFeatureClickDataset(
23+
dataset_dir=dataset_dir,
24+
session_range=session_range,
25+
)
26+
27+
return DataLoader(
28+
dataset,
29+
batch_size=256,
30+
collate_fn=dataset.collate_fn,
31+
num_workers=2,
32+
persistent_workers=True,
33+
)
34+
35+
36+
def get_baidu_annotation_loader(
37+
dataset_dir: Path,
38+
session_range: Tuple[int, int],
39+
):
40+
dataset = BaiduUltrFeatureAnnotationDataset(
41+
dataset_dir=dataset_dir,
42+
session_range=session_range,
43+
)
44+
45+
return DataLoader(
46+
dataset,
47+
batch_size=256,
48+
collate_fn=dataset.collate_fn,
49+
num_workers=2,
50+
persistent_workers=True,
51+
)
52+
53+
54+
def main():
55+
# Load sessions from a subset of the Baidu-ULTR dataset with pre-processed query-doc-features:
56+
dataset_dir = Path("../../clax-datasets/baidu-ultr-uva")
57+
query_doc_features = 768
58+
59+
train_loader = get_baidu_click_loader(
60+
dataset_dir,
61+
session_range=(0, 1_000_000),
62+
)
63+
val_loader = get_baidu_click_loader(
64+
dataset_dir,
65+
session_range=(1_000_000, 1_500_000),
66+
)
67+
test_loader = get_baidu_click_loader(
68+
dataset_dir,
69+
session_range=(1_500_000, 2_000_000),
70+
)
71+
annotation_loader = get_baidu_annotation_loader(
72+
dataset_dir,
73+
session_range=(0, 400_000),
74+
)
75+
76+
# Instantiate a PBM with a deep cross v2 network for document attraction,
77+
# note might be slow on CPU:
78+
rngs = nnx.Rngs(42)
79+
80+
model = PositionBasedModel(
81+
attraction=DeepCrossParameterConfig(
82+
use_feature="query_doc_features",
83+
features=query_doc_features,
84+
),
85+
positions=10,
86+
rngs=rngs,
87+
)
88+
89+
trainer = Trainer(
90+
optax.adamw(0.0003),
91+
epochs=3,
92+
early_stopping=EarlyStopping(patience=0),
93+
)
94+
train_df = trainer.train(model, train_loader, val_loader)
95+
click_df = trainer.test_clicks(model, test_loader)
96+
ranking_df = trainer.test_ranking(model, annotation_loader)
97+
98+
99+
if __name__ == "__main__":
100+
main()
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from pathlib import Path
2+
from typing import Tuple, Dict
3+
4+
import optax
5+
from flax import nnx
6+
from flax.training.early_stopping import EarlyStopping
7+
from jax import Array
8+
from torch.utils.data import DataLoader
9+
10+
from clax import PositionBasedModel
11+
from clax.datasets import (
12+
BaiduUltrFeatureClickDataset,
13+
BaiduUltrFeatureAnnotationDataset,
14+
)
15+
from clax.trainer import Trainer
16+
17+
18+
def get_baidu_click_loader(
19+
dataset_dir: Path,
20+
session_range: Tuple[int, int],
21+
):
22+
dataset = BaiduUltrFeatureClickDataset(
23+
dataset_dir=dataset_dir,
24+
session_range=session_range,
25+
)
26+
27+
return DataLoader(
28+
dataset,
29+
batch_size=256,
30+
collate_fn=dataset.collate_fn,
31+
num_workers=2,
32+
persistent_workers=True,
33+
)
34+
35+
36+
def get_baidu_annotation_loader(
37+
dataset_dir: Path,
38+
session_range: Tuple[int, int],
39+
):
40+
dataset = BaiduUltrFeatureAnnotationDataset(
41+
dataset_dir=dataset_dir,
42+
session_range=session_range,
43+
)
44+
45+
return DataLoader(
46+
dataset,
47+
batch_size=256,
48+
collate_fn=dataset.collate_fn,
49+
num_workers=2,
50+
persistent_workers=True,
51+
)
52+
53+
54+
class CustomAttraction(nnx.Module):
55+
"""
56+
Example of a custom flax module with attention,
57+
every module needs to specify how to compute a logit,
58+
log probability and probability for a given batch.
59+
60+
In the simplest case, the logit layer can be re-used for probability
61+
and log probability computation.
62+
"""
63+
64+
def __init__(self, query_doc_features, rngs):
65+
super().__init__()
66+
self.attention = nnx.MultiHeadAttention(
67+
num_heads=1,
68+
in_features=query_doc_features,
69+
qkv_features=8,
70+
decode=False,
71+
rngs=rngs,
72+
)
73+
self.projection = nnx.Linear(query_doc_features, 1, rngs=rngs)
74+
75+
def logit(self, batch: Dict) -> Array:
76+
return self.projection(self.attention(batch["query_doc_features"])).squeeze()
77+
78+
def prob(self, batch: Dict) -> Array:
79+
return nnx.sigmoid(self.logit(batch))
80+
81+
def log_prob(self, batch: Dict) -> Array:
82+
return nnx.log_sigmoid(self.logit(batch))
83+
84+
85+
def main():
86+
# Load sessions from a subset of the Baidu-ULTR dataset with pre-processed query-doc-features:
87+
dataset_dir = Path("../../clax-datasets/baidu-ultr-uva")
88+
query_doc_features = 768
89+
90+
train_loader = get_baidu_click_loader(
91+
dataset_dir,
92+
session_range=(0, 100_000),
93+
)
94+
val_loader = get_baidu_click_loader(
95+
dataset_dir,
96+
session_range=(1_000_000, 1_500_000),
97+
)
98+
test_loader = get_baidu_click_loader(
99+
dataset_dir,
100+
session_range=(1_500_000, 2_000_000),
101+
)
102+
annotation_loader = get_baidu_annotation_loader(
103+
dataset_dir,
104+
session_range=(0, 400_000),
105+
)
106+
107+
# Instantiate a PBM with a custom module for document attraction,
108+
# note might be slow on CPU:
109+
rngs = nnx.Rngs(42)
110+
111+
model = PositionBasedModel(
112+
attraction=CustomAttraction(query_doc_features, rngs),
113+
positions=10,
114+
rngs=rngs,
115+
)
116+
trainer = Trainer(
117+
optax.adamw(0.0003),
118+
epochs=3,
119+
early_stopping=EarlyStopping(patience=0),
120+
)
121+
train_df = trainer.train(model, train_loader, val_loader)
122+
click_df = trainer.test_clicks(model, test_loader)
123+
ranking_df = trainer.test_ranking(model, annotation_loader)
124+
125+
126+
if __name__ == "__main__":
127+
main()

0 commit comments

Comments
 (0)