Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ testing {
implementation "org.apache.sshd:sshd-sftp:$sshdVersion"
implementation "org.apache.sshd:sshd-scp:$sshdVersion"
implementation "ch.qos.logback:logback-classic:1.5.18"
implementation 'org.glassfish.grizzly:grizzly-http-server:3.0.1'
}

targets {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import com.hierynomus.sshj.test.HttpServer;
import com.hierynomus.sshj.test.SshServerExtension;
import com.hierynomus.sshj.test.util.FileUtil;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.connection.channel.direct.LocalPortForwarder;
import net.schmizz.sshj.connection.channel.direct.Parameters;
Expand All @@ -29,35 +28,30 @@

import java.io.*;
import java.net.*;
import java.nio.file.Files;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class LocalPortForwarderTest {
private static final String LOCALHOST_URL = "http://127.0.0.1:8080";

@RegisterExtension
public SshServerExtension fixture = new SshServerExtension();

@RegisterExtension
public HttpServer httpServer = new HttpServer();

@BeforeEach
public void setUp() throws IOException {
public void setUp() {
fixture.getServer().setForwardingFilter(new AcceptAllForwardingFilter());
File file = Files.createFile(httpServer.getDocRoot().toPath().resolve("index.html")).toFile();
FileUtil.writeToFile(file, "<html><head/><body><h1>Hi!</h1></body></html>");
}

@Test
public void shouldHaveWorkingHttpServer() throws IOException {
assertEquals(200, httpGet());
assertEquals(HttpURLConnection.HTTP_NOT_FOUND, httpGet());
}

@Test
public void shouldHaveHttpServerThatClosesConnectionAfterResponse() throws IOException {
// Just to check that the test server does close connections before we try through the forwarder...
httpGetAndAssertConnectionClosedByServer(8080);
httpGetAndAssertConnectionClosedByServer(httpServer.getServerUrl().getPort());
}

@Test
Expand All @@ -68,7 +62,8 @@ public void shouldCloseConnectionWhenRemoteServerClosesConnection() throws IOExc
ServerSocket serverSocket = new ServerSocket();
serverSocket.setReuseAddress(true);
serverSocket.bind(new InetSocketAddress("0.0.0.0", 12345));
LocalPortForwarder localPortForwarder = sshClient.newLocalPortForwarder(new Parameters("0.0.0.0", 12345, "localhost", 8080), serverSocket);
final int serverPort = httpServer.getServerUrl().getPort();
LocalPortForwarder localPortForwarder = sshClient.newLocalPortForwarder(new Parameters("0.0.0.0", 12345, "localhost", serverPort), serverSocket);
new Thread(() -> {
try {
localPortForwarder.listen();
Expand All @@ -90,7 +85,7 @@ public static void httpGetAndAssertConnectionClosedByServer(int port) throws IOE
// It returns 400 Bad Request because it's missing a bunch of info, but the HTTP response doesn't matter, we just want to test the connection closing.
OutputStream outputStream = socket.getOutputStream();
PrintWriter writer = new PrintWriter(outputStream);
writer.println("GET / HTTP/1.1");
writer.println("GET / HTTP/1.1\r\n");
writer.println("");
writer.flush();

Expand All @@ -111,7 +106,7 @@ public static void httpGetAndAssertConnectionClosedByServer(int port) throws IOE
}

private int httpGet() throws IOException {
final URL url = new URL(LOCALHOST_URL);
final URL url = httpServer.getServerUrl().toURL();
final HttpURLConnection urlConnection = (HttpURLConnection) url.openConnection();
urlConnection.setConnectTimeout(3000);
urlConnection.setRequestMethod("GET");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import com.hierynomus.sshj.test.HttpServer;
import com.hierynomus.sshj.test.SshServerExtension;
import com.hierynomus.sshj.test.util.FileUtil;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.connection.ConnectionException;
import net.schmizz.sshj.connection.channel.forwarded.RemotePortForwarder;
Expand All @@ -27,20 +26,18 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import java.io.File;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URL;
import java.nio.file.Files;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class RemotePortForwarderTest {
private static final PortRange RANGE = new PortRange(9000, 9999);
private static final String LOCALHOST = "127.0.0.1";
private static final String LOCALHOST_URL_FORMAT = "http://127.0.0.1:%d";
private static final InetSocketAddress HTTP_SERVER_SOCKET_ADDR = new InetSocketAddress(LOCALHOST, 8080);
private static final String URL_FORMAT = "http://%s:%d";

@RegisterExtension
public SshServerExtension fixture = new SshServerExtension();
Expand All @@ -49,21 +46,21 @@ public class RemotePortForwarderTest {
public HttpServer httpServer = new HttpServer();

@BeforeEach
public void setUp() throws IOException {
public void setUp() {
fixture.getServer().setForwardingFilter(new AcceptAllForwardingFilter());
File file = Files.createFile(httpServer.getDocRoot().toPath().resolve("index.html")).toFile();
FileUtil.writeToFile(file, "<html><head/><body><h1>Hi!</h1></body></html>");
}

@Test
public void shouldHaveWorkingHttpServer() throws IOException {
assertEquals(200, httpGet(8080));
final URI serverUrl = httpServer.getServerUrl();

assertEquals(HttpURLConnection.HTTP_NOT_FOUND, httpGet(serverUrl.getHost(), serverUrl.getPort()));
}

@Test
public void shouldDynamicallyForwardPortForLocalhost() throws IOException {
SSHClient sshClient = getFixtureClient();
RemotePortForwarder.Forward bind = forwardPort(sshClient, "127.0.0.1", new SinglePort(0));
RemotePortForwarder.Forward bind = forwardPort(sshClient, LOCALHOST, new SinglePort(0));
assertHttpGetSuccess(bind);
}

Expand All @@ -84,7 +81,7 @@ public void shouldDynamicallyForwardPortForAllProtocols() throws IOException {
@Test
public void shouldForwardPortForLocalhost() throws IOException {
SSHClient sshClient = getFixtureClient();
RemotePortForwarder.Forward bind = forwardPort(sshClient, "127.0.0.1", RANGE);
RemotePortForwarder.Forward bind = forwardPort(sshClient, LOCALHOST, RANGE);
assertHttpGetSuccess(bind);
}

Expand All @@ -103,17 +100,22 @@ public void shouldForwardPortForAllProtocols() throws IOException {
}

private void assertHttpGetSuccess(final RemotePortForwarder.Forward bind) throws IOException {
assertEquals(200, httpGet(bind.getPort()));
final String bindAddress = bind.getAddress();
final String address = bindAddress.isEmpty() ? LOCALHOST : bindAddress;
final int port = bind.getPort();
assertEquals(HttpURLConnection.HTTP_NOT_FOUND, httpGet(address, port));
}

private RemotePortForwarder.Forward forwardPort(SSHClient sshClient, String address, PortRange portRange) throws IOException {
while (true) {
final URI serverUrl = httpServer.getServerUrl();
final InetSocketAddress serverAddress = new InetSocketAddress(serverUrl.getHost(), serverUrl.getPort());
try {
return sshClient.getRemotePortForwarder().bind(
// where the server should listen
new RemotePortForwarder.Forward(address, portRange.nextPort()),
// what we do with incoming connections that are forwarded to us
new SocketForwardingConnectListener(HTTP_SERVER_SOCKET_ADDR));
new SocketForwardingConnectListener(serverAddress));
} catch (ConnectionException ce) {
if (!portRange.hasNext()) {
throw ce;
Expand All @@ -122,8 +124,8 @@ private RemotePortForwarder.Forward forwardPort(SSHClient sshClient, String addr
}
}

private int httpGet(final int port) throws IOException {
final URL url = new URL(String.format(LOCALHOST_URL_FORMAT, port));
private int httpGet(final String address, final int port) throws IOException {
final URL url = new URL(String.format(URL_FORMAT, address, port));
final HttpURLConnection urlConnection = (HttpURLConnection) url.openConnection();
urlConnection.setConnectTimeout(3000);
urlConnection.setRequestMethod("GET");
Expand Down
34 changes: 14 additions & 20 deletions src/test/java/com/hierynomus/sshj/test/HttpServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,36 @@
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;

import java.io.File;
import java.nio.file.Files;
import java.net.InetSocketAddress;
import java.net.URI;

/**
* Can be used to setup a test HTTP server
*/
public class HttpServer implements BeforeEachCallback, AfterEachCallback {

private org.glassfish.grizzly.http.server.HttpServer httpServer;
private static final String BIND_ADDRESS = "127.0.0.1";


private File docRoot ;
private com.sun.net.httpserver.HttpServer httpServer;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jikes... com.sun.net... As far as I remember, wasn't it an antipattern to depend on anything in com.sun as this is not guaranteed to be present in other JVMs.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scratch that, I've read up on it... Seems that I'm confusing sun.* and com.sun.*... Nothing to see here, thanks for the PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing and merging @hierynomus! Yes, the package name for the JDK HttpServer causes a natural double take, so thanks for giving it a closer look, as it is part of the public JDK.


@Override
public void afterEach(ExtensionContext context) throws Exception {
try {
httpServer.shutdownNow();
} catch (Exception e) {}
public void afterEach(ExtensionContext context) {
try {
docRoot.delete();
} catch (Exception e) {}

httpServer.stop(0);
} catch (Exception ignored) {}
}

@Override
public void beforeEach(ExtensionContext context) throws Exception {
docRoot = Files.createTempDirectory("sshj").toFile();
httpServer = org.glassfish.grizzly.http.server.HttpServer.createSimpleServer(docRoot.getAbsolutePath());
httpServer = com.sun.net.httpserver.HttpServer.create();
final InetSocketAddress socketAddress = new InetSocketAddress(BIND_ADDRESS, 0);
httpServer.bind(socketAddress, 10);
httpServer.start();
}

public org.glassfish.grizzly.http.server.HttpServer getHttpServer() {
return httpServer;
}

public File getDocRoot() {
return docRoot;
public URI getServerUrl() {
final InetSocketAddress bindAddress = httpServer.getAddress();
final String serverUrl = String.format("http://%s:%d", BIND_ADDRESS, bindAddress.getPort());
return URI.create(serverUrl);
}
}