Skip to content

[Internal] Switch code formatter to Black #900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
3 changes: 2 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ If it is appropriate to write a design document, the document must be hosted eit
Small patches and bug fixes don't need prior communication.

## Coding Style
Code style is enforced by a formatter check in your pull request. We use [yapf](https://github.com/google/yapf) to format our code. Run `make fmt` to ensure your code is properly formatted prior to raising a pull request.

Code style is enforced by a formatter check in your pull request. We use [Black](https://github.com/psf/black) to format our code. Run `make fmt` to ensure your code is properly formatted prior to raising a pull request.

## Signed Commits
This repo requires all contributors to sign their commits. To configure this, you can follow [Github's documentation](https://docs.github.com/en/authentication/managing-commit-signature-verification/signing-commits) to create a GPG key, upload it to your Github account, and configure your git client to sign commits.
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ install:
pip install .

fmt:
yapf -pri databricks tests
black databricks tests
autoflake -ri databricks tests
isort databricks tests

fmte:
yapf -pri examples
black examples
autoflake -ri examples
isort examples

Expand Down
318 changes: 185 additions & 133 deletions databricks/sdk/__init__.py

Large diffs are not rendered by default.

190 changes: 107 additions & 83 deletions databricks/sdk/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,45 @@
from .logger import RoundTrip
from .retries import retried

logger = logging.getLogger('databricks.sdk')
logger = logging.getLogger("databricks.sdk")


def _fix_host_if_needed(host: Optional[str]) -> Optional[str]:
if not host:
return host

# Add a default scheme if it's missing
if '://' not in host:
host = 'https://' + host
if "://" not in host:
host = "https://" + host

o = urllib.parse.urlparse(host)
# remove trailing slash
path = o.path.rstrip('/')
path = o.path.rstrip("/")
# remove port if 443
netloc = o.netloc
if o.port == 443:
netloc = netloc.split(':')[0]
netloc = netloc.split(":")[0]

return urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment))


class _BaseClient:

def __init__(self,
debug_truncate_bytes: int = None,
retry_timeout_seconds: int = None,
user_agent_base: str = None,
header_factory: Callable[[], dict] = None,
max_connection_pools: int = None,
max_connections_per_pool: int = None,
pool_block: bool = True,
http_timeout_seconds: float = None,
extra_error_customizers: List[_ErrorCustomizer] = None,
debug_headers: bool = False,
clock: Clock = None,
streaming_buffer_size: int = 1024 * 1024): # 1MB
def __init__(
self,
debug_truncate_bytes: int = None,
retry_timeout_seconds: int = None,
user_agent_base: str = None,
header_factory: Callable[[], dict] = None,
max_connection_pools: int = None,
max_connections_per_pool: int = None,
pool_block: bool = True,
http_timeout_seconds: float = None,
extra_error_customizers: List[_ErrorCustomizer] = None,
debug_headers: bool = False,
clock: Clock = None,
streaming_buffer_size: int = 1024 * 1024,
): # 1MB
"""
:param debug_truncate_bytes:
:param retry_timeout_seconds:
Expand Down Expand Up @@ -87,9 +89,11 @@ def __init__(self,
# We don't use `max_retries` from HTTPAdapter to align with a more production-ready
# retry strategy established in the Databricks SDK for Go. See _is_retryable and
# @retried for more details.
http_adapter = requests.adapters.HTTPAdapter(pool_connections=max_connections_per_pool or 20,
pool_maxsize=max_connection_pools or 20,
pool_block=pool_block)
http_adapter = requests.adapters.HTTPAdapter(
pool_connections=max_connections_per_pool or 20,
pool_maxsize=max_connection_pools or 20,
pool_block=pool_block,
)
self._session.mount("https://", http_adapter)

# Default to 60 seconds
Expand All @@ -110,7 +114,7 @@ def _fix_query_string(query: Optional[dict] = None) -> Optional[dict]:
# See: https://github.com/databricks/databricks-sdk-py/issues/142
if query is None:
return None
with_fixed_bools = {k: v if type(v) != bool else ('true' if v else 'false') for k, v in query.items()}
with_fixed_bools = {k: v if type(v) != bool else ("true" if v else "false") for k, v in query.items()}

# Query parameters may be nested, e.g.
# {'filter_by': {'user_ids': [123, 456]}}
Expand Down Expand Up @@ -140,30 +144,34 @@ def _is_seekable_stream(data) -> bool:
return False
return data.seekable()

def do(self,
method: str,
url: str,
query: dict = None,
headers: dict = None,
body: dict = None,
raw: bool = False,
files=None,
data=None,
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None,
response_headers: List[str] = None) -> Union[dict, list, BinaryIO]:
def do(
self,
method: str,
url: str,
query: dict = None,
headers: dict = None,
body: dict = None,
raw: bool = False,
files=None,
data=None,
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None,
response_headers: List[str] = None,
) -> Union[dict, list, BinaryIO]:
if headers is None:
headers = {}
headers['User-Agent'] = self._user_agent_base
headers["User-Agent"] = self._user_agent_base

# Wrap strings and bytes in a seekable stream so that we can rewind them.
if isinstance(data, (str, bytes)):
data = io.BytesIO(data.encode('utf-8') if isinstance(data, str) else data)
data = io.BytesIO(data.encode("utf-8") if isinstance(data, str) else data)

if not data:
# The request is not a stream.
call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable,
clock=self._clock)(self._perform)
call = retried(
timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable,
clock=self._clock,
)(self._perform)
elif self._is_seekable_stream(data):
# Keep track of the initial position of the stream so that we can rewind to it
# if we need to retry the request.
Expand All @@ -173,25 +181,29 @@ def rewind():
logger.debug(f"Rewinding input data to offset {initial_data_position} before retry")
data.seek(initial_data_position)

call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable,
clock=self._clock,
before_retry=rewind)(self._perform)
call = retried(
timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable,
clock=self._clock,
before_retry=rewind,
)(self._perform)
else:
# Do not retry if the stream is not seekable. This is necessary to avoid bugs
# where the retry doesn't re-read already read data from the stream.
logger.debug(f"Retry disabled for non-seekable stream: type={type(data)}")
call = self._perform

response = call(method,
url,
query=query,
headers=headers,
body=body,
raw=raw,
files=files,
data=data,
auth=auth)
response = call(
method,
url,
query=query,
headers=headers,
body=body,
raw=raw,
files=files,
data=data,
auth=auth,
)

resp = dict()
for header in response_headers if response_headers else []:
Expand Down Expand Up @@ -220,6 +232,7 @@ def _is_retryable(err: BaseException) -> Optional[str]:
# and Databricks SDK for Go retries
# (see https://github.com/databricks/databricks-sdk-go/blob/main/apierr/errors.go)
from urllib3.exceptions import ProxyError

if isinstance(err, ProxyError):
err = err.original_error
if isinstance(err, requests.ConnectionError):
Expand All @@ -230,48 +243,55 @@ def _is_retryable(err: BaseException) -> Optional[str]:
#
# return a simple string for debug log readability, as `raise TimeoutError(...) from err`
# will bubble up the original exception in case we reach max retries.
return f'cannot connect'
return f"cannot connect"
if isinstance(err, requests.Timeout):
# corresponds to `TLS handshake timeout` and `i/o timeout` in Go.
#
# return a simple string for debug log readability, as `raise TimeoutError(...) from err`
# will bubble up the original exception in case we reach max retries.
return f'timeout'
return f"timeout"
if isinstance(err, DatabricksError):
message = str(err)
transient_error_string_matches = [
"com.databricks.backend.manager.util.UnknownWorkerEnvironmentException",
"does not have any associated worker environments", "There is no worker environment with id",
"Unknown worker environment", "ClusterNotReadyException", "Unexpected error",
"does not have any associated worker environments",
"There is no worker environment with id",
"Unknown worker environment",
"ClusterNotReadyException",
"Unexpected error",
"Please try again later or try a faster operation.",
"RPC token bucket limit has been exceeded",
]
for substring in transient_error_string_matches:
if substring not in message:
continue
return f'matched {substring}'
return f"matched {substring}"
return None

def _perform(self,
method: str,
url: str,
query: dict = None,
headers: dict = None,
body: dict = None,
raw: bool = False,
files=None,
data=None,
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None):
response = self._session.request(method,
url,
params=self._fix_query_string(query),
json=body,
headers=headers,
files=files,
data=data,
auth=auth,
stream=raw,
timeout=self._http_timeout_seconds)
def _perform(
self,
method: str,
url: str,
query: dict = None,
headers: dict = None,
body: dict = None,
raw: bool = False,
files=None,
data=None,
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None,
):
response = self._session.request(
method,
url,
params=self._fix_query_string(query),
json=body,
headers=headers,
files=files,
data=data,
auth=auth,
stream=raw,
timeout=self._http_timeout_seconds,
)
self._record_request_log(response, raw=raw or data is not None or files is not None)
error = self._error_parser.get_api_error(response)
if error is not None:
Expand Down Expand Up @@ -312,7 +332,7 @@ def flush(self) -> int:

def __init__(self, response: _RawResponse, chunk_size: Union[int, None] = None):
self._response = response
self._buffer = b''
self._buffer = b""
self._content = None
self._chunk_size = chunk_size

Expand All @@ -338,14 +358,14 @@ def isatty(self) -> bool:

def read(self, n: int = -1) -> bytes:
"""
Read up to n bytes from the response stream. If n is negative, read
until the end of the stream.
Read up to n bytes from the response stream. If n is negative, read
until the end of the stream.
"""

self._open()
read_everything = n < 0
remaining_bytes = n
res = b''
res = b""
while remaining_bytes > 0 or read_everything:
if len(self._buffer) == 0:
try:
Expand Down Expand Up @@ -395,8 +415,12 @@ def __next__(self) -> bytes:
def __iter__(self) -> Iterator[bytes]:
return self._content

def __exit__(self, t: Union[Type[BaseException], None], value: Union[BaseException, None],
traceback: Union[TracebackType, None]) -> None:
def __exit__(
self,
t: Union[Type[BaseException], None],
value: Union[BaseException, None],
traceback: Union[TracebackType, None],
) -> None:
self._content = None
self._buffer = b''
self._buffer = b""
self.close()
19 changes: 12 additions & 7 deletions databricks/sdk/_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ def __set_name__(self, owner, name):
if self.attrname is None:
self.attrname = name
elif name != self.attrname:
raise TypeError("Cannot assign the same cached_property to two different names "
f"({self.attrname!r} and {name!r}).")
raise TypeError(
"Cannot assign the same cached_property to two different names " f"({self.attrname!r} and {name!r})."
)

def __get__(self, instance, owner=None):
if instance is None:
Expand All @@ -26,17 +27,21 @@ def __get__(self, instance, owner=None):
raise TypeError("Cannot use cached_property instance without calling __set_name__ on it.")
try:
cache = instance.__dict__
except AttributeError: # not all objects have __dict__ (e.g. class defines slots)
msg = (f"No '__dict__' attribute on {type(instance).__name__!r} "
f"instance to cache {self.attrname!r} property.")
except AttributeError: # not all objects have __dict__ (e.g. class defines slots)
msg = (
f"No '__dict__' attribute on {type(instance).__name__!r} "
f"instance to cache {self.attrname!r} property."
)
raise TypeError(msg) from None
val = cache.get(self.attrname, _NOT_FOUND)
if val is _NOT_FOUND:
val = self.func(instance)
try:
cache[self.attrname] = val
except TypeError:
msg = (f"The '__dict__' attribute on {type(instance).__name__!r} instance "
f"does not support item assignment for caching {self.attrname!r} property.")
msg = (
f"The '__dict__' attribute on {type(instance).__name__!r} instance "
f"does not support item assignment for caching {self.attrname!r} property."
)
raise TypeError(msg) from None
return val
Loading
Loading