Skip to content

Commit f3498a2

Browse files
add mlflow
1 parent 988d3c4 commit f3498a2

File tree

4 files changed

+1958
-234
lines changed

4 files changed

+1958
-234
lines changed

.dvc/config

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[core]
2-
remote = origin
3-
['remote "origin"']
4-
url = https://dagshub.com/khuyentran1401/prefect-dvc.dvc
2+
remote = storage
3+
['remote "storage"']
4+
url = s3://khuyen-dvc-demo

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ authors = [{ name = "Khuyen" }]
66
requires-python = ">=3.8"
77
dependencies = [
88
"dvc",
9+
"dvc-s3>=3.0.1",
10+
"mlflow>=2.17.2",
911
"pandas>=2.0.3",
1012
"scikit-learn>=1.3.2",
1113
"yellowbrick>=1.5",

src/segment.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from yellowbrick.cluster import KElbowVisualizer
1111
import hydra
1212
from pathlib import Path
13+
import mlflow
14+
from sklearn.metrics import silhouette_score
15+
from mlflow.models import infer_signature
1316

1417
warnings.simplefilter(action="ignore", category=DeprecationWarning)
1518

@@ -65,15 +68,30 @@ def save_data_and_model(data: pd.DataFrame, model: KMeans, config: DictConfig):
6568

6669
@hydra.main(config_path="../config", config_name="main", version_base="1.2")
6770
def segment(config: DictConfig) -> None:
71+
72+
# Data processing
6873
data = read_process_data(config)
6974
pca = get_pca_model(data)
7075
pca_df = reduce_dimension(data, pca)
7176
k_best = get_best_k_cluster(pca_df)
7277
model = get_clusters_model(pca_df, k_best)
7378
pred = predict(model, pca_df)
7479
data = insert_clusters_to_df(data, pred)
80+
silhouette_avg = silhouette_score(pca_df, pred)
81+
82+
# Save data and model locally
7583
save_data_and_model(data, model, config)
7684

85+
with mlflow.start_run():
86+
87+
mlflow.log_params({"n_components": 3, "random_state": 42, "best_k": k_best})
88+
mlflow.log_metric("silhouette_score", silhouette_avg)
89+
signature = infer_signature(pca_df, pred)
90+
mlflow.sklearn.log_model(
91+
model, "kmeans_model", signature=signature, input_example=pca_df.head()
92+
)
93+
mlflow.log_artifact(config.final.path, "processed_data")
94+
7795

7896
if __name__ == "__main__":
7997
segment()

0 commit comments

Comments
 (0)