Skip to content

Authenticate to Redis with Microsoft Entra ID using Token Cache #372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>fi.hsl</groupId>
<artifactId>transitdata-common</artifactId>
<version>2.0.3-RC6</version>
<version>2.0.3-RC10</version>
<packaging>jar</packaging>
<name>Common utilities for Transitdata projects</name>
<properties>
Expand Down Expand Up @@ -78,10 +78,16 @@
<version>${pulsar.version}</version>
</dependency>

<dependency>
<groupId>com.azure</groupId>
<artifactId>azure-identity</artifactId>
<version>1.11.2</version>
</dependency>

<dependency>
<groupId>redis.clients</groupId>
<artifactId>jedis</artifactId>
<version>4.4.3</version>
<version>5.1.0</version>
<type>jar</type>
<scope>compile</scope>
</dependency>
Expand Down
42 changes: 31 additions & 11 deletions src/main/java/fi/hsl/common/pulsar/PulsarApplication.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
package fi.hsl.common.pulsar;

import com.azure.core.credential.AccessToken;
import com.azure.core.credential.TokenRequestContext;
import com.azure.identity.DefaultAzureCredential;
import com.azure.identity.DefaultAzureCredentialBuilder;
import com.typesafe.config.Config;
import fi.hsl.common.health.HealthServer;
import fi.hsl.common.redis.RedisUtils;
import org.apache.pulsar.client.admin.PulsarAdmin;
import org.apache.pulsar.client.api.*;
import org.jetbrains.annotations.NotNull;
Expand All @@ -19,6 +24,9 @@
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import static fi.hsl.common.redis.RedisUtils.createJedisClient;
import static fi.hsl.common.redis.RedisUtils.extractUsernameFromToken;

public class PulsarApplication implements AutoCloseable {
private static final Logger log = LoggerFactory.getLogger(PulsarApplication.class);

Expand Down Expand Up @@ -84,14 +92,9 @@ public PulsarApplicationContext initialize(@NotNull Config config) throws Except
}

if (config.getBoolean("redis.enabled")) {
int connTimeOutSecs = 2;
if (config.hasPath("redis.connTimeOutSecs")) {
connTimeOutSecs = config.getInt("redis.connTimeOutSecs");
}
jedis = createRedisClient(
config.getString("redis.host"),
config.getInt("redis.port"),
connTimeOutSecs);
config.getInt("redis.port"));
}

if (config.getBoolean("health.enabled")) {
Expand Down Expand Up @@ -150,11 +153,28 @@ public PulsarApplicationContext initialize(@NotNull Config config) throws Except
}

@NotNull
protected Jedis createRedisClient(@NotNull String redisHost, int port, int connTimeOutSecs) {
log.info("Connecting to Redis at " + redisHost + ":" + port + " with connection timeout of (s): "+ connTimeOutSecs);
int timeOutMs = connTimeOutSecs * 1000;
Jedis jedis = new Jedis(redisHost, port, timeOutMs);
jedis.connect();
protected Jedis createRedisClient(@NotNull String redisHost, int port) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer having the opportunity to use both our own Redis or Redis managed by Azure. Please rather create a separate method or class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, sorry, I have to write this somewhere:

Some of the methods in this class do not really use private instance variables of the object, apart from the logger. Would a cleaner approach aim for static methods, perhaps split over several classes? Why is something named PulsarApplication creating Jedis clients with intricate code and health checks?

If possible, I would rather follow "clean as you code" instead of waiting for separate refactoring issues that might not get prioritized high enough, unless we need a total rewrite issue.

log.info("Connecting to Redis at {}:{}", redisHost, port);

//Construct a Token Credential from Identity library, e.g. DefaultAzureCredential / ClientSecretCredential / Client CertificateCredential / ManagedIdentityCredential etc.
DefaultAzureCredential defaultAzureCredential = new DefaultAzureCredentialBuilder().build();

// Fetch a Microsoft Entra token to be used for authentication. This token will be used as the password.
TokenRequestContext trc = new TokenRequestContext().addScopes("https://redis.azure.com/.default");
RedisUtils.TokenRefreshCache tokenRefreshCache = new RedisUtils.TokenRefreshCache(defaultAzureCredential, trc);
AccessToken accessToken = tokenRefreshCache.getAccessToken();

// SSL connection is required.
boolean useSsl = true;
String username = extractUsernameFromToken(accessToken.getToken());

// Create Jedis client and connect to the Azure Cache for Redis over the TLS/SSL port using the access token as password.
// Note: Cache Host Name, Port, Microsoft Entra access token and SSL connections are required below.
jedis = createJedisClient(redisHost, port, username, accessToken, useSsl);

// Configure the jedis instance for proactive authentication before token expires.
tokenRefreshCache.setJedisInstanceToAuthenticate(jedis);

log.info("Redis connected: " + jedis.isConnected());
return jedis;
}
Expand Down
111 changes: 111 additions & 0 deletions src/main/java/fi/hsl/common/redis/RedisUtils.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
package fi.hsl.common.redis;

import com.azure.core.credential.AccessToken;
import com.azure.core.credential.TokenCredential;
import com.azure.core.credential.TokenRequestContext;
import com.azure.core.util.CoreUtils;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import fi.hsl.common.pulsar.PulsarApplicationContext;
import fi.hsl.common.transitdata.TransitdataProperties;
import org.jetbrains.annotations.NotNull;
Expand All @@ -10,9 +16,12 @@
import redis.clients.jedis.params.ScanParams;
import redis.clients.jedis.resps.ScanResult;

import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.OffsetDateTime;
import java.time.format.DateTimeFormatter;
import java.util.*;
import java.util.concurrent.ThreadLocalRandom;

public class RedisUtils {
private static final Logger log = LoggerFactory.getLogger(RedisUtils.class);
Expand Down Expand Up @@ -227,4 +236,106 @@ public boolean checkResponse(@Nullable final String response) {
public boolean checkResponse(@Nullable final Long response) {
return response != null && response == 1;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For all code copied from the Azure code samples, please refer to the exact commit and mention the license under which it is copied in a comment. E.g. The following method is copied or adapted from https://github.com/Azure/azure-sdk-for-java/blob/e7103f9019669032c7ffc3b51f1bd30c6ad8655f/sdk/identity/azure-identity/src/samples/Azure-Cache-For-Redis/Jedis/Azure-AAD-Authentication-With-Jedis.md under the MIT license.

// Azure Cache for Redis helper code
public static Jedis createJedisClient(String cacheHostname, int port, String username, AccessToken accessToken, boolean useSsl) {
return new Jedis(cacheHostname, port, DefaultJedisClientConfig.builder()
.password(accessToken.getToken())
.user(username)
.ssl(useSsl)
.build());
}

public static String extractUsernameFromToken(String token) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method looks nasty (and is from the Azure code samples). I would bury these Azure-specific things in a separate Azure kludge class if we choose to keep both kinds of Jedis clients.

String[] parts = token.split("\\.");
String base64 = parts[1];

base64 = addPaddingToBase64String(base64);

byte[] jsonBytes = Base64.getDecoder().decode(base64);
String json = new String(jsonBytes, StandardCharsets.UTF_8);
JsonObject jwt = JsonParser.parseString(json).getAsJsonObject();

return jwt.get("oid").getAsString();
}

private static String addPaddingToBase64String(String input) {
if (input != null && !input.isEmpty()) {
int paddingLength = (4 - input.length() % 4) % 4;
input += "=".repeat(paddingLength);
}
return input;
}

/**
* The token cache to store and proactively refresh the access token.
*/
public static class TokenRefreshCache {
private final TokenCredential tokenCredential;
private final TokenRequestContext tokenRequestContext;
private final Timer timer;
private volatile AccessToken accessToken;
private final Duration maxRefreshOffset = Duration.ofMinutes(5);
private final Duration baseRefreshOffset = Duration.ofMinutes(2);
private Jedis jedisInstanceToAuthenticate;
private String username;

/**
* Creates an instance of TokenRefreshCache
* @param tokenCredential the token credential to be used for authentication.
* @param tokenRequestContext the token request context to be used for authentication.
*/
public TokenRefreshCache(TokenCredential tokenCredential, TokenRequestContext tokenRequestContext) {
this.tokenCredential = tokenCredential;
this.tokenRequestContext = tokenRequestContext;
this.timer = new Timer();
}

/**
* Gets the cached access token.
* @return the AccessToken
*/
public AccessToken getAccessToken() {
if (accessToken != null) {
return accessToken;
} else {
TokenRefreshTask tokenRefreshTask = new TokenRefreshTask();
accessToken = tokenCredential.getToken(tokenRequestContext).block();
timer.schedule(tokenRefreshTask, getTokenRefreshDelay());
return accessToken;
}
}

private class TokenRefreshTask extends TimerTask {
// Add your task here
public void run() {
accessToken = tokenCredential.getToken(tokenRequestContext).block();
username = extractUsernameFromToken(accessToken.getToken());
log.info("Refreshed Token with Expiry: " + accessToken.getExpiresAt().toEpochSecond());

if (jedisInstanceToAuthenticate != null && !CoreUtils.isNullOrEmpty(username)) {
jedisInstanceToAuthenticate.auth(username, accessToken.getToken());
log.info("Refreshed Jedis Connection with fresh access token, token expires at : "
+ accessToken.getExpiresAt().toEpochSecond());
}
timer.schedule(new TokenRefreshTask(), getTokenRefreshDelay());
}
}

private long getTokenRefreshDelay() {
return ((accessToken.getExpiresAt()
.minusSeconds(ThreadLocalRandom.current().nextLong(baseRefreshOffset.getSeconds(), maxRefreshOffset.getSeconds()))
.toEpochSecond() - OffsetDateTime.now().toEpochSecond()) * 1000);
}

/**
* Sets the Jedis to proactively authenticate before token expiry.
* @param jedisInstanceToAuthenticate the instance to authenticate
* @return the updated instance
*/
public TokenRefreshCache setJedisInstanceToAuthenticate(Jedis jedisInstanceToAuthenticate) {
this.jedisInstanceToAuthenticate = jedisInstanceToAuthenticate;
return this;
}
}
}
18 changes: 18 additions & 0 deletions src/test/java/fi/hsl/common/redis/RedisUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import org.testcontainers.utility.DockerImageName;
import redis.clients.jedis.Jedis;

import java.util.Base64;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;

Expand Down Expand Up @@ -51,4 +53,20 @@ public void testSetGetExpiringValue() throws InterruptedException {

assertFalse(redisUtils.getValue("test").isPresent());
}

@Test
public void extractUsernameFromTokenBase64PaddingWorks() {
// Payload with length not a multiple of 4 (e.g., 2 or 3 mod 4)
String header = Base64.getUrlEncoder().withoutPadding().encodeToString("{\"alg\":\"none\"}".getBytes());

// 2 mod 4 length
String payload2 = Base64.getUrlEncoder().withoutPadding().encodeToString("{\"oid\":\"ab\"}".getBytes());
String token2 = header + "." + payload2 + ".sig";
assertEquals("ab", RedisUtils.extractUsernameFromToken(token2));

// 3 mod 4 length
String payload3 = Base64.getUrlEncoder().withoutPadding().encodeToString("{\"oid\":\"abc\"}".getBytes());
String token3 = header + "." + payload3 + ".sig";
assertEquals("abc", RedisUtils.extractUsernameFromToken(token3));
}
}