Skip to content

Commit 4d06a9c

Browse files
authored
ODSC-39731/support_running_sparkpipelinemodel_in_df (#152)
1 parent 5b4f70d commit 4d06a9c

File tree

16 files changed

+343
-118
lines changed

16 files changed

+343
-118
lines changed

ads/common/object_storage_details.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,20 @@ def is_valid_uri(uri: str) -> bool:
121121
"It must follow the pattern `oci://<bucket_name>@<namespace>/<prefix>`."
122122
)
123123
return True
124+
125+
@staticmethod
126+
def is_oci_path(uri: str = None) -> bool:
127+
"""Check if the given path is oci object storage uri.
128+
129+
Parameters
130+
----------
131+
uri: str
132+
The URI of the target.
133+
134+
Returns
135+
-------
136+
bool: return True if the path is oci object storage uri.
137+
"""
138+
if not uri:
139+
return False
140+
return uri.startswith("oci://")

ads/common/utils.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -495,9 +495,7 @@ def print_user_message(
495495
)
496496

497497
if is_documentation_mode() and is_notebook():
498-
499498
if display_type.lower() == "tip":
500-
501499
if "\n" in msg:
502500
t = "<b>{}:</b>".format(title.upper().strip()) if title else ""
503501

@@ -567,7 +565,6 @@ def print_user_message(
567565
)
568566

569567
elif display_type.startswith("info"):
570-
571568
user_message = msg.strip().replace("\n", "<br>")
572569

573570
if see_also_links:
@@ -640,7 +637,6 @@ def ellipsis_strings(raw, n=24):
640637

641638
result = []
642639
for s in sequence:
643-
644640
if len(str(s)) <= n:
645641
result.append(s)
646642
else:
@@ -1136,36 +1132,44 @@ def is_data_too_wide(
11361132
return col_num > max_col_num
11371133

11381134

1139-
def get_files(directory: str):
1135+
def get_files(directory: str, auth: Optional[Dict] = None):
11401136
"""List out all the file names under this directory.
11411137
11421138
Parameters
11431139
----------
11441140
directory: str
11451141
The directory to list out all the files from.
1142+
auth: (Dict, optional). Defaults to None.
1143+
The default authentication is set using `ads.set_auth` API. If you need to override the
1144+
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
1145+
authentication signer and kwargs required to instantiate IdentityClient object.
11461146
11471147
Returns
11481148
-------
11491149
List
11501150
List of the files in the directory.
11511151
"""
11521152
directory = directory.rstrip("/")
1153-
if os.path.exists(os.path.join(directory, ".model-ignore")):
1154-
ignore_patterns = (
1155-
Path(os.path.join(directory), ".model-ignore")
1156-
.read_text()
1157-
.strip()
1158-
.split("\n")
1159-
)
1153+
path_scheme = urlparse(directory).scheme or "file"
1154+
storage_options = auth or authutil.default_signer()
1155+
model_ignore_path = os.path.join(directory, ".model-ignore")
1156+
if is_path_exists(model_ignore_path, auth=auth):
1157+
with fsspec.open(model_ignore_path, "r", **storage_options) as f:
1158+
ignore_patterns = f.read().strip().split("\n")
11601159
else:
11611160
ignore_patterns = []
11621161
file_names = []
1163-
for root, dirs, files in os.walk(directory):
1162+
fs = fsspec.filesystem(path_scheme, **storage_options)
1163+
for root, dirs, files in fs.walk(directory):
11641164
for name in files:
11651165
file_names.append(os.path.join(root, name))
11661166
for name in dirs:
11671167
file_names.append(os.path.join(root, name))
11681168

1169+
# return all files in remote directory.
1170+
if directory.startswith("oci://"):
1171+
directory = directory.lstrip("oci://")
1172+
11691173
for ignore in ignore_patterns:
11701174
if not ignore.startswith("#") and ignore.strip() != "":
11711175
matches = []
@@ -1228,7 +1232,7 @@ def copy_from_uri(
12281232
force_overwrite: (bool, optional). Defaults to False.
12291233
Whether to overwrite existing files or not.
12301234
auth: (Dict, optional). Defaults to None.
1231-
The default authetication is set using `ads.set_auth` API. If you need to override the
1235+
The default authentication is set using `ads.set_auth` API. If you need to override the
12321236
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
12331237
authentication signer and kwargs required to instantiate IdentityClient object.
12341238
@@ -1294,7 +1298,7 @@ def copy_file(
12941298
force_overwrite: (bool, optional). Defaults to False.
12951299
Whether to overwrite existing files or not.
12961300
auth: (Dict, optional). Defaults to None.
1297-
The default authetication is set using `ads.set_auth` API. If you need to override the
1301+
The default authentication is set using `ads.set_auth` API. If you need to override the
12981302
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
12991303
authentication signer and kwargs required to instantiate IdentityClient object.
13001304
chunk_size: (int, optinal). Defaults to `DEFAULT_BUFFER_SIZE`
@@ -1357,7 +1361,7 @@ def remove_file(file_path: str, auth: Optional[Dict] = None) -> None:
13571361
file_path: str
13581362
The path of the source file, which can be local path or OCI object storage URI.
13591363
auth: (Dict, optional). Defaults to None.
1360-
The default authetication is set using `ads.set_auth` API. If you need to override the
1364+
The default authentication is set using `ads.set_auth` API. If you need to override the
13611365
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
13621366
authentication signer and kwargs required to instantiate IdentityClient object.
13631367
@@ -1570,3 +1574,26 @@ def extract_region(auth: Optional[Dict] = None) -> Union[str, None]:
15701574
pass
15711575

15721576
return None
1577+
1578+
1579+
def is_path_exists(uri: str, auth: Optional[Dict] = None) -> bool:
1580+
"""Check if the given path which can be local path or OCI object storage URI exists.
1581+
1582+
Parameters
1583+
----------
1584+
uri: str
1585+
The URI of the target, which can be local path or OCI object storage URI.
1586+
auth: (Dict, optional). Defaults to None.
1587+
The default authentication is set using `ads.set_auth` API. If you need to override the
1588+
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
1589+
authentication signer and kwargs required to instantiate IdentityClient object.
1590+
1591+
Returns
1592+
-------
1593+
bool: return True if the path exists.
1594+
"""
1595+
path_scheme = urlparse(uri).scheme or "file"
1596+
storage_options = auth or authutil.default_signer()
1597+
if fsspec.filesystem(path_scheme, **storage_options).exists(uri):
1598+
return True
1599+
return False

ads/feature_engineering/schema.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,21 @@
44
# Copyright (c) 2021, 2022 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import asteval
8+
import fsspec
9+
import json
10+
import os
11+
import sys
12+
import yaml
713
from abc import ABC, abstractmethod
14+
from cerberus import Validator
15+
from copy import deepcopy
816
from dataclasses import dataclass, field
917
from typing import Dict, List, Optional
10-
from copy import deepcopy
11-
12-
import yaml
18+
from string import Template
19+
from os import path
1320
from ads.common.serializer import DataClassSerializable
21+
from ads.common.object_storage_details import ObjectStorageDetails
1422

1523
try:
1624
from yaml import CDumper as dumper
@@ -19,24 +27,13 @@
1927
from yaml import Dumper as dumper
2028
from yaml import Loader as loader
2129

22-
import json
23-
import os
24-
import asteval
25-
from os import path
26-
27-
import fsspec
28-
import yaml
29-
from cerberus import Validator
30-
import sys
31-
from string import Template
32-
33-
3430
SCHEMA_VALIDATOR_NAME = "data_schema.json"
3531
INPUT_OUTPUT_SCHENA_SIZE_LIMIT = 32000
3632
SCHEMA_VERSION = "1.1"
3733
DEFAULT_SCHEMA_VERSION = "1.0"
3834
SCHEMA_KEY = "schema"
3935
SCHEMA_VERSION_KEY = "version"
36+
DEFAULT_STORAGE_OPTIONS = None
4037

4138

4239
class SchemaSizeTooLarge(ValueError):
@@ -685,13 +682,16 @@ def to_json(self):
685682
"""
686683
return json.dumps(self.to_dict()).replace("NaN", "null")
687684

688-
def to_json_file(self, file_path):
685+
def to_json_file(self, file_path, storage_options: dict = None):
689686
"""Saves the data schema into a json file.
690687
691688
Parameters
692689
----------
693690
file_path : str
694691
File Path to store the schema in json format.
692+
storage_options: dict. Default None
693+
Parameters passed on to the backend filesystem class.
694+
Defaults to `storage_options` set using `DatasetFactory.set_default_storage()`.
695695
696696
Returns
697697
-------
@@ -704,12 +704,19 @@ def to_json_file(self, file_path):
704704
".json"
705705
], f"The file `{basename}` is not a valid JSON file. The `{file_path}` must have the extension .json."
706706
if directory and not os.path.exists(directory):
707-
try:
708-
os.mkdir(directory)
709-
except:
710-
raise Exception(f"Error creating the directory.")
711-
with open(os.path.join(directory, basename), "w") as json_file:
712-
json.dump(self.to_dict(), json_file)
707+
if not ObjectStorageDetails.is_oci_path(directory):
708+
try:
709+
os.mkdir(directory)
710+
except:
711+
raise Exception(f"Error creating the directory.")
712+
if not storage_options:
713+
storage_options = DEFAULT_STORAGE_OPTIONS or {"config": {}}
714+
with fsspec.open(
715+
os.path.join(directory, basename),
716+
mode="w",
717+
**(storage_options),
718+
) as f:
719+
f.write(json.dumps(self.to_dict()))
713720

714721
def to_yaml_file(self, file_path):
715722
"""Saves the data schema into a yaml file.

0 commit comments

Comments
 (0)