Skip to content

Commit edba0c3

Browse files
committed
enable parquet
1 parent 8ec68bf commit edba0c3

File tree

3 files changed

+41
-38
lines changed

3 files changed

+41
-38
lines changed

ads/opctl/operator/lowcode/common/utils.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,32 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

43
# Copyright (c) 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

7-
import argparse
86
import logging
97
import os
108
import shutil
119
import sys
1210
import tempfile
13-
import time
14-
from string import Template
15-
from typing import Any, Dict, List, Tuple
16-
import pandas as pd
17-
from ads.opctl import logger
18-
import oracledb
11+
from typing import List, Union
1912

2013
import fsspec
21-
import yaml
22-
from typing import Union
14+
import oracledb
15+
import pandas as pd
2316

17+
from ads.common.object_storage_details import ObjectStorageDetails
2418
from ads.opctl import logger
19+
from ads.opctl.operator.common.operator_config import OutputDirectory
2520
from ads.opctl.operator.lowcode.common.errors import (
26-
InputDataError,
2721
InvalidParameterError,
28-
PermissionsError,
29-
DataMismatchError,
3022
)
31-
from ads.opctl.operator.common.operator_config import OutputDirectory
32-
from ads.common.object_storage_details import ObjectStorageDetails
3323
from ads.secrets import ADBSecretKeeper
3424

3525

3626
def call_pandas_fsspec(pd_fn, filename, storage_options, **kwargs):
37-
if fsspec.utils.get_protocol(filename) == "file":
38-
return pd_fn(filename, **kwargs)
39-
elif fsspec.utils.get_protocol(filename) in ["http", "https"]:
27+
if fsspec.utils.get_protocol(filename) == "file" or fsspec.utils.get_protocol(
28+
filename
29+
) in ["http", "https"]:
4030
return pd_fn(filename, **kwargs)
4131

4232
storage_options = storage_options or (
@@ -48,7 +38,7 @@ def call_pandas_fsspec(pd_fn, filename, storage_options, **kwargs):
4838

4939
def load_data(data_spec, storage_options=None, **kwargs):
5040
if data_spec is None:
51-
raise InvalidParameterError(f"No details provided for this data source.")
41+
raise InvalidParameterError("No details provided for this data source.")
5242
filename = data_spec.url
5343
format = data_spec.format
5444
columns = data_spec.columns
@@ -67,7 +57,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
6757
if not format:
6858
_, format = os.path.splitext(filename)
6959
format = format[1:]
70-
if format in ["json", "clipboard", "excel", "csv", "feather", "hdf"]:
60+
if format in ["json", "clipboard", "excel", "csv", "feather", "hdf", "parquet"]:
7161
read_fn = getattr(pd, f"read_{format}")
7262
data = call_pandas_fsspec(
7363
read_fn, filename, storage_options=storage_options
@@ -84,19 +74,31 @@ def load_data(data_spec, storage_options=None, **kwargs):
8474
with tempfile.TemporaryDirectory() as temp_dir:
8575
if vault_secret_id is not None:
8676
try:
87-
with ADBSecretKeeper.load_secret(vault_secret_id, wallet_dir=temp_dir) as adwsecret:
88-
if 'wallet_location' in adwsecret and 'wallet_location' not in connect_args:
89-
shutil.unpack_archive(adwsecret["wallet_location"], temp_dir)
90-
connect_args['wallet_location'] = temp_dir
91-
if 'user_name' in adwsecret and 'user' not in connect_args:
92-
connect_args['user'] = adwsecret['user_name']
93-
if 'password' in adwsecret and 'password' not in connect_args:
94-
connect_args['password'] = adwsecret['password']
95-
if 'service_name' in adwsecret and 'service_name' not in connect_args:
96-
connect_args['service_name'] = adwsecret['service_name']
77+
with ADBSecretKeeper.load_secret(
78+
vault_secret_id, wallet_dir=temp_dir
79+
) as adwsecret:
80+
if (
81+
"wallet_location" in adwsecret
82+
and "wallet_location" not in connect_args
83+
):
84+
shutil.unpack_archive(
85+
adwsecret["wallet_location"], temp_dir
86+
)
87+
connect_args["wallet_location"] = temp_dir
88+
if "user_name" in adwsecret and "user" not in connect_args:
89+
connect_args["user"] = adwsecret["user_name"]
90+
if "password" in adwsecret and "password" not in connect_args:
91+
connect_args["password"] = adwsecret["password"]
92+
if (
93+
"service_name" in adwsecret
94+
and "service_name" not in connect_args
95+
):
96+
connect_args["service_name"] = adwsecret["service_name"]
9797

9898
except Exception as e:
99-
raise Exception(f"Could not retrieve database credentials from vault {vault_secret_id}: {e}")
99+
raise Exception(
100+
f"Could not retrieve database credentials from vault {vault_secret_id}: {e}"
101+
)
100102

101103
con = oracledb.connect(**connect_args)
102104
if table_name is not None:
@@ -105,11 +107,11 @@ def load_data(data_spec, storage_options=None, **kwargs):
105107
data = pd.read_sql(sql, con)
106108
else:
107109
raise InvalidParameterError(
108-
f"Database `connect_args` provided without sql query or table name. Please specify either `sql` or `table_name`."
110+
"Database `connect_args` provided without sql query or table name. Please specify either `sql` or `table_name`."
109111
)
110112
else:
111113
raise InvalidParameterError(
112-
f"No filename/url provided, and no connect_args provided. Please specify one of these if you want to read data from a file or a database respectively."
114+
"No filename/url provided, and no connect_args provided. Please specify one of these if you want to read data from a file or a database respectively."
113115
)
114116
if columns:
115117
# keep only these columns, done after load because only CSV supports stream filtering
@@ -232,7 +234,7 @@ def human_time_friendly(seconds):
232234
accumulator.append(
233235
"{} {}{}".format(int(amount), unit, "" if amount == 1 else "s")
234236
)
235-
accumulator.append("{} secs".format(round(seconds, 2)))
237+
accumulator.append(f"{round(seconds, 2)} secs")
236238
return ", ".join(accumulator)
237239

238240

@@ -248,9 +250,7 @@ def find_output_dirname(output_dir: OutputDirectory):
248250
unique_output_dir = f"{output_dir}_{counter}"
249251
counter += 1
250252
logger.warn(
251-
"Since the output directory was not specified, the output will be saved to {} directory.".format(
252-
unique_output_dir
253-
)
253+
f"Since the output directory was not specified, the output will be saved to {unique_output_dir} directory."
254254
)
255255
return unique_output_dir
256256

ads/opctl/operator/lowcode/forecast/model/prophet.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ def _build_model(self) -> pd.DataFrame:
142142
dt_column=self.spec.datetime_column.name,
143143
)
144144

145+
# if os.environ["OCI__IS_SPARK"]:
146+
# pass
147+
# else:
145148
Parallel(n_jobs=-1, require="sharedmem")(
146149
delayed(ProphetOperatorModel._train_model)(
147150
self, i, series_id, df, model_kwargs.copy()

ads/opctl/operator/lowcode/forecast/schema.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ spec:
311311
missing_value_imputation:
312312
type: boolean
313313
required: false
314-
default: false
314+
default: true
315315
outlier_treatment:
316316
type: boolean
317317
required: false

0 commit comments

Comments
 (0)