Skip to content

Commit aeca240

Browse files
committed
made all defaults above workload path fns
1 parent 2ebb113 commit aeca240

File tree

5 files changed

+60
-45
lines changed

5 files changed

+60
-45
lines changed

scripts/run_protox_e2e_test.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
from benchmark.tpch.constants import DEFAULT_TPCH_SEED
1212
from util.pg import get_is_postgres_running
1313
from util.workspace import (
14-
default_embedder_path,
15-
default_hpoed_agent_params_path,
1614
default_pristine_dbdata_snapshot_path,
1715
default_replay_data_fpath,
1816
default_repo_path,
19-
default_tables_path,
20-
default_traindata_path,
2117
default_tuning_steps_dpath,
18+
get_default_embedder_path,
19+
get_default_hpoed_agent_params_path,
20+
get_default_tables_path,
21+
get_default_traindata_path,
2222
get_default_workload_path,
2323
get_workload_name,
2424
)
@@ -96,7 +96,9 @@ def run_e2e_for_benchmark(benchmark_name: str, intended_dbdata_hardware: str) ->
9696

9797
# Run the full Proto-X training pipeline, asserting things along the way
9898
# Setup (workload and database)
99-
tables_dpath = default_tables_path(workspace_dpath, benchmark_name, scale_factor)
99+
tables_dpath = get_default_tables_path(
100+
workspace_dpath, benchmark_name, scale_factor
101+
)
100102
if Stage.Tables in STAGES_TO_RUN:
101103
assert not tables_dpath.exists()
102104
subprocess.run(
@@ -135,7 +137,7 @@ def run_e2e_for_benchmark(benchmark_name: str, intended_dbdata_hardware: str) ->
135137
assert pristine_dbdata_snapshot_fpath.exists()
136138

137139
# Tuning (embedding, HPO, and actual tuning)
138-
traindata_dpath = default_traindata_path(
140+
traindata_dpath = get_default_traindata_path(
139141
workspace_dpath, benchmark_name, workload_name
140142
)
141143
if Stage.EmbeddingData in STAGES_TO_RUN:
@@ -146,7 +148,7 @@ def run_e2e_for_benchmark(benchmark_name: str, intended_dbdata_hardware: str) ->
146148
)
147149
assert traindata_dpath.exists()
148150

149-
embedder_dpath = default_embedder_path(
151+
embedder_dpath = get_default_embedder_path(
150152
workspace_dpath, benchmark_name, workload_name
151153
)
152154
if Stage.EmbeddingModel in STAGES_TO_RUN:
@@ -157,7 +159,7 @@ def run_e2e_for_benchmark(benchmark_name: str, intended_dbdata_hardware: str) ->
157159
)
158160
assert embedder_dpath.exists()
159161

160-
hpoed_agent_params_fpath = default_hpoed_agent_params_path(
162+
hpoed_agent_params_fpath = get_default_hpoed_agent_params_path(
161163
workspace_dpath, benchmark_name, workload_name
162164
)
163165
if Stage.TuneHPO in STAGES_TO_RUN:

tune/protox/agent/hpo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@
3636
DBGymConfig,
3737
TuningMode,
3838
default_dbdata_parent_dpath,
39-
default_embedder_path,
4039
default_pgbin_path,
4140
default_pristine_dbdata_snapshot_path,
4241
fully_resolve_path,
4342
get_default_benchbase_config_path,
4443
get_default_benchmark_config_path,
44+
get_default_embedder_path,
4545
get_default_hpoed_agent_params_fname,
4646
get_default_workload_name_suffix,
4747
get_default_workload_path,
@@ -120,7 +120,7 @@ def __init__(
120120
"--embedder-path",
121121
type=Path,
122122
default=None,
123-
help=f"The path to the directory that contains an `embedder.pth` file with a trained encoder and decoder as well as a `config` file. The default is {default_embedder_path(WORKSPACE_PATH_PLACEHOLDER, BENCHMARK_NAME_PLACEHOLDER, WORKLOAD_NAME_PLACEHOLDER)}",
123+
help=f"The path to the directory that contains an `embedder.pth` file with a trained encoder and decoder as well as a `config` file. The default is {get_default_embedder_path(WORKSPACE_PATH_PLACEHOLDER, BENCHMARK_NAME_PLACEHOLDER, WORKLOAD_NAME_PLACEHOLDER)}",
124124
)
125125
@click.option(
126126
"--benchmark-config-path",
@@ -270,7 +270,7 @@ def hpo(
270270
workload_name_suffix = get_default_workload_name_suffix(benchmark_name)
271271
workload_name = get_workload_name(scale_factor, workload_name_suffix)
272272
if embedder_path is None:
273-
embedder_path = default_embedder_path(
273+
embedder_path = get_default_embedder_path(
274274
dbgym_cfg.dbgym_workspace_path, benchmark_name, workload_name
275275
)
276276
if benchmark_config_path is None:

tune/protox/agent/tune.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
WORKSPACE_PATH_PLACEHOLDER,
1919
DBGymConfig,
2020
TuningMode,
21-
default_hpoed_agent_params_path,
2221
fully_resolve_path,
22+
get_default_hpoed_agent_params_path,
2323
get_default_tuning_steps_dname,
2424
get_default_workload_name_suffix,
2525
get_workload_name,
@@ -48,7 +48,7 @@
4848
"--hpoed-agent-params-path",
4949
default=None,
5050
type=Path,
51-
help=f"The path to best params found by the agent HPO process. The default is {default_hpoed_agent_params_path(WORKSPACE_PATH_PLACEHOLDER, BENCHMARK_NAME_PLACEHOLDER, WORKLOAD_NAME_PLACEHOLDER)}",
51+
help=f"The path to best params found by the agent HPO process. The default is {get_default_hpoed_agent_params_path(WORKSPACE_PATH_PLACEHOLDER, BENCHMARK_NAME_PLACEHOLDER, WORKLOAD_NAME_PLACEHOLDER)}",
5252
)
5353
@click.option(
5454
"--enable-boot-during-tune",
@@ -83,7 +83,7 @@ def tune(
8383
workload_name_suffix = get_default_workload_name_suffix(benchmark_name)
8484
workload_name = get_workload_name(scale_factor, workload_name_suffix)
8585
if hpoed_agent_params_path is None:
86-
hpoed_agent_params_path = default_hpoed_agent_params_path(
86+
hpoed_agent_params_path = get_default_hpoed_agent_params_path(
8787
dbgym_cfg.dbgym_workspace_path, benchmark_name, workload_name
8888
)
8989

tune/protox/embedding/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
WORKLOAD_NAME_PLACEHOLDER,
3030
WORKSPACE_PATH_PLACEHOLDER,
3131
DBGymConfig,
32-
default_traindata_path,
3332
fully_resolve_path,
3433
get_default_benchmark_config_path,
34+
get_default_traindata_path,
3535
get_default_workload_name_suffix,
3636
get_default_workload_path,
3737
get_workload_name,
@@ -66,7 +66,7 @@
6666
"--traindata-path",
6767
type=Path,
6868
default=None,
69-
help=f"The path to the .parquet file containing the training data to use to train the embedding models. The default is {default_traindata_path(WORKSPACE_PATH_PLACEHOLDER, BENCHMARK_NAME_PLACEHOLDER, WORKLOAD_NAME_PLACEHOLDER)}.",
69+
help=f"The path to the .parquet file containing the training data to use to train the embedding models. The default is {get_default_traindata_path(WORKSPACE_PATH_PLACEHOLDER, BENCHMARK_NAME_PLACEHOLDER, WORKLOAD_NAME_PLACEHOLDER)}.",
7070
)
7171
@click.option(
7272
"--seed",
@@ -201,7 +201,7 @@ def train(
201201
workload_name_suffix = get_default_workload_name_suffix(benchmark_name)
202202
workload_name = get_workload_name(scale_factor, workload_name_suffix)
203203
if traindata_path is None:
204-
traindata_path = default_traindata_path(
204+
traindata_path = get_default_traindata_path(
205205
dbgym_cfg.dbgym_workspace_path, benchmark_name, workload_name
206206
)
207207
# TODO(phw2): figure out whether different scale factors use the same config

util/workspace.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -149,38 +149,51 @@ def get_default_replay_data_fname(
149149
# folder called run_*/dbgym_agent_protox_tune/tuning_steps. However, replay itself generates an replay_info.log file, which goes in
150150
# run_*/dbgym_agent_protox_tune/tuning_steps/. The bug was that my replay function was overwriting the replay_info.log file of the
151151
# tuning run. By naming all symlinks "*.link", we avoid the possibility of subtle bugs like this happening.
152-
default_traindata_path: Callable[[Path, str, str], Path] = (
153-
lambda workspace_path, benchmark_name, workload_name: get_symlinks_path_from_workspace_path(
154-
workspace_path
152+
def get_default_traindata_path(
153+
workspace_path: Path, benchmark_name: str, workload_name: str
154+
) -> Path:
155+
return (
156+
get_symlinks_path_from_workspace_path(workspace_path)
157+
/ "dbgym_tune_protox_embedding"
158+
/ "data"
159+
/ (get_default_traindata_fname(benchmark_name, workload_name) + ".link")
155160
)
156-
/ "dbgym_tune_protox_embedding"
157-
/ "data"
158-
/ (get_default_traindata_fname(benchmark_name, workload_name) + ".link")
159-
)
160-
default_embedder_path: Callable[[Path, str, str], Path] = (
161-
lambda workspace_path, benchmark_name, workload_name: get_symlinks_path_from_workspace_path(
162-
workspace_path
161+
162+
163+
def get_default_embedder_path(
164+
workspace_path: Path, benchmark_name: str, workload_name: str
165+
) -> Path:
166+
return (
167+
get_symlinks_path_from_workspace_path(workspace_path)
168+
/ "dbgym_tune_protox_embedding"
169+
/ "data"
170+
/ (get_default_embedder_dname(benchmark_name, workload_name) + ".link")
163171
)
164-
/ "dbgym_tune_protox_embedding"
165-
/ "data"
166-
/ (get_default_embedder_dname(benchmark_name, workload_name) + ".link")
167-
)
168-
default_hpoed_agent_params_path: Callable[[Path, str, str], Path] = (
169-
lambda workspace_path, benchmark_name, workload_name: get_symlinks_path_from_workspace_path(
170-
workspace_path
172+
173+
174+
def get_default_hpoed_agent_params_path(
175+
workspace_path: Path, benchmark_name: str, workload_name: str
176+
) -> Path:
177+
return (
178+
get_symlinks_path_from_workspace_path(workspace_path)
179+
/ "dbgym_tune_protox_agent"
180+
/ "data"
181+
/ (
182+
get_default_hpoed_agent_params_fname(benchmark_name, workload_name)
183+
+ ".link"
184+
)
171185
)
172-
/ "dbgym_tune_protox_agent"
173-
/ "data"
174-
/ (get_default_hpoed_agent_params_fname(benchmark_name, workload_name) + ".link")
175-
)
176-
default_tables_path: Callable[[Path, str, float | str], Path] = (
177-
lambda workspace_path, benchmark_name, scale_factor: get_symlinks_path_from_workspace_path(
178-
workspace_path
186+
187+
188+
def get_default_tables_path(
189+
workspace_path: Path, benchmark_name: str, scale_factor: float | str
190+
) -> Path:
191+
return (
192+
get_symlinks_path_from_workspace_path(workspace_path)
193+
/ f"dbgym_benchmark_{benchmark_name}"
194+
/ "data"
195+
/ (get_default_tables_dname(scale_factor) + ".link")
179196
)
180-
/ f"dbgym_benchmark_{benchmark_name}"
181-
/ "data"
182-
/ (get_default_tables_dname(scale_factor) + ".link")
183-
)
184197

185198

186199
def get_default_workload_path(

0 commit comments

Comments
 (0)