diff --git a/.codegen/_openapi_sha b/.codegen/_openapi_sha index f5d482394..b65888c6f 100644 --- a/.codegen/_openapi_sha +++ b/.codegen/_openapi_sha @@ -1 +1 @@ -e6971360e2752b3513d44a25d25f6cc5448056c8 \ No newline at end of file +28e61a3e6657d349b135445af2a0f03ddfa8c3fc \ No newline at end of file diff --git a/databricks/sdk/apps/v2/apps.py b/databricks/sdk/apps/v2/apps.py index 0beb5434c..b4315be38 100644 --- a/databricks/sdk/apps/v2/apps.py +++ b/databricks/sdk/apps/v2/apps.py @@ -3,11 +3,15 @@ from __future__ import annotations import logging +import random +import time from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Iterator, List, Optional -from ...service._internal import _enum, _from_dict, _repeated_dict +from ...databricks.errors import OperationFailed +from ...service._internal import (WaitUntilDoneOptions, _enum, _from_dict, + _repeated_dict) _LOG = logging.getLogger("databricks.sdk") @@ -1061,6 +1065,178 @@ class StopAppRequest: """The name of the app.""" +class AppsCreateWaiter: + raw_response: App + """raw_response is the raw response of the Create call.""" + _service: AppsAPI + _name: str + + def __init__(self, raw_response: App, service: AppsAPI, name: str): + self._service = service + self.raw_response = raw_response + self._name = name + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> App: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (ComputeState.ACTIVE,) + failure_states = ( + ComputeState.ERROR, + ComputeState.STOPPED, + ) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get(name=self._name) + status = poll.compute_status.state + status_message = f"current status: {status}" + if poll.compute_status: + status_message = poll.compute_status.message + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach ACTIVE, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"name={self._name}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + +class AppsDeployWaiter: + raw_response: AppDeployment + """raw_response is the raw response of the Deploy call.""" + _service: AppsAPI + _app_name: str + _deployment_id: str + + def __init__(self, raw_response: AppDeployment, service: AppsAPI, app_name: str, deployment_id: str): + self._service = service + self.raw_response = raw_response + self._app_name = app_name + self._deployment_id = deployment_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> AppDeployment: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (AppDeploymentState.SUCCEEDED,) + failure_states = (AppDeploymentState.FAILED,) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get_deployment(app_name=self._app_name, deployment_id=self._deployment_id) + status = poll.status.state + status_message = f"current status: {status}" + if poll.status: + status_message = poll.status.message + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach SUCCEEDED, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"app_name={self._app_name}, deployment_id={self._deployment_id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + +class AppsStartWaiter: + raw_response: App + """raw_response is the raw response of the Start call.""" + _service: AppsAPI + _name: str + + def __init__(self, raw_response: App, service: AppsAPI, name: str): + self._service = service + self.raw_response = raw_response + self._name = name + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> App: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (ComputeState.ACTIVE,) + failure_states = ( + ComputeState.ERROR, + ComputeState.STOPPED, + ) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get(name=self._name) + status = poll.compute_status.state + status_message = f"current status: {status}" + if poll.compute_status: + status_message = poll.compute_status.message + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach ACTIVE, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"name={self._name}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + +class AppsStopWaiter: + raw_response: App + """raw_response is the raw response of the Stop call.""" + _service: AppsAPI + _name: str + + def __init__(self, raw_response: App, service: AppsAPI, name: str): + self._service = service + self.raw_response = raw_response + self._name = name + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> App: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (ComputeState.STOPPED,) + failure_states = (ComputeState.ERROR,) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get(name=self._name) + status = poll.compute_status.state + status_message = f"current status: {status}" + if poll.compute_status: + status_message = poll.compute_status.message + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach STOPPED, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"name={self._name}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + class AppsAPI: """Apps run directly on a customer’s Databricks instance, integrate with their data, use and extend Databricks services, and enable users to interact through single sign-on.""" @@ -1068,7 +1244,7 @@ class AppsAPI: def __init__(self, api_client): self._api = api_client - def create(self, *, app: Optional[App] = None, no_compute: Optional[bool] = None) -> App: + def create(self, *, app: Optional[App] = None, no_compute: Optional[bool] = None) -> AppsCreateWaiter: """Create an app. Creates a new app. @@ -1090,8 +1266,8 @@ def create(self, *, app: Optional[App] = None, no_compute: Optional[bool] = None "Content-Type": "application/json", } - res = self._api.do("POST", "/api/2.0/apps", query=query, body=body, headers=headers) - return App.from_dict(res) + op_response = self._api.do("POST", "/api/2.0/apps", query=query, body=body, headers=headers) + return AppsCreateWaiter(service=self, raw_response=App.from_dict(op_response), name=op_response["name"]) def delete(self, name: str) -> App: """Delete an app. @@ -1111,7 +1287,7 @@ def delete(self, name: str) -> App: res = self._api.do("DELETE", f"/api/2.0/apps/{name}", headers=headers) return App.from_dict(res) - def deploy(self, app_name: str, *, app_deployment: Optional[AppDeployment] = None) -> AppDeployment: + def deploy(self, app_name: str, *, app_deployment: Optional[AppDeployment] = None) -> AppsDeployWaiter: """Create an app deployment. Creates an app deployment for the app with the supplied name. @@ -1130,8 +1306,13 @@ def deploy(self, app_name: str, *, app_deployment: Optional[AppDeployment] = Non "Content-Type": "application/json", } - res = self._api.do("POST", f"/api/2.0/apps/{app_name}/deployments", body=body, headers=headers) - return AppDeployment.from_dict(res) + op_response = self._api.do("POST", f"/api/2.0/apps/{app_name}/deployments", body=body, headers=headers) + return AppsDeployWaiter( + service=self, + raw_response=AppDeployment.from_dict(op_response), + app_name=app_name, + deployment_id=op_response["deployment_id"], + ) def get(self, name: str) -> App: """Get an app. @@ -1298,7 +1479,7 @@ def set_permissions( res = self._api.do("PUT", f"/api/2.0/permissions/apps/{app_name}", body=body, headers=headers) return AppPermissions.from_dict(res) - def start(self, name: str) -> App: + def start(self, name: str) -> AppsStartWaiter: """Start an app. Start the last active deployment of the app in the workspace. @@ -1316,10 +1497,10 @@ def start(self, name: str) -> App: "Content-Type": "application/json", } - res = self._api.do("POST", f"/api/2.0/apps/{name}/start", headers=headers) - return App.from_dict(res) + op_response = self._api.do("POST", f"/api/2.0/apps/{name}/start", headers=headers) + return AppsStartWaiter(service=self, raw_response=App.from_dict(op_response), name=op_response["name"]) - def stop(self, name: str) -> App: + def stop(self, name: str) -> AppsStopWaiter: """Stop an app. Stops the active deployment of the app in the workspace. @@ -1337,8 +1518,8 @@ def stop(self, name: str) -> App: "Content-Type": "application/json", } - res = self._api.do("POST", f"/api/2.0/apps/{name}/stop", headers=headers) - return App.from_dict(res) + op_response = self._api.do("POST", f"/api/2.0/apps/{name}/stop", headers=headers) + return AppsStopWaiter(service=self, raw_response=App.from_dict(op_response), name=op_response["name"]) def update(self, name: str, *, app: Optional[App] = None) -> App: """Update an app. diff --git a/databricks/sdk/catalog/v2/catalog.py b/databricks/sdk/catalog/v2/catalog.py index 05c75356d..393b6e183 100755 --- a/databricks/sdk/catalog/v2/catalog.py +++ b/databricks/sdk/catalog/v2/catalog.py @@ -3,12 +3,15 @@ from __future__ import annotations import logging +import random +import time from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Iterator, List, Optional -from ...service._internal import (_enum, _from_dict, _repeated_dict, - _repeated_enum) +from ...databricks.errors import OperationFailed +from ...service._internal import (WaitUntilDoneOptions, _enum, _from_dict, + _repeated_dict, _repeated_enum) _LOG = logging.getLogger("databricks.sdk") @@ -12283,13 +12286,52 @@ def update(self, full_name: str, version: int, *, comment: Optional[str] = None) return ModelVersionInfo.from_dict(res) +class OnlineTablesCreateWaiter: + raw_response: OnlineTable + """raw_response is the raw response of the Create call.""" + _service: OnlineTablesAPI + _name: str + + def __init__(self, raw_response: OnlineTable, service: OnlineTablesAPI, name: str): + self._service = service + self.raw_response = raw_response + self._name = name + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> OnlineTable: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (ProvisioningInfoState.ACTIVE,) + failure_states = (ProvisioningInfoState.FAILED,) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get(name=self._name) + status = poll.unity_catalog_provisioning_state + status_message = f"current status: {status}" + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach ACTIVE, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"name={self._name}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + class OnlineTablesAPI: """Online tables provide lower latency and higher QPS access to data from Delta tables.""" def __init__(self, api_client): self._api = api_client - def create(self, *, table: Optional[OnlineTable] = None) -> OnlineTable: + def create(self, *, table: Optional[OnlineTable] = None) -> OnlineTablesCreateWaiter: """Create an Online Table. Create a new Online Table. @@ -12307,8 +12349,10 @@ def create(self, *, table: Optional[OnlineTable] = None) -> OnlineTable: "Content-Type": "application/json", } - res = self._api.do("POST", "/api/2.0/online-tables", body=body, headers=headers) - return OnlineTable.from_dict(res) + op_response = self._api.do("POST", "/api/2.0/online-tables", body=body, headers=headers) + return OnlineTablesCreateWaiter( + service=self, raw_response=OnlineTable.from_dict(op_response), name=op_response["name"] + ) def delete(self, name: str): """Delete an Online Table. diff --git a/databricks/sdk/cleanrooms/v2/cleanrooms.py b/databricks/sdk/cleanrooms/v2/cleanrooms.py index 2e4bff8b9..1c6b4b7ae 100755 --- a/databricks/sdk/cleanrooms/v2/cleanrooms.py +++ b/databricks/sdk/cleanrooms/v2/cleanrooms.py @@ -837,10 +837,11 @@ class CleanRoomTaskRunState: life_cycle_state: Optional[CleanRoomTaskRunLifeCycleState] = None """A value indicating the run's current lifecycle state. This field is always available in the - response.""" + response. Note: Additional states might be introduced in future releases.""" result_state: Optional[CleanRoomTaskRunResultState] = None - """A value indicating the run's result. This field is only available for terminal lifecycle states.""" + """A value indicating the run's result. This field is only available for terminal lifecycle states. + Note: Additional states might be introduced in future releases.""" def as_dict(self) -> dict: """Serializes the CleanRoomTaskRunState into a dictionary suitable for use as a JSON request body.""" @@ -1153,6 +1154,7 @@ class ComplianceStandard(Enum): IRAP_PROTECTED = "IRAP_PROTECTED" ISMAP = "ISMAP" ITAR_EAR = "ITAR_EAR" + K_FSI = "K_FSI" NONE = "NONE" PCI_DSS = "PCI_DSS" diff --git a/databricks/sdk/compute/v2/compute.py b/databricks/sdk/compute/v2/compute.py index 54b3da663..6080feb41 100755 --- a/databricks/sdk/compute/v2/compute.py +++ b/databricks/sdk/compute/v2/compute.py @@ -3,12 +3,15 @@ from __future__ import annotations import logging +import random +import time from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Iterator, List, Optional -from ...service._internal import (_enum, _from_dict, _repeated_dict, - _repeated_enum) +from ...databricks.errors import OperationFailed +from ...service._internal import (WaitUntilDoneOptions, _enum, _from_dict, + _repeated_dict, _repeated_enum) _LOG = logging.getLogger("databricks.sdk") @@ -4455,6 +4458,10 @@ class EditInstancePool: min_idle_instances: Optional[int] = None """Minimum number of idle instances to keep in the instance pool""" + node_type_flexibility: Optional[NodeTypeFlexibility] = None + """For Fleet-pool V2, this object contains the information about the alternate node type ids to use + when attempting to launch a cluster if the node type id is not available.""" + def as_dict(self) -> dict: """Serializes the EditInstancePool into a dictionary suitable for use as a JSON request body.""" body = {} @@ -4470,6 +4477,8 @@ def as_dict(self) -> dict: body["max_capacity"] = self.max_capacity if self.min_idle_instances is not None: body["min_idle_instances"] = self.min_idle_instances + if self.node_type_flexibility: + body["node_type_flexibility"] = self.node_type_flexibility.as_dict() if self.node_type_id is not None: body["node_type_id"] = self.node_type_id return body @@ -4489,6 +4498,8 @@ def as_shallow_dict(self) -> dict: body["max_capacity"] = self.max_capacity if self.min_idle_instances is not None: body["min_idle_instances"] = self.min_idle_instances + if self.node_type_flexibility: + body["node_type_flexibility"] = self.node_type_flexibility if self.node_type_id is not None: body["node_type_id"] = self.node_type_id return body @@ -4503,6 +4514,7 @@ def from_dict(cls, d: Dict[str, Any]) -> EditInstancePool: instance_pool_name=d.get("instance_pool_name", None), max_capacity=d.get("max_capacity", None), min_idle_instances=d.get("min_idle_instances", None), + node_type_flexibility=_from_dict(d, "node_type_flexibility", NodeTypeFlexibility), node_type_id=d.get("node_type_id", None), ) @@ -5340,6 +5352,10 @@ class GetInstancePool: min_idle_instances: Optional[int] = None """Minimum number of idle instances to keep in the instance pool""" + node_type_flexibility: Optional[NodeTypeFlexibility] = None + """For Fleet-pool V2, this object contains the information about the alternate node type ids to use + when attempting to launch a cluster if the node type id is not available.""" + node_type_id: Optional[str] = None """This field encodes, through a single value, the resources available to each of the Spark nodes in this cluster. For example, the Spark nodes can be provisioned and optimized for memory or @@ -5390,6 +5406,8 @@ def as_dict(self) -> dict: body["max_capacity"] = self.max_capacity if self.min_idle_instances is not None: body["min_idle_instances"] = self.min_idle_instances + if self.node_type_flexibility: + body["node_type_flexibility"] = self.node_type_flexibility.as_dict() if self.node_type_id is not None: body["node_type_id"] = self.node_type_id if self.preloaded_docker_images: @@ -5431,6 +5449,8 @@ def as_shallow_dict(self) -> dict: body["max_capacity"] = self.max_capacity if self.min_idle_instances is not None: body["min_idle_instances"] = self.min_idle_instances + if self.node_type_flexibility: + body["node_type_flexibility"] = self.node_type_flexibility if self.node_type_id is not None: body["node_type_id"] = self.node_type_id if self.preloaded_docker_images: @@ -5461,6 +5481,7 @@ def from_dict(cls, d: Dict[str, Any]) -> GetInstancePool: instance_pool_name=d.get("instance_pool_name", None), max_capacity=d.get("max_capacity", None), min_idle_instances=d.get("min_idle_instances", None), + node_type_flexibility=_from_dict(d, "node_type_flexibility", NodeTypeFlexibility), node_type_id=d.get("node_type_id", None), preloaded_docker_images=_repeated_dict(d, "preloaded_docker_images", DockerImage), preloaded_spark_versions=d.get("preloaded_spark_versions", None), @@ -6295,6 +6316,10 @@ class InstancePoolAndStats: min_idle_instances: Optional[int] = None """Minimum number of idle instances to keep in the instance pool""" + node_type_flexibility: Optional[NodeTypeFlexibility] = None + """For Fleet-pool V2, this object contains the information about the alternate node type ids to use + when attempting to launch a cluster if the node type id is not available.""" + node_type_id: Optional[str] = None """This field encodes, through a single value, the resources available to each of the Spark nodes in this cluster. For example, the Spark nodes can be provisioned and optimized for memory or @@ -6345,6 +6370,8 @@ def as_dict(self) -> dict: body["max_capacity"] = self.max_capacity if self.min_idle_instances is not None: body["min_idle_instances"] = self.min_idle_instances + if self.node_type_flexibility: + body["node_type_flexibility"] = self.node_type_flexibility.as_dict() if self.node_type_id is not None: body["node_type_id"] = self.node_type_id if self.preloaded_docker_images: @@ -6386,6 +6413,8 @@ def as_shallow_dict(self) -> dict: body["max_capacity"] = self.max_capacity if self.min_idle_instances is not None: body["min_idle_instances"] = self.min_idle_instances + if self.node_type_flexibility: + body["node_type_flexibility"] = self.node_type_flexibility if self.node_type_id is not None: body["node_type_id"] = self.node_type_id if self.preloaded_docker_images: @@ -6416,6 +6445,7 @@ def from_dict(cls, d: Dict[str, Any]) -> InstancePoolAndStats: instance_pool_name=d.get("instance_pool_name", None), max_capacity=d.get("max_capacity", None), min_idle_instances=d.get("min_idle_instances", None), + node_type_flexibility=_from_dict(d, "node_type_flexibility", NodeTypeFlexibility), node_type_id=d.get("node_type_id", None), preloaded_docker_images=_repeated_dict(d, "preloaded_docker_images", DockerImage), preloaded_spark_versions=d.get("preloaded_spark_versions", None), @@ -7581,6 +7611,9 @@ def from_dict(cls, d: Dict[str, Any]) -> LogSyncStatus: return cls(last_attempted=d.get("last_attempted", None), last_exception=d.get("last_exception", None)) +MapAny = Dict[str, Any] + + @dataclass class MavenLibrary: coordinates: str @@ -7875,6 +7908,28 @@ def from_dict(cls, d: Dict[str, Any]) -> NodeType: ) +@dataclass +class NodeTypeFlexibility: + """For Fleet-V2 using classic clusters, this object contains the information about the alternate + node type ids to use when attempting to launch a cluster. It can be used with both the driver + and worker node types.""" + + def as_dict(self) -> dict: + """Serializes the NodeTypeFlexibility into a dictionary suitable for use as a JSON request body.""" + body = {} + return body + + def as_shallow_dict(self) -> dict: + """Serializes the NodeTypeFlexibility into a shallow dictionary of its immediate attributes.""" + body = {} + return body + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> NodeTypeFlexibility: + """Deserializes the NodeTypeFlexibility from a dictionary.""" + return cls() + + @dataclass class PendingInstanceError: """Error message of a failed pending instances""" @@ -9967,6 +10022,297 @@ def update_permissions( return ClusterPolicyPermissions.from_dict(res) +class ClustersCreateWaiter: + raw_response: CreateClusterResponse + """raw_response is the raw response of the Create call.""" + _service: ClustersAPI + _cluster_id: str + + def __init__(self, raw_response: CreateClusterResponse, service: ClustersAPI, cluster_id: str): + self._service = service + self.raw_response = raw_response + self._cluster_id = cluster_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> ClusterDetails: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (State.RUNNING,) + failure_states = ( + State.ERROR, + State.TERMINATED, + ) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get(cluster_id=self._cluster_id) + status = poll.state + status_message = f"current status: {status}" + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach RUNNING, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"cluster_id={self._cluster_id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + +class ClustersDeleteWaiter: + raw_response: DeleteClusterResponse + """raw_response is the raw response of the Delete call.""" + _service: ClustersAPI + _cluster_id: str + + def __init__(self, raw_response: DeleteClusterResponse, service: ClustersAPI, cluster_id: str): + self._service = service + self.raw_response = raw_response + self._cluster_id = cluster_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> ClusterDetails: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (State.TERMINATED,) + failure_states = (State.ERROR,) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get(cluster_id=self._cluster_id) + status = poll.state + status_message = f"current status: {status}" + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach TERMINATED, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"cluster_id={self._cluster_id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + +class ClustersEditWaiter: + raw_response: EditClusterResponse + """raw_response is the raw response of the Edit call.""" + _service: ClustersAPI + _cluster_id: str + + def __init__(self, raw_response: EditClusterResponse, service: ClustersAPI, cluster_id: str): + self._service = service + self.raw_response = raw_response + self._cluster_id = cluster_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> ClusterDetails: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (State.RUNNING,) + failure_states = ( + State.ERROR, + State.TERMINATED, + ) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get(cluster_id=self._cluster_id) + status = poll.state + status_message = f"current status: {status}" + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach RUNNING, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"cluster_id={self._cluster_id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + +class ClustersResizeWaiter: + raw_response: ResizeClusterResponse + """raw_response is the raw response of the Resize call.""" + _service: ClustersAPI + _cluster_id: str + + def __init__(self, raw_response: ResizeClusterResponse, service: ClustersAPI, cluster_id: str): + self._service = service + self.raw_response = raw_response + self._cluster_id = cluster_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> ClusterDetails: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (State.RUNNING,) + failure_states = ( + State.ERROR, + State.TERMINATED, + ) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get(cluster_id=self._cluster_id) + status = poll.state + status_message = f"current status: {status}" + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach RUNNING, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"cluster_id={self._cluster_id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + +class ClustersRestartWaiter: + raw_response: RestartClusterResponse + """raw_response is the raw response of the Restart call.""" + _service: ClustersAPI + _cluster_id: str + + def __init__(self, raw_response: RestartClusterResponse, service: ClustersAPI, cluster_id: str): + self._service = service + self.raw_response = raw_response + self._cluster_id = cluster_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> ClusterDetails: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (State.RUNNING,) + failure_states = ( + State.ERROR, + State.TERMINATED, + ) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get(cluster_id=self._cluster_id) + status = poll.state + status_message = f"current status: {status}" + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach RUNNING, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"cluster_id={self._cluster_id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + +class ClustersStartWaiter: + raw_response: StartClusterResponse + """raw_response is the raw response of the Start call.""" + _service: ClustersAPI + _cluster_id: str + + def __init__(self, raw_response: StartClusterResponse, service: ClustersAPI, cluster_id: str): + self._service = service + self.raw_response = raw_response + self._cluster_id = cluster_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> ClusterDetails: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (State.RUNNING,) + failure_states = ( + State.ERROR, + State.TERMINATED, + ) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get(cluster_id=self._cluster_id) + status = poll.state + status_message = f"current status: {status}" + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach RUNNING, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"cluster_id={self._cluster_id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + +class ClustersUpdateWaiter: + raw_response: UpdateClusterResponse + """raw_response is the raw response of the Update call.""" + _service: ClustersAPI + _cluster_id: str + + def __init__(self, raw_response: UpdateClusterResponse, service: ClustersAPI, cluster_id: str): + self._service = service + self.raw_response = raw_response + self._cluster_id = cluster_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> ClusterDetails: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (State.RUNNING,) + failure_states = ( + State.ERROR, + State.TERMINATED, + ) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get(cluster_id=self._cluster_id) + status = poll.state + status_message = f"current status: {status}" + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach RUNNING, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"cluster_id={self._cluster_id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + class ClustersAPI: """The Clusters API allows you to create, start, edit, list, terminate, and delete clusters. @@ -10051,7 +10397,7 @@ def create( ssh_public_keys: Optional[List[str]] = None, use_ml_runtime: Optional[bool] = None, workload_type: Optional[WorkloadType] = None, - ) -> CreateClusterResponse: + ) -> ClustersCreateWaiter: """Create new cluster. Creates a new Spark cluster. This method will acquire new instances from the cloud provider if @@ -10303,10 +10649,14 @@ def create( "Content-Type": "application/json", } - res = self._api.do("POST", "/api/2.1/clusters/create", body=body, headers=headers) - return CreateClusterResponse.from_dict(res) + op_response = self._api.do("POST", "/api/2.1/clusters/create", body=body, headers=headers) + return ClustersCreateWaiter( + service=self, + raw_response=CreateClusterResponse.from_dict(op_response), + cluster_id=op_response["cluster_id"], + ) - def delete(self, cluster_id: str): + def delete(self, cluster_id: str) -> ClustersDeleteWaiter: """Terminate cluster. Terminates the Spark cluster with the specified ID. The cluster is removed asynchronously. Once the @@ -10328,7 +10678,10 @@ def delete(self, cluster_id: str): "Content-Type": "application/json", } - self._api.do("POST", "/api/2.1/clusters/delete", body=body, headers=headers) + op_response = self._api.do("POST", "/api/2.1/clusters/delete", body=body, headers=headers) + return ClustersDeleteWaiter( + service=self, raw_response=DeleteClusterResponse.from_dict(op_response), cluster_id=cluster_id + ) def edit( self, @@ -10364,7 +10717,7 @@ def edit( ssh_public_keys: Optional[List[str]] = None, use_ml_runtime: Optional[bool] = None, workload_type: Optional[WorkloadType] = None, - ): + ) -> ClustersEditWaiter: """Update cluster configuration. Updates the configuration of a cluster to match the provided attributes and size. A cluster can be @@ -10613,7 +10966,10 @@ def edit( "Content-Type": "application/json", } - self._api.do("POST", "/api/2.1/clusters/edit", body=body, headers=headers) + op_response = self._api.do("POST", "/api/2.1/clusters/edit", body=body, headers=headers) + return ClustersEditWaiter( + service=self, raw_response=EditClusterResponse.from_dict(op_response), cluster_id=cluster_id + ) def events( self, @@ -10861,7 +11217,9 @@ def pin(self, cluster_id: str): self._api.do("POST", "/api/2.1/clusters/pin", body=body, headers=headers) - def resize(self, cluster_id: str, *, autoscale: Optional[AutoScale] = None, num_workers: Optional[int] = None): + def resize( + self, cluster_id: str, *, autoscale: Optional[AutoScale] = None, num_workers: Optional[int] = None + ) -> ClustersResizeWaiter: """Resize cluster. Resizes a cluster to have a desired number of workers. This will fail unless the cluster is in a @@ -10898,9 +11256,12 @@ def resize(self, cluster_id: str, *, autoscale: Optional[AutoScale] = None, num_ "Content-Type": "application/json", } - self._api.do("POST", "/api/2.1/clusters/resize", body=body, headers=headers) + op_response = self._api.do("POST", "/api/2.1/clusters/resize", body=body, headers=headers) + return ClustersResizeWaiter( + service=self, raw_response=ResizeClusterResponse.from_dict(op_response), cluster_id=cluster_id + ) - def restart(self, cluster_id: str, *, restart_user: Optional[str] = None): + def restart(self, cluster_id: str, *, restart_user: Optional[str] = None) -> ClustersRestartWaiter: """Restart cluster. Restarts a Spark cluster with the supplied ID. If the cluster is not currently in a `RUNNING` state, @@ -10924,7 +11285,10 @@ def restart(self, cluster_id: str, *, restart_user: Optional[str] = None): "Content-Type": "application/json", } - self._api.do("POST", "/api/2.1/clusters/restart", body=body, headers=headers) + op_response = self._api.do("POST", "/api/2.1/clusters/restart", body=body, headers=headers) + return ClustersRestartWaiter( + service=self, raw_response=RestartClusterResponse.from_dict(op_response), cluster_id=cluster_id + ) def set_permissions( self, cluster_id: str, *, access_control_list: Optional[List[ClusterAccessControlRequest]] = None @@ -10966,7 +11330,7 @@ def spark_versions(self) -> GetSparkVersionsResponse: res = self._api.do("GET", "/api/2.1/clusters/spark-versions", headers=headers) return GetSparkVersionsResponse.from_dict(res) - def start(self, cluster_id: str): + def start(self, cluster_id: str) -> ClustersStartWaiter: """Start terminated cluster. Starts a terminated Spark cluster with the supplied ID. This works similar to `createCluster` except: @@ -10990,7 +11354,10 @@ def start(self, cluster_id: str): "Content-Type": "application/json", } - self._api.do("POST", "/api/2.1/clusters/start", body=body, headers=headers) + op_response = self._api.do("POST", "/api/2.1/clusters/start", body=body, headers=headers) + return ClustersStartWaiter( + service=self, raw_response=StartClusterResponse.from_dict(op_response), cluster_id=cluster_id + ) def unpin(self, cluster_id: str): """Unpin cluster. @@ -11013,7 +11380,9 @@ def unpin(self, cluster_id: str): self._api.do("POST", "/api/2.1/clusters/unpin", body=body, headers=headers) - def update(self, cluster_id: str, update_mask: str, *, cluster: Optional[UpdateClusterResource] = None): + def update( + self, cluster_id: str, update_mask: str, *, cluster: Optional[UpdateClusterResource] = None + ) -> ClustersUpdateWaiter: """Update cluster configuration (partial). Updates the configuration of a cluster to match the partial set of attributes and size. Denote which @@ -11059,7 +11428,10 @@ def update(self, cluster_id: str, update_mask: str, *, cluster: Optional[UpdateC "Content-Type": "application/json", } - self._api.do("POST", "/api/2.1/clusters/update", body=body, headers=headers) + op_response = self._api.do("POST", "/api/2.1/clusters/update", body=body, headers=headers) + return ClustersUpdateWaiter( + service=self, raw_response=UpdateClusterResponse.from_dict(op_response), cluster_id=cluster_id + ) def update_permissions( self, cluster_id: str, *, access_control_list: Optional[List[ClusterAccessControlRequest]] = None @@ -11086,6 +11458,154 @@ def update_permissions( return ClusterPermissions.from_dict(res) +class CommandExecutionCancelWaiter: + raw_response: CancelResponse + """raw_response is the raw response of the Cancel call.""" + _service: CommandExecutionAPI + _cluster_id: str + _command_id: str + _context_id: str + + def __init__( + self, + raw_response: CancelResponse, + service: CommandExecutionAPI, + cluster_id: str, + command_id: str, + context_id: str, + ): + self._service = service + self.raw_response = raw_response + self._cluster_id = cluster_id + self._command_id = command_id + self._context_id = context_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> CommandStatusResponse: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (CommandStatus.CANCELLED,) + failure_states = (CommandStatus.ERROR,) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.command_status( + cluster_id=self._cluster_id, command_id=self._command_id, context_id=self._context_id + ) + status = poll.status + status_message = f"current status: {status}" + if poll.results: + status_message = poll.results.cause + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach Cancelled, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"cluster_id={self._cluster_id}, command_id={self._command_id}, context_id={self._context_id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + +class CommandExecutionCreateWaiter: + raw_response: Created + """raw_response is the raw response of the Create call.""" + _service: CommandExecutionAPI + _cluster_id: str + _context_id: str + + def __init__(self, raw_response: Created, service: CommandExecutionAPI, cluster_id: str, context_id: str): + self._service = service + self.raw_response = raw_response + self._cluster_id = cluster_id + self._context_id = context_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> ContextStatusResponse: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (ContextStatus.RUNNING,) + failure_states = (ContextStatus.ERROR,) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.context_status(cluster_id=self._cluster_id, context_id=self._context_id) + status = poll.status + status_message = f"current status: {status}" + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach Running, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"cluster_id={self._cluster_id}, context_id={self._context_id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + +class CommandExecutionExecuteWaiter: + raw_response: Created + """raw_response is the raw response of the Execute call.""" + _service: CommandExecutionAPI + _cluster_id: str + _command_id: str + _context_id: str + + def __init__( + self, raw_response: Created, service: CommandExecutionAPI, cluster_id: str, command_id: str, context_id: str + ): + self._service = service + self.raw_response = raw_response + self._cluster_id = cluster_id + self._command_id = command_id + self._context_id = context_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> CommandStatusResponse: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = ( + CommandStatus.FINISHED, + CommandStatus.ERROR, + ) + failure_states = ( + CommandStatus.CANCELLED, + CommandStatus.CANCELLING, + ) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.command_status( + cluster_id=self._cluster_id, command_id=self._command_id, context_id=self._context_id + ) + status = poll.status + status_message = f"current status: {status}" + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach Finished or Error, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"cluster_id={self._cluster_id}, command_id={self._command_id}, context_id={self._context_id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + class CommandExecutionAPI: """This API allows execution of Python, Scala, SQL, or R commands on running Databricks Clusters. This API only supports (classic) all-purpose clusters. Serverless compute is not supported.""" @@ -11095,7 +11615,7 @@ def __init__(self, api_client): def cancel( self, *, cluster_id: Optional[str] = None, command_id: Optional[str] = None, context_id: Optional[str] = None - ): + ) -> CommandExecutionCancelWaiter: """Cancel a command. Cancels a currently running command within an execution context. @@ -11122,7 +11642,14 @@ def cancel( "Content-Type": "application/json", } - self._api.do("POST", "/api/1.2/commands/cancel", body=body, headers=headers) + op_response = self._api.do("POST", "/api/1.2/commands/cancel", body=body, headers=headers) + return CommandExecutionCancelWaiter( + service=self, + raw_response=CancelResponse.from_dict(op_response), + cluster_id=cluster_id, + command_id=command_id, + context_id=context_id, + ) def command_status(self, cluster_id: str, context_id: str, command_id: str) -> CommandStatusResponse: """Get command info. @@ -11175,7 +11702,9 @@ def context_status(self, cluster_id: str, context_id: str) -> ContextStatusRespo res = self._api.do("GET", "/api/1.2/contexts/status", query=query, headers=headers) return ContextStatusResponse.from_dict(res) - def create(self, *, cluster_id: Optional[str] = None, language: Optional[Language] = None) -> Created: + def create( + self, *, cluster_id: Optional[str] = None, language: Optional[Language] = None + ) -> CommandExecutionCreateWaiter: """Create an execution context. Creates an execution context for running cluster commands. @@ -11200,8 +11729,13 @@ def create(self, *, cluster_id: Optional[str] = None, language: Optional[Languag "Content-Type": "application/json", } - res = self._api.do("POST", "/api/1.2/contexts/create", body=body, headers=headers) - return Created.from_dict(res) + op_response = self._api.do("POST", "/api/1.2/contexts/create", body=body, headers=headers) + return CommandExecutionCreateWaiter( + service=self, + raw_response=Created.from_dict(op_response), + cluster_id=cluster_id, + context_id=op_response["id"], + ) def destroy(self, cluster_id: str, context_id: str): """Delete an execution context. @@ -11232,7 +11766,7 @@ def execute( command: Optional[str] = None, context_id: Optional[str] = None, language: Optional[Language] = None, - ) -> Created: + ) -> CommandExecutionExecuteWaiter: """Run a command. Runs a cluster command in the given execution context, using the provided language. @@ -11265,8 +11799,14 @@ def execute( "Content-Type": "application/json", } - res = self._api.do("POST", "/api/1.2/commands/execute", body=body, headers=headers) - return Created.from_dict(res) + op_response = self._api.do("POST", "/api/1.2/commands/execute", body=body, headers=headers) + return CommandExecutionExecuteWaiter( + service=self, + raw_response=Created.from_dict(op_response), + cluster_id=cluster_id, + command_id=op_response["id"], + context_id=context_id, + ) class GlobalInitScriptsAPI: @@ -11577,6 +12117,7 @@ def edit( idle_instance_autotermination_minutes: Optional[int] = None, max_capacity: Optional[int] = None, min_idle_instances: Optional[int] = None, + node_type_flexibility: Optional[NodeTypeFlexibility] = None, ): """Edit an existing instance pool. @@ -11609,6 +12150,9 @@ def edit( upsize requests. :param min_idle_instances: int (optional) Minimum number of idle instances to keep in the instance pool + :param node_type_flexibility: :class:`NodeTypeFlexibility` (optional) + For Fleet-pool V2, this object contains the information about the alternate node type ids to use + when attempting to launch a cluster if the node type id is not available. """ @@ -11625,6 +12169,8 @@ def edit( body["max_capacity"] = max_capacity if min_idle_instances is not None: body["min_idle_instances"] = min_idle_instances + if node_type_flexibility is not None: + body["node_type_flexibility"] = node_type_flexibility.as_dict() if node_type_id is not None: body["node_type_id"] = node_type_id headers = { diff --git a/databricks/sdk/dashboards/v2/dashboards.py b/databricks/sdk/dashboards/v2/dashboards.py index c242458e1..c5459a3ef 100755 --- a/databricks/sdk/dashboards/v2/dashboards.py +++ b/databricks/sdk/dashboards/v2/dashboards.py @@ -3,11 +3,15 @@ from __future__ import annotations import logging +import random +import time from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Iterator, List, Optional -from ...service._internal import _enum, _from_dict, _repeated_dict +from ...databricks.errors import OperationFailed +from ...service._internal import (WaitUntilDoneOptions, _enum, _from_dict, + _repeated_dict) _LOG = logging.getLogger("databricks.sdk") @@ -868,45 +872,27 @@ def from_dict(cls, d: Dict[str, Any]) -> GenieCreateConversationMessageRequest: @dataclass class GenieGenerateDownloadFullQueryResultResponse: - error: Optional[str] = None - """Error message if Genie failed to download the result""" - - status: Optional[MessageStatus] = None - """Download result status""" - - transient_statement_id: Optional[str] = None - """Transient Statement ID. Use this ID to track the download request in subsequent polling calls""" + download_id: Optional[str] = None + """Download ID. Use this ID to track the download request in subsequent polling calls""" def as_dict(self) -> dict: """Serializes the GenieGenerateDownloadFullQueryResultResponse into a dictionary suitable for use as a JSON request body.""" body = {} - if self.error is not None: - body["error"] = self.error - if self.status is not None: - body["status"] = self.status.value - if self.transient_statement_id is not None: - body["transient_statement_id"] = self.transient_statement_id + if self.download_id is not None: + body["download_id"] = self.download_id return body def as_shallow_dict(self) -> dict: """Serializes the GenieGenerateDownloadFullQueryResultResponse into a shallow dictionary of its immediate attributes.""" body = {} - if self.error is not None: - body["error"] = self.error - if self.status is not None: - body["status"] = self.status - if self.transient_statement_id is not None: - body["transient_statement_id"] = self.transient_statement_id + if self.download_id is not None: + body["download_id"] = self.download_id return body @classmethod def from_dict(cls, d: Dict[str, Any]) -> GenieGenerateDownloadFullQueryResultResponse: """Deserializes the GenieGenerateDownloadFullQueryResultResponse from a dictionary.""" - return cls( - error=d.get("error", None), - status=_enum(d, "status", MessageStatus), - transient_statement_id=d.get("transient_statement_id", None), - ) + return cls(download_id=d.get("download_id", None)) @dataclass @@ -915,16 +901,11 @@ class GenieGetDownloadFullQueryResultResponse: """SQL Statement Execution response. See [Get status, manifest, and result first chunk](:method:statementexecution/getstatement) for more details.""" - transient_statement_id: Optional[str] = None - """Transient Statement ID""" - def as_dict(self) -> dict: """Serializes the GenieGetDownloadFullQueryResultResponse into a dictionary suitable for use as a JSON request body.""" body = {} if self.statement_response: body["statement_response"] = self.statement_response.as_dict() - if self.transient_statement_id is not None: - body["transient_statement_id"] = self.transient_statement_id return body def as_shallow_dict(self) -> dict: @@ -932,17 +913,12 @@ def as_shallow_dict(self) -> dict: body = {} if self.statement_response: body["statement_response"] = self.statement_response - if self.transient_statement_id is not None: - body["transient_statement_id"] = self.transient_statement_id return body @classmethod def from_dict(cls, d: Dict[str, Any]) -> GenieGetDownloadFullQueryResultResponse: """Deserializes the GenieGetDownloadFullQueryResultResponse from a dictionary.""" - return cls( - statement_response=_from_dict(d, "statement_response", StatementResponse), - transient_statement_id=d.get("transient_statement_id", None), - ) + return cls(statement_response=_from_dict(d, "statement_response", StatementResponse)) @dataclass @@ -1210,7 +1186,7 @@ def from_dict(cls, d: Dict[str, Any]) -> GenieResultMetadata: @dataclass class GenieSpace: space_id: str - """Space ID""" + """Genie space ID""" title: str """Title of the Genie Space""" @@ -2652,6 +2628,109 @@ def from_dict(cls, d: Dict[str, Any]) -> UnpublishDashboardResponse: return cls() +class GenieCreateMessageWaiter: + raw_response: GenieMessage + """raw_response is the raw response of the CreateMessage call.""" + _service: GenieAPI + _conversation_id: str + _message_id: str + _space_id: str + + def __init__( + self, raw_response: GenieMessage, service: GenieAPI, conversation_id: str, message_id: str, space_id: str + ): + self._service = service + self.raw_response = raw_response + self._conversation_id = conversation_id + self._message_id = message_id + self._space_id = space_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> GenieMessage: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (MessageStatus.COMPLETED,) + failure_states = (MessageStatus.FAILED,) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get_message( + conversation_id=self._conversation_id, message_id=self._message_id, space_id=self._space_id + ) + status = poll.status + status_message = f"current status: {status}" + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach COMPLETED, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = ( + f"conversation_id={self._conversation_id}, message_id={self._message_id}, space_id={self._space_id}" + ) + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + +class GenieStartConversationWaiter: + raw_response: GenieStartConversationResponse + """raw_response is the raw response of the StartConversation call.""" + _service: GenieAPI + _conversation_id: str + _message_id: str + _space_id: str + + def __init__( + self, + raw_response: GenieStartConversationResponse, + service: GenieAPI, + conversation_id: str, + message_id: str, + space_id: str, + ): + self._service = service + self.raw_response = raw_response + self._conversation_id = conversation_id + self._message_id = message_id + self._space_id = space_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> GenieMessage: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (MessageStatus.COMPLETED,) + failure_states = (MessageStatus.FAILED,) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get_message( + conversation_id=self._conversation_id, message_id=self._message_id, space_id=self._space_id + ) + status = poll.status + status_message = f"current status: {status}" + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach COMPLETED, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = ( + f"conversation_id={self._conversation_id}, message_id={self._message_id}, space_id={self._space_id}" + ) + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + class GenieAPI: """Genie provides a no-code experience for business users, powered by AI/BI. Analysts set up spaces that business users can use to ask questions using natural language. Genie uses data registered to Unity @@ -2661,7 +2740,7 @@ class GenieAPI: def __init__(self, api_client): self._api = api_client - def create_message(self, space_id: str, conversation_id: str, content: str) -> GenieMessage: + def create_message(self, space_id: str, conversation_id: str, content: str) -> GenieCreateMessageWaiter: """Create conversation message. Create new message in a [conversation](:method:genie/startconversation). The AI response uses all @@ -2686,13 +2765,19 @@ def create_message(self, space_id: str, conversation_id: str, content: str) -> G "Content-Type": "application/json", } - res = self._api.do( + op_response = self._api.do( "POST", f"/api/2.0/genie/spaces/{space_id}/conversations/{conversation_id}/messages", body=body, headers=headers, ) - return GenieMessage.from_dict(res) + return GenieCreateMessageWaiter( + service=self, + raw_response=GenieMessage.from_dict(op_response), + conversation_id=conversation_id, + message_id=op_response["id"], + space_id=space_id, + ) def execute_message_attachment_query( self, space_id: str, conversation_id: str, message_id: str, attachment_id: str @@ -2758,15 +2843,14 @@ def generate_download_full_query_result( ) -> GenieGenerateDownloadFullQueryResultResponse: """Generate full query result download. - Initiate full SQL query result download and obtain a transient ID for tracking the download progress. - This call initiates a new SQL execution to generate the query result. The result is stored in an - external link can be retrieved using the [Get Download Full Query - Result](:method:genie/getdownloadfullqueryresult) API. Warning: Databricks strongly recommends that - you protect the URLs that are returned by the `EXTERNAL_LINKS` disposition. See [Execute - Statement](:method:statementexecution/executestatement) for more details. + Initiates a new SQL execution and returns a `download_id` that you can use to track the progress of + the download. The query result is stored in an external link and can be retrieved using the [Get + Download Full Query Result](:method:genie/getdownloadfullqueryresult) API. Warning: Databricks + strongly recommends that you protect the URLs that are returned by the `EXTERNAL_LINKS` disposition. + See [Execute Statement](:method:statementexecution/executestatement) for more details. :param space_id: str - Space ID + Genie space ID :param conversation_id: str Conversation ID :param message_id: str @@ -2783,51 +2867,46 @@ def generate_download_full_query_result( res = self._api.do( "POST", - f"/api/2.0/genie/spaces/{space_id}/conversations/{conversation_id}/messages/{message_id}/attachments/{attachment_id}/generate-download", + f"/api/2.0/genie/spaces/{space_id}/conversations/{conversation_id}/messages/{message_id}/attachments/{attachment_id}/downloads", headers=headers, ) return GenieGenerateDownloadFullQueryResultResponse.from_dict(res) def get_download_full_query_result( - self, space_id: str, conversation_id: str, message_id: str, attachment_id: str, transient_statement_id: str + self, space_id: str, conversation_id: str, message_id: str, attachment_id: str, download_id: str ) -> GenieGetDownloadFullQueryResultResponse: - """Get download full query result status. - - Poll download progress and retrieve the SQL query result external link(s) upon completion. Warning: - Databricks strongly recommends that you protect the URLs that are returned by the `EXTERNAL_LINKS` - disposition. When you use the `EXTERNAL_LINKS` disposition, a short-lived, presigned URL is generated, - which can be used to download the results directly from Amazon S3. As a short-lived access credential - is embedded in this presigned URL, you should protect the URL. Because presigned URLs are already - generated with embedded temporary access credentials, you must not set an Authorization header in the - download requests. See [Execute Statement](:method:statementexecution/executestatement) for more - details. + """Get download full query result. + + After [Generating a Full Query Result Download](:method:genie/getdownloadfullqueryresult) and + successfully receiving a `download_id`, use this API to poll the download progress. When the download + is complete, the API returns one or more external links to the query result files. Warning: Databricks + strongly recommends that you protect the URLs that are returned by the `EXTERNAL_LINKS` disposition. + You must not set an Authorization header in download requests. When using the `EXTERNAL_LINKS` + disposition, Databricks returns presigned URLs that grant temporary access to data. See [Execute + Statement](:method:statementexecution/executestatement) for more details. :param space_id: str - Space ID + Genie space ID :param conversation_id: str Conversation ID :param message_id: str Message ID :param attachment_id: str Attachment ID - :param transient_statement_id: str - Transient Statement ID. This ID is provided by the [Start Download - endpoint](:method:genie/startdownloadfullqueryresult) + :param download_id: str + Download ID. This ID is provided by the [Generate Download + endpoint](:method:genie/generateDownloadFullQueryResult) :returns: :class:`GenieGetDownloadFullQueryResultResponse` """ - query = {} - if transient_statement_id is not None: - query["transient_statement_id"] = transient_statement_id headers = { "Accept": "application/json", } res = self._api.do( "GET", - f"/api/2.0/genie/spaces/{space_id}/conversations/{conversation_id}/messages/{message_id}/attachments/{attachment_id}/get-download", - query=query, + f"/api/2.0/genie/spaces/{space_id}/conversations/{conversation_id}/messages/{message_id}/attachments/{attachment_id}/downloads/{download_id}", headers=headers, ) return GenieGetDownloadFullQueryResultResponse.from_dict(res) @@ -2967,7 +3046,7 @@ def get_space(self, space_id: str) -> GenieSpace: res = self._api.do("GET", f"/api/2.0/genie/spaces/{space_id}", headers=headers) return GenieSpace.from_dict(res) - def start_conversation(self, space_id: str, content: str) -> GenieStartConversationResponse: + def start_conversation(self, space_id: str, content: str) -> GenieStartConversationWaiter: """Start conversation. Start a new conversation. @@ -2989,8 +3068,16 @@ def start_conversation(self, space_id: str, content: str) -> GenieStartConversat "Content-Type": "application/json", } - res = self._api.do("POST", f"/api/2.0/genie/spaces/{space_id}/start-conversation", body=body, headers=headers) - return GenieStartConversationResponse.from_dict(res) + op_response = self._api.do( + "POST", f"/api/2.0/genie/spaces/{space_id}/start-conversation", body=body, headers=headers + ) + return GenieStartConversationWaiter( + service=self, + raw_response=GenieStartConversationResponse.from_dict(op_response), + conversation_id=op_response["conversation_id"], + message_id=op_response["message_id"], + space_id=space_id, + ) class LakeviewAPI: diff --git a/databricks/sdk/jobs/v2/jobs.py b/databricks/sdk/jobs/v2/jobs.py index 76aeaac34..54ad72fd0 100755 --- a/databricks/sdk/jobs/v2/jobs.py +++ b/databricks/sdk/jobs/v2/jobs.py @@ -3,11 +3,15 @@ from __future__ import annotations import logging +import random +import time from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Iterator, List, Optional -from ...service._internal import _enum, _from_dict, _repeated_dict +from ...databricks.errors import OperationFailed +from ...service._internal import (WaitUntilDoneOptions, _enum, _from_dict, + _repeated_dict) _LOG = logging.getLogger("databricks.sdk") @@ -417,7 +421,11 @@ class BaseRun: effective_performance_target: Optional[PerformanceTarget] = None """The actual performance target used by the serverless run during execution. This can differ from the client-set performance target on the request depending on whether the performance mode is - supported by the job type.""" + supported by the job type. + + * `STANDARD`: Enables cost-efficient execution of serverless workloads. * + `PERFORMANCE_OPTIMIZED`: Prioritizes fast startup and execution times through rapid scaling and + optimized cluster performance.""" end_time: Optional[int] = None """The time at which this run ended in epoch milliseconds (milliseconds since 1/1/1970 UTC). This @@ -856,10 +864,11 @@ class CleanRoomTaskRunState: life_cycle_state: Optional[CleanRoomTaskRunLifeCycleState] = None """A value indicating the run's current lifecycle state. This field is always available in the - response.""" + response. Note: Additional states might be introduced in future releases.""" result_state: Optional[CleanRoomTaskRunResultState] = None - """A value indicating the run's result. This field is only available for terminal lifecycle states.""" + """A value indicating the run's result. This field is only available for terminal lifecycle states. + Note: Additional states might be introduced in future releases.""" def as_dict(self) -> dict: """Serializes the CleanRoomTaskRunState into a dictionary suitable for use as a JSON request body.""" @@ -1362,8 +1371,7 @@ class CreateJob: job_clusters: Optional[List[JobCluster]] = None """A list of job cluster specifications that can be shared and reused by tasks of this job. Libraries cannot be declared in a shared job cluster. You must declare dependent libraries in - task settings. If more than 100 job clusters are available, you can paginate through them using - :method:jobs/get.""" + task settings.""" max_concurrent_runs: Optional[int] = None """An optional maximum allowed number of concurrent runs of the job. Set this value if you want to @@ -1387,7 +1395,11 @@ class CreateJob: performance_target: Optional[PerformanceTarget] = None """The performance mode on a serverless job. The performance target determines the level of compute - performance or cost-efficiency for the run.""" + performance or cost-efficiency for the run. + + * `STANDARD`: Enables cost-efficient execution of serverless workloads. * + `PERFORMANCE_OPTIMIZED`: Prioritizes fast startup and execution times through rapid scaling and + optimized cluster performance.""" queue: Optional[QueueSettings] = None """The queue settings of the job.""" @@ -1408,9 +1420,11 @@ class CreateJob: be added to the job.""" tasks: Optional[List[Task]] = None - """A list of task specifications to be executed by this job. If more than 100 tasks are available, - you can paginate through them using :method:jobs/get. Use the `next_page_token` field at the - object root to determine if more results are available.""" + """A list of task specifications to be executed by this job. It supports up to 1000 elements in + write endpoints (:method:jobs/create, :method:jobs/reset, :method:jobs/update, + :method:jobs/submit). Read endpoints return only 100 tasks. If more than 100 tasks are + available, you can paginate through them using :method:jobs/get. Use the `next_page_token` field + at the object root to determine if more results are available.""" timeout_seconds: Optional[int] = None """An optional timeout applied to each run of this job. A value of `0` means no timeout.""" @@ -1679,11 +1693,14 @@ class DashboardTask: """Configures the Lakeview Dashboard job task type.""" dashboard_id: Optional[str] = None + """The identifier of the dashboard to refresh.""" subscription: Optional[Subscription] = None + """Optional: subscription configuration for sending the dashboard snapshot.""" warehouse_id: Optional[str] = None - """The warehouse id to execute the dashboard with for the schedule""" + """Optional: The warehouse id to execute the dashboard with for the schedule. If not specified, the + default warehouse of the dashboard will be used.""" def as_dict(self) -> dict: """Serializes the DashboardTask into a dictionary suitable for use as a JSON request body.""" @@ -3837,8 +3854,7 @@ class JobSettings: job_clusters: Optional[List[JobCluster]] = None """A list of job cluster specifications that can be shared and reused by tasks of this job. Libraries cannot be declared in a shared job cluster. You must declare dependent libraries in - task settings. If more than 100 job clusters are available, you can paginate through them using - :method:jobs/get.""" + task settings.""" max_concurrent_runs: Optional[int] = None """An optional maximum allowed number of concurrent runs of the job. Set this value if you want to @@ -3862,7 +3878,11 @@ class JobSettings: performance_target: Optional[PerformanceTarget] = None """The performance mode on a serverless job. The performance target determines the level of compute - performance or cost-efficiency for the run.""" + performance or cost-efficiency for the run. + + * `STANDARD`: Enables cost-efficient execution of serverless workloads. * + `PERFORMANCE_OPTIMIZED`: Prioritizes fast startup and execution times through rapid scaling and + optimized cluster performance.""" queue: Optional[QueueSettings] = None """The queue settings of the job.""" @@ -3883,9 +3903,11 @@ class JobSettings: be added to the job.""" tasks: Optional[List[Task]] = None - """A list of task specifications to be executed by this job. If more than 100 tasks are available, - you can paginate through them using :method:jobs/get. Use the `next_page_token` field at the - object root to determine if more results are available.""" + """A list of task specifications to be executed by this job. It supports up to 1000 elements in + write endpoints (:method:jobs/create, :method:jobs/reset, :method:jobs/update, + :method:jobs/submit). Read endpoints return only 100 tasks. If more than 100 tasks are + available, you can paginate through them using :method:jobs/get. Use the `next_page_token` field + at the object root to determine if more results are available.""" timeout_seconds: Optional[int] = None """An optional timeout applied to each run of this job. A value of `0` means no timeout.""" @@ -5090,7 +5112,6 @@ class PerformanceTarget(Enum): on serverless compute should be. The performance mode on the job or pipeline should map to a performance setting that is passed to Cluster Manager (see cluster-common PerformanceTarget).""" - COST_OPTIMIZED = "COST_OPTIMIZED" PERFORMANCE_OPTIMIZED = "PERFORMANCE_OPTIMIZED" STANDARD = "STANDARD" @@ -5555,6 +5576,15 @@ def from_dict(cls, d: Dict[str, Any]) -> RCranLibrary: @dataclass class RepairHistoryItem: + effective_performance_target: Optional[PerformanceTarget] = None + """The actual performance target used by the serverless run during execution. This can differ from + the client-set performance target on the request depending on whether the performance mode is + supported by the job type. + + * `STANDARD`: Enables cost-efficient execution of serverless workloads. * + `PERFORMANCE_OPTIMIZED`: Prioritizes fast startup and execution times through rapid scaling and + optimized cluster performance.""" + end_time: Optional[int] = None """The end time of the (repaired) run.""" @@ -5579,6 +5609,8 @@ class RepairHistoryItem: def as_dict(self) -> dict: """Serializes the RepairHistoryItem into a dictionary suitable for use as a JSON request body.""" body = {} + if self.effective_performance_target is not None: + body["effective_performance_target"] = self.effective_performance_target.value if self.end_time is not None: body["end_time"] = self.end_time if self.id is not None: @@ -5598,6 +5630,8 @@ def as_dict(self) -> dict: def as_shallow_dict(self) -> dict: """Serializes the RepairHistoryItem into a shallow dictionary of its immediate attributes.""" body = {} + if self.effective_performance_target is not None: + body["effective_performance_target"] = self.effective_performance_target if self.end_time is not None: body["end_time"] = self.end_time if self.id is not None: @@ -5618,6 +5652,7 @@ def as_shallow_dict(self) -> dict: def from_dict(cls, d: Dict[str, Any]) -> RepairHistoryItem: """Deserializes the RepairHistoryItem from a dictionary.""" return cls( + effective_performance_target=_enum(d, "effective_performance_target", PerformanceTarget), end_time=d.get("end_time", None), id=d.get("id", None), start_time=d.get("start_time", None), @@ -5679,6 +5714,15 @@ class RepairRun: [Task parameter variables]: https://docs.databricks.com/jobs.html#parameter-variables [dbutils.widgets.get]: https://docs.databricks.com/dev-tools/databricks-utils.html""" + performance_target: Optional[PerformanceTarget] = None + """The performance mode on a serverless job. The performance target determines the level of compute + performance or cost-efficiency for the run. This field overrides the performance target defined + on the job level. + + * `STANDARD`: Enables cost-efficient execution of serverless workloads. * + `PERFORMANCE_OPTIMIZED`: Prioritizes fast startup and execution times through rapid scaling and + optimized cluster performance.""" + pipeline_params: Optional[PipelineParams] = None """Controls whether the pipeline should perform a full refresh""" @@ -5745,6 +5789,8 @@ def as_dict(self) -> dict: body["latest_repair_id"] = self.latest_repair_id if self.notebook_params: body["notebook_params"] = self.notebook_params + if self.performance_target is not None: + body["performance_target"] = self.performance_target.value if self.pipeline_params: body["pipeline_params"] = self.pipeline_params.as_dict() if self.python_named_params: @@ -5778,6 +5824,8 @@ def as_shallow_dict(self) -> dict: body["latest_repair_id"] = self.latest_repair_id if self.notebook_params: body["notebook_params"] = self.notebook_params + if self.performance_target is not None: + body["performance_target"] = self.performance_target if self.pipeline_params: body["pipeline_params"] = self.pipeline_params if self.python_named_params: @@ -5807,6 +5855,7 @@ def from_dict(cls, d: Dict[str, Any]) -> RepairRun: job_parameters=d.get("job_parameters", None), latest_repair_id=d.get("latest_repair_id", None), notebook_params=d.get("notebook_params", None), + performance_target=_enum(d, "performance_target", PerformanceTarget), pipeline_params=_from_dict(d, "pipeline_params", PipelineParams), python_named_params=d.get("python_named_params", None), python_params=d.get("python_params", None), @@ -6209,7 +6258,11 @@ class Run: effective_performance_target: Optional[PerformanceTarget] = None """The actual performance target used by the serverless run during execution. This can differ from the client-set performance target on the request depending on whether the performance mode is - supported by the job type.""" + supported by the job type. + + * `STANDARD`: Enables cost-efficient execution of serverless workloads. * + `PERFORMANCE_OPTIMIZED`: Prioritizes fast startup and execution times through rapid scaling and + optimized cluster performance.""" end_time: Optional[int] = None """The time at which this run ended in epoch milliseconds (milliseconds since 1/1/1970 UTC). This @@ -6933,7 +6986,11 @@ class RunNow: performance_target: Optional[PerformanceTarget] = None """The performance mode on a serverless job. The performance target determines the level of compute performance or cost-efficiency for the run. This field overrides the performance target defined - on the job-level.""" + on the job level. + + * `STANDARD`: Enables cost-efficient execution of serverless workloads. * + `PERFORMANCE_OPTIMIZED`: Prioritizes fast startup and execution times through rapid scaling and + optimized cluster performance.""" pipeline_params: Optional[PipelineParams] = None """Controls whether the pipeline should perform a full refresh""" @@ -7393,13 +7450,14 @@ class RunState: life_cycle_state: Optional[RunLifeCycleState] = None """A value indicating the run's current lifecycle state. This field is always available in the - response.""" + response. Note: Additional states might be introduced in future releases.""" queue_reason: Optional[str] = None """The reason indicating why the run was queued.""" result_state: Optional[RunResultState] = None - """A value indicating the run's result. This field is only available for terminal lifecycle states.""" + """A value indicating the run's result. This field is only available for terminal lifecycle states. + Note: Additional states might be introduced in future releases.""" state_message: Optional[str] = None """A descriptive message for the current state. This field is unstructured, and its exact format is @@ -7534,7 +7592,7 @@ class RunTask: does not support retries or notifications.""" dashboard_task: Optional[DashboardTask] = None - """The task runs a DashboardTask when the `dashboard_task` field is present.""" + """The task refreshes a dashboard and sends a snapshot to subscribers.""" dbt_task: Optional[DbtTask] = None """The task runs one or more dbt commands when the `dbt_task` field is present. The dbt task @@ -7554,7 +7612,11 @@ class RunTask: effective_performance_target: Optional[PerformanceTarget] = None """The actual performance target used by the serverless run during execution. This can differ from the client-set performance target on the request depending on whether the performance mode is - supported by the job type.""" + supported by the job type. + + * `STANDARD`: Enables cost-efficient execution of serverless workloads. * + `PERFORMANCE_OPTIMIZED`: Prioritizes fast startup and execution times through rapid scaling and + optimized cluster performance.""" email_notifications: Optional[JobEmailNotifications] = None """An optional set of email addresses notified when the task run begins or completes. The default @@ -9014,7 +9076,7 @@ class SubmitTask: does not support retries or notifications.""" dashboard_task: Optional[DashboardTask] = None - """The task runs a DashboardTask when the `dashboard_task` field is present.""" + """The task refreshes a dashboard and sends a snapshot to subscribers.""" dbt_task: Optional[DbtTask] = None """The task runs one or more dbt commands when the `dbt_task` field is present. The dbt task @@ -9054,7 +9116,7 @@ class SubmitTask: """An optional list of libraries to be installed on the cluster. The default value is an empty list.""" - new_cluster: Optional[JobsClusterSpec] = None + new_cluster: Optional[ClusterSpec] = None """If new_cluster, a description of a new cluster that is created for each run.""" notebook_task: Optional[NotebookTask] = None @@ -9256,7 +9318,7 @@ def from_dict(cls, d: Dict[str, Any]) -> SubmitTask: gen_ai_compute_task=_from_dict(d, "gen_ai_compute_task", GenAiComputeTask), health=_from_dict(d, "health", JobsHealthRules), libraries=_repeated_dict(d, "libraries", Library), - new_cluster=_from_dict(d, "new_cluster", JobsClusterSpec), + new_cluster=_from_dict(d, "new_cluster", ClusterSpec), notebook_task=_from_dict(d, "notebook_task", NotebookTask), notification_settings=_from_dict(d, "notification_settings", TaskNotificationSettings), pipeline_task=_from_dict(d, "pipeline_task", PipelineTask), @@ -9283,6 +9345,7 @@ class Subscription: """When true, the subscription will not send emails.""" subscribers: Optional[List[SubscriptionSubscriber]] = None + """The list of subscribers to send the snapshot of the dashboard to.""" def as_dict(self) -> dict: """Serializes the Subscription into a dictionary suitable for use as a JSON request body.""" @@ -9319,8 +9382,12 @@ def from_dict(cls, d: Dict[str, Any]) -> Subscription: @dataclass class SubscriptionSubscriber: destination_id: Optional[str] = None + """A snapshot of the dashboard will be sent to the destination when the `destination_id` field is + present.""" user_name: Optional[str] = None + """A snapshot of the dashboard will be sent to the user's email when the `user_name` field is + present.""" def as_dict(self) -> dict: """Serializes the SubscriptionSubscriber into a dictionary suitable for use as a JSON request body.""" @@ -9419,7 +9486,7 @@ class Task: does not support retries or notifications.""" dashboard_task: Optional[DashboardTask] = None - """The task runs a DashboardTask when the `dashboard_task` field is present.""" + """The task refreshes a dashboard and sends a snapshot to subscribers.""" dbt_task: Optional[DbtTask] = None """The task runs one or more dbt commands when the `dbt_task` field is present. The dbt task @@ -9924,7 +9991,7 @@ class TerminationCodeCode(Enum): invalid configuration. Refer to the state message for further details. * `CLOUD_FAILURE`: The run failed due to a cloud provider issue. Refer to the state message for further details. * `MAX_JOB_QUEUE_SIZE_EXCEEDED`: The run was skipped due to reaching the job level queue size - limit. + limit. * `DISABLED`: The run was never executed because it was disabled explicitly by the user. [Link]: https://kb.databricks.com/en_US/notebooks/too-many-execution-contexts-are-open-right-now""" @@ -9933,6 +10000,7 @@ class TerminationCodeCode(Enum): CLOUD_FAILURE = "CLOUD_FAILURE" CLUSTER_ERROR = "CLUSTER_ERROR" CLUSTER_REQUEST_LIMIT_EXCEEDED = "CLUSTER_REQUEST_LIMIT_EXCEEDED" + DISABLED = "DISABLED" DRIVER_ERROR = "DRIVER_ERROR" FEATURE_DISABLED = "FEATURE_DISABLED" INTERNAL_ERROR = "INTERNAL_ERROR" @@ -9988,7 +10056,7 @@ class TerminationDetails: invalid configuration. Refer to the state message for further details. * `CLOUD_FAILURE`: The run failed due to a cloud provider issue. Refer to the state message for further details. * `MAX_JOB_QUEUE_SIZE_EXCEEDED`: The run was skipped due to reaching the job level queue size - limit. + limit. * `DISABLED`: The run was never executed because it was disabled explicitly by the user. [Link]: https://kb.databricks.com/en_US/notebooks/too-many-execution-contexts-are-open-right-now""" @@ -10485,6 +10553,182 @@ def from_dict(cls, d: Dict[str, Any]) -> WorkspaceStorageInfo: return cls(destination=d.get("destination", None)) +class JobsCancelRunWaiter: + raw_response: CancelRunResponse + """raw_response is the raw response of the CancelRun call.""" + _service: JobsAPI + _run_id: int + + def __init__(self, raw_response: CancelRunResponse, service: JobsAPI, run_id: int): + self._service = service + self.raw_response = raw_response + self._run_id = run_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> Run: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = ( + RunLifeCycleState.TERMINATED, + RunLifeCycleState.SKIPPED, + ) + failure_states = (RunLifeCycleState.INTERNAL_ERROR,) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get_run(run_id=self._run_id) + status = poll.state.life_cycle_state + status_message = f"current status: {status}" + if poll.state: + status_message = poll.state.state_message + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach TERMINATED or SKIPPED, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"run_id={self._run_id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + +class JobsRepairRunWaiter: + raw_response: RepairRunResponse + """raw_response is the raw response of the RepairRun call.""" + _service: JobsAPI + _run_id: int + + def __init__(self, raw_response: RepairRunResponse, service: JobsAPI, run_id: int): + self._service = service + self.raw_response = raw_response + self._run_id = run_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> Run: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = ( + RunLifeCycleState.TERMINATED, + RunLifeCycleState.SKIPPED, + ) + failure_states = (RunLifeCycleState.INTERNAL_ERROR,) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get_run(run_id=self._run_id) + status = poll.state.life_cycle_state + status_message = f"current status: {status}" + if poll.state: + status_message = poll.state.state_message + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach TERMINATED or SKIPPED, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"run_id={self._run_id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + +class JobsRunNowWaiter: + raw_response: RunNowResponse + """raw_response is the raw response of the RunNow call.""" + _service: JobsAPI + _run_id: int + + def __init__(self, raw_response: RunNowResponse, service: JobsAPI, run_id: int): + self._service = service + self.raw_response = raw_response + self._run_id = run_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> Run: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = ( + RunLifeCycleState.TERMINATED, + RunLifeCycleState.SKIPPED, + ) + failure_states = (RunLifeCycleState.INTERNAL_ERROR,) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get_run(run_id=self._run_id) + status = poll.state.life_cycle_state + status_message = f"current status: {status}" + if poll.state: + status_message = poll.state.state_message + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach TERMINATED or SKIPPED, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"run_id={self._run_id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + +class JobsSubmitWaiter: + raw_response: SubmitRunResponse + """raw_response is the raw response of the Submit call.""" + _service: JobsAPI + _run_id: int + + def __init__(self, raw_response: SubmitRunResponse, service: JobsAPI, run_id: int): + self._service = service + self.raw_response = raw_response + self._run_id = run_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> Run: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = ( + RunLifeCycleState.TERMINATED, + RunLifeCycleState.SKIPPED, + ) + failure_states = (RunLifeCycleState.INTERNAL_ERROR,) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get_run(run_id=self._run_id) + status = poll.state.life_cycle_state + status_message = f"current status: {status}" + if poll.state: + status_message = poll.state.state_message + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach TERMINATED or SKIPPED, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"run_id={self._run_id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + class JobsAPI: """The Jobs API allows you to create, edit, and delete jobs. @@ -10530,7 +10774,7 @@ def cancel_all_runs(self, *, all_queued_runs: Optional[bool] = None, job_id: Opt self._api.do("POST", "/api/2.2/jobs/runs/cancel-all", body=body, headers=headers) - def cancel_run(self, run_id: int): + def cancel_run(self, run_id: int) -> JobsCancelRunWaiter: """Cancel a run. Cancels a job run or a task run. The run is canceled asynchronously, so it may still be running when @@ -10550,7 +10794,8 @@ def cancel_run(self, run_id: int): "Content-Type": "application/json", } - self._api.do("POST", "/api/2.2/jobs/runs/cancel", body=body, headers=headers) + op_response = self._api.do("POST", "/api/2.2/jobs/runs/cancel", body=body, headers=headers) + return JobsCancelRunWaiter(service=self, raw_response=CancelRunResponse.from_dict(op_response), run_id=run_id) def create( self, @@ -10628,7 +10873,6 @@ def create( :param job_clusters: List[:class:`JobCluster`] (optional) A list of job cluster specifications that can be shared and reused by tasks of this job. Libraries cannot be declared in a shared job cluster. You must declare dependent libraries in task settings. - If more than 100 job clusters are available, you can paginate through them using :method:jobs/get. :param max_concurrent_runs: int (optional) An optional maximum allowed number of concurrent runs of the job. Set this value if you want to be able to execute multiple runs of the same job concurrently. This is useful for example if you @@ -10648,6 +10892,10 @@ def create( :param performance_target: :class:`PerformanceTarget` (optional) The performance mode on a serverless job. The performance target determines the level of compute performance or cost-efficiency for the run. + + * `STANDARD`: Enables cost-efficient execution of serverless workloads. * `PERFORMANCE_OPTIMIZED`: + Prioritizes fast startup and execution times through rapid scaling and optimized cluster + performance. :param queue: :class:`QueueSettings` (optional) The queue settings of the job. :param run_as: :class:`JobRunAs` (optional) @@ -10663,9 +10911,11 @@ def create( clusters, and are subject to the same limitations as cluster tags. A maximum of 25 tags can be added to the job. :param tasks: List[:class:`Task`] (optional) - A list of task specifications to be executed by this job. If more than 100 tasks are available, you - can paginate through them using :method:jobs/get. Use the `next_page_token` field at the object root - to determine if more results are available. + A list of task specifications to be executed by this job. It supports up to 1000 elements in write + endpoints (:method:jobs/create, :method:jobs/reset, :method:jobs/update, :method:jobs/submit). Read + endpoints return only 100 tasks. If more than 100 tasks are available, you can paginate through them + using :method:jobs/get. Use the `next_page_token` field at the object root to determine if more + results are available. :param timeout_seconds: int (optional) An optional timeout applied to each run of this job. A value of `0` means no timeout. :param trigger: :class:`TriggerSettings` (optional) @@ -10958,8 +11208,8 @@ def list( Retrieves a list of jobs. :param expand_tasks: bool (optional) - Whether to include task and cluster details in the response. Note that in API 2.2, only the first - 100 elements will be shown. Use :method:jobs/get to paginate through all tasks and clusters. + Whether to include task and cluster details in the response. Note that only the first 100 elements + will be shown. Use :method:jobs/get to paginate through all tasks and clusters. :param limit: int (optional) The number of jobs to return. This value must be greater than 0 and less or equal to 100. The default value is 20. @@ -11025,8 +11275,8 @@ def list_runs( If completed_only is `true`, only completed runs are included in the results; otherwise, lists both active and completed runs. This field cannot be `true` when active_only is `true`. :param expand_tasks: bool (optional) - Whether to include task and cluster details in the response. Note that in API 2.2, only the first - 100 elements will be shown. Use :method:jobs/getrun to paginate through all tasks and clusters. + Whether to include task and cluster details in the response. Note that only the first 100 elements + will be shown. Use :method:jobs/getrun to paginate through all tasks and clusters. :param job_id: int (optional) The job for which to list runs. If omitted, the Jobs service lists runs from all jobs. :param limit: int (optional) @@ -11093,6 +11343,7 @@ def repair_run( job_parameters: Optional[Dict[str, str]] = None, latest_repair_id: Optional[int] = None, notebook_params: Optional[Dict[str, str]] = None, + performance_target: Optional[PerformanceTarget] = None, pipeline_params: Optional[PipelineParams] = None, python_named_params: Optional[Dict[str, str]] = None, python_params: Optional[List[str]] = None, @@ -11101,7 +11352,7 @@ def repair_run( rerun_tasks: Optional[List[str]] = None, spark_submit_params: Optional[List[str]] = None, sql_params: Optional[Dict[str, str]] = None, - ) -> RepairRunResponse: + ) -> JobsRepairRunWaiter: """Repair a job run. Re-run one or more tasks. Tasks are re-run as part of the original job run. They use the current job @@ -11143,6 +11394,14 @@ def repair_run( [Task parameter variables]: https://docs.databricks.com/jobs.html#parameter-variables [dbutils.widgets.get]: https://docs.databricks.com/dev-tools/databricks-utils.html + :param performance_target: :class:`PerformanceTarget` (optional) + The performance mode on a serverless job. The performance target determines the level of compute + performance or cost-efficiency for the run. This field overrides the performance target defined on + the job level. + + * `STANDARD`: Enables cost-efficient execution of serverless workloads. * `PERFORMANCE_OPTIMIZED`: + Prioritizes fast startup and execution times through rapid scaling and optimized cluster + performance. :param pipeline_params: :class:`PipelineParams` (optional) Controls whether the pipeline should perform a full refresh :param python_named_params: Dict[str,str] (optional) @@ -11203,6 +11462,8 @@ def repair_run( body["latest_repair_id"] = latest_repair_id if notebook_params is not None: body["notebook_params"] = notebook_params + if performance_target is not None: + body["performance_target"] = performance_target.value if pipeline_params is not None: body["pipeline_params"] = pipeline_params.as_dict() if python_named_params is not None: @@ -11226,8 +11487,8 @@ def repair_run( "Content-Type": "application/json", } - res = self._api.do("POST", "/api/2.2/jobs/runs/repair", body=body, headers=headers) - return RepairRunResponse.from_dict(res) + op_response = self._api.do("POST", "/api/2.2/jobs/runs/repair", body=body, headers=headers) + return JobsRepairRunWaiter(service=self, raw_response=RepairRunResponse.from_dict(op_response), run_id=run_id) def reset(self, job_id: int, new_settings: JobSettings): """Update all job settings (reset). @@ -11273,7 +11534,7 @@ def run_now( queue: Optional[QueueSettings] = None, spark_submit_params: Optional[List[str]] = None, sql_params: Optional[Dict[str, str]] = None, - ) -> RunNowResponse: + ) -> JobsRunNowWaiter: """Trigger a new job run. Run a job and return the `run_id` of the triggered run. @@ -11330,7 +11591,11 @@ def run_now( :param performance_target: :class:`PerformanceTarget` (optional) The performance mode on a serverless job. The performance target determines the level of compute performance or cost-efficiency for the run. This field overrides the performance target defined on - the job-level. + the job level. + + * `STANDARD`: Enables cost-efficient execution of serverless workloads. * `PERFORMANCE_OPTIMIZED`: + Prioritizes fast startup and execution times through rapid scaling and optimized cluster + performance. :param pipeline_params: :class:`PipelineParams` (optional) Controls whether the pipeline should perform a full refresh :param python_named_params: Dict[str,str] (optional) @@ -11409,8 +11674,10 @@ def run_now( "Content-Type": "application/json", } - res = self._api.do("POST", "/api/2.2/jobs/run-now", body=body, headers=headers) - return RunNowResponse.from_dict(res) + op_response = self._api.do("POST", "/api/2.2/jobs/run-now", body=body, headers=headers) + return JobsRunNowWaiter( + service=self, raw_response=RunNowResponse.from_dict(op_response), run_id=op_response["run_id"] + ) def set_permissions( self, job_id: str, *, access_control_list: Optional[List[JobAccessControlRequest]] = None @@ -11454,7 +11721,7 @@ def submit( tasks: Optional[List[SubmitTask]] = None, timeout_seconds: Optional[int] = None, webhook_notifications: Optional[WebhookNotifications] = None, - ) -> SubmitRunResponse: + ) -> JobsSubmitWaiter: """Create and trigger a one-time run. Submit a one-time run. This endpoint allows you to submit a workload directly without creating a job. @@ -11548,8 +11815,10 @@ def submit( "Content-Type": "application/json", } - res = self._api.do("POST", "/api/2.2/jobs/runs/submit", body=body, headers=headers) - return SubmitRunResponse.from_dict(res) + op_response = self._api.do("POST", "/api/2.2/jobs/runs/submit", body=body, headers=headers) + return JobsSubmitWaiter( + service=self, raw_response=SubmitRunResponse.from_dict(op_response), run_id=op_response["run_id"] + ) def update( self, job_id: int, *, fields_to_remove: Optional[List[str]] = None, new_settings: Optional[JobSettings] = None diff --git a/databricks/sdk/ml/v2/ml.py b/databricks/sdk/ml/v2/ml.py index 44f5f28c5..f6b1ae6d2 100755 --- a/databricks/sdk/ml/v2/ml.py +++ b/databricks/sdk/ml/v2/ml.py @@ -3,12 +3,15 @@ from __future__ import annotations import logging +import random +import time from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Iterator, List, Optional -from ...service._internal import (_enum, _from_dict, _repeated_dict, - _repeated_enum) +from ...databricks.errors import OperationFailed +from ...service._internal import (WaitUntilDoneOptions, _enum, _from_dict, + _repeated_dict, _repeated_enum) _LOG = logging.getLogger("databricks.sdk") @@ -7267,6 +7270,48 @@ def update_run( return UpdateRunResponse.from_dict(res) +class ForecastingCreateExperimentWaiter: + raw_response: CreateForecastingExperimentResponse + """raw_response is the raw response of the CreateExperiment call.""" + _service: ForecastingAPI + _experiment_id: str + + def __init__(self, raw_response: CreateForecastingExperimentResponse, service: ForecastingAPI, experiment_id: str): + self._service = service + self.raw_response = raw_response + self._experiment_id = experiment_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> ForecastingExperiment: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (ForecastingExperimentState.SUCCEEDED,) + failure_states = ( + ForecastingExperimentState.FAILED, + ForecastingExperimentState.CANCELLED, + ) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get_experiment(experiment_id=self._experiment_id) + status = poll.state + status_message = f"current status: {status}" + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach SUCCEEDED, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"experiment_id={self._experiment_id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + class ForecastingAPI: """The Forecasting API allows you to create and get serverless forecasting experiments""" @@ -7292,7 +7337,7 @@ def create_experiment( split_column: Optional[str] = None, timeseries_identifier_columns: Optional[List[str]] = None, training_frameworks: Optional[List[str]] = None, - ) -> CreateForecastingExperimentResponse: + ) -> ForecastingCreateExperimentWaiter: """Create a forecasting experiment. Creates a serverless forecasting experiment. Returns the experiment ID. @@ -7385,8 +7430,12 @@ def create_experiment( "Content-Type": "application/json", } - res = self._api.do("POST", "/api/2.0/automl/create-forecasting-experiment", body=body, headers=headers) - return CreateForecastingExperimentResponse.from_dict(res) + op_response = self._api.do("POST", "/api/2.0/automl/create-forecasting-experiment", body=body, headers=headers) + return ForecastingCreateExperimentWaiter( + service=self, + raw_response=CreateForecastingExperimentResponse.from_dict(op_response), + experiment_id=op_response["experiment_id"], + ) def get_experiment(self, experiment_id: str) -> ForecastingExperiment: """Get a forecasting experiment. diff --git a/databricks/sdk/pipelines/v2/pipelines.py b/databricks/sdk/pipelines/v2/pipelines.py index d6e732f3a..1279ba74b 100755 --- a/databricks/sdk/pipelines/v2/pipelines.py +++ b/databricks/sdk/pipelines/v2/pipelines.py @@ -3,12 +3,15 @@ from __future__ import annotations import logging +import random +import time from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Iterator, List, Optional -from ...service._internal import (_enum, _from_dict, _repeated_dict, - _repeated_enum) +from ...databricks.errors import OperationFailed +from ...service._internal import (WaitUntilDoneOptions, _enum, _from_dict, + _repeated_dict, _repeated_enum) _LOG = logging.getLogger("databricks.sdk") @@ -4160,6 +4163,45 @@ def from_dict(cls, d: Dict[str, Any]) -> WorkspaceStorageInfo: return cls(destination=d.get("destination", None)) +class PipelinesStopWaiter: + raw_response: StopPipelineResponse + """raw_response is the raw response of the Stop call.""" + _service: PipelinesAPI + _pipeline_id: str + + def __init__(self, raw_response: StopPipelineResponse, service: PipelinesAPI, pipeline_id: str): + self._service = service + self.raw_response = raw_response + self._pipeline_id = pipeline_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> GetPipelineResponse: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (PipelineState.IDLE,) + failure_states = (PipelineState.FAILED,) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get(pipeline_id=self._pipeline_id) + status = poll.state + status_message = f"current status: {status}" + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach IDLE, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"pipeline_id={self._pipeline_id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + class PipelinesAPI: """The Delta Live Tables API allows you to create, edit, delete, start, and view details about pipelines. @@ -4655,7 +4697,7 @@ def start_update( res = self._api.do("POST", f"/api/2.0/pipelines/{pipeline_id}/updates", body=body, headers=headers) return StartUpdateResponse.from_dict(res) - def stop(self, pipeline_id: str): + def stop(self, pipeline_id: str) -> PipelinesStopWaiter: """Stop a pipeline. Stops the pipeline by canceling the active update. If there is no active update for the pipeline, this @@ -4672,7 +4714,10 @@ def stop(self, pipeline_id: str): "Accept": "application/json", } - self._api.do("POST", f"/api/2.0/pipelines/{pipeline_id}/stop", headers=headers) + op_response = self._api.do("POST", f"/api/2.0/pipelines/{pipeline_id}/stop", headers=headers) + return PipelinesStopWaiter( + service=self, raw_response=StopPipelineResponse.from_dict(op_response), pipeline_id=pipeline_id + ) def update( self, diff --git a/databricks/sdk/provisioning/v2/provisioning.py b/databricks/sdk/provisioning/v2/provisioning.py index 263203849..60e7c5c2f 100755 --- a/databricks/sdk/provisioning/v2/provisioning.py +++ b/databricks/sdk/provisioning/v2/provisioning.py @@ -3,12 +3,15 @@ from __future__ import annotations import logging +import random +import time from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Iterator, List, Optional -from ...service._internal import (_enum, _from_dict, _repeated_dict, - _repeated_enum) +from ...databricks.errors import OperationFailed +from ...service._internal import (WaitUntilDoneOptions, _enum, _from_dict, + _repeated_dict, _repeated_enum) _LOG = logging.getLogger("databricks.sdk") @@ -3151,6 +3154,90 @@ def list(self) -> Iterator[VpcEndpoint]: return [VpcEndpoint.from_dict(v) for v in res] +class WorkspacesCreateWaiter: + raw_response: Workspace + """raw_response is the raw response of the Create call.""" + _service: WorkspacesAPI + _workspace_id: int + + def __init__(self, raw_response: Workspace, service: WorkspacesAPI, workspace_id: int): + self._service = service + self.raw_response = raw_response + self._workspace_id = workspace_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> Workspace: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (WorkspaceStatus.RUNNING,) + failure_states = ( + WorkspaceStatus.BANNED, + WorkspaceStatus.FAILED, + ) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get(workspace_id=self._workspace_id) + status = poll.workspace_status + status_message = f"current status: {status}" + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach RUNNING, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"workspace_id={self._workspace_id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + +class WorkspacesUpdateWaiter: + raw_response: UpdateResponse + """raw_response is the raw response of the Update call.""" + _service: WorkspacesAPI + _workspace_id: int + + def __init__(self, raw_response: UpdateResponse, service: WorkspacesAPI, workspace_id: int): + self._service = service + self.raw_response = raw_response + self._workspace_id = workspace_id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> Workspace: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (WorkspaceStatus.RUNNING,) + failure_states = ( + WorkspaceStatus.BANNED, + WorkspaceStatus.FAILED, + ) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get(workspace_id=self._workspace_id) + status = poll.workspace_status + status_message = f"current status: {status}" + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach RUNNING, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"workspace_id={self._workspace_id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + class WorkspacesAPI: """These APIs manage workspaces for this account. A Databricks workspace is an environment for accessing all of your Databricks assets. The workspace organizes objects (notebooks, libraries, and experiments) into @@ -3182,7 +3269,7 @@ def create( private_access_settings_id: Optional[str] = None, storage_configuration_id: Optional[str] = None, storage_customer_managed_key_id: Optional[str] = None, - ) -> Workspace: + ) -> WorkspacesCreateWaiter: """Create a new workspace. Creates a new workspace. @@ -3327,8 +3414,12 @@ def create( "Content-Type": "application/json", } - res = self._api.do("POST", f"/api/2.0/accounts/{self._api.account_id}/workspaces", body=body, headers=headers) - return Workspace.from_dict(res) + op_response = self._api.do( + "POST", f"/api/2.0/accounts/{self._api.account_id}/workspaces", body=body, headers=headers + ) + return WorkspacesCreateWaiter( + service=self, raw_response=Workspace.from_dict(op_response), workspace_id=op_response["workspace_id"] + ) def delete(self, workspace_id: int): """Delete a workspace. @@ -3414,7 +3505,7 @@ def update( private_access_settings_id: Optional[str] = None, storage_configuration_id: Optional[str] = None, storage_customer_managed_key_id: Optional[str] = None, - ): + ) -> WorkspacesUpdateWaiter: """Update workspace configuration. Updates a workspace configuration for either a running workspace or a failed workspace. The elements @@ -3569,6 +3660,9 @@ def update( "Content-Type": "application/json", } - self._api.do( + op_response = self._api.do( "PATCH", f"/api/2.0/accounts/{self._api.account_id}/workspaces/{workspace_id}", body=body, headers=headers ) + return WorkspacesUpdateWaiter( + service=self, raw_response=UpdateResponse.from_dict(op_response), workspace_id=workspace_id + ) diff --git a/databricks/sdk/service/_internal.py b/databricks/sdk/service/_internal.py index d1c124832..4cf48a72a 100644 --- a/databricks/sdk/service/_internal.py +++ b/databricks/sdk/service/_internal.py @@ -49,6 +49,10 @@ def _escape_multi_segment_path_parameter(param: str) -> str: ReturnType = TypeVar("ReturnType") +class WaitUntilDoneOptions: + timeout: datetime.timedelta = datetime.timedelta(minutes=20) + + class Wait(Generic[ReturnType]): def __init__(self, waiter: Callable, response: Any = None, **kwargs) -> None: diff --git a/databricks/sdk/serving/v2/serving.py b/databricks/sdk/serving/v2/serving.py index 4ca074a96..e4cdc07ee 100755 --- a/databricks/sdk/serving/v2/serving.py +++ b/databricks/sdk/serving/v2/serving.py @@ -3,11 +3,15 @@ from __future__ import annotations import logging +import random +import time from dataclasses import dataclass from enum import Enum from typing import Any, BinaryIO, Dict, Iterator, List, Optional -from ...service._internal import _enum, _from_dict, _repeated_dict +from ...databricks.errors import OperationFailed +from ...service._internal import (WaitUntilDoneOptions, _enum, _from_dict, + _repeated_dict) _LOG = logging.getLogger("databricks.sdk") @@ -4148,6 +4152,90 @@ def from_dict(cls, d: Dict[str, Any]) -> V1ResponseChoiceElement: ) +class ServingEndpointsCreateWaiter: + raw_response: ServingEndpointDetailed + """raw_response is the raw response of the Create call.""" + _service: ServingEndpointsAPI + _name: str + + def __init__(self, raw_response: ServingEndpointDetailed, service: ServingEndpointsAPI, name: str): + self._service = service + self.raw_response = raw_response + self._name = name + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> ServingEndpointDetailed: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (EndpointStateConfigUpdate.NOT_UPDATING,) + failure_states = ( + EndpointStateConfigUpdate.UPDATE_FAILED, + EndpointStateConfigUpdate.UPDATE_CANCELED, + ) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get(name=self._name) + status = poll.state.config_update + status_message = f"current status: {status}" + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach NOT_UPDATING, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"name={self._name}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + +class ServingEndpointsUpdateConfigWaiter: + raw_response: ServingEndpointDetailed + """raw_response is the raw response of the UpdateConfig call.""" + _service: ServingEndpointsAPI + _name: str + + def __init__(self, raw_response: ServingEndpointDetailed, service: ServingEndpointsAPI, name: str): + self._service = service + self.raw_response = raw_response + self._name = name + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> ServingEndpointDetailed: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (EndpointStateConfigUpdate.NOT_UPDATING,) + failure_states = ( + EndpointStateConfigUpdate.UPDATE_FAILED, + EndpointStateConfigUpdate.UPDATE_CANCELED, + ) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get(name=self._name) + status = poll.state.config_update + status_message = f"current status: {status}" + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach NOT_UPDATING, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"name={self._name}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + class ServingEndpointsAPI: """The Serving Endpoints API allows you to create, update, and delete model serving endpoints. @@ -4195,7 +4283,7 @@ def create( rate_limits: Optional[List[RateLimit]] = None, route_optimized: Optional[bool] = None, tags: Optional[List[EndpointTag]] = None, - ) -> ServingEndpointDetailed: + ) -> ServingEndpointsCreateWaiter: """Create a new serving endpoint. :param name: str @@ -4241,8 +4329,10 @@ def create( "Content-Type": "application/json", } - res = self._api.do("POST", "/api/2.0/serving-endpoints", body=body, headers=headers) - return ServingEndpointDetailed.from_dict(res) + op_response = self._api.do("POST", "/api/2.0/serving-endpoints", body=body, headers=headers) + return ServingEndpointsCreateWaiter( + service=self, raw_response=ServingEndpointDetailed.from_dict(op_response), name=op_response["name"] + ) def delete(self, name: str): """Delete a serving endpoint. @@ -4687,7 +4777,7 @@ def update_config( served_entities: Optional[List[ServedEntityInput]] = None, served_models: Optional[List[ServedModelInput]] = None, traffic_config: Optional[TrafficConfig] = None, - ) -> ServingEndpointDetailed: + ) -> ServingEndpointsUpdateConfigWaiter: """Update config of a serving endpoint. Updates any combination of the serving endpoint's served entities, the compute configuration of those @@ -4727,8 +4817,10 @@ def update_config( "Content-Type": "application/json", } - res = self._api.do("PUT", f"/api/2.0/serving-endpoints/{name}/config", body=body, headers=headers) - return ServingEndpointDetailed.from_dict(res) + op_response = self._api.do("PUT", f"/api/2.0/serving-endpoints/{name}/config", body=body, headers=headers) + return ServingEndpointsUpdateConfigWaiter( + service=self, raw_response=ServingEndpointDetailed.from_dict(op_response), name=op_response["name"] + ) def update_permissions( self, diff --git a/databricks/sdk/settings/v2/client.py b/databricks/sdk/settings/v2/client.py index 8c1642bb8..0a39b656b 100755 --- a/databricks/sdk/settings/v2/client.py +++ b/databricks/sdk/settings/v2/client.py @@ -13,7 +13,9 @@ CredentialsManagerAPI, CspEnablementAccountAPI, DefaultNamespaceAPI, DisableLegacyAccessAPI, DisableLegacyDbfsAPI, DisableLegacyFeaturesAPI, - EnableIpAccessListsAPI, EnableResultsDownloadingAPI, + EnableExportNotebookAPI, EnableIpAccessListsAPI, + EnableNotebookTableClipboardAPI, + EnableResultsDownloadingAPI, EnhancedSecurityMonitoringAPI, EsmEnablementAccountAPI, IpAccessListsAPI, NetworkConnectivityAPI, NotificationDestinationsAPI, PersonalComputeAPI, @@ -867,6 +869,73 @@ def __init__( super().__init__(client.ApiClient(config)) +class EnableExportNotebookClient(EnableExportNotebookAPI): + """ + Controls whether users can export notebooks and files from the Workspace UI. By default, this + setting is enabled. + """ + + def __init__( + self, + *, + host: Optional[str] = None, + account_id: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, + token: Optional[str] = None, + profile: Optional[str] = None, + config_file: Optional[str] = None, + azure_workspace_resource_id: Optional[str] = None, + azure_client_secret: Optional[str] = None, + azure_client_id: Optional[str] = None, + azure_tenant_id: Optional[str] = None, + azure_environment: Optional[str] = None, + auth_type: Optional[str] = None, + cluster_id: Optional[str] = None, + google_credentials: Optional[str] = None, + google_service_account: Optional[str] = None, + debug_truncate_bytes: Optional[int] = None, + debug_headers: Optional[bool] = None, + product="unknown", + product_version="0.0.0", + credentials_strategy: Optional[CredentialsStrategy] = None, + credentials_provider: Optional[CredentialsStrategy] = None, + config: Optional[client.Config] = None, + ): + + if not config: + config = client.Config( + host=host, + account_id=account_id, + username=username, + password=password, + client_id=client_id, + client_secret=client_secret, + token=token, + profile=profile, + config_file=config_file, + azure_workspace_resource_id=azure_workspace_resource_id, + azure_client_secret=azure_client_secret, + azure_client_id=azure_client_id, + azure_tenant_id=azure_tenant_id, + azure_environment=azure_environment, + auth_type=auth_type, + cluster_id=cluster_id, + google_credentials=google_credentials, + google_service_account=google_service_account, + credentials_strategy=credentials_strategy, + credentials_provider=credentials_provider, + debug_truncate_bytes=debug_truncate_bytes, + debug_headers=debug_headers, + product=product, + product_version=product_version, + ) + self._config = config.copy() + super().__init__(client.ApiClient(config)) + + class EnableIpAccessListsClient(EnableIpAccessListsAPI): """ Controls the enforcement of IP access lists for accessing the account console. Allowing you to @@ -934,10 +1003,76 @@ def __init__( super().__init__(client.ApiClient(config)) +class EnableNotebookTableClipboardClient(EnableNotebookTableClipboardAPI): + """ + Controls whether users can copy tabular data to the clipboard via the UI. By default, this + setting is enabled. + """ + + def __init__( + self, + *, + host: Optional[str] = None, + account_id: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, + token: Optional[str] = None, + profile: Optional[str] = None, + config_file: Optional[str] = None, + azure_workspace_resource_id: Optional[str] = None, + azure_client_secret: Optional[str] = None, + azure_client_id: Optional[str] = None, + azure_tenant_id: Optional[str] = None, + azure_environment: Optional[str] = None, + auth_type: Optional[str] = None, + cluster_id: Optional[str] = None, + google_credentials: Optional[str] = None, + google_service_account: Optional[str] = None, + debug_truncate_bytes: Optional[int] = None, + debug_headers: Optional[bool] = None, + product="unknown", + product_version="0.0.0", + credentials_strategy: Optional[CredentialsStrategy] = None, + credentials_provider: Optional[CredentialsStrategy] = None, + config: Optional[client.Config] = None, + ): + + if not config: + config = client.Config( + host=host, + account_id=account_id, + username=username, + password=password, + client_id=client_id, + client_secret=client_secret, + token=token, + profile=profile, + config_file=config_file, + azure_workspace_resource_id=azure_workspace_resource_id, + azure_client_secret=azure_client_secret, + azure_client_id=azure_client_id, + azure_tenant_id=azure_tenant_id, + azure_environment=azure_environment, + auth_type=auth_type, + cluster_id=cluster_id, + google_credentials=google_credentials, + google_service_account=google_service_account, + credentials_strategy=credentials_strategy, + credentials_provider=credentials_provider, + debug_truncate_bytes=debug_truncate_bytes, + debug_headers=debug_headers, + product=product, + product_version=product_version, + ) + self._config = config.copy() + super().__init__(client.ApiClient(config)) + + class EnableResultsDownloadingClient(EnableResultsDownloadingAPI): """ - The Enable Results Downloading API controls the workspace level conf for the enablement of - downloading results. + Controls whether users can download notebook results. By default, this setting is enabled. """ def __init__( diff --git a/databricks/sdk/settings/v2/settings.py b/databricks/sdk/settings/v2/settings.py index 67ae16b9b..a1567fda4 100755 --- a/databricks/sdk/settings/v2/settings.py +++ b/databricks/sdk/settings/v2/settings.py @@ -635,6 +635,7 @@ class ComplianceStandard(Enum): IRAP_PROTECTED = "IRAP_PROTECTED" ISMAP = "ISMAP" ITAR_EAR = "ITAR_EAR" + K_FSI = "K_FSI" NONE = "NONE" PCI_DSS = "PCI_DSS" @@ -1705,6 +1706,74 @@ def from_dict(cls, d: Dict[str, Any]) -> Empty: return cls() +@dataclass +class EnableExportNotebook: + boolean_val: Optional[BooleanMessage] = None + + setting_name: Optional[str] = None + """Name of the corresponding setting. This field is populated in the response, but it will not be + respected even if it's set in the request body. The setting name in the path parameter will be + respected instead. Setting name is required to be 'default' if the setting only has one instance + per workspace.""" + + def as_dict(self) -> dict: + """Serializes the EnableExportNotebook into a dictionary suitable for use as a JSON request body.""" + body = {} + if self.boolean_val: + body["boolean_val"] = self.boolean_val.as_dict() + if self.setting_name is not None: + body["setting_name"] = self.setting_name + return body + + def as_shallow_dict(self) -> dict: + """Serializes the EnableExportNotebook into a shallow dictionary of its immediate attributes.""" + body = {} + if self.boolean_val: + body["boolean_val"] = self.boolean_val + if self.setting_name is not None: + body["setting_name"] = self.setting_name + return body + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> EnableExportNotebook: + """Deserializes the EnableExportNotebook from a dictionary.""" + return cls(boolean_val=_from_dict(d, "boolean_val", BooleanMessage), setting_name=d.get("setting_name", None)) + + +@dataclass +class EnableNotebookTableClipboard: + boolean_val: Optional[BooleanMessage] = None + + setting_name: Optional[str] = None + """Name of the corresponding setting. This field is populated in the response, but it will not be + respected even if it's set in the request body. The setting name in the path parameter will be + respected instead. Setting name is required to be 'default' if the setting only has one instance + per workspace.""" + + def as_dict(self) -> dict: + """Serializes the EnableNotebookTableClipboard into a dictionary suitable for use as a JSON request body.""" + body = {} + if self.boolean_val: + body["boolean_val"] = self.boolean_val.as_dict() + if self.setting_name is not None: + body["setting_name"] = self.setting_name + return body + + def as_shallow_dict(self) -> dict: + """Serializes the EnableNotebookTableClipboard into a shallow dictionary of its immediate attributes.""" + body = {} + if self.boolean_val: + body["boolean_val"] = self.boolean_val + if self.setting_name is not None: + body["setting_name"] = self.setting_name + return body + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> EnableNotebookTableClipboard: + """Deserializes the EnableNotebookTableClipboard from a dictionary.""" + return cls(boolean_val=_from_dict(d, "boolean_val", BooleanMessage), setting_name=d.get("setting_name", None)) + + @dataclass class EnableResultsDownloading: boolean_val: Optional[BooleanMessage] = None @@ -4415,6 +4484,110 @@ def from_dict(cls, d: Dict[str, Any]) -> UpdateDisableLegacyFeaturesRequest: ) +@dataclass +class UpdateEnableExportNotebookRequest: + """Details required to update a setting.""" + + allow_missing: bool + """This should always be set to true for Settings API. Added for AIP compliance.""" + + setting: EnableExportNotebook + + field_mask: str + """The field mask must be a single string, with multiple fields separated by commas (no spaces). + The field path is relative to the resource object, using a dot (`.`) to navigate sub-fields + (e.g., `author.given_name`). Specification of elements in sequence or map fields is not allowed, + as only the entire collection field can be specified. Field names must exactly match the + resource field names. + + A field mask of `*` indicates full replacement. It’s recommended to always explicitly list the + fields being updated and avoid using `*` wildcards, as it can lead to unintended results if the + API changes in the future.""" + + def as_dict(self) -> dict: + """Serializes the UpdateEnableExportNotebookRequest into a dictionary suitable for use as a JSON request body.""" + body = {} + if self.allow_missing is not None: + body["allow_missing"] = self.allow_missing + if self.field_mask is not None: + body["field_mask"] = self.field_mask + if self.setting: + body["setting"] = self.setting.as_dict() + return body + + def as_shallow_dict(self) -> dict: + """Serializes the UpdateEnableExportNotebookRequest into a shallow dictionary of its immediate attributes.""" + body = {} + if self.allow_missing is not None: + body["allow_missing"] = self.allow_missing + if self.field_mask is not None: + body["field_mask"] = self.field_mask + if self.setting: + body["setting"] = self.setting + return body + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> UpdateEnableExportNotebookRequest: + """Deserializes the UpdateEnableExportNotebookRequest from a dictionary.""" + return cls( + allow_missing=d.get("allow_missing", None), + field_mask=d.get("field_mask", None), + setting=_from_dict(d, "setting", EnableExportNotebook), + ) + + +@dataclass +class UpdateEnableNotebookTableClipboardRequest: + """Details required to update a setting.""" + + allow_missing: bool + """This should always be set to true for Settings API. Added for AIP compliance.""" + + setting: EnableNotebookTableClipboard + + field_mask: str + """The field mask must be a single string, with multiple fields separated by commas (no spaces). + The field path is relative to the resource object, using a dot (`.`) to navigate sub-fields + (e.g., `author.given_name`). Specification of elements in sequence or map fields is not allowed, + as only the entire collection field can be specified. Field names must exactly match the + resource field names. + + A field mask of `*` indicates full replacement. It’s recommended to always explicitly list the + fields being updated and avoid using `*` wildcards, as it can lead to unintended results if the + API changes in the future.""" + + def as_dict(self) -> dict: + """Serializes the UpdateEnableNotebookTableClipboardRequest into a dictionary suitable for use as a JSON request body.""" + body = {} + if self.allow_missing is not None: + body["allow_missing"] = self.allow_missing + if self.field_mask is not None: + body["field_mask"] = self.field_mask + if self.setting: + body["setting"] = self.setting.as_dict() + return body + + def as_shallow_dict(self) -> dict: + """Serializes the UpdateEnableNotebookTableClipboardRequest into a shallow dictionary of its immediate attributes.""" + body = {} + if self.allow_missing is not None: + body["allow_missing"] = self.allow_missing + if self.field_mask is not None: + body["field_mask"] = self.field_mask + if self.setting: + body["setting"] = self.setting + return body + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> UpdateEnableNotebookTableClipboardRequest: + """Deserializes the UpdateEnableNotebookTableClipboardRequest from a dictionary.""" + return cls( + allow_missing=d.get("allow_missing", None), + field_mask=d.get("field_mask", None), + setting=_from_dict(d, "setting", EnableNotebookTableClipboard), + ) + + @dataclass class UpdateEnableResultsDownloadingRequest: """Details required to update a setting.""" @@ -5987,6 +6160,70 @@ def update(self, allow_missing: bool, setting: DisableLegacyFeatures, field_mask return DisableLegacyFeatures.from_dict(res) +class EnableExportNotebookAPI: + """Controls whether users can export notebooks and files from the Workspace UI. By default, this setting is + enabled.""" + + def __init__(self, api_client): + self._api = api_client + + def get_enable_export_notebook(self) -> EnableExportNotebook: + """Get the Notebook and File exporting setting. + + Gets the Notebook and File exporting setting. + + :returns: :class:`EnableExportNotebook` + """ + + headers = { + "Accept": "application/json", + } + + res = self._api.do("GET", "/api/2.0/settings/types/enable-export-notebook/names/default", headers=headers) + return EnableExportNotebook.from_dict(res) + + def patch_enable_export_notebook( + self, allow_missing: bool, setting: EnableExportNotebook, field_mask: str + ) -> EnableExportNotebook: + """Update the Notebook and File exporting setting. + + Updates the Notebook and File exporting setting. The model follows eventual consistency, which means + the get after the update operation might receive stale values for some time. + + :param allow_missing: bool + This should always be set to true for Settings API. Added for AIP compliance. + :param setting: :class:`EnableExportNotebook` + :param field_mask: str + The field mask must be a single string, with multiple fields separated by commas (no spaces). The + field path is relative to the resource object, using a dot (`.`) to navigate sub-fields (e.g., + `author.given_name`). Specification of elements in sequence or map fields is not allowed, as only + the entire collection field can be specified. Field names must exactly match the resource field + names. + + A field mask of `*` indicates full replacement. It’s recommended to always explicitly list the + fields being updated and avoid using `*` wildcards, as it can lead to unintended results if the API + changes in the future. + + :returns: :class:`EnableExportNotebook` + """ + body = {} + if allow_missing is not None: + body["allow_missing"] = allow_missing + if field_mask is not None: + body["field_mask"] = field_mask + if setting is not None: + body["setting"] = setting.as_dict() + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } + + res = self._api.do( + "PATCH", "/api/2.0/settings/types/enable-export-notebook/names/default", body=body, headers=headers + ) + return EnableExportNotebook.from_dict(res) + + class EnableIpAccessListsAPI: """Controls the enforcement of IP access lists for accessing the account console. Allowing you to enable or disable restricted access based on IP addresses.""" @@ -6096,17 +6333,82 @@ def update(self, allow_missing: bool, setting: AccountIpAccessEnable, field_mask return AccountIpAccessEnable.from_dict(res) +class EnableNotebookTableClipboardAPI: + """Controls whether users can copy tabular data to the clipboard via the UI. By default, this setting is + enabled.""" + + def __init__(self, api_client): + self._api = api_client + + def get_enable_notebook_table_clipboard(self) -> EnableNotebookTableClipboard: + """Get the Results Table Clipboard features setting. + + Gets the Results Table Clipboard features setting. + + :returns: :class:`EnableNotebookTableClipboard` + """ + + headers = { + "Accept": "application/json", + } + + res = self._api.do( + "GET", "/api/2.0/settings/types/enable-notebook-table-clipboard/names/default", headers=headers + ) + return EnableNotebookTableClipboard.from_dict(res) + + def patch_enable_notebook_table_clipboard( + self, allow_missing: bool, setting: EnableNotebookTableClipboard, field_mask: str + ) -> EnableNotebookTableClipboard: + """Update the Results Table Clipboard features setting. + + Updates the Results Table Clipboard features setting. The model follows eventual consistency, which + means the get after the update operation might receive stale values for some time. + + :param allow_missing: bool + This should always be set to true for Settings API. Added for AIP compliance. + :param setting: :class:`EnableNotebookTableClipboard` + :param field_mask: str + The field mask must be a single string, with multiple fields separated by commas (no spaces). The + field path is relative to the resource object, using a dot (`.`) to navigate sub-fields (e.g., + `author.given_name`). Specification of elements in sequence or map fields is not allowed, as only + the entire collection field can be specified. Field names must exactly match the resource field + names. + + A field mask of `*` indicates full replacement. It’s recommended to always explicitly list the + fields being updated and avoid using `*` wildcards, as it can lead to unintended results if the API + changes in the future. + + :returns: :class:`EnableNotebookTableClipboard` + """ + body = {} + if allow_missing is not None: + body["allow_missing"] = allow_missing + if field_mask is not None: + body["field_mask"] = field_mask + if setting is not None: + body["setting"] = setting.as_dict() + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } + + res = self._api.do( + "PATCH", "/api/2.0/settings/types/enable-notebook-table-clipboard/names/default", body=body, headers=headers + ) + return EnableNotebookTableClipboard.from_dict(res) + + class EnableResultsDownloadingAPI: - """The Enable Results Downloading API controls the workspace level conf for the enablement of downloading - results.""" + """Controls whether users can download notebook results. By default, this setting is enabled.""" def __init__(self, api_client): self._api = api_client def get_enable_results_downloading(self) -> EnableResultsDownloading: - """Get the Enable Results Downloading setting. + """Get the Notebook results download setting. - Gets the Enable Results Downloading setting. + Gets the Notebook results download setting. :returns: :class:`EnableResultsDownloading` """ @@ -6121,10 +6423,10 @@ def get_enable_results_downloading(self) -> EnableResultsDownloading: def patch_enable_results_downloading( self, allow_missing: bool, setting: EnableResultsDownloading, field_mask: str ) -> EnableResultsDownloading: - """Update the Enable Results Downloading setting. + """Update the Notebook results download setting. - Updates the Enable Results Downloading setting. The model follows eventual consistency, which means - the get after the update operation might receive stale values for some time. + Updates the Notebook results download setting. The model follows eventual consistency, which means the + get after the update operation might receive stale values for some time. :param allow_missing: bool This should always be set to true for Settings API. Added for AIP compliance. diff --git a/databricks/sdk/sql/v2/sql.py b/databricks/sdk/sql/v2/sql.py index 89bfffd79..05f849b21 100755 --- a/databricks/sdk/sql/v2/sql.py +++ b/databricks/sdk/sql/v2/sql.py @@ -3,12 +3,15 @@ from __future__ import annotations import logging +import random +import time from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Iterator, List, Optional -from ...service._internal import (_enum, _from_dict, _repeated_dict, - _repeated_enum) +from ...databricks.errors import OperationFailed +from ...service._internal import (WaitUntilDoneOptions, _enum, _from_dict, + _repeated_dict, _repeated_enum) _LOG = logging.getLogger("databricks.sdk") @@ -9735,6 +9738,175 @@ def get_statement_result_chunk_n(self, statement_id: str, chunk_index: int) -> R return ResultData.from_dict(res) +class WarehousesCreateWaiter: + raw_response: CreateWarehouseResponse + """raw_response is the raw response of the Create call.""" + _service: WarehousesAPI + _id: str + + def __init__(self, raw_response: CreateWarehouseResponse, service: WarehousesAPI, id: str): + self._service = service + self.raw_response = raw_response + self._id = id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> GetWarehouseResponse: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (State.RUNNING,) + failure_states = ( + State.STOPPED, + State.DELETED, + ) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get(id=self._id) + status = poll.state + status_message = f"current status: {status}" + if poll.health: + status_message = poll.health.summary + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach RUNNING, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"id={self._id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + +class WarehousesEditWaiter: + raw_response: EditWarehouseResponse + """raw_response is the raw response of the Edit call.""" + _service: WarehousesAPI + _id: str + + def __init__(self, raw_response: EditWarehouseResponse, service: WarehousesAPI, id: str): + self._service = service + self.raw_response = raw_response + self._id = id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> GetWarehouseResponse: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (State.RUNNING,) + failure_states = ( + State.STOPPED, + State.DELETED, + ) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get(id=self._id) + status = poll.state + status_message = f"current status: {status}" + if poll.health: + status_message = poll.health.summary + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach RUNNING, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"id={self._id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + +class WarehousesStartWaiter: + raw_response: StartWarehouseResponse + """raw_response is the raw response of the Start call.""" + _service: WarehousesAPI + _id: str + + def __init__(self, raw_response: StartWarehouseResponse, service: WarehousesAPI, id: str): + self._service = service + self.raw_response = raw_response + self._id = id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> GetWarehouseResponse: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (State.RUNNING,) + failure_states = ( + State.STOPPED, + State.DELETED, + ) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get(id=self._id) + status = poll.state + status_message = f"current status: {status}" + if poll.health: + status_message = poll.health.summary + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach RUNNING, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"id={self._id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + +class WarehousesStopWaiter: + raw_response: StopWarehouseResponse + """raw_response is the raw response of the Stop call.""" + _service: WarehousesAPI + _id: str + + def __init__(self, raw_response: StopWarehouseResponse, service: WarehousesAPI, id: str): + self._service = service + self.raw_response = raw_response + self._id = id + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> GetWarehouseResponse: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (State.STOPPED,) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get(id=self._id) + status = poll.state + status_message = f"current status: {status}" + if poll.health: + status_message = poll.health.summary + if status in target_states: + return poll + prefix = f"id={self._id}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + class WarehousesAPI: """A SQL warehouse is a compute resource that lets you run SQL commands on data objects within Databricks SQL. Compute resources are infrastructure resources that provide processing capabilities in the cloud.""" @@ -9758,7 +9930,7 @@ def create( spot_instance_policy: Optional[SpotInstancePolicy] = None, tags: Optional[EndpointTags] = None, warehouse_type: Optional[CreateWarehouseRequestWarehouseType] = None, - ) -> CreateWarehouseResponse: + ) -> WarehousesCreateWaiter: """Create a warehouse. Creates a new SQL warehouse. @@ -9855,8 +10027,10 @@ def create( "Content-Type": "application/json", } - res = self._api.do("POST", "/api/2.0/sql/warehouses", body=body, headers=headers) - return CreateWarehouseResponse.from_dict(res) + op_response = self._api.do("POST", "/api/2.0/sql/warehouses", body=body, headers=headers) + return WarehousesCreateWaiter( + service=self, raw_response=CreateWarehouseResponse.from_dict(op_response), id=op_response["id"] + ) def delete(self, id: str): """Delete a warehouse. @@ -9892,7 +10066,7 @@ def edit( spot_instance_policy: Optional[SpotInstancePolicy] = None, tags: Optional[EndpointTags] = None, warehouse_type: Optional[EditWarehouseRequestWarehouseType] = None, - ): + ) -> WarehousesEditWaiter: """Update a warehouse. Updates the configuration for a SQL warehouse. @@ -9990,7 +10164,8 @@ def edit( "Content-Type": "application/json", } - self._api.do("POST", f"/api/2.0/sql/warehouses/{id}/edit", body=body, headers=headers) + op_response = self._api.do("POST", f"/api/2.0/sql/warehouses/{id}/edit", body=body, headers=headers) + return WarehousesEditWaiter(service=self, raw_response=EditWarehouseResponse.from_dict(op_response), id=id) def get(self, id: str) -> GetWarehouseResponse: """Get warehouse info. @@ -10178,7 +10353,7 @@ def set_workspace_warehouse_config( self._api.do("PUT", "/api/2.0/sql/config/warehouses", body=body, headers=headers) - def start(self, id: str): + def start(self, id: str) -> WarehousesStartWaiter: """Start a warehouse. Starts a SQL warehouse. @@ -10195,9 +10370,10 @@ def start(self, id: str): "Accept": "application/json", } - self._api.do("POST", f"/api/2.0/sql/warehouses/{id}/start", headers=headers) + op_response = self._api.do("POST", f"/api/2.0/sql/warehouses/{id}/start", headers=headers) + return WarehousesStartWaiter(service=self, raw_response=StartWarehouseResponse.from_dict(op_response), id=id) - def stop(self, id: str): + def stop(self, id: str) -> WarehousesStopWaiter: """Stop a warehouse. Stops a SQL warehouse. @@ -10214,7 +10390,8 @@ def stop(self, id: str): "Accept": "application/json", } - self._api.do("POST", f"/api/2.0/sql/warehouses/{id}/stop", headers=headers) + op_response = self._api.do("POST", f"/api/2.0/sql/warehouses/{id}/stop", headers=headers) + return WarehousesStopWaiter(service=self, raw_response=StopWarehouseResponse.from_dict(op_response), id=id) def update_permissions( self, warehouse_id: str, *, access_control_list: Optional[List[WarehouseAccessControlRequest]] = None diff --git a/databricks/sdk/vectorsearch/v2/client.py b/databricks/sdk/vectorsearch/v2/client.py index d78b24681..a4e48071b 100755 --- a/databricks/sdk/vectorsearch/v2/client.py +++ b/databricks/sdk/vectorsearch/v2/client.py @@ -82,9 +82,9 @@ class VectorSearchIndexesClient(VectorSearchIndexesAPI): **Index**: An efficient representation of your embedding vectors that supports real-time and efficient approximate nearest neighbor (ANN) search queries. - There are 2 types of Vector Search indexes: * **Delta Sync Index**: An index that automatically + There are 2 types of Vector Search indexes: - **Delta Sync Index**: An index that automatically syncs with a source Delta Table, automatically and incrementally updating the index as the - underlying data in the Delta Table changes. * **Direct Vector Access Index**: An index that + underlying data in the Delta Table changes. - **Direct Vector Access Index**: An index that supports direct read and write of vectors and metadata through our REST and SDK APIs. With this model, the user manages index updates. """ diff --git a/databricks/sdk/vectorsearch/v2/vectorsearch.py b/databricks/sdk/vectorsearch/v2/vectorsearch.py index a0145cbf7..836f020f3 100755 --- a/databricks/sdk/vectorsearch/v2/vectorsearch.py +++ b/databricks/sdk/vectorsearch/v2/vectorsearch.py @@ -3,11 +3,15 @@ from __future__ import annotations import logging +import random +import time from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Iterator, List, Optional -from ...service._internal import _enum, _from_dict, _repeated_dict +from ...databricks.errors import OperationFailed +from ...service._internal import (WaitUntilDoneOptions, _enum, _from_dict, + _repeated_dict) _LOG = logging.getLogger("databricks.sdk") @@ -42,14 +46,19 @@ def from_dict(cls, d: Dict[str, Any]) -> ColumnInfo: @dataclass class CreateEndpoint: name: str - """Name of endpoint""" + """Name of the vector search endpoint""" endpoint_type: EndpointType - """Type of endpoint.""" + """Type of endpoint""" + + budget_policy_id: Optional[str] = None + """The budget policy id to be applied""" def as_dict(self) -> dict: """Serializes the CreateEndpoint into a dictionary suitable for use as a JSON request body.""" body = {} + if self.budget_policy_id is not None: + body["budget_policy_id"] = self.budget_policy_id if self.endpoint_type is not None: body["endpoint_type"] = self.endpoint_type.value if self.name is not None: @@ -59,6 +68,8 @@ def as_dict(self) -> dict: def as_shallow_dict(self) -> dict: """Serializes the CreateEndpoint into a shallow dictionary of its immediate attributes.""" body = {} + if self.budget_policy_id is not None: + body["budget_policy_id"] = self.budget_policy_id if self.endpoint_type is not None: body["endpoint_type"] = self.endpoint_type if self.name is not None: @@ -68,7 +79,11 @@ def as_shallow_dict(self) -> dict: @classmethod def from_dict(cls, d: Dict[str, Any]) -> CreateEndpoint: """Deserializes the CreateEndpoint from a dictionary.""" - return cls(endpoint_type=_enum(d, "endpoint_type", EndpointType), name=d.get("name", None)) + return cls( + budget_policy_id=d.get("budget_policy_id", None), + endpoint_type=_enum(d, "endpoint_type", EndpointType), + name=d.get("name", None), + ) @dataclass @@ -83,12 +98,11 @@ class CreateVectorIndexRequest: """Primary key of the index""" index_type: VectorIndexType - """There are 2 types of Vector Search indexes: - - - `DELTA_SYNC`: An index that automatically syncs with a source Delta Table, automatically and - incrementally updating the index as the underlying data in the Delta Table changes. - - `DIRECT_ACCESS`: An index that supports direct read and write of vectors and metadata through - our REST and SDK APIs. With this model, the user manages index updates.""" + """There are 2 types of Vector Search indexes: - `DELTA_SYNC`: An index that automatically syncs + with a source Delta Table, automatically and incrementally updating the index as the underlying + data in the Delta Table changes. - `DIRECT_ACCESS`: An index that supports direct read and write + of vectors and metadata through our REST and SDK APIs. With this model, the user manages index + updates.""" delta_sync_index_spec: Optional[DeltaSyncVectorIndexSpecRequest] = None """Specification for Delta Sync Index. Required if `index_type` is `DELTA_SYNC`.""" @@ -144,33 +158,39 @@ def from_dict(cls, d: Dict[str, Any]) -> CreateVectorIndexRequest: @dataclass -class CreateVectorIndexResponse: - vector_index: Optional[VectorIndex] = None +class CustomTag: + key: str + """Key field for a vector search endpoint tag.""" + + value: Optional[str] = None + """[Optional] Value field for a vector search endpoint tag.""" def as_dict(self) -> dict: - """Serializes the CreateVectorIndexResponse into a dictionary suitable for use as a JSON request body.""" + """Serializes the CustomTag into a dictionary suitable for use as a JSON request body.""" body = {} - if self.vector_index: - body["vector_index"] = self.vector_index.as_dict() + if self.key is not None: + body["key"] = self.key + if self.value is not None: + body["value"] = self.value return body def as_shallow_dict(self) -> dict: - """Serializes the CreateVectorIndexResponse into a shallow dictionary of its immediate attributes.""" + """Serializes the CustomTag into a shallow dictionary of its immediate attributes.""" body = {} - if self.vector_index: - body["vector_index"] = self.vector_index + if self.key is not None: + body["key"] = self.key + if self.value is not None: + body["value"] = self.value return body @classmethod - def from_dict(cls, d: Dict[str, Any]) -> CreateVectorIndexResponse: - """Deserializes the CreateVectorIndexResponse from a dictionary.""" - return cls(vector_index=_from_dict(d, "vector_index", VectorIndex)) + def from_dict(cls, d: Dict[str, Any]) -> CustomTag: + """Deserializes the CustomTag from a dictionary.""" + return cls(key=d.get("key", None), value=d.get("value", None)) @dataclass class DeleteDataResult: - """Result of the upsert or delete operation.""" - failed_primary_keys: Optional[List[str]] = None """List of primary keys for rows that failed to process.""" @@ -204,51 +224,14 @@ def from_dict(cls, d: Dict[str, Any]) -> DeleteDataResult: class DeleteDataStatus(Enum): - """Status of the delete operation.""" FAILURE = "FAILURE" PARTIAL_SUCCESS = "PARTIAL_SUCCESS" SUCCESS = "SUCCESS" -@dataclass -class DeleteDataVectorIndexRequest: - """Request payload for deleting data from a vector index.""" - - primary_keys: List[str] - """List of primary keys for the data to be deleted.""" - - index_name: Optional[str] = None - """Name of the vector index where data is to be deleted. Must be a Direct Vector Access Index.""" - - def as_dict(self) -> dict: - """Serializes the DeleteDataVectorIndexRequest into a dictionary suitable for use as a JSON request body.""" - body = {} - if self.index_name is not None: - body["index_name"] = self.index_name - if self.primary_keys: - body["primary_keys"] = [v for v in self.primary_keys] - return body - - def as_shallow_dict(self) -> dict: - """Serializes the DeleteDataVectorIndexRequest into a shallow dictionary of its immediate attributes.""" - body = {} - if self.index_name is not None: - body["index_name"] = self.index_name - if self.primary_keys: - body["primary_keys"] = self.primary_keys - return body - - @classmethod - def from_dict(cls, d: Dict[str, Any]) -> DeleteDataVectorIndexRequest: - """Deserializes the DeleteDataVectorIndexRequest from a dictionary.""" - return cls(index_name=d.get("index_name", None), primary_keys=d.get("primary_keys", None)) - - @dataclass class DeleteDataVectorIndexResponse: - """Response to a delete data vector index request.""" - result: Optional[DeleteDataResult] = None """Result of the upsert or delete operation.""" @@ -326,20 +309,17 @@ class DeltaSyncVectorIndexSpecRequest: """The columns that contain the embedding source.""" embedding_vector_columns: Optional[List[EmbeddingVectorColumn]] = None - """The columns that contain the embedding vectors. The format should be array[double].""" + """The columns that contain the embedding vectors.""" embedding_writeback_table: Optional[str] = None - """[Optional] Automatically sync the vector index contents and computed embeddings to the specified - Delta table. The only supported table name is the index name with the suffix `_writeback_table`.""" + """[Optional] Name of the Delta table to sync the vector index contents and computed embeddings to.""" pipeline_type: Optional[PipelineType] = None - """Pipeline execution mode. - - - `TRIGGERED`: If the pipeline uses the triggered execution mode, the system stops processing - after successfully refreshing the source table in the pipeline once, ensuring the table is - updated based on the data available when the update started. - `CONTINUOUS`: If the pipeline - uses continuous execution, the pipeline processes new data as it arrives in the source table to - keep vector index fresh.""" + """Pipeline execution mode. - `TRIGGERED`: If the pipeline uses the triggered execution mode, the + system stops processing after successfully refreshing the source table in the pipeline once, + ensuring the table is updated based on the data available when the update started. - + `CONTINUOUS`: If the pipeline uses continuous execution, the pipeline processes new data as it + arrives in the source table to keep vector index fresh.""" source_table: Optional[str] = None """The name of the source table.""" @@ -406,13 +386,11 @@ class DeltaSyncVectorIndexSpecResponse: """The ID of the pipeline that is used to sync the index.""" pipeline_type: Optional[PipelineType] = None - """Pipeline execution mode. - - - `TRIGGERED`: If the pipeline uses the triggered execution mode, the system stops processing - after successfully refreshing the source table in the pipeline once, ensuring the table is - updated based on the data available when the update started. - `CONTINUOUS`: If the pipeline - uses continuous execution, the pipeline processes new data as it arrives in the source table to - keep vector index fresh.""" + """Pipeline execution mode. - `TRIGGERED`: If the pipeline uses the triggered execution mode, the + system stops processing after successfully refreshing the source table in the pipeline once, + ensuring the table is updated based on the data available when the update started. - + `CONTINUOUS`: If the pipeline uses continuous execution, the pipeline processes new data as it + arrives in the source table to keep vector index fresh.""" source_table: Optional[str] = None """The name of the source table.""" @@ -467,17 +445,15 @@ def from_dict(cls, d: Dict[str, Any]) -> DeltaSyncVectorIndexSpecResponse: @dataclass class DirectAccessVectorIndexSpec: embedding_source_columns: Optional[List[EmbeddingSourceColumn]] = None - """Contains the optional model endpoint to use during query time.""" + """The columns that contain the embedding source. The format should be array[double].""" embedding_vector_columns: Optional[List[EmbeddingVectorColumn]] = None + """The columns that contain the embedding vectors. The format should be array[double].""" schema_json: Optional[str] = None - """The schema of the index in JSON format. - - Supported types are `integer`, `long`, `float`, `double`, `boolean`, `string`, `date`, - `timestamp`. - - Supported types for vector column: `array`, `array`,`.""" + """The schema of the index in JSON format. Supported types are `integer`, `long`, `float`, + `double`, `boolean`, `string`, `date`, `timestamp`. Supported types for vector column: + `array`, `array`,`.""" def as_dict(self) -> dict: """Serializes the DirectAccessVectorIndexSpec into a dictionary suitable for use as a JSON request body.""" @@ -583,11 +559,17 @@ class EndpointInfo: creator: Optional[str] = None """Creator of the endpoint""" + custom_tags: Optional[List[CustomTag]] = None + """The custom tags assigned to the endpoint""" + + effective_budget_policy_id: Optional[str] = None + """The budget policy id applied to the endpoint""" + endpoint_status: Optional[EndpointStatus] = None """Current status of the endpoint""" endpoint_type: Optional[EndpointType] = None - """Type of endpoint.""" + """Type of endpoint""" id: Optional[str] = None """Unique identifier of the endpoint""" @@ -599,7 +581,7 @@ class EndpointInfo: """User who last updated the endpoint""" name: Optional[str] = None - """Name of endpoint""" + """Name of the vector search endpoint""" num_indexes: Optional[int] = None """Number of indexes on the endpoint""" @@ -611,6 +593,10 @@ def as_dict(self) -> dict: body["creation_timestamp"] = self.creation_timestamp if self.creator is not None: body["creator"] = self.creator + if self.custom_tags: + body["custom_tags"] = [v.as_dict() for v in self.custom_tags] + if self.effective_budget_policy_id is not None: + body["effective_budget_policy_id"] = self.effective_budget_policy_id if self.endpoint_status: body["endpoint_status"] = self.endpoint_status.as_dict() if self.endpoint_type is not None: @@ -634,6 +620,10 @@ def as_shallow_dict(self) -> dict: body["creation_timestamp"] = self.creation_timestamp if self.creator is not None: body["creator"] = self.creator + if self.custom_tags: + body["custom_tags"] = self.custom_tags + if self.effective_budget_policy_id is not None: + body["effective_budget_policy_id"] = self.effective_budget_policy_id if self.endpoint_status: body["endpoint_status"] = self.endpoint_status if self.endpoint_type is not None: @@ -656,6 +646,8 @@ def from_dict(cls, d: Dict[str, Any]) -> EndpointInfo: return cls( creation_timestamp=d.get("creation_timestamp", None), creator=d.get("creator", None), + custom_tags=_repeated_dict(d, "custom_tags", CustomTag), + effective_budget_policy_id=d.get("effective_budget_policy_id", None), endpoint_status=_from_dict(d, "endpoint_status", EndpointStatus), endpoint_type=_enum(d, "endpoint_type", EndpointType), id=d.get("id", None), @@ -751,7 +743,14 @@ def from_dict(cls, d: Dict[str, Any]) -> ListEndpointResponse: @dataclass class ListValue: + """copied from proto3 / Google Well Known Types, source: + https://github.com/protocolbuffers/protobuf/blob/450d24ca820750c5db5112a6f0b0c2efb9758021/src/google/protobuf/struct.proto + `ListValue` is a wrapper around a repeated field of values. + + The JSON representation for `ListValue` is JSON array.""" + values: Optional[List[Value]] = None + """Repeated field of dynamically typed values.""" def as_dict(self) -> dict: """Serializes the ListValue into a dictionary suitable for use as a JSON request body.""" @@ -851,12 +850,11 @@ class MiniVectorIndex: """Name of the endpoint associated with the index""" index_type: Optional[VectorIndexType] = None - """There are 2 types of Vector Search indexes: - - - `DELTA_SYNC`: An index that automatically syncs with a source Delta Table, automatically and - incrementally updating the index as the underlying data in the Delta Table changes. - - `DIRECT_ACCESS`: An index that supports direct read and write of vectors and metadata through - our REST and SDK APIs. With this model, the user manages index updates.""" + """There are 2 types of Vector Search indexes: - `DELTA_SYNC`: An index that automatically syncs + with a source Delta Table, automatically and incrementally updating the index as the underlying + data in the Delta Table changes. - `DIRECT_ACCESS`: An index that supports direct read and write + of vectors and metadata through our REST and SDK APIs. With this model, the user manages index + updates.""" name: Optional[str] = None """Name of the index""" @@ -906,14 +904,69 @@ def from_dict(cls, d: Dict[str, Any]) -> MiniVectorIndex: ) -class PipelineType(Enum): - """Pipeline execution mode. +@dataclass +class PatchEndpointBudgetPolicyRequest: + budget_policy_id: str + """The budget policy id to be applied""" + + endpoint_name: Optional[str] = None + """Name of the vector search endpoint""" + + def as_dict(self) -> dict: + """Serializes the PatchEndpointBudgetPolicyRequest into a dictionary suitable for use as a JSON request body.""" + body = {} + if self.budget_policy_id is not None: + body["budget_policy_id"] = self.budget_policy_id + if self.endpoint_name is not None: + body["endpoint_name"] = self.endpoint_name + return body + + def as_shallow_dict(self) -> dict: + """Serializes the PatchEndpointBudgetPolicyRequest into a shallow dictionary of its immediate attributes.""" + body = {} + if self.budget_policy_id is not None: + body["budget_policy_id"] = self.budget_policy_id + if self.endpoint_name is not None: + body["endpoint_name"] = self.endpoint_name + return body + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> PatchEndpointBudgetPolicyRequest: + """Deserializes the PatchEndpointBudgetPolicyRequest from a dictionary.""" + return cls(budget_policy_id=d.get("budget_policy_id", None), endpoint_name=d.get("endpoint_name", None)) + - - `TRIGGERED`: If the pipeline uses the triggered execution mode, the system stops processing - after successfully refreshing the source table in the pipeline once, ensuring the table is - updated based on the data available when the update started. - `CONTINUOUS`: If the pipeline - uses continuous execution, the pipeline processes new data as it arrives in the source table to - keep vector index fresh.""" +@dataclass +class PatchEndpointBudgetPolicyResponse: + effective_budget_policy_id: Optional[str] = None + """The budget policy applied to the vector search endpoint.""" + + def as_dict(self) -> dict: + """Serializes the PatchEndpointBudgetPolicyResponse into a dictionary suitable for use as a JSON request body.""" + body = {} + if self.effective_budget_policy_id is not None: + body["effective_budget_policy_id"] = self.effective_budget_policy_id + return body + + def as_shallow_dict(self) -> dict: + """Serializes the PatchEndpointBudgetPolicyResponse into a shallow dictionary of its immediate attributes.""" + body = {} + if self.effective_budget_policy_id is not None: + body["effective_budget_policy_id"] = self.effective_budget_policy_id + return body + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> PatchEndpointBudgetPolicyResponse: + """Deserializes the PatchEndpointBudgetPolicyResponse from a dictionary.""" + return cls(effective_budget_policy_id=d.get("effective_budget_policy_id", None)) + + +class PipelineType(Enum): + """Pipeline execution mode. - `TRIGGERED`: If the pipeline uses the triggered execution mode, the + system stops processing after successfully refreshing the source table in the pipeline once, + ensuring the table is updated based on the data available when the update started. - + `CONTINUOUS`: If the pipeline uses continuous execution, the pipeline processes new data as it + arrives in the source table to keep vector index fresh.""" CONTINUOUS = "CONTINUOUS" TRIGGERED = "TRIGGERED" @@ -975,9 +1028,11 @@ class QueryVectorIndexRequest: filters_json: Optional[str] = None """JSON string representing query filters. - Example filters: - `{"id <": 5}`: Filter for id less than 5. - `{"id >": 5}`: Filter for id - greater than 5. - `{"id <=": 5}`: Filter for id less than equal to 5. - `{"id >=": 5}`: Filter - for id greater than equal to 5. - `{"id": 5}`: Filter for id equal to 5.""" + Example filters: + + - `{"id <": 5}`: Filter for id less than 5. - `{"id >": 5}`: Filter for id greater than 5. - + `{"id <=": 5}`: Filter for id less than equal to 5. - `{"id >=": 5}`: Filter for id greater than + equal to 5. - `{"id": 5}`: Filter for id equal to 5.""" index_name: Optional[str] = None """Name of the vector index to query.""" @@ -1109,7 +1164,7 @@ def from_dict(cls, d: Dict[str, Any]) -> QueryVectorIndexResponse: class ResultData: """Data returned in the query result.""" - data_array: Optional[List[List[str]]] = None + data_array: Optional[List[ListValue]] = None """Data rows returned in the query.""" row_count: Optional[int] = None @@ -1119,7 +1174,7 @@ def as_dict(self) -> dict: """Serializes the ResultData into a dictionary suitable for use as a JSON request body.""" body = {} if self.data_array: - body["data_array"] = [v for v in self.data_array] + body["data_array"] = [v.as_dict() for v in self.data_array] if self.row_count is not None: body["row_count"] = self.row_count return body @@ -1136,7 +1191,7 @@ def as_shallow_dict(self) -> dict: @classmethod def from_dict(cls, d: Dict[str, Any]) -> ResultData: """Deserializes the ResultData from a dictionary.""" - return cls(data_array=d.get("data_array", None), row_count=d.get("row_count", None)) + return cls(data_array=_repeated_dict(d, "data_array", ListValue), row_count=d.get("row_count", None)) @dataclass @@ -1175,8 +1230,6 @@ def from_dict(cls, d: Dict[str, Any]) -> ResultManifest: @dataclass class ScanVectorIndexRequest: - """Request payload for scanning data from a vector index.""" - index_name: Optional[str] = None """Name of the vector index to scan.""" @@ -1254,6 +1307,15 @@ def from_dict(cls, d: Dict[str, Any]) -> ScanVectorIndexResponse: @dataclass class Struct: + """copied from proto3 / Google Well Known Types, source: + https://github.com/protocolbuffers/protobuf/blob/450d24ca820750c5db5112a6f0b0c2efb9758021/src/google/protobuf/struct.proto + `Struct` represents a structured data value, consisting of fields which map to dynamically typed + values. In some languages, `Struct` might be supported by a native representation. For example, + in scripting languages like JS a struct is represented as an object. The details of that + representation are described together with the proto support for the language. + + The JSON representation for `Struct` is JSON object.""" + fields: Optional[List[MapStringValueEntry]] = None """Data entry, corresponding to a row in a vector index.""" @@ -1296,9 +1358,71 @@ def from_dict(cls, d: Dict[str, Any]) -> SyncIndexResponse: @dataclass -class UpsertDataResult: - """Result of the upsert or delete operation.""" +class UpdateEndpointCustomTagsRequest: + custom_tags: List[CustomTag] + """The new custom tags for the vector search endpoint""" + + endpoint_name: Optional[str] = None + """Name of the vector search endpoint""" + def as_dict(self) -> dict: + """Serializes the UpdateEndpointCustomTagsRequest into a dictionary suitable for use as a JSON request body.""" + body = {} + if self.custom_tags: + body["custom_tags"] = [v.as_dict() for v in self.custom_tags] + if self.endpoint_name is not None: + body["endpoint_name"] = self.endpoint_name + return body + + def as_shallow_dict(self) -> dict: + """Serializes the UpdateEndpointCustomTagsRequest into a shallow dictionary of its immediate attributes.""" + body = {} + if self.custom_tags: + body["custom_tags"] = self.custom_tags + if self.endpoint_name is not None: + body["endpoint_name"] = self.endpoint_name + return body + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> UpdateEndpointCustomTagsRequest: + """Deserializes the UpdateEndpointCustomTagsRequest from a dictionary.""" + return cls(custom_tags=_repeated_dict(d, "custom_tags", CustomTag), endpoint_name=d.get("endpoint_name", None)) + + +@dataclass +class UpdateEndpointCustomTagsResponse: + custom_tags: Optional[List[CustomTag]] = None + """All the custom tags that are applied to the vector search endpoint.""" + + name: Optional[str] = None + """The name of the vector search endpoint whose custom tags were updated.""" + + def as_dict(self) -> dict: + """Serializes the UpdateEndpointCustomTagsResponse into a dictionary suitable for use as a JSON request body.""" + body = {} + if self.custom_tags: + body["custom_tags"] = [v.as_dict() for v in self.custom_tags] + if self.name is not None: + body["name"] = self.name + return body + + def as_shallow_dict(self) -> dict: + """Serializes the UpdateEndpointCustomTagsResponse into a shallow dictionary of its immediate attributes.""" + body = {} + if self.custom_tags: + body["custom_tags"] = self.custom_tags + if self.name is not None: + body["name"] = self.name + return body + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> UpdateEndpointCustomTagsResponse: + """Deserializes the UpdateEndpointCustomTagsResponse from a dictionary.""" + return cls(custom_tags=_repeated_dict(d, "custom_tags", CustomTag), name=d.get("name", None)) + + +@dataclass +class UpsertDataResult: failed_primary_keys: Optional[List[str]] = None """List of primary keys for rows that failed to process.""" @@ -1332,7 +1456,6 @@ def from_dict(cls, d: Dict[str, Any]) -> UpsertDataResult: class UpsertDataStatus(Enum): - """Status of the upsert operation.""" FAILURE = "FAILURE" PARTIAL_SUCCESS = "PARTIAL_SUCCESS" @@ -1341,8 +1464,6 @@ class UpsertDataStatus(Enum): @dataclass class UpsertDataVectorIndexRequest: - """Request payload for upserting data into a vector index.""" - inputs_json: str """JSON string representing the data to be upserted.""" @@ -1375,8 +1496,6 @@ def from_dict(cls, d: Dict[str, Any]) -> UpsertDataVectorIndexRequest: @dataclass class UpsertDataVectorIndexResponse: - """Response to an upsert data vector index request.""" - result: Optional[UpsertDataResult] = None """Result of the upsert or delete operation.""" @@ -1412,14 +1531,25 @@ class Value: bool_value: Optional[bool] = None list_value: Optional[ListValue] = None - - null_value: Optional[str] = None + """copied from proto3 / Google Well Known Types, source: + https://github.com/protocolbuffers/protobuf/blob/450d24ca820750c5db5112a6f0b0c2efb9758021/src/google/protobuf/struct.proto + `ListValue` is a wrapper around a repeated field of values. + + The JSON representation for `ListValue` is JSON array.""" number_value: Optional[float] = None string_value: Optional[str] = None struct_value: Optional[Struct] = None + """copied from proto3 / Google Well Known Types, source: + https://github.com/protocolbuffers/protobuf/blob/450d24ca820750c5db5112a6f0b0c2efb9758021/src/google/protobuf/struct.proto + `Struct` represents a structured data value, consisting of fields which map to dynamically typed + values. In some languages, `Struct` might be supported by a native representation. For example, + in scripting languages like JS a struct is represented as an object. The details of that + representation are described together with the proto support for the language. + + The JSON representation for `Struct` is JSON object.""" def as_dict(self) -> dict: """Serializes the Value into a dictionary suitable for use as a JSON request body.""" @@ -1428,8 +1558,6 @@ def as_dict(self) -> dict: body["bool_value"] = self.bool_value if self.list_value: body["list_value"] = self.list_value.as_dict() - if self.null_value is not None: - body["null_value"] = self.null_value if self.number_value is not None: body["number_value"] = self.number_value if self.string_value is not None: @@ -1445,8 +1573,6 @@ def as_shallow_dict(self) -> dict: body["bool_value"] = self.bool_value if self.list_value: body["list_value"] = self.list_value - if self.null_value is not None: - body["null_value"] = self.null_value if self.number_value is not None: body["number_value"] = self.number_value if self.string_value is not None: @@ -1461,7 +1587,6 @@ def from_dict(cls, d: Dict[str, Any]) -> Value: return cls( bool_value=d.get("bool_value", None), list_value=_from_dict(d, "list_value", ListValue), - null_value=d.get("null_value", None), number_value=d.get("number_value", None), string_value=d.get("string_value", None), struct_value=_from_dict(d, "struct_value", Struct), @@ -1481,12 +1606,11 @@ class VectorIndex: """Name of the endpoint associated with the index""" index_type: Optional[VectorIndexType] = None - """There are 2 types of Vector Search indexes: - - - `DELTA_SYNC`: An index that automatically syncs with a source Delta Table, automatically and - incrementally updating the index as the underlying data in the Delta Table changes. - - `DIRECT_ACCESS`: An index that supports direct read and write of vectors and metadata through - our REST and SDK APIs. With this model, the user manages index updates.""" + """There are 2 types of Vector Search indexes: - `DELTA_SYNC`: An index that automatically syncs + with a source Delta Table, automatically and incrementally updating the index as the underlying + data in the Delta Table changes. - `DIRECT_ACCESS`: An index that supports direct read and write + of vectors and metadata through our REST and SDK APIs. With this model, the user manages index + updates.""" name: Optional[str] = None """Name of the index""" @@ -1605,38 +1729,84 @@ def from_dict(cls, d: Dict[str, Any]) -> VectorIndexStatus: class VectorIndexType(Enum): - """There are 2 types of Vector Search indexes: - - - `DELTA_SYNC`: An index that automatically syncs with a source Delta Table, automatically and - incrementally updating the index as the underlying data in the Delta Table changes. - - `DIRECT_ACCESS`: An index that supports direct read and write of vectors and metadata through - our REST and SDK APIs. With this model, the user manages index updates.""" + """There are 2 types of Vector Search indexes: - `DELTA_SYNC`: An index that automatically syncs + with a source Delta Table, automatically and incrementally updating the index as the underlying + data in the Delta Table changes. - `DIRECT_ACCESS`: An index that supports direct read and write + of vectors and metadata through our REST and SDK APIs. With this model, the user manages index + updates.""" DELTA_SYNC = "DELTA_SYNC" DIRECT_ACCESS = "DIRECT_ACCESS" +class VectorSearchEndpointsCreateEndpointWaiter: + raw_response: EndpointInfo + """raw_response is the raw response of the CreateEndpoint call.""" + _service: VectorSearchEndpointsAPI + _endpoint_name: str + + def __init__(self, raw_response: EndpointInfo, service: VectorSearchEndpointsAPI, endpoint_name: str): + self._service = service + self.raw_response = raw_response + self._endpoint_name = endpoint_name + + def WaitUntilDone(self, opts: Optional[WaitUntilDoneOptions] = None) -> EndpointInfo: + if opts is None: + opts = WaitUntilDoneOptions() + deadline = time.time() + opts.timeout.total_seconds() + target_states = (EndpointStatusState.ONLINE,) + failure_states = (EndpointStatusState.OFFLINE,) + status_message = "polling..." + attempt = 1 + while time.time() < deadline: + poll = self._service.get_endpoint(endpoint_name=self._endpoint_name) + status = poll.endpoint_status.state + status_message = f"current status: {status}" + if poll.endpoint_status: + status_message = poll.endpoint_status.message + if status in target_states: + return poll + if status in failure_states: + msg = f"failed to reach ONLINE, got {status}: {status_message}" + raise OperationFailed(msg) + prefix = f"endpoint_name={self._endpoint_name}" + sleep = attempt + if sleep > 10: + # sleep 10s max per attempt + sleep = 10 + _LOG.debug(f"{prefix}: ({status}) {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + raise TimeoutError(f"timed out after {opts.timeout}: {status_message}") + + class VectorSearchEndpointsAPI: """**Endpoint**: Represents the compute resources to host vector search indexes.""" def __init__(self, api_client): self._api = api_client - def create_endpoint(self, name: str, endpoint_type: EndpointType) -> EndpointInfo: + def create_endpoint( + self, name: str, endpoint_type: EndpointType, *, budget_policy_id: Optional[str] = None + ) -> VectorSearchEndpointsCreateEndpointWaiter: """Create an endpoint. Create a new endpoint. :param name: str - Name of endpoint + Name of the vector search endpoint :param endpoint_type: :class:`EndpointType` - Type of endpoint. + Type of endpoint + :param budget_policy_id: str (optional) + The budget policy id to be applied :returns: Long-running operation waiter for :class:`EndpointInfo`. See :method:WaitGetEndpointVectorSearchEndpointOnline for more details. """ body = {} + if budget_policy_id is not None: + body["budget_policy_id"] = budget_policy_id if endpoint_type is not None: body["endpoint_type"] = endpoint_type.value if name is not None: @@ -1646,25 +1816,33 @@ def create_endpoint(self, name: str, endpoint_type: EndpointType) -> EndpointInf "Content-Type": "application/json", } - res = self._api.do("POST", "/api/2.0/vector-search/endpoints", body=body, headers=headers) - return EndpointInfo.from_dict(res) + op_response = self._api.do("POST", "/api/2.0/vector-search/endpoints", body=body, headers=headers) + return VectorSearchEndpointsCreateEndpointWaiter( + service=self, raw_response=EndpointInfo.from_dict(op_response), endpoint_name=op_response["name"] + ) def delete_endpoint(self, endpoint_name: str): """Delete an endpoint. + Delete a vector search endpoint. + :param endpoint_name: str - Name of the endpoint + Name of the vector search endpoint """ - headers = {} + headers = { + "Accept": "application/json", + } self._api.do("DELETE", f"/api/2.0/vector-search/endpoints/{endpoint_name}", headers=headers) def get_endpoint(self, endpoint_name: str) -> EndpointInfo: """Get an endpoint. + Get details for a single vector search endpoint. + :param endpoint_name: str Name of the endpoint @@ -1681,6 +1859,8 @@ def get_endpoint(self, endpoint_name: str) -> EndpointInfo: def list_endpoints(self, *, page_token: Optional[str] = None) -> Iterator[EndpointInfo]: """List all endpoints. + List all vector search endpoints in the workspace. + :param page_token: str (optional) Token for pagination @@ -1703,14 +1883,66 @@ def list_endpoints(self, *, page_token: Optional[str] = None) -> Iterator[Endpoi return query["page_token"] = json["next_page_token"] + def update_endpoint_budget_policy( + self, endpoint_name: str, budget_policy_id: str + ) -> PatchEndpointBudgetPolicyResponse: + """Update the budget policy of an endpoint. + + Update the budget policy of an endpoint + + :param endpoint_name: str + Name of the vector search endpoint + :param budget_policy_id: str + The budget policy id to be applied + + :returns: :class:`PatchEndpointBudgetPolicyResponse` + """ + body = {} + if budget_policy_id is not None: + body["budget_policy_id"] = budget_policy_id + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } + + res = self._api.do( + "PATCH", f"/api/2.0/vector-search/endpoints/{endpoint_name}/budget-policy", body=body, headers=headers + ) + return PatchEndpointBudgetPolicyResponse.from_dict(res) + + def update_endpoint_custom_tags( + self, endpoint_name: str, custom_tags: List[CustomTag] + ) -> UpdateEndpointCustomTagsResponse: + """Update the custom tags of an endpoint. + + :param endpoint_name: str + Name of the vector search endpoint + :param custom_tags: List[:class:`CustomTag`] + The new custom tags for the vector search endpoint + + :returns: :class:`UpdateEndpointCustomTagsResponse` + """ + body = {} + if custom_tags is not None: + body["custom_tags"] = [v.as_dict() for v in custom_tags] + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } + + res = self._api.do( + "PATCH", f"/api/2.0/vector-search/endpoints/{endpoint_name}/tags", body=body, headers=headers + ) + return UpdateEndpointCustomTagsResponse.from_dict(res) + class VectorSearchIndexesAPI: """**Index**: An efficient representation of your embedding vectors that supports real-time and efficient approximate nearest neighbor (ANN) search queries. - There are 2 types of Vector Search indexes: * **Delta Sync Index**: An index that automatically syncs with + There are 2 types of Vector Search indexes: - **Delta Sync Index**: An index that automatically syncs with a source Delta Table, automatically and incrementally updating the index as the underlying data in the - Delta Table changes. * **Direct Vector Access Index**: An index that supports direct read and write of + Delta Table changes. - **Direct Vector Access Index**: An index that supports direct read and write of vectors and metadata through our REST and SDK APIs. With this model, the user manages index updates.""" def __init__(self, api_client): @@ -1725,7 +1957,7 @@ def create_index( *, delta_sync_index_spec: Optional[DeltaSyncVectorIndexSpecRequest] = None, direct_access_index_spec: Optional[DirectAccessVectorIndexSpec] = None, - ) -> CreateVectorIndexResponse: + ) -> VectorIndex: """Create an index. Create a new index. @@ -1737,18 +1969,16 @@ def create_index( :param primary_key: str Primary key of the index :param index_type: :class:`VectorIndexType` - There are 2 types of Vector Search indexes: - - - `DELTA_SYNC`: An index that automatically syncs with a source Delta Table, automatically and - incrementally updating the index as the underlying data in the Delta Table changes. - - `DIRECT_ACCESS`: An index that supports direct read and write of vectors and metadata through our - REST and SDK APIs. With this model, the user manages index updates. + There are 2 types of Vector Search indexes: - `DELTA_SYNC`: An index that automatically syncs with a + source Delta Table, automatically and incrementally updating the index as the underlying data in the + Delta Table changes. - `DIRECT_ACCESS`: An index that supports direct read and write of vectors and + metadata through our REST and SDK APIs. With this model, the user manages index updates. :param delta_sync_index_spec: :class:`DeltaSyncVectorIndexSpecRequest` (optional) Specification for Delta Sync Index. Required if `index_type` is `DELTA_SYNC`. :param direct_access_index_spec: :class:`DirectAccessVectorIndexSpec` (optional) Specification for Direct Vector Access Index. Required if `index_type` is `DIRECT_ACCESS`. - :returns: :class:`CreateVectorIndexResponse` + :returns: :class:`VectorIndex` """ body = {} if delta_sync_index_spec is not None: @@ -1769,7 +1999,7 @@ def create_index( } res = self._api.do("POST", "/api/2.0/vector-search/indexes", body=body, headers=headers) - return CreateVectorIndexResponse.from_dict(res) + return VectorIndex.from_dict(res) def delete_data_vector_index(self, index_name: str, primary_keys: List[str]) -> DeleteDataVectorIndexResponse: """Delete data from index. @@ -1783,16 +2013,16 @@ def delete_data_vector_index(self, index_name: str, primary_keys: List[str]) -> :returns: :class:`DeleteDataVectorIndexResponse` """ - body = {} + + query = {} if primary_keys is not None: - body["primary_keys"] = [v for v in primary_keys] + query["primary_keys"] = [v for v in primary_keys] headers = { "Accept": "application/json", - "Content-Type": "application/json", } res = self._api.do( - "POST", f"/api/2.0/vector-search/indexes/{index_name}/delete-data", body=body, headers=headers + "DELETE", f"/api/2.0/vector-search/indexes/{index_name}/delete-data", query=query, headers=headers ) return DeleteDataVectorIndexResponse.from_dict(res) @@ -1807,7 +2037,9 @@ def delete_index(self, index_name: str): """ - headers = {} + headers = { + "Accept": "application/json", + } self._api.do("DELETE", f"/api/2.0/vector-search/indexes/{index_name}", headers=headers) @@ -1886,9 +2118,11 @@ def query_index( :param filters_json: str (optional) JSON string representing query filters. - Example filters: - `{"id <": 5}`: Filter for id less than 5. - `{"id >": 5}`: Filter for id greater - than 5. - `{"id <=": 5}`: Filter for id less than equal to 5. - `{"id >=": 5}`: Filter for id - greater than equal to 5. - `{"id": 5}`: Filter for id equal to 5. + Example filters: + + - `{"id <": 5}`: Filter for id less than 5. - `{"id >": 5}`: Filter for id greater than 5. - `{"id + <=": 5}`: Filter for id less than equal to 5. - `{"id >=": 5}`: Filter for id greater than equal to + 5. - `{"id": 5}`: Filter for id equal to 5. :param num_results: int (optional) Number of results to return. Defaults to 10. :param query_text: str (optional) @@ -2001,7 +2235,9 @@ def sync_index(self, index_name: str): """ - headers = {} + headers = { + "Accept": "application/json", + } self._api.do("POST", f"/api/2.0/vector-search/indexes/{index_name}/sync", headers=headers) diff --git a/tests/integration/test_jobs.py b/tests/integration/test_jobs.py index d6e835898..6f3508634 100644 --- a/tests/integration/test_jobs.py +++ b/tests/integration/test_jobs.py @@ -1,7 +1,11 @@ +import datetime import logging from databricks.sdk.compute.v2.client import ClustersClient +from databricks.sdk.files.v2.client import DbfsClient +from databricks.sdk.iam.v2.client import CurrentUserClient from databricks.sdk.jobs.v2.client import JobsClient +from databricks.sdk.service._internal import WaitUntilDoneOptions def test_jobs(w): @@ -14,44 +18,41 @@ def test_jobs(w): assert found > 0 -# TODO: Re-enable this after adding waiters to the SDK -# def test_submitting_jobs(w, random, env_or_skip): -# from databricks.sdk.jobs.v2 import jobs -# from databricks.sdk.compute.v2 import compute - -# cuc = CurrentUserClient(config=w) -# jc = JobsClient(config=w) -# dc = DbfsClient(config=w) - -# py_on_dbfs = f"/home/{cuc.me().user_name}/sample.py" -# with dc.open(py_on_dbfs, write=True, overwrite=True) as f: -# f.write(b'import time; time.sleep(10); print("Hello, World!")') - -# waiter = jc.submit( -# run_name=f"py-sdk-{random(8)}", -# tasks=[ -# jobs.SubmitTask( -# task_key="pi", -# new_cluster=jobs.JobsClusterSpec( -# spark_version=w.clusters.select_spark_version(long_term_support=True), -# # node_type_id=w.clusters.select_node_type(local_disk=True), -# instance_pool_id=env_or_skip("TEST_INSTANCE_POOL_ID"), -# num_workers=1, -# ), -# spark_python_task=jobs.SparkPythonTask(python_file=f"dbfs:{py_on_dbfs}"), -# ) -# ], -# ) - -# logging.info(f"starting to poll: {waiter.run_id}") - -# def print_status(run: jobs.Run): -# statuses = [f"{t.task_key}: {t.state.life_cycle_state}" for t in run.tasks] -# logging.info(f'workflow intermediate status: {", ".join(statuses)}') - -# run = waiter.result(timeout=datetime.timedelta(minutes=15), callback=print_status) - -# logging.info(f"job finished: {run.run_page_url}") +def test_submitting_jobs(w, random, env_or_skip): + from databricks.sdk.jobs.v2 import jobs + + cuc = CurrentUserClient(config=w) + jc = JobsClient(config=w) + dc = DbfsClient(config=w) + cc = ClustersClient(config=w) + + py_on_dbfs = f"/home/{cuc.me().user_name}/sample.py" + with dc.open(py_on_dbfs, write=True, overwrite=True) as f: + f.write(b'import time; time.sleep(10); print("Hello, World!")') + + waiter = jc.submit( + run_name=f"py-sdk-{random(8)}", + tasks=[ + jobs.SubmitTask( + task_key="pi", + new_cluster=jobs.JobsClusterSpec( + spark_version=cc.select_spark_version(long_term_support=True), + # node_type_id=cc.select_node_type(local_disk=True), + instance_pool_id=env_or_skip("TEST_INSTANCE_POOL_ID"), + num_workers=1, + ), + spark_python_task=jobs.SparkPythonTask(python_file=f"dbfs:{py_on_dbfs}"), + ) + ], + ) + + logging.info(f"starting to poll: {waiter.raw_response.run_id}") + + options = WaitUntilDoneOptions() + options.timeout = datetime.timedelta(minutes=15) + run = waiter.WaitUntilDone(opts=options) + + logging.info(f"job finished: {run.run_page_url}") def test_last_job_runs(w):