diff --git a/pom.xml b/pom.xml
index a2df6ace..5bb16a25 100644
--- a/pom.xml
+++ b/pom.xml
@@ -2,7 +2,7 @@
4.0.0
fi.hsl
transitdata-common
- 2.0.3-RC6
+ 2.0.3-RC10
jar
Common utilities for Transitdata projects
@@ -78,10 +78,16 @@
${pulsar.version}
+
+ com.azure
+ azure-identity
+ 1.11.2
+
+
redis.clients
jedis
- 4.4.3
+ 5.1.0
jar
compile
diff --git a/src/main/java/fi/hsl/common/pulsar/PulsarApplication.java b/src/main/java/fi/hsl/common/pulsar/PulsarApplication.java
index 8213c515..9a50f9fe 100644
--- a/src/main/java/fi/hsl/common/pulsar/PulsarApplication.java
+++ b/src/main/java/fi/hsl/common/pulsar/PulsarApplication.java
@@ -1,7 +1,12 @@
package fi.hsl.common.pulsar;
+import com.azure.core.credential.AccessToken;
+import com.azure.core.credential.TokenRequestContext;
+import com.azure.identity.DefaultAzureCredential;
+import com.azure.identity.DefaultAzureCredentialBuilder;
import com.typesafe.config.Config;
import fi.hsl.common.health.HealthServer;
+import fi.hsl.common.redis.RedisUtils;
import org.apache.pulsar.client.admin.PulsarAdmin;
import org.apache.pulsar.client.api.*;
import org.jetbrains.annotations.NotNull;
@@ -19,6 +24,9 @@
import java.util.regex.Pattern;
import java.util.stream.Collectors;
+import static fi.hsl.common.redis.RedisUtils.createJedisClient;
+import static fi.hsl.common.redis.RedisUtils.extractUsernameFromToken;
+
public class PulsarApplication implements AutoCloseable {
private static final Logger log = LoggerFactory.getLogger(PulsarApplication.class);
@@ -84,14 +92,9 @@ public PulsarApplicationContext initialize(@NotNull Config config) throws Except
}
if (config.getBoolean("redis.enabled")) {
- int connTimeOutSecs = 2;
- if (config.hasPath("redis.connTimeOutSecs")) {
- connTimeOutSecs = config.getInt("redis.connTimeOutSecs");
- }
jedis = createRedisClient(
config.getString("redis.host"),
- config.getInt("redis.port"),
- connTimeOutSecs);
+ config.getInt("redis.port"));
}
if (config.getBoolean("health.enabled")) {
@@ -150,11 +153,28 @@ public PulsarApplicationContext initialize(@NotNull Config config) throws Except
}
@NotNull
- protected Jedis createRedisClient(@NotNull String redisHost, int port, int connTimeOutSecs) {
- log.info("Connecting to Redis at " + redisHost + ":" + port + " with connection timeout of (s): "+ connTimeOutSecs);
- int timeOutMs = connTimeOutSecs * 1000;
- Jedis jedis = new Jedis(redisHost, port, timeOutMs);
- jedis.connect();
+ protected Jedis createRedisClient(@NotNull String redisHost, int port) {
+ log.info("Connecting to Redis at {}:{}", redisHost, port);
+
+ //Construct a Token Credential from Identity library, e.g. DefaultAzureCredential / ClientSecretCredential / Client CertificateCredential / ManagedIdentityCredential etc.
+ DefaultAzureCredential defaultAzureCredential = new DefaultAzureCredentialBuilder().build();
+
+ // Fetch a Microsoft Entra token to be used for authentication. This token will be used as the password.
+ TokenRequestContext trc = new TokenRequestContext().addScopes("https://redis.azure.com/.default");
+ RedisUtils.TokenRefreshCache tokenRefreshCache = new RedisUtils.TokenRefreshCache(defaultAzureCredential, trc);
+ AccessToken accessToken = tokenRefreshCache.getAccessToken();
+
+ // SSL connection is required.
+ boolean useSsl = true;
+ String username = extractUsernameFromToken(accessToken.getToken());
+
+ // Create Jedis client and connect to the Azure Cache for Redis over the TLS/SSL port using the access token as password.
+ // Note: Cache Host Name, Port, Microsoft Entra access token and SSL connections are required below.
+ jedis = createJedisClient(redisHost, port, username, accessToken, useSsl);
+
+ // Configure the jedis instance for proactive authentication before token expires.
+ tokenRefreshCache.setJedisInstanceToAuthenticate(jedis);
+
log.info("Redis connected: " + jedis.isConnected());
return jedis;
}
diff --git a/src/main/java/fi/hsl/common/redis/RedisUtils.java b/src/main/java/fi/hsl/common/redis/RedisUtils.java
index 2a9571af..f52a6c5a 100644
--- a/src/main/java/fi/hsl/common/redis/RedisUtils.java
+++ b/src/main/java/fi/hsl/common/redis/RedisUtils.java
@@ -1,5 +1,11 @@
package fi.hsl.common.redis;
+import com.azure.core.credential.AccessToken;
+import com.azure.core.credential.TokenCredential;
+import com.azure.core.credential.TokenRequestContext;
+import com.azure.core.util.CoreUtils;
+import com.google.gson.JsonObject;
+import com.google.gson.JsonParser;
import fi.hsl.common.pulsar.PulsarApplicationContext;
import fi.hsl.common.transitdata.TransitdataProperties;
import org.jetbrains.annotations.NotNull;
@@ -10,9 +16,12 @@
import redis.clients.jedis.params.ScanParams;
import redis.clients.jedis.resps.ScanResult;
+import java.nio.charset.StandardCharsets;
+import java.time.Duration;
import java.time.OffsetDateTime;
import java.time.format.DateTimeFormatter;
import java.util.*;
+import java.util.concurrent.ThreadLocalRandom;
public class RedisUtils {
private static final Logger log = LoggerFactory.getLogger(RedisUtils.class);
@@ -227,4 +236,106 @@ public boolean checkResponse(@Nullable final String response) {
public boolean checkResponse(@Nullable final Long response) {
return response != null && response == 1;
}
+
+ // Azure Cache for Redis helper code
+ public static Jedis createJedisClient(String cacheHostname, int port, String username, AccessToken accessToken, boolean useSsl) {
+ return new Jedis(cacheHostname, port, DefaultJedisClientConfig.builder()
+ .password(accessToken.getToken())
+ .user(username)
+ .ssl(useSsl)
+ .build());
+ }
+
+ public static String extractUsernameFromToken(String token) {
+ String[] parts = token.split("\\.");
+ String base64 = parts[1];
+
+ base64 = addPaddingToBase64String(base64);
+
+ byte[] jsonBytes = Base64.getDecoder().decode(base64);
+ String json = new String(jsonBytes, StandardCharsets.UTF_8);
+ JsonObject jwt = JsonParser.parseString(json).getAsJsonObject();
+
+ return jwt.get("oid").getAsString();
+ }
+
+ private static String addPaddingToBase64String(String input) {
+ if (input != null && !input.isEmpty()) {
+ int paddingLength = (4 - input.length() % 4) % 4;
+ input += "=".repeat(paddingLength);
+ }
+ return input;
+ }
+
+ /**
+ * The token cache to store and proactively refresh the access token.
+ */
+ public static class TokenRefreshCache {
+ private final TokenCredential tokenCredential;
+ private final TokenRequestContext tokenRequestContext;
+ private final Timer timer;
+ private volatile AccessToken accessToken;
+ private final Duration maxRefreshOffset = Duration.ofMinutes(5);
+ private final Duration baseRefreshOffset = Duration.ofMinutes(2);
+ private Jedis jedisInstanceToAuthenticate;
+ private String username;
+
+ /**
+ * Creates an instance of TokenRefreshCache
+ * @param tokenCredential the token credential to be used for authentication.
+ * @param tokenRequestContext the token request context to be used for authentication.
+ */
+ public TokenRefreshCache(TokenCredential tokenCredential, TokenRequestContext tokenRequestContext) {
+ this.tokenCredential = tokenCredential;
+ this.tokenRequestContext = tokenRequestContext;
+ this.timer = new Timer();
+ }
+
+ /**
+ * Gets the cached access token.
+ * @return the AccessToken
+ */
+ public AccessToken getAccessToken() {
+ if (accessToken != null) {
+ return accessToken;
+ } else {
+ TokenRefreshTask tokenRefreshTask = new TokenRefreshTask();
+ accessToken = tokenCredential.getToken(tokenRequestContext).block();
+ timer.schedule(tokenRefreshTask, getTokenRefreshDelay());
+ return accessToken;
+ }
+ }
+
+ private class TokenRefreshTask extends TimerTask {
+ // Add your task here
+ public void run() {
+ accessToken = tokenCredential.getToken(tokenRequestContext).block();
+ username = extractUsernameFromToken(accessToken.getToken());
+ log.info("Refreshed Token with Expiry: " + accessToken.getExpiresAt().toEpochSecond());
+
+ if (jedisInstanceToAuthenticate != null && !CoreUtils.isNullOrEmpty(username)) {
+ jedisInstanceToAuthenticate.auth(username, accessToken.getToken());
+ log.info("Refreshed Jedis Connection with fresh access token, token expires at : "
+ + accessToken.getExpiresAt().toEpochSecond());
+ }
+ timer.schedule(new TokenRefreshTask(), getTokenRefreshDelay());
+ }
+ }
+
+ private long getTokenRefreshDelay() {
+ return ((accessToken.getExpiresAt()
+ .minusSeconds(ThreadLocalRandom.current().nextLong(baseRefreshOffset.getSeconds(), maxRefreshOffset.getSeconds()))
+ .toEpochSecond() - OffsetDateTime.now().toEpochSecond()) * 1000);
+ }
+
+ /**
+ * Sets the Jedis to proactively authenticate before token expiry.
+ * @param jedisInstanceToAuthenticate the instance to authenticate
+ * @return the updated instance
+ */
+ public TokenRefreshCache setJedisInstanceToAuthenticate(Jedis jedisInstanceToAuthenticate) {
+ this.jedisInstanceToAuthenticate = jedisInstanceToAuthenticate;
+ return this;
+ }
+ }
}
diff --git a/src/test/java/fi/hsl/common/redis/RedisUtilsTest.java b/src/test/java/fi/hsl/common/redis/RedisUtilsTest.java
index f7ec2aa4..00f9535a 100644
--- a/src/test/java/fi/hsl/common/redis/RedisUtilsTest.java
+++ b/src/test/java/fi/hsl/common/redis/RedisUtilsTest.java
@@ -8,6 +8,8 @@
import org.testcontainers.utility.DockerImageName;
import redis.clients.jedis.Jedis;
+import java.util.Base64;
+
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
@@ -51,4 +53,20 @@ public void testSetGetExpiringValue() throws InterruptedException {
assertFalse(redisUtils.getValue("test").isPresent());
}
+
+ @Test
+ public void extractUsernameFromTokenBase64PaddingWorks() {
+ // Payload with length not a multiple of 4 (e.g., 2 or 3 mod 4)
+ String header = Base64.getUrlEncoder().withoutPadding().encodeToString("{\"alg\":\"none\"}".getBytes());
+
+ // 2 mod 4 length
+ String payload2 = Base64.getUrlEncoder().withoutPadding().encodeToString("{\"oid\":\"ab\"}".getBytes());
+ String token2 = header + "." + payload2 + ".sig";
+ assertEquals("ab", RedisUtils.extractUsernameFromToken(token2));
+
+ // 3 mod 4 length
+ String payload3 = Base64.getUrlEncoder().withoutPadding().encodeToString("{\"oid\":\"abc\"}".getBytes());
+ String token3 = header + "." + payload3 + ".sig";
+ assertEquals("abc", RedisUtils.extractUsernameFromToken(token3));
+ }
}