Skip to content

Add databricks_automl to the conda env for arima and deepar #151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 7, 2024

Conversation

sabhya-db
Copy link

@sabhya-db sabhya-db commented Oct 5, 2024

It was already passed in for Prophet. Adding for ARIMA and DeepAR, too, so we can log correctly.

@sabhya-db sabhya-db changed the base branch from main to branch-0.2.20.3 October 5, 2024 17:44
@sabhya-db sabhya-db changed the title Pass databricks_automl as extra pip dependency in case log_model(...) can't capture it Add databricks_automl to the conda env for arima and deepar Oct 5, 2024
Copy link

@apeforest apeforest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a unit test

@sabhya-db
Copy link
Author

Can you add a unit test

@apeforest the existing UT to log, load, and predict an ARIMA model passes, see here: https://github.com/databricks/automl/blob/branch-0.2.20.3/runtime/tests/automl_runtime/forecast/pmdarima/model_test.py#L432

I could add a test / check to see that the requirements.txt that are logged with "automl_runtime=0.2.20.x" for all 3 models?

@sabhya-db sabhya-db requested a review from apeforest October 7, 2024 16:55
@@ -57,10 +57,14 @@ def mlflow_forecast_log_model(forecast_model: ForecastModel,
:param forecast_model: Forecast model wrapper
:param sample_input: sample input Dataframes for model inference
"""
# log the model without signature if infer_signature is failed.
# TODO: we should not be logging without a signature since it cannot be registered to UC then

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a JIRA ticket number to TODO:

e.g. # TODO(ML-XXXXX): we should not be logging...

@@ -25,8 +25,10 @@
from mlflow.protos.databricks_pb2 import ErrorCode, INVALID_PARAMETER_VALUE
from pmdarima.arima import ARIMA

from databricks.automl_runtime.forecast.pmdarima.model import ArimaModel, MultiSeriesArimaModel, AbstractArimaModel, \
mlflow_arima_log_model
from databricks.automl_runtime.forecast.pmdarima.model import (

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please break imports into multiple lines:

See https://peps.python.org/pep-0008/#imports

pred_df = loaded_model.predict(sample_input)

assert pred_df.columns.tolist() == [time_col, "yhat", id_col]
assert len(pred_df) == self.prediction_length * 2
assert pred_df[time_col].min() > sample_input[time_col].max()

def _check_requirements(self, run_id: str):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if a method is being tested in a TestClass, make it a public method (i.e. without prefix _)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a unit test itself, it's a helper method called only within other unit test cases, should I still remove prefix?

Copy link

@apeforest apeforest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sabhya-db. A few nits. Otherwise, LGTM Thanks

@sabhya-db
Copy link
Author

Thanks @sabhya-db. A few nits. Otherwise, LGTM Thanks

Thanks for checking, will fix nits soon

Copy link
Contributor

@es94129 es94129 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Do you know if it's a serverless-only issue, or should we also port this to the main branch?

@sabhya-db
Copy link
Author

LGTM. Do you know if it's a serverless-only issue, or should we also port this to the main branch?

@es94129

We should do this on the main branch too, an ARIMA model logged from the main branch won't be able to be served at present. For a quick hotfix you can just add "automl_runtime=xyz" to ARIMA_CONDA_ENV here: sabhya-db/ML-45532-sample-weight-2

@es94129
Copy link
Contributor

es94129 commented Oct 7, 2024

Can you also port the changes for ARIMA to the main branch then?

@sabhya-db
Copy link
Author

Can you also port the changes for ARIMA to the main branch then?

will do in a separate PR

@sabhya-db sabhya-db merged commit f6de201 into branch-0.2.20.3 Oct 7, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants