diff --git a/mindsdb_sql/parser/dialects/mindsdb/__init__.py b/mindsdb_sql/parser/dialects/mindsdb/__init__.py index aa26ef4b..f0c945bd 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/__init__.py +++ b/mindsdb_sql/parser/dialects/mindsdb/__init__.py @@ -1,7 +1,7 @@ from .agents import CreateAgent, DropAgent, UpdateAgent from .create_view import CreateView from .create_database import CreateDatabase -from .create_predictor import CreatePredictor, CreateAnomalyDetectionModel +from .create_predictor import CreatePredictor, CreateAnomalyDetectionModel, CreateForecastingModel from .drop_predictor import DropPredictor from .retrain_predictor import RetrainPredictor from .finetune_predictor import FinetunePredictor diff --git a/mindsdb_sql/parser/dialects/mindsdb/create_predictor.py b/mindsdb_sql/parser/dialects/mindsdb/create_predictor.py index f98d48f1..9c103888 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/create_predictor.py +++ b/mindsdb_sql/parser/dialects/mindsdb/create_predictor.py @@ -152,3 +152,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._command = 'CREATE ANOMALY DETECTION MODEL' self.task = Identifier('AnomalyDetection') + + +class CreateForecastingModel(CreatePredictorBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._command = 'CREATE FORECASTING MODEL' + self.task = Identifier('Forecasting') diff --git a/mindsdb_sql/parser/dialects/mindsdb/lexer.py b/mindsdb_sql/parser/dialects/mindsdb/lexer.py index 752dddaf..10bd205c 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/lexer.py +++ b/mindsdb_sql/parser/dialects/mindsdb/lexer.py @@ -30,6 +30,7 @@ class MindsDBLexer(Lexer): LATEST, LAST, HORIZON, USING, ENGINE, TRAIN, PREDICT, PARAMETERS, JOB, CHATBOT, EVERY,PROJECT, ANOMALY, DETECTION, + FORECASTING, KNOWLEDGE_BASE, KNOWLEDGE_BASES, SKILL, AGENT, @@ -118,6 +119,7 @@ class MindsDBLexer(Lexer): # Typed models ANOMALY = r'\bANOMALY\b' DETECTION = r'\bDETECTION\b' + FORECASTING = r'\bFORECASTING\b' KNOWLEDGE_BASE = r'\bKNOWLEDGE[_|\s]BASE\b' KNOWLEDGE_BASES = r'\bKNOWLEDGE[_|\s]BASES\b' diff --git a/mindsdb_sql/parser/dialects/mindsdb/parser.py b/mindsdb_sql/parser/dialects/mindsdb/parser.py index 41c010cd..4f1a6b7a 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/parser.py +++ b/mindsdb_sql/parser/dialects/mindsdb/parser.py @@ -6,7 +6,8 @@ from mindsdb_sql.parser.dialects.mindsdb.drop_predictor import DropPredictor from mindsdb_sql.parser.dialects.mindsdb.drop_dataset import DropDataset from mindsdb_sql.parser.dialects.mindsdb.drop_ml_engine import DropMLEngine -from mindsdb_sql.parser.dialects.mindsdb.create_predictor import CreatePredictor, CreateAnomalyDetectionModel +from mindsdb_sql.parser.dialects.mindsdb.create_predictor import CreatePredictor +from mindsdb_sql.parser.dialects.mindsdb.create_predictor import CreateAnomalyDetectionModel, CreateForecastingModel from mindsdb_sql.parser.dialects.mindsdb.create_database import CreateDatabase from mindsdb_sql.parser.dialects.mindsdb.create_ml_engine import CreateMLEngine from mindsdb_sql.parser.dialects.mindsdb.create_view import CreateView @@ -64,6 +65,7 @@ class MindsDBParser(Parser): 'create_integration', 'create_view', 'create_anomaly_detection_model', + 'create_forecasting_model', 'drop_predictor', 'drop_datasource', 'drop_dataset', @@ -817,6 +819,60 @@ def create_anomaly_detection_model(self, p): p.create_anomaly_detection_model.using = p.kw_parameter_list return p.create_anomaly_detection_model + ## Forecasting + @_( + 'CREATE FORECASTING MODEL identifier PREDICT result_columns', # for pre-trained models (e.g. TimeGPT) + 'CREATE FORECASTING MODEL identifier FROM identifier LPAREN raw_query RPAREN PREDICT result_columns', + # TODO add IF_NOT_EXISTS elegantly (should be low level BNF expansion) + ) + def create_forecasting_model(self, p): + query_str = None + if hasattr(p, 'raw_query'): + query_str = tokens_to_string(p.raw_query) + + if hasattr(p, 'identifier'): + # single identifier field + name = p.identifier + else: + name = p.identifier0 + + return CreateForecastingModel( + name=name, + targets=getattr(p, 'result_columns', None), + integration_name=getattr(p, 'identifier1', None), + query_str=query_str, + if_not_exists=hasattr(p, 'IF_NOT_EXISTS') + ) + + @_('create_forecasting_model WINDOW integer') + def create_forecasting_model(self, p): + p.create_forecasting_model.window = p.integer + return p.create_forecasting_model + + @_('create_forecasting_model HORIZON integer') + def create_forecasting_model(self, p): + p.create_forecasting_model.horizon = p.integer + return p.create_forecasting_model + + @_('create_forecasting_model GROUP_BY expr_list') + def create_forecasting_model(self, p): + group_by = p.expr_list + if not isinstance(group_by, list): + group_by = [group_by] + + p.create_forecasting_model.group_by = group_by + return p.create_forecasting_model + + @_('create_forecasting_model ORDER_BY ordering_terms') + def create_forecasting_model(self, p): + p.create_forecasting_model.order_by = p.ordering_terms + return p.create_forecasting_model + + @_('create_forecasting_model USING kw_parameter_list') + def create_forecasting_model(self, p): + p.create_forecasting_model.using = p.kw_parameter_list + return p.create_forecasting_model + # RETRAIN PREDICTOR @_('RETRAIN identifier', diff --git a/tests/test_parser/test_mindsdb/test_create_predictor.py b/tests/test_parser/test_mindsdb/test_create_predictor.py index b035ab7a..cbb42da4 100644 --- a/tests/test_parser/test_mindsdb/test_create_predictor.py +++ b/tests/test_parser/test_mindsdb/test_create_predictor.py @@ -185,3 +185,36 @@ def test_create_anomaly_detection_model(self): assert to_single_line(str(ast)) == to_single_line(str(expected_ast)) assert ast.to_tree() == expected_ast.to_tree() + + def test_create_forecasting_model(self): + create_clause = "CREATE FORECASTING MODEL forecasting_model" + rest_clause = """ + FROM integration_name (select * FROM table) + PREDICT y + WINDOW 10 + HORIZON 5 + ORDER BY time + GROUP BY group + USING + param='a' + """ + sql = create_clause + rest_clause + ast = parse_sql(sql, dialect='mindsdb') + + expected_ast = CreateForecastingModel( + name=Identifier('forecasting_model'), + task=Identifier('Forecasting'), + integration_name=Identifier('integration_name'), + query_str='select * FROM table', + targets=[Identifier('y')], + window=10, + horizon=5, + order_by=[OrderBy(Identifier('time'), direction='default')], + group_by=[Identifier('group')], + using={ + 'param': 'a' + } + ) + + assert to_single_line(str(ast)) == to_single_line(str(expected_ast)) + assert ast.to_tree() == expected_ast.to_tree()