diff --git a/src/java.net.http/share/classes/java/net/http/HttpRequest.java b/src/java.net.http/share/classes/java/net/http/HttpRequest.java index 7ba6ed25b41e3..61366b4a695d6 100644 --- a/src/java.net.http/share/classes/java/net/http/HttpRequest.java +++ b/src/java.net.http/share/classes/java/net/http/HttpRequest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -29,10 +29,9 @@ import java.io.InputStream; import java.net.URI; import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.OpenOption; import java.nio.file.Path; import java.time.Duration; import java.util.Iterator; @@ -720,6 +719,34 @@ public static BodyPublisher ofFile(Path path) throws FileNotFoundException { return RequestPublishers.FilePublisher.create(path); } + /** + * {@return a request body publisher whose body is the {@code length} + * content bytes read from the provided file {@code channel} starting + * from the specified {@code offset}} + *

+ * The {@linkplain FileChannel file channel} will be read using + * {@link FileChannel#read(ByteBuffer, long) FileChannel.read(ByteBuffer buffer, long position)}, + * which does not modify the channel's position. Thus, the same file + * channel may be shared between several publishers passed to + * concurrent requests. + *

+ * The file channel will not be closed upon completion. The caller is + * expected to manage the life cycle of the channel, and close it + * appropriately when not needed anymore. + * + * @param channel a file channel + * @param offset the offset of the first byte + * @param length the number of bytes to read from the file channel + * + * @throws IndexOutOfBoundsException if the specified byte range is + * found to be {@linkplain Objects.checkFromIndexSize(long, long, long) out of bounds} + * compared with the size of the file referred by the channel + */ + public static BodyPublisher ofFileChannel(FileChannel channel, long offset, long length) { + Objects.requireNonNull(channel, "channel"); + return new RequestPublishers.FileChannelPublisher(channel, offset, length); + } + /** * A request body publisher that takes data from an {@code Iterable} * of byte arrays. An {@link Iterable} is provided which supplies diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/RequestPublishers.java b/src/java.net.http/share/classes/jdk/internal/net/http/RequestPublishers.java index dd5443c503567..e90acdc609df1 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/RequestPublishers.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/RequestPublishers.java @@ -32,6 +32,7 @@ import java.lang.reflect.UndeclaredThrowableException; import java.net.http.HttpRequest.BodyPublisher; import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; import java.nio.charset.Charset; import java.nio.file.Files; import java.nio.file.NoSuchFileException; @@ -418,6 +419,90 @@ public long contentLength() { } } + public static final class FileChannelPublisher implements BodyPublisher { + + private final FileChannel channel; + + private final long position; + + private final long limit; + + public FileChannelPublisher(FileChannel channel, long offset, long length) { + this.channel = Objects.requireNonNull(channel, "channel"); + long fileSize = fileSize(channel); + Objects.checkFromIndexSize(offset, length, fileSize); + this.position = offset; + this.limit = offset + length; + } + + private static long fileSize(FileChannel channel) { + try { + return channel.size(); + } catch (IOException ioe) { + throw new UncheckedIOException(ioe); + } + } + + @Override + public long contentLength() { + return limit - position; + } + + @Override + public void subscribe(Flow.Subscriber subscriber) { + Iterable iterable = () -> new FileChannelIterator(channel, position, limit); + new PullPublisher<>(iterable).subscribe(subscriber); + } + + } + + private static final class FileChannelIterator implements Iterator { + + private final FileChannel channel; + + private final long limit; + + private long position; + + private boolean terminated; + + private FileChannelIterator(FileChannel channel, long position, long limit) { + this.channel = channel; + this.position = position; + this.limit = limit; + } + + @Override + public synchronized boolean hasNext() { + return position < limit && !terminated; + } + + @Override + public synchronized ByteBuffer next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + long remaining = limit - position; + ByteBuffer buffer = Utils.getBufferWithAtMost(remaining); + try { + int readLength = channel.read(buffer, position); + // Short-circuit if `read()` has failed, e.g., due to file content being changed in the meantime + if (readLength < 0) { + // We *must* throw to signal that the request needs to be cancelled. + // Otherwise, the server will continue waiting data. + throw new IOException("Unexpected EOF (position=%s)".formatted(position)); + } else { + position += readLength; + } + } catch (IOException ioe) { + terminated = true; + throw new UncheckedIOException(ioe); + } + return buffer.flip(); + } + + } + public static final class PublisherAdapter implements BodyPublisher { private final Publisher publisher; @@ -430,12 +515,12 @@ public PublisherAdapter(Publisher publisher, } @Override - public final long contentLength() { + public long contentLength() { return contentLength; } @Override - public final void subscribe(Flow.Subscriber subscriber) { + public void subscribe(Flow.Subscriber subscriber) { publisher.subscribe(subscriber); } } diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/common/Utils.java b/src/java.net.http/share/classes/jdk/internal/net/http/common/Utils.java index 8aefa0ee5baf4..ea4e429f40209 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/common/Utils.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/common/Utils.java @@ -367,10 +367,32 @@ public static String describeOps(int interestOps) { public static IllegalArgumentException newIAE(String message, Object... args) { return new IllegalArgumentException(format(message, args)); } + + /** + * {@return a new {@link ByteBuffer} instance of configured capacity for the HTTP Client} + */ public static ByteBuffer getBuffer() { return ByteBuffer.allocate(BUFSIZE); } + /** + * {@return a new {@link ByteBuffer} instance whose capacity is set to the + * smaller of the specified {@code maxCapacity} and the default + * ({@value BUFSIZE})} + * + * @param maxCapacity a buffer capacity, in bytes + * @throws IllegalArgumentException if {@code capacity < 0} + */ + public static ByteBuffer getBufferWithAtMost(long maxCapacity) { + if (maxCapacity < 0) { + throw new IllegalArgumentException( + // Match the message produced by `ByteBuffer::createCapacityException` + "capacity < 0: (%s < 0)".formatted(maxCapacity)); + } + int effectiveCapacity = (int) Math.min(maxCapacity, BUFSIZE); + return ByteBuffer.allocate(effectiveCapacity); + } + public static Throwable getCompletionCause(Throwable x) { Throwable cause = x; while ((cause instanceof CompletionException) diff --git a/test/jdk/java/net/httpclient/FileChannelPublisherTest.java b/test/jdk/java/net/httpclient/FileChannelPublisherTest.java new file mode 100644 index 0000000000000..7350e66b52cb4 --- /dev/null +++ b/test/jdk/java/net/httpclient/FileChannelPublisherTest.java @@ -0,0 +1,693 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @summary Verifies `HttpRequest.BodyPublishers::ofFileChannel` + * @library /test/lib + * /test/jdk/java/net/httpclient/lib + * @build jdk.httpclient.test.lib.common.HttpServerAdapters + * jdk.test.lib.net.SimpleSSLContext + * @run junit FileChannelPublisherTest + */ + +import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestHandler; +import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestServer; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.Utils; +import jdk.test.lib.net.SimpleSSLContext; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledOnOs; +import org.junit.jupiter.api.condition.OS; +import org.junit.jupiter.api.io.CleanupMode; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.net.ssl.SSLContext; +import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; + +import static java.net.http.HttpClient.Builder.NO_PROXY; +import static java.net.http.HttpRequest.BodyPublishers.ofFileChannel; +import static java.net.http.HttpResponse.BodyHandlers.discarding; +import static java.net.http.HttpResponse.BodyHandlers.ofInputStream; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class FileChannelPublisherTest { + + private static final String CLASS_NAME = FileChannelPublisherTest.class.getSimpleName(); + + private static final Logger LOGGER = Utils.getDebugLogger(CLASS_NAME::toString, Utils.DEBUG); + + private static final int DEFAULT_BUFFER_SIZE = Utils.getBuffer().capacity(); + + private static final SSLContext SSL_CONTEXT = createSslContext(); + + private static final HttpClient CLIENT = HttpClient.newBuilder().sslContext(SSL_CONTEXT).proxy(NO_PROXY).build(); + + private static final ExecutorService EXECUTOR = Executors.newCachedThreadPool(); + + private static final ServerRequestPair HTTP1 = ServerRequestPair.of(Version.HTTP_1_1, false); + + private static final ServerRequestPair HTTPS1 = ServerRequestPair.of(Version.HTTP_1_1, true); + + private static final ServerRequestPair HTTP2 = ServerRequestPair.of(Version.HTTP_2, false); + + private static final ServerRequestPair HTTPS2 = ServerRequestPair.of(Version.HTTP_2, true); + + private static SSLContext createSslContext() { + try { + return new SimpleSSLContext().get(); + } catch (IOException exception) { + throw new UncheckedIOException(exception); + } + } + + private record ServerRequestPair( + String serverName, + HttpTestServer server, + BlockingQueue serverReadRequestBodyBytes, + HttpRequest.Builder requestBuilder, + boolean secure) { + + private static CountDownLatch SERVER_REQUEST_RECEIVED_SIGNAL = null; + + private static CountDownLatch SERVER_READ_PERMISSION = null; + + private static ServerRequestPair of(Version version, boolean secure) { + + // Create the server + SSLContext sslContext = secure ? SSL_CONTEXT : null; + HttpTestServer server = createServer(version, sslContext); + String serverName = secure ? version.toString().replaceFirst("_", "S_") : version.toString(); + + // Add the handler + String handlerPath = "/%s/".formatted(CLASS_NAME); + BlockingQueue serverReadRequestBodyBytes = + addRequestBodyConsumingServerHandler(serverName, server, handlerPath); + + // Create the request builder + String requestUriScheme = secure ? "https" : "http"; + // `x` suffix in the URI is not a typo, but ensures that *only* the parent handler path is matched + URI requestUri = URI.create("%s://%s%sx".formatted(requestUriScheme, server.serverAuthority(), handlerPath)); + HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(requestUri).version(version); + + // Create the pair + ServerRequestPair pair = new ServerRequestPair(serverName, server, serverReadRequestBodyBytes, requestBuilder, secure); + pair.server.start(); + LOGGER.log("Server[%s] is started at `%s`", pair, server.serverAuthority()); + + return pair; + + } + + private static HttpTestServer createServer(Version version, SSLContext sslContext) { + try { + // The default HTTP/1.1 test server processes requests sequentially. + // This causes a deadlock for concurrent tests such as `testSlicedUpload()`. + // Hence, explicitly providing a multithreaded executor for HTTP/1.1. + ExecutorService executor = Version.HTTP_1_1.equals(version) ? EXECUTOR : null; + return HttpTestServer.create(version, sslContext, executor); + } catch (IOException ioe) { + throw new UncheckedIOException(ioe); + } + } + + private static BlockingQueue addRequestBodyConsumingServerHandler( + String serverName, HttpTestServer server, String handlerPath) { + BlockingQueue readRequestBodyBytes = new LinkedBlockingQueue<>(); + HttpTestHandler handler = exchange -> { + // `HttpTestExchange::toString` changes on failure, pin it + String exchangeName = exchange.toString(); + try (exchange) { + + // Discard `HEAD` requests used for initial connection admission + if ("HEAD".equals(exchange.getRequestMethod())) { + exchange.sendResponseHeaders(200, -1L); + return; + } + + signalServerRequestReceived(serverName, exchangeName); + awaitServerReadPermission(serverName, exchangeName); + + LOGGER.log("Server[%s] is reading the request body (exchange=%s)", serverName, exchangeName); + byte[] requestBodyBytes = exchange.getRequestBody().readAllBytes(); + LOGGER.log("Server[%s] has read %s bytes (exchange=%s)", serverName, requestBodyBytes.length, exchangeName); + readRequestBodyBytes.add(requestBodyBytes); + + LOGGER.log("Server[%s] is writing the response (exchange=%s)", serverName, exchangeName); + exchange.sendResponseHeaders(200, requestBodyBytes.length); + exchange.getResponseBody().write(requestBodyBytes); + + } catch (Exception exception) { + LOGGER.log( + "Server[%s] failed to process the request (exchange=%s)".formatted(serverName, exception), + exception); + readRequestBodyBytes.add(new byte[0]); + } finally { + LOGGER.log("Server[%s] completed processing the request (exchange=%s)", serverName, exchangeName); + } + }; + server.addHandler(handler, handlerPath); + return readRequestBodyBytes; + } + + private static void signalServerRequestReceived(String serverName, String exchangeName) { + if (SERVER_REQUEST_RECEIVED_SIGNAL != null) { + LOGGER.log("Server[%s] is signaling that the request is received (exchange=%s)", serverName, exchangeName); + SERVER_REQUEST_RECEIVED_SIGNAL.countDown(); + } + } + + private static void awaitServerReadPermission(String serverName, String exchangeName) { + if (SERVER_READ_PERMISSION != null) { + LOGGER.log("Server[%s] is waiting for the read permission (exchange=%s)", serverName, exchangeName); + try { + SERVER_READ_PERMISSION.await(); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); // Restore the `interrupted` flag + throw new RuntimeException(ie); + } + } + } + + @Override + public String toString() { + return serverName; + } + + } + + @AfterAll + static void shutDown() { + LOGGER.log("Closing the client"); + CLIENT.close(); + LOGGER.log("Closing servers"); + closeServers(); + LOGGER.log("Closing the executor"); + EXECUTOR.shutdownNow(); + } + + private static void closeServers() { + Exception[] exceptionRef = {null}; + Stream + .of(HTTP1, HTTPS1, HTTP2, HTTPS2) + .map(pair -> (Runnable) pair.server::stop) + .forEach(terminator -> { + try { + terminator.run(); + } catch (Exception exception) { + if (exceptionRef[0] == null) { + exceptionRef[0] = exception; + } else { + exceptionRef[0].addSuppressed(exception); + } + } + }); + if (exceptionRef[0] != null) { + throw new RuntimeException("failed closing one or more server resources", exceptionRef[0]); + } + } + + /** + * Resets {@link ServerRequestPair#serverReadRequestBodyBytes()} to avoid leftover state from a test leaking to the next. + */ + @BeforeEach + void resetServerHandlerResults() { + Stream + .of(HTTP1, HTTPS1, HTTP2, HTTPS2) + .forEach(pair -> pair.serverReadRequestBodyBytes.clear()); + } + + static ServerRequestPair[] serverRequestPairs() { + return new ServerRequestPair[]{ + HTTP1, + HTTPS1, + HTTP2, + HTTPS2 + }; + } + + @Test + void testNullFileChannel() { + assertThrows(NullPointerException.class, () -> ofFileChannel(null, 0, 1)); + } + + @ParameterizedTest + @CsvSource({ + "6,-1,1", // offset < 0 + "6,7,1", // offset > fileSize + "6,0,-1", // length < 0 + "6,0,7", // length > fileSize + "6,2,5" // (offset + length) > fileSize + }) + void testIllegalOffset( + int fileLength, + int fileChannelOffset, + int fileChannelLength, + @TempDir(cleanup = CleanupMode.ON_SUCCESS) Path tempDir) throws Exception { + withFileChannel(tempDir.resolve("data.txt"), fileLength, (_, fileChannel) -> + assertThrows( + IndexOutOfBoundsException.class, + () -> ofFileChannel(fileChannel, fileChannelOffset, fileChannelLength))); + } + + @ParameterizedTest + @MethodSource("serverRequestPairs") + void testContentLessThanBufferSize( + ServerRequestPair pair, + @TempDir(cleanup = CleanupMode.ON_SUCCESS) Path tempDir) throws Exception { + + int fileLength = 6; + assertTrue(fileLength < DEFAULT_BUFFER_SIZE); + + testSuccessfulContentDelivery( + "Complete content", + pair, tempDir, fileLength, 0, fileLength); + + { + int fileChannelOffset = 1; + int fileChannelLength = fileLength - 1; + String debuggingContext = debuggingContext(fileLength, fileChannelOffset, fileChannelLength); + assertEquals( + fileLength - fileChannelOffset, fileChannelLength, + "must be until EOF " + debuggingContext); + testSuccessfulContentDelivery( + "Partial content until the EOF " + debuggingContext, + pair, tempDir, fileLength, fileChannelOffset, fileChannelLength); + } + + { + int fileChannelOffset = 1; + int fileChannelLength = fileLength - 2; + String debuggingContext = debuggingContext(fileLength, fileChannelOffset, fileChannelLength); + assertTrue( + fileLength - fileChannelOffset > fileChannelLength, + "must end before EOF " + debuggingContext); + testSuccessfulContentDelivery( + "Partial content *before* the EOF " + debuggingContext, + pair, tempDir, fileLength, fileChannelOffset, fileChannelLength); + } + + } + + @ParameterizedTest + @MethodSource("serverRequestPairs") + void testContentMoreThanBufferSize( + ServerRequestPair pair, + @TempDir(cleanup = CleanupMode.ON_SUCCESS) Path tempDir) throws Exception { + + int fileLength = 1 + 3 * DEFAULT_BUFFER_SIZE; + + testSuccessfulContentDelivery( + "Complete content", + pair, tempDir, fileLength, 0, fileLength); + + { + int fileChannelOffset = 1; + int fileChannelLength = 3 * DEFAULT_BUFFER_SIZE; + String debuggingContext = debuggingContext(fileLength, fileChannelOffset, fileChannelLength); + assertEquals( + fileLength - fileChannelOffset, fileChannelLength, + "must be until EOF " + debuggingContext); + testSuccessfulContentDelivery( + "Partial content until the EOF. Occupies exactly 3 buffers. " + debuggingContext, + pair, tempDir, fileLength, fileChannelOffset, fileChannelLength); + } + + { + int fileChannelOffset = 2; + int fileChannelLength = 3 * DEFAULT_BUFFER_SIZE - 1; + String debuggingContext = debuggingContext(fileLength, fileChannelOffset, fileChannelLength); + assertEquals( + fileLength - fileChannelOffset, fileChannelLength, + "must be until EOF " + debuggingContext); + testSuccessfulContentDelivery( + "Partial content until the EOF. Occupies 3 buffers, the last is custom sized. " + debuggingContext, + pair, tempDir, fileLength, fileChannelOffset, fileChannelLength); + } + + { + int fileChannelOffset = 2; + int fileChannelLength = 2 * DEFAULT_BUFFER_SIZE; + String debuggingContext = debuggingContext(fileLength, fileChannelOffset, fileChannelLength); + assertTrue( + fileLength - fileChannelOffset > fileChannelLength, + "must end before EOF " + debuggingContext); + testSuccessfulContentDelivery( + "Partial content *before* the EOF. Occupies exactly 2 buffers. " + debuggingContext, + pair, tempDir, fileLength, fileChannelOffset, fileChannelLength); + } + + { + int fileChannelOffset = 2; + int fileChannelLength = 3 * DEFAULT_BUFFER_SIZE - 2; + String debuggingContext = debuggingContext(fileLength, fileChannelOffset, fileChannelLength); + assertTrue( + fileLength - fileChannelOffset > fileChannelLength, + "must end before EOF " + debuggingContext); + testSuccessfulContentDelivery( + "Partial content *before* the EOF. Occupies 3 buffers, the last is custom sized. "+ debuggingContext, + pair, tempDir, fileLength, fileChannelOffset, fileChannelLength); + } + + } + + private static String debuggingContext(int fileLength, int fileChannelOffset, int fileChannelLength) { + Map context = new LinkedHashMap<>(); // Using `LHM` to preserve the insertion order + context.put("DEFAULT_BUFFER_SIZE", DEFAULT_BUFFER_SIZE); + context.put("fileLength", fileLength); + context.put("fileChannelOffset", fileChannelOffset); + context.put("fileChannelLength", fileChannelLength); + boolean customSizedBuffer = fileChannelLength % DEFAULT_BUFFER_SIZE == 0; + context.put("customSizedBuffer", customSizedBuffer); + return context.toString(); + } + + private void testSuccessfulContentDelivery( + String caseDescription, + ServerRequestPair pair, + Path tempDir, + int fileLength, + int fileChannelOffset, + int fileChannelLength) throws Exception { + + // Case names come handy even when no debug logging is enabled. + // Hence, intentionally avoiding `Logger`. + System.err.printf("Case: %s%n", caseDescription); + + // Create the file to upload + String fileName = "data-%d-%d-%d.txt".formatted(fileLength, fileChannelOffset, fileChannelLength); + Path filePath = tempDir.resolve(fileName); + withFileChannel(filePath, fileLength, (fileBytes, fileChannel) -> { + + // Upload the file + HttpRequest request = pair + .requestBuilder + .POST(ofFileChannel(fileChannel, fileChannelOffset, fileChannelLength)) + .build(); + CLIENT.send(request, discarding()); + + // Verify the received request body + byte[] expectedRequestBodyBytes = new byte[fileChannelLength]; + System.arraycopy(fileBytes, fileChannelOffset, expectedRequestBodyBytes, 0, fileChannelLength); + byte[] actualRequestBodyBytes = pair.serverReadRequestBodyBytes.take(); + assertArrayEquals(expectedRequestBodyBytes, actualRequestBodyBytes); + + }); + + } + + /** + * Big enough file length to observe the effects of publisher state corruption while uploading. + *

+ * Certain tests follow below steps: + *

+ *
    + *
  1. Issue the request
  2. + *
  3. Wait for the server's signal that the request (not the body!) is received
  4. + *
  5. Corrupt the publisher's state; modify the file, close the file channel, etc.
  6. + *
  7. Signal the server to proceed with reading
  8. + *
+ *

+ * With small files, even before we permit the server to read (step 4), file gets already uploaded. + * This voids the effect of state corruption (step 3). + * To circumvent this, use this big enough file size. + *

+ * + * @see #testChannelCloseDuringPublisherRead(ServerRequestPair, Path) + * @see #testFileModificationDuringPublisherRead(ServerRequestPair, Path) + */ + private static final int BIG_FILE_LENGTH = 8 * 1024 * 1024; // 8 MiB + + @ParameterizedTest + @MethodSource("serverRequestPairs") + void testChannelCloseDuringPublisherRead( + ServerRequestPair pair, + @TempDir(cleanup = CleanupMode.ON_SUCCESS) Path tempDir) + throws Exception { + establishInitialConnection(pair); + ServerRequestPair.SERVER_REQUEST_RECEIVED_SIGNAL = new CountDownLatch(1); + ServerRequestPair.SERVER_READ_PERMISSION = new CountDownLatch(1); + try { + + int fileLength = BIG_FILE_LENGTH; + AtomicReference>> responseFutureRef = new AtomicReference<>(); + withFileChannel(tempDir.resolve("data.txt"), fileLength, ((_, fileChannel) -> { + + // Issue the request + LOGGER.log("Issuing the request"); + HttpRequest request = pair + .requestBuilder + .POST(ofFileChannel(fileChannel, 0, fileLength)) + .build(); + responseFutureRef.set(CLIENT.sendAsync(request, discarding())); + + // Wait for server to receive the request + LOGGER.log("Waiting for the request to be received"); + ServerRequestPair.SERVER_REQUEST_RECEIVED_SIGNAL.await(); + + })); + + LOGGER.log("File channel is closed"); + + // Let the server proceed + LOGGER.log("Permitting the server to proceed"); + ServerRequestPair.SERVER_READ_PERMISSION.countDown(); + + // Verifying the client failure + LOGGER.log("Verifying the client failure"); + Exception requestFailure = assertThrows(ExecutionException.class, () -> responseFutureRef.get().get()); + assertInstanceOf(UncheckedIOException.class, requestFailure.getCause()); + assertInstanceOf(ClosedChannelException.class, requestFailure.getCause().getCause()); + + verifyServerIncompleteRead(pair, fileLength); + + } finally { + ServerRequestPair.SERVER_REQUEST_RECEIVED_SIGNAL = null; + ServerRequestPair.SERVER_READ_PERMISSION = null; + } + } + + @ParameterizedTest + @MethodSource("serverRequestPairs") + // On Windows, modification while reading is not possible. + // Recall the infamous `The process cannot access the file because it is being used by another process`. + @DisabledOnOs(OS.WINDOWS) + void testFileModificationDuringPublisherRead( + ServerRequestPair pair, + @TempDir(cleanup = CleanupMode.ON_SUCCESS) Path tempDir) + throws Exception { + establishInitialConnection(pair); + ServerRequestPair.SERVER_REQUEST_RECEIVED_SIGNAL = new CountDownLatch(1); + ServerRequestPair.SERVER_READ_PERMISSION = new CountDownLatch(1); + try { + + int fileLength = BIG_FILE_LENGTH; + Path filePath = tempDir.resolve("data.txt"); + withFileChannel(filePath, fileLength, ((_, fileChannel) -> { + + // Issue the request + LOGGER.log("Issuing the request"); + HttpRequest request = pair + .requestBuilder + .POST(ofFileChannel(fileChannel, 0, fileLength)) + .build(); + CompletableFuture> responseFuture = CLIENT.sendAsync(request, discarding()); + + // Wait for server to receive the request + LOGGER.log("Waiting for the request to be received"); + ServerRequestPair.SERVER_REQUEST_RECEIVED_SIGNAL.await(); + + // Modify the file + LOGGER.log("Modifying the file"); + Files.write(filePath, generateFileBytes(1)); + + // Let the server proceed + LOGGER.log("Permitting the server to proceed"); + ServerRequestPair.SERVER_READ_PERMISSION.countDown(); + + // Verifying the client failure + LOGGER.log("Verifying the client failure"); + Exception requestFailure = assertThrows(ExecutionException.class, responseFuture::get); + String requestFailureMessage = requestFailure.getMessage(); + assertTrue( + requestFailureMessage.contains("Unexpected EOF"), + "unexpected message: " + requestFailureMessage); + + verifyServerIncompleteRead(pair, fileLength); + + })); + + } finally { + ServerRequestPair.SERVER_REQUEST_RECEIVED_SIGNAL = null; + ServerRequestPair.SERVER_READ_PERMISSION = null; + } + } + + private static void verifyServerIncompleteRead(ServerRequestPair pair, int fileLength) throws InterruptedException { + LOGGER.log("Verifying the server's incomplete read"); + byte[] readRequestBodyBytes = pair.serverReadRequestBodyBytes.take(); + assertTrue( + readRequestBodyBytes.length < fileLength, + "was expecting `readRequestBodyBytes < fileLength` (%s < %s)".formatted( + readRequestBodyBytes.length, fileLength)); + } + + @ParameterizedTest + @MethodSource("serverRequestPairs") + void testSlicedUpload( + ServerRequestPair pair, + @TempDir(cleanup = CleanupMode.ON_SUCCESS) Path tempDir) + throws Exception { + + // Populate the file + int sliceCount = 4; + int sliceLength = 14_281; // Intentionally using a prime number to increase the chances of hitting corner cases + int fileLength = sliceCount * sliceLength; + byte[] fileBytes = generateFileBytes(fileLength); + Path filePath = tempDir.resolve("data.txt"); + Files.write(filePath, fileBytes, StandardOpenOption.CREATE); + + List responseBodyStreams = new ArrayList<>(sliceCount); + try (FileChannel fileChannel = FileChannel.open(filePath)) { + + // Upload the complete file in mutually exclusive slices + List>> responseFutures = new ArrayList<>(sliceCount); + for (int sliceIndex = 0; sliceIndex < sliceCount; sliceIndex++) { + LOGGER.log("Issuing request %d/%d", (sliceIndex + 1), sliceCount); + HttpRequest request = pair + .requestBuilder + .POST(ofFileChannel(fileChannel, sliceIndex * sliceLength, sliceLength)) + .build(); + responseFutures.add(CLIENT.sendAsync( + request, + // Intentionally using an `InputStream` response + // handler to defer consuming the response body + // until after the file channel is closed: + ofInputStream())); + } + + // Collect response body `InputStream`s from all requests + for (int sliceIndex = 0; sliceIndex < sliceCount; sliceIndex++) { + LOGGER.log("Collecting response body `InputStream` for request %d/%d", (sliceIndex + 1), sliceCount); + HttpResponse response = responseFutures.get(sliceIndex).get(); + assertEquals(200, response.statusCode()); + responseBodyStreams.add(response.body()); + } + + } + + LOGGER.log("File channel is closed"); + + // Verify response bodies + for (int sliceIndex = 0; sliceIndex < sliceCount; sliceIndex++) { + LOGGER.log("Consuming response body %d/%d", (sliceIndex + 1), sliceCount); + byte[] expectedResponseBodyBytes = new byte[sliceLength]; + System.arraycopy(fileBytes, sliceIndex * sliceLength, expectedResponseBodyBytes, 0, sliceLength); + try (InputStream responseBodyStream = responseBodyStreams.get(sliceIndex)) { + byte[] responseBodyBytes = responseBodyStream.readAllBytes(); + assertArrayEquals(expectedResponseBodyBytes, responseBodyBytes); + } + } + + } + + /** + * Performs the initial {@code HEAD} request to the specified server. This + * effectively admits a connection to the client's pool, where all protocol + * upgrades, handshakes, etc. are already performed. + *

+ * HTTP/2 test server consumes the complete request payload in the very + * first upgrade frame. That is, if a client sends 100 MiB of data, all + * of it will be consumed first before the configured handler is + * invoked. Though certain tests expect the data to be consumed + * piecemeal. To accommodate this, we ensure client has an upgraded + * connection in the pool. + *

+ */ + private static void establishInitialConnection(ServerRequestPair pair) { + LOGGER.log("Server[%s] is getting queried for the initial connection pool admission", pair); + try { + CLIENT.send(pair.requestBuilder.HEAD().build(), discarding()); + } catch (Exception exception) { + throw new RuntimeException(exception); + } + } + + private static void withFileChannel(Path filePath, int fileLength, FileChannelConsumer fileChannelConsumer) throws Exception { + byte[] fileBytes = generateFileBytes(fileLength); + Files.write(filePath, fileBytes, StandardOpenOption.CREATE); + try (FileChannel fileChannel = FileChannel.open(filePath)) { + fileChannelConsumer.consume(fileBytes, fileChannel); + } + } + + @FunctionalInterface + private interface FileChannelConsumer { + + void consume(byte[] fileBytes, FileChannel fileChannel) throws Exception; + + } + + private static byte[] generateFileBytes(int length) { + byte[] bytes = new byte[length]; + for (int i = 0; i < length; i++) { + bytes[i] = (byte) i; + } + return bytes; + } + +}