diff --git a/trino-aws-proxy-spark3/src/main/java/io/trino/aws/proxy/spark3/PresignAwareAmazonS3.java b/trino-aws-proxy-spark3/src/main/java/io/trino/aws/proxy/spark3/PresignAwareAmazonS3.java index f7faab7a..e5bc6944 100644 --- a/trino-aws-proxy-spark3/src/main/java/io/trino/aws/proxy/spark3/PresignAwareAmazonS3.java +++ b/trino-aws-proxy-spark3/src/main/java/io/trino/aws/proxy/spark3/PresignAwareAmazonS3.java @@ -626,6 +626,8 @@ public S3Object getObject(GetObjectRequest getObjectRequest) return getPresignedUrl("GET", getObjectRequest.getBucketName(), getObjectRequest.getKey()) .map(presigned -> { PresignedUrlDownloadRequest presignedUrlDownloadRequest = new PresignedUrlDownloadRequest(presigned.url); + Optional.ofNullable(getObjectRequest.getRange()) + .ifPresent(range -> presignedUrlDownloadRequest.withRange(range[0], range[1])); return delegate.download(presignedUrlDownloadRequest).getS3Object(); }) .orElseGet(() -> delegate.getObject(getObjectRequest)); @@ -638,6 +640,8 @@ public ObjectMetadata getObject(GetObjectRequest getObjectRequest, File destinat return getPresignedUrl("GET", getObjectRequest.getBucketName(), getObjectRequest.getKey()) .map(presigned -> { PresignedUrlDownloadRequest presignedUrlDownloadRequest = new PresignedUrlDownloadRequest(presigned.url); + Optional.ofNullable(getObjectRequest.getRange()) + .ifPresent(range -> presignedUrlDownloadRequest.withRange(range[0], range[1])); delegate.download(presignedUrlDownloadRequest, destinationFile); return presigned.objectMetadata; }) diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestDatabaseSecurity.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestDatabaseSecurity.java index f087d2c8..522655af 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestDatabaseSecurity.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestDatabaseSecurity.java @@ -103,6 +103,9 @@ public TestDatabaseSecurity(S3Client s3Client, PySparkContainer pySparkContainer public void testDatabaseSecurity() throws Exception { + // create the test bucket + s3Client.createBucket(r -> r.bucket("test")); + createDatabaseAndTable(s3Client, pySparkContainer); clearInputStreamAndClose(inputToContainerStdin(pySparkContainer.containerId(), "spark.sql(\"select * from %s.%s\").show()".formatted(DATABASE_NAME, TABLE_NAME)), line -> line.equals("| John Galt| 28|")); diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestPySparkSql.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestPySparkSql.java index f18e0368..9f581180 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestPySparkSql.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestPySparkSql.java @@ -19,6 +19,7 @@ import io.trino.aws.proxy.server.testing.containers.PySparkContainer; import io.trino.aws.proxy.server.testing.harness.BuilderFilter; import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import software.amazon.awssdk.services.s3.S3Client; @@ -54,6 +55,13 @@ public TestPySparkSql(S3Client s3Client, PySparkContainer pySparkContainer) this.pySparkContainer = requireNonNull(pySparkContainer, "pySparkContainer is null"); } + @BeforeAll + public void setupBucket() + { + // create the test bucket + s3Client.createBucket(r -> r.bucket("test")); + } + @Test public void testSql() throws Exception @@ -75,12 +83,29 @@ public void testSql() "spark.sql(\"select * from %s.%s\").show()".formatted(DATABASE_NAME, TABLE_NAME)), line -> line.equals("| c| 30|")); } - public static void createDatabaseAndTable(S3Client s3Client, PySparkContainer container) + @Test + public void testParquet() throws Exception { - // create the test bucket - s3Client.createBucket(r -> r.bucket("test")); + // upload a CSV file + s3Client.putObject(r -> r.bucket("test").key("test_parquet/file.csv"), Path.of(Resources.getResource("test.csv").toURI())); + // read the CSV file and write it as Parquet + clearInputStreamAndClose(inputToContainerStdin(pySparkContainer.containerId(), """ + df = spark.read.csv("s3a://test/test_parquet/file.csv") + df.write.parquet("s3a://test/test_parquet/file.parquet") + """), line -> line.equals(">>> ") || line.matches(".*Write Job [\\w-]+ committed.*")); + + // read the Parquet file + clearInputStreamAndClose(inputToContainerStdin(pySparkContainer.containerId(), """ + parquetDF = spark.read.parquet("s3a://test/test_parquet/file.parquet") + parquetDF.show() + """), line -> line.equals("| John Galt| 28|")); + } + + public static void createDatabaseAndTable(S3Client s3Client, PySparkContainer container) + throws Exception + { // upload a CSV file as a potential table s3Client.putObject(r -> r.bucket("test").key("table/file.csv"), Path.of(Resources.getResource("test.csv").toURI()));