Skip to content

Streaming PUT - integration tests #648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions examples/streaming_put.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/usr/bin/env python3
"""
Simple example of streaming PUT operations.

This demonstrates the basic usage of streaming PUT with the __input_stream__ token.
"""

import io
import os
from databricks import sql

def main():
"""Simple streaming PUT example."""

# Connect to Databricks
connection = sql.connect(
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
access_token=os.getenv("DATABRICKS_TOKEN"),
)

with connection.cursor() as cursor:
# Create a simple data stream
data = b"Hello, streaming world!"
stream = io.BytesIO(data)

# Upload to Unity Catalog volume
cursor.execute(
"PUT '__input_stream__' INTO '/Volumes/my_catalog/my_schema/my_volume/hello.txt'",
input_stream=stream
)

print("File uploaded successfully!")

if __name__ == "__main__":
main()
183 changes: 136 additions & 47 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence, BinaryIO
import pandas

try:
Expand Down Expand Up @@ -455,6 +455,7 @@ def __init__(
self.active_command_id = None
self.escaper = ParamEscaper()
self.lastrowid = None
self._input_stream_data: Optional[BinaryIO] = None

self.ASYNC_DEFAULT_POLLING_INTERVAL = 2

Expand Down Expand Up @@ -625,6 +626,33 @@ def _handle_staging_operation(
is not descended from staging_allowed_local_path.
"""

assert self.active_result_set is not None
row = self.active_result_set.fetchone()
assert row is not None

# Parse headers
headers = (
json.loads(row.headers) if isinstance(row.headers, str) else row.headers
)
headers = dict(headers) if headers else {}

# Handle __input_stream__ token for PUT operations
if (
row.operation == "PUT"
and getattr(row, "localFile", None) == "__input_stream__"
):
if not self._input_stream_data:
raise ProgrammingError(
"No input stream provided for streaming operation",
session_id_hex=self.connection.get_session_id_hex(),
)
return self._handle_staging_put_stream(
presigned_url=row.presignedUrl,
stream=self._input_stream_data,
headers=headers,
)

# For non-streaming operations, validate staging_allowed_local_path
if isinstance(staging_allowed_local_path, type(str())):
_staging_allowed_local_paths = [staging_allowed_local_path]
elif isinstance(staging_allowed_local_path, type(list())):
Expand All @@ -639,10 +667,6 @@ def _handle_staging_operation(
os.path.abspath(i) for i in _staging_allowed_local_paths
]

assert self.active_result_set is not None
row = self.active_result_set.fetchone()
assert row is not None

# Must set to None in cases where server response does not include localFile
abs_localFile = None

Expand All @@ -665,15 +689,10 @@ def _handle_staging_operation(
session_id_hex=self.connection.get_session_id_hex(),
)

# May be real headers, or could be json string
headers = (
json.loads(row.headers) if isinstance(row.headers, str) else row.headers
)

handler_args = {
"presigned_url": row.presignedUrl,
"local_file": abs_localFile,
"headers": dict(headers) or {},
"headers": headers,
}

logger.debug(
Expand All @@ -696,6 +715,60 @@ def _handle_staging_operation(
session_id_hex=self.connection.get_session_id_hex(),
)

@log_latency(StatementType.SQL)
def _handle_staging_put_stream(
self,
presigned_url: str,
stream: BinaryIO,
headers: Optional[dict] = None,
) -> None:
"""Handle PUT operation with streaming data.

Args:
presigned_url: The presigned URL for upload
stream: Binary stream to upload
headers: Optional HTTP headers

Raises:
OperationalError: If the upload fails
"""

# Prepare headers
http_headers = dict(headers) if headers else {}

try:
# Stream directly to presigned URL
response = requests.put(
url=presigned_url,
data=stream,
headers=http_headers,
timeout=300, # 5 minute timeout
)

# Check response codes
OK = requests.codes.ok # 200
CREATED = requests.codes.created # 201
ACCEPTED = requests.codes.accepted # 202
NO_CONTENT = requests.codes.no_content # 204

if response.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]:
raise OperationalError(
f"Staging operation over HTTP was unsuccessful: {response.status_code}-{response.text}",
session_id_hex=self.connection.get_session_id_hex(),
)

if response.status_code == ACCEPTED:
logger.debug(
f"Response code {ACCEPTED} from server indicates upload was accepted "
"but not yet applied on the server. It's possible this command may fail later."
)

except requests.exceptions.RequestException as e:
raise OperationalError(
f"HTTP request failed during stream upload: {str(e)}",
session_id_hex=self.connection.get_session_id_hex(),
) from e

@log_latency(StatementType.SQL)
def _handle_staging_put(
self, presigned_url: str, local_file: str, headers: Optional[dict] = None
Expand Down Expand Up @@ -783,6 +856,7 @@ def execute(
self,
operation: str,
parameters: Optional[TParameterCollection] = None,
input_stream: Optional[BinaryIO] = None,
enforce_embedded_schema_correctness=False,
) -> "Cursor":
"""
Expand Down Expand Up @@ -820,47 +894,62 @@ def execute(
logger.debug(
"Cursor.execute(operation=%s, parameters=%s)", operation, parameters
)
try:
# Store stream data if provided
self._input_stream_data = None
if input_stream is not None:
# Validate stream has required methods
if not hasattr(input_stream, "read"):
raise TypeError(
"input_stream must be a binary stream with read() method"
)
self._input_stream_data = input_stream

param_approach = self._determine_parameter_approach(parameters)
if param_approach == ParameterApproach.NONE:
prepared_params = NO_NATIVE_PARAMS
prepared_operation = operation

elif param_approach == ParameterApproach.INLINE:
prepared_operation, prepared_params = self._prepare_inline_parameters(
operation, parameters
)
elif param_approach == ParameterApproach.NATIVE:
normalized_parameters = self._normalize_tparametercollection(parameters)
param_structure = self._determine_parameter_structure(normalized_parameters)
transformed_operation = transform_paramstyle(
operation, normalized_parameters, param_structure
)
prepared_operation, prepared_params = self._prepare_native_parameters(
transformed_operation, normalized_parameters, param_structure
)
param_approach = self._determine_parameter_approach(parameters)
if param_approach == ParameterApproach.NONE:
prepared_params = NO_NATIVE_PARAMS
prepared_operation = operation

self._check_not_closed()
self._close_and_clear_active_result_set()
self.active_result_set = self.backend.execute_command(
operation=prepared_operation,
session_id=self.connection.session.session_id,
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
lz4_compression=self.connection.lz4_compression,
cursor=self,
use_cloud_fetch=self.connection.use_cloud_fetch,
parameters=prepared_params,
async_op=False,
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
)
elif param_approach == ParameterApproach.INLINE:
prepared_operation, prepared_params = self._prepare_inline_parameters(
operation, parameters
)
elif param_approach == ParameterApproach.NATIVE:
normalized_parameters = self._normalize_tparametercollection(parameters)
param_structure = self._determine_parameter_structure(
normalized_parameters
)
transformed_operation = transform_paramstyle(
operation, normalized_parameters, param_structure
)
prepared_operation, prepared_params = self._prepare_native_parameters(
transformed_operation, normalized_parameters, param_structure
)

if self.active_result_set and self.active_result_set.is_staging_operation:
self._handle_staging_operation(
staging_allowed_local_path=self.connection.staging_allowed_local_path
self._check_not_closed()
self._close_and_clear_active_result_set()
self.active_result_set = self.backend.execute_command(
operation=prepared_operation,
session_id=self.connection.session.session_id,
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
lz4_compression=self.connection.lz4_compression,
cursor=self,
use_cloud_fetch=self.connection.use_cloud_fetch,
parameters=prepared_params,
async_op=False,
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
)

return self
if self.active_result_set and self.active_result_set.is_staging_operation:
self._handle_staging_operation(
staging_allowed_local_path=self.connection.staging_allowed_local_path
)

return self
finally:
# Clean up stream data
self._input_stream_data = None

@log_latency(StatementType.QUERY)
def execute_async(
Expand Down
51 changes: 51 additions & 0 deletions tests/e2e/common/streaming_put_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env python3
"""
E2E tests for streaming PUT operations.
"""

import io
import pytest
from datetime import datetime


class PySQLStreamingPutTestSuiteMixin:
"""Test suite for streaming PUT operations."""

def test_streaming_put_basic(self, catalog, schema):
"""Test basic streaming PUT functionality."""

# Create test data
test_data = b"Hello, streaming world! This is test data."
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"stream_test_{timestamp}.txt"

with self.connection() as conn:
with conn.cursor() as cursor:
with io.BytesIO(test_data) as stream:
cursor.execute(
f"PUT '__input_stream__' INTO '/Volumes/{catalog}/{schema}/e2etests/{filename}'",
input_stream=stream
)

# Verify file exists
cursor.execute(f"LIST '/Volumes/{catalog}/{schema}/e2etests/'")
files = cursor.fetchall()

# Check if our file is in the list
file_paths = [row[0] for row in files]
expected_path = f"/Volumes/{catalog}/{schema}/e2etests/{filename}"

assert expected_path in file_paths, f"File {expected_path} not found in {file_paths}"


def test_streaming_put_missing_stream(self, catalog, schema):
"""Test that missing stream raises appropriate error."""

with self.connection() as conn:
with conn.cursor() as cursor:
# Test without providing stream
with pytest.raises(Exception): # Should fail
cursor.execute(
f"PUT '__input_stream__' INTO '/Volumes/{catalog}/{schema}/e2etests/test.txt'"
# Note: No input_stream parameter
)
3 changes: 2 additions & 1 deletion tests/e2e/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
)
from tests.e2e.common.staging_ingestion_tests import PySQLStagingIngestionTestSuiteMixin
from tests.e2e.common.retry_test_mixins import PySQLRetryTestsMixin

from tests.e2e.common.uc_volume_tests import PySQLUCVolumeTestSuiteMixin
from tests.e2e.common.streaming_put_tests import PySQLStreamingPutTestSuiteMixin

from databricks.sql.exc import SessionAlreadyClosedError

Expand Down Expand Up @@ -256,6 +256,7 @@ class TestPySQLCoreSuite(
PySQLStagingIngestionTestSuiteMixin,
PySQLRetryTestsMixin,
PySQLUCVolumeTestSuiteMixin,
PySQLStreamingPutTestSuiteMixin,
):
validate_row_value_type = True
validate_result = True
Expand Down
Loading
Loading