diff --git a/pom.xml b/pom.xml index a2df6ace..5bb16a25 100644 --- a/pom.xml +++ b/pom.xml @@ -2,7 +2,7 @@ 4.0.0 fi.hsl transitdata-common - 2.0.3-RC6 + 2.0.3-RC10 jar Common utilities for Transitdata projects @@ -78,10 +78,16 @@ ${pulsar.version} + + com.azure + azure-identity + 1.11.2 + + redis.clients jedis - 4.4.3 + 5.1.0 jar compile diff --git a/src/main/java/fi/hsl/common/pulsar/PulsarApplication.java b/src/main/java/fi/hsl/common/pulsar/PulsarApplication.java index 8213c515..9a50f9fe 100644 --- a/src/main/java/fi/hsl/common/pulsar/PulsarApplication.java +++ b/src/main/java/fi/hsl/common/pulsar/PulsarApplication.java @@ -1,7 +1,12 @@ package fi.hsl.common.pulsar; +import com.azure.core.credential.AccessToken; +import com.azure.core.credential.TokenRequestContext; +import com.azure.identity.DefaultAzureCredential; +import com.azure.identity.DefaultAzureCredentialBuilder; import com.typesafe.config.Config; import fi.hsl.common.health.HealthServer; +import fi.hsl.common.redis.RedisUtils; import org.apache.pulsar.client.admin.PulsarAdmin; import org.apache.pulsar.client.api.*; import org.jetbrains.annotations.NotNull; @@ -19,6 +24,9 @@ import java.util.regex.Pattern; import java.util.stream.Collectors; +import static fi.hsl.common.redis.RedisUtils.createJedisClient; +import static fi.hsl.common.redis.RedisUtils.extractUsernameFromToken; + public class PulsarApplication implements AutoCloseable { private static final Logger log = LoggerFactory.getLogger(PulsarApplication.class); @@ -84,14 +92,9 @@ public PulsarApplicationContext initialize(@NotNull Config config) throws Except } if (config.getBoolean("redis.enabled")) { - int connTimeOutSecs = 2; - if (config.hasPath("redis.connTimeOutSecs")) { - connTimeOutSecs = config.getInt("redis.connTimeOutSecs"); - } jedis = createRedisClient( config.getString("redis.host"), - config.getInt("redis.port"), - connTimeOutSecs); + config.getInt("redis.port")); } if (config.getBoolean("health.enabled")) { @@ -150,11 +153,28 @@ public PulsarApplicationContext initialize(@NotNull Config config) throws Except } @NotNull - protected Jedis createRedisClient(@NotNull String redisHost, int port, int connTimeOutSecs) { - log.info("Connecting to Redis at " + redisHost + ":" + port + " with connection timeout of (s): "+ connTimeOutSecs); - int timeOutMs = connTimeOutSecs * 1000; - Jedis jedis = new Jedis(redisHost, port, timeOutMs); - jedis.connect(); + protected Jedis createRedisClient(@NotNull String redisHost, int port) { + log.info("Connecting to Redis at {}:{}", redisHost, port); + + //Construct a Token Credential from Identity library, e.g. DefaultAzureCredential / ClientSecretCredential / Client CertificateCredential / ManagedIdentityCredential etc. + DefaultAzureCredential defaultAzureCredential = new DefaultAzureCredentialBuilder().build(); + + // Fetch a Microsoft Entra token to be used for authentication. This token will be used as the password. + TokenRequestContext trc = new TokenRequestContext().addScopes("https://redis.azure.com/.default"); + RedisUtils.TokenRefreshCache tokenRefreshCache = new RedisUtils.TokenRefreshCache(defaultAzureCredential, trc); + AccessToken accessToken = tokenRefreshCache.getAccessToken(); + + // SSL connection is required. + boolean useSsl = true; + String username = extractUsernameFromToken(accessToken.getToken()); + + // Create Jedis client and connect to the Azure Cache for Redis over the TLS/SSL port using the access token as password. + // Note: Cache Host Name, Port, Microsoft Entra access token and SSL connections are required below. + jedis = createJedisClient(redisHost, port, username, accessToken, useSsl); + + // Configure the jedis instance for proactive authentication before token expires. + tokenRefreshCache.setJedisInstanceToAuthenticate(jedis); + log.info("Redis connected: " + jedis.isConnected()); return jedis; } diff --git a/src/main/java/fi/hsl/common/redis/RedisUtils.java b/src/main/java/fi/hsl/common/redis/RedisUtils.java index 2a9571af..f52a6c5a 100644 --- a/src/main/java/fi/hsl/common/redis/RedisUtils.java +++ b/src/main/java/fi/hsl/common/redis/RedisUtils.java @@ -1,5 +1,11 @@ package fi.hsl.common.redis; +import com.azure.core.credential.AccessToken; +import com.azure.core.credential.TokenCredential; +import com.azure.core.credential.TokenRequestContext; +import com.azure.core.util.CoreUtils; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; import fi.hsl.common.pulsar.PulsarApplicationContext; import fi.hsl.common.transitdata.TransitdataProperties; import org.jetbrains.annotations.NotNull; @@ -10,9 +16,12 @@ import redis.clients.jedis.params.ScanParams; import redis.clients.jedis.resps.ScanResult; +import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.time.OffsetDateTime; import java.time.format.DateTimeFormatter; import java.util.*; +import java.util.concurrent.ThreadLocalRandom; public class RedisUtils { private static final Logger log = LoggerFactory.getLogger(RedisUtils.class); @@ -227,4 +236,106 @@ public boolean checkResponse(@Nullable final String response) { public boolean checkResponse(@Nullable final Long response) { return response != null && response == 1; } + + // Azure Cache for Redis helper code + public static Jedis createJedisClient(String cacheHostname, int port, String username, AccessToken accessToken, boolean useSsl) { + return new Jedis(cacheHostname, port, DefaultJedisClientConfig.builder() + .password(accessToken.getToken()) + .user(username) + .ssl(useSsl) + .build()); + } + + public static String extractUsernameFromToken(String token) { + String[] parts = token.split("\\."); + String base64 = parts[1]; + + base64 = addPaddingToBase64String(base64); + + byte[] jsonBytes = Base64.getDecoder().decode(base64); + String json = new String(jsonBytes, StandardCharsets.UTF_8); + JsonObject jwt = JsonParser.parseString(json).getAsJsonObject(); + + return jwt.get("oid").getAsString(); + } + + private static String addPaddingToBase64String(String input) { + if (input != null && !input.isEmpty()) { + int paddingLength = (4 - input.length() % 4) % 4; + input += "=".repeat(paddingLength); + } + return input; + } + + /** + * The token cache to store and proactively refresh the access token. + */ + public static class TokenRefreshCache { + private final TokenCredential tokenCredential; + private final TokenRequestContext tokenRequestContext; + private final Timer timer; + private volatile AccessToken accessToken; + private final Duration maxRefreshOffset = Duration.ofMinutes(5); + private final Duration baseRefreshOffset = Duration.ofMinutes(2); + private Jedis jedisInstanceToAuthenticate; + private String username; + + /** + * Creates an instance of TokenRefreshCache + * @param tokenCredential the token credential to be used for authentication. + * @param tokenRequestContext the token request context to be used for authentication. + */ + public TokenRefreshCache(TokenCredential tokenCredential, TokenRequestContext tokenRequestContext) { + this.tokenCredential = tokenCredential; + this.tokenRequestContext = tokenRequestContext; + this.timer = new Timer(); + } + + /** + * Gets the cached access token. + * @return the AccessToken + */ + public AccessToken getAccessToken() { + if (accessToken != null) { + return accessToken; + } else { + TokenRefreshTask tokenRefreshTask = new TokenRefreshTask(); + accessToken = tokenCredential.getToken(tokenRequestContext).block(); + timer.schedule(tokenRefreshTask, getTokenRefreshDelay()); + return accessToken; + } + } + + private class TokenRefreshTask extends TimerTask { + // Add your task here + public void run() { + accessToken = tokenCredential.getToken(tokenRequestContext).block(); + username = extractUsernameFromToken(accessToken.getToken()); + log.info("Refreshed Token with Expiry: " + accessToken.getExpiresAt().toEpochSecond()); + + if (jedisInstanceToAuthenticate != null && !CoreUtils.isNullOrEmpty(username)) { + jedisInstanceToAuthenticate.auth(username, accessToken.getToken()); + log.info("Refreshed Jedis Connection with fresh access token, token expires at : " + + accessToken.getExpiresAt().toEpochSecond()); + } + timer.schedule(new TokenRefreshTask(), getTokenRefreshDelay()); + } + } + + private long getTokenRefreshDelay() { + return ((accessToken.getExpiresAt() + .minusSeconds(ThreadLocalRandom.current().nextLong(baseRefreshOffset.getSeconds(), maxRefreshOffset.getSeconds())) + .toEpochSecond() - OffsetDateTime.now().toEpochSecond()) * 1000); + } + + /** + * Sets the Jedis to proactively authenticate before token expiry. + * @param jedisInstanceToAuthenticate the instance to authenticate + * @return the updated instance + */ + public TokenRefreshCache setJedisInstanceToAuthenticate(Jedis jedisInstanceToAuthenticate) { + this.jedisInstanceToAuthenticate = jedisInstanceToAuthenticate; + return this; + } + } } diff --git a/src/test/java/fi/hsl/common/redis/RedisUtilsTest.java b/src/test/java/fi/hsl/common/redis/RedisUtilsTest.java index f7ec2aa4..00f9535a 100644 --- a/src/test/java/fi/hsl/common/redis/RedisUtilsTest.java +++ b/src/test/java/fi/hsl/common/redis/RedisUtilsTest.java @@ -8,6 +8,8 @@ import org.testcontainers.utility.DockerImageName; import redis.clients.jedis.Jedis; +import java.util.Base64; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -51,4 +53,20 @@ public void testSetGetExpiringValue() throws InterruptedException { assertFalse(redisUtils.getValue("test").isPresent()); } + + @Test + public void extractUsernameFromTokenBase64PaddingWorks() { + // Payload with length not a multiple of 4 (e.g., 2 or 3 mod 4) + String header = Base64.getUrlEncoder().withoutPadding().encodeToString("{\"alg\":\"none\"}".getBytes()); + + // 2 mod 4 length + String payload2 = Base64.getUrlEncoder().withoutPadding().encodeToString("{\"oid\":\"ab\"}".getBytes()); + String token2 = header + "." + payload2 + ".sig"; + assertEquals("ab", RedisUtils.extractUsernameFromToken(token2)); + + // 3 mod 4 length + String payload3 = Base64.getUrlEncoder().withoutPadding().encodeToString("{\"oid\":\"abc\"}".getBytes()); + String token3 = header + "." + payload3 + ".sig"; + assertEquals("abc", RedisUtils.extractUsernameFromToken(token3)); + } }