Skip to content

Commit 95dfdaf

Browse files
authored
Fix JWT token refresh to Fusion validation request (#6231) [ci fast]
Signed-off-by: Paolo Di Tommaso <paolo.ditommaso@gmail.com>
1 parent c65955c commit 95dfdaf

File tree

3 files changed

+174
-19
lines changed

3 files changed

+174
-19
lines changed

plugins/nf-tower/src/main/io/seqera/tower/plugin/TowerFactory.groovy

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class TowerFactory implements TraceObserverFactory {
100100
@Memoized
101101
static TowerClient client(Session session, Map<String,String> env) {
102102
final config = session.config
103-
Boolean isEnabled = config.navigate('tower.enabled') as Boolean || env.get('TOWER_WORKFLOW_ID')
103+
Boolean isEnabled = config.navigate('tower.enabled') as Boolean || env.get('TOWER_WORKFLOW_ID') || config.navigate('fusion.enabled') as Boolean
104104
return isEnabled
105105
? createTowerClient0(session, config, env)
106106
: null

plugins/nf-tower/src/main/io/seqera/tower/plugin/TowerFusionToken.groovy

Lines changed: 85 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import com.google.common.util.concurrent.UncheckedExecutionException
1515
import com.google.gson.Gson
1616
import com.google.gson.JsonSyntaxException
1717
import dev.failsafe.Failsafe
18+
import dev.failsafe.FailsafeException
1819
import dev.failsafe.RetryPolicy
1920
import dev.failsafe.event.EventListener
2021
import dev.failsafe.event.ExecutionAttemptedEvent
@@ -62,6 +63,8 @@ class TowerFusionToken implements FusionToken {
6263
private static final int DEFAULT_RETRY_POLICY_MAX_ATTEMPTS = 10
6364
private static final double DEFAULT_RETRY_POLICY_JITTER = 0.5
6465

66+
private CookieManager cookieManager = new CookieManager()
67+
6568
// The HttpClient instance used to send requests
6669
private final HttpClient httpClient = newDefaultHttpClient()
6770

@@ -80,7 +83,9 @@ class TowerFusionToken implements FusionToken {
8083
private String endpoint
8184

8285
// Platform access token to use for requests
83-
private String accessToken
86+
private volatile String accessToken
87+
88+
private volatile String refreshToken
8489

8590
// Platform workflowId
8691
private String workspaceId
@@ -93,6 +98,7 @@ class TowerFusionToken implements FusionToken {
9398
final env = SysEnv.get()
9499
this.endpoint = PlatformHelper.getEndpoint(config, env)
95100
this.accessToken = PlatformHelper.getAccessToken(config, env)
101+
this.refreshToken = PlatformHelper.getRefreshToken(config, env)
96102
this.workflowId = env.get('TOWER_WORKFLOW_ID')
97103
this.workspaceId = PlatformHelper.getWorkspaceId(config, env)
98104
}
@@ -182,11 +188,11 @@ class TowerFusionToken implements FusionToken {
182188
* Create a new HttpClient instance with default settings
183189
* @return The new HttpClient instance
184190
*/
185-
private static HttpClient newDefaultHttpClient() {
191+
private HttpClient newDefaultHttpClient() {
186192
final builder = HttpClient.newBuilder()
187193
.version(HttpClient.Version.HTTP_1_1)
188194
.followRedirects(HttpClient.Redirect.NEVER)
189-
.cookieHandler(new CookieManager())
195+
.cookieHandler(cookieManager)
190196
.connectTimeout(DEFAULT_CONNECTION_TIMEOUT)
191197
// use virtual threads executor if enabled
192198
if ( Threads.useVirtual() ) {
@@ -236,8 +242,17 @@ class TowerFusionToken implements FusionToken {
236242
* @param req The HttpRequest to send
237243
* @return The HttpResponse received
238244
*/
239-
private <T> HttpResponse<String> safeHttpSend(HttpRequest req, RetryPolicy<T> policy) {
240-
return Failsafe.with(policy).get(
245+
private <T> HttpResponse<String> safeHttpSend(HttpRequest req) {
246+
try {
247+
safeApply(req)
248+
}
249+
catch (FailsafeException e) {
250+
throw e.cause
251+
}
252+
}
253+
254+
private <T> HttpResponse<String> safeApply(HttpRequest req) {
255+
return Failsafe.with(retryPolicy).get(
241256
() -> {
242257
log.debug "Http request: method=${req.method()}; uri=${req.uri()}; request=${req}"
243258
final resp = httpClient.send(req, HttpResponse.BodyHandlers.ofString())
@@ -289,28 +304,35 @@ class TowerFusionToken implements FusionToken {
289304
/**
290305
* Request a license token from Platform.
291306
*
292-
* @param req The LicenseTokenRequest object
307+
* @param request The LicenseTokenRequest object
293308
* @return The LicenseTokenResponse object
294-
*
295-
* @throws AbortOperationException if a Platform access token cannot be found
296-
* @throws UnauthorizedException if the access token is invalid
297-
* @throws BadResponseException if the response is not as expected
298-
* @throws IllegalStateException if the request cannot be sent
299309
*/
300-
private GetLicenseTokenResponse sendRequest(GetLicenseTokenRequest req) throws AbortOperationException, UnauthorizedException, BadResponseException, IllegalStateException {
310+
private GetLicenseTokenResponse sendRequest(GetLicenseTokenRequest request) {
311+
return sendRequest0(request, 1)
312+
}
301313

302-
final httpReq = makeHttpRequest(req)
314+
private GetLicenseTokenResponse sendRequest0(GetLicenseTokenRequest request, int attempt) {
315+
316+
final httpReq = makeHttpRequest(request)
303317

304318
try {
305-
final resp = safeHttpSend(httpReq, retryPolicy)
319+
final resp = safeHttpSend(httpReq)
306320

307321
if( resp.statusCode() == 200 ) {
308322
final ret = parseLicenseTokenResponse(resp.body())
309323
return ret
310324
}
311325

312326
if( resp.statusCode() == 401 ) {
313-
throw new UnauthorizedException("Unauthorized [401] - Verify you have provided a Seqera Platform valid access token")
327+
final shouldRetry = accessToken
328+
&& refreshToken
329+
&& attempt==1
330+
&& refreshJwtToken0(refreshToken)
331+
if( shouldRetry ) {
332+
return sendRequest0(request, attempt+1)
333+
}
334+
else
335+
throw new UnauthorizedException("Unauthorized [401] - Verify you have provided a Seqera Platform valid access token")
314336
}
315337

316338
throw new BadResponseException("Invalid response: ${httpReq.method()} ${httpReq.uri()} [${resp.statusCode()}] ${resp.body()}")
@@ -319,4 +341,52 @@ class TowerFusionToken implements FusionToken {
319341
throw new IllegalStateException("Unable to send request to '${httpReq.uri()}' : ${e.message}")
320342
}
321343
}
344+
345+
protected boolean refreshJwtToken0(String refresh) {
346+
log.debug "Token refresh request >> $refresh"
347+
348+
final req = HttpRequest.newBuilder()
349+
.uri(new URI("${endpoint}/oauth/access_token"))
350+
.headers('Content-Type',"application/x-www-form-urlencoded")
351+
.POST(HttpRequest.BodyPublishers.ofString("grant_type=refresh_token&refresh_token=${URLEncoder.encode(refresh, 'UTF-8')}"))
352+
.build()
353+
354+
final resp = safeHttpSend(req)
355+
final code = resp.statusCode()
356+
final body = resp.body()
357+
log.debug "Refresh cookie response: [${code}] ${body}"
358+
if( resp.statusCode() != 200 )
359+
return false
360+
361+
final authCookie = getCookie('JWT')
362+
final refreshCookie = getCookie('JWT_REFRESH_TOKEN')
363+
364+
// set the new bearer token in the current client session
365+
if( authCookie?.value ) {
366+
log.trace "Updating http client bearer token=$authCookie.value"
367+
accessToken = authCookie.value
368+
}
369+
else {
370+
log.warn "Missing JWT cookie from refresh token response ~ $authCookie"
371+
}
372+
373+
// set the new refresh token
374+
if( refreshCookie?.value ) {
375+
log.trace "Updating http client refresh token=$refreshCookie.value"
376+
refreshToken = refreshCookie.value
377+
}
378+
else {
379+
log.warn "Missing JWT_REFRESH_TOKEN cookie from refresh token response ~ $refreshCookie"
380+
}
381+
382+
return true
383+
}
384+
385+
private HttpCookie getCookie(final String cookieName) {
386+
for( HttpCookie it : cookieManager.cookieStore.cookies ) {
387+
if( it.name == cookieName )
388+
return it
389+
}
390+
return null
391+
}
322392
}

plugins/nf-tower/src/test/io/seqera/tower/plugin/TowerFusionEnvTest.groovy

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import java.time.temporal.ChronoUnit
55

66
import com.github.tomakehurst.wiremock.WireMockServer
77
import com.github.tomakehurst.wiremock.client.WireMock
8+
import com.github.tomakehurst.wiremock.stubbing.Scenario
89
import io.seqera.tower.plugin.exception.UnauthorizedException
910
import nextflow.Global
1011
import nextflow.Session
@@ -43,7 +44,6 @@ class TowerFusionEnvTest extends Specification {
4344
SysEnv.pop() // <-- restore the system host env
4445
}
4546

46-
4747
def 'should return the endpoint from the config'() {
4848
given: 'a session'
4949
Global.session = Mock(Session) {
@@ -287,7 +287,6 @@ class TowerFusionEnvTest extends Specification {
287287
and: 'the request is correct'
288288
wireMockServer.verify(1, WireMock.postRequestedFor(WireMock.urlEqualTo("/license/token/"))
289289
.withHeader('Authorization', WireMock.equalTo('Bearer abc123')))
290-
291290
}
292291

293292
def 'should get a license token with environment'() {
@@ -336,7 +335,94 @@ class TowerFusionEnvTest extends Specification {
336335
SysEnv.pop()
337336
}
338337

338+
def 'should refresh the auth token on 401 and retry the request'() {
339+
given:
340+
SysEnv.push([
341+
TOWER_WORKFLOW_ID: '12345',
342+
TOWER_ACCESS_TOKEN: 'abc-token',
343+
TOWER_REFRESH_TOKEN: 'xyz-refresh',
344+
TOWER_WORKSPACE_ID: '67890',
345+
TOWER_API_ENDPOINT: wireMockServer.baseUrl()
346+
])
347+
def PRODUCT = 'some-product'
348+
def VERSION = 'some-version'
349+
and:
350+
Global.session = Mock(Session) { getConfig() >> [:] }
351+
and:
352+
def provider = new TowerFusionToken()
353+
354+
and: 'prepare stubs'
355+
356+
final now = Instant.now()
357+
final expirationDate = GsonHelper.toJson(now.plus(1, ChronoUnit.DAYS))
358+
359+
// 1️⃣ First attempt: /license/token/ fails with 401
360+
wireMockServer.stubFor(
361+
WireMock.post(urlEqualTo("/license/token/"))
362+
.inScenario("Refresh flow")
363+
.whenScenarioStateIs(Scenario.STARTED)
364+
.willReturn(WireMock.aResponse().withStatus(401))
365+
.willSetStateTo("Token Refreshed")
366+
)
367+
368+
// 2️⃣ Refresh token call
369+
wireMockServer.stubFor(
370+
WireMock.post(urlEqualTo("/oauth/access_token"))
371+
.inScenario("Refresh flow")
372+
.whenScenarioStateIs("Token Refreshed")
373+
.withHeader('Content-Type', equalTo('application/x-www-form-urlencoded'))
374+
.withRequestBody(containing('grant_type=refresh_token'))
375+
.withRequestBody(containing(URLEncoder.encode('xyz-refresh', 'UTF-8')))
376+
.willReturn(
377+
WireMock.aResponse()
378+
.withStatus(200)
379+
.withHeader('Set-Cookie', 'JWT=new-abc-token')
380+
.withHeader('Set-Cookie', 'JWT_REFRESH_TOKEN=new-refresh-456')
381+
)
382+
.willSetStateTo("Retry Ready")
383+
)
339384

385+
// 3️⃣ Retry: /license/token/ succeeds
386+
wireMockServer.stubFor(
387+
WireMock.post(urlEqualTo("/license/token/"))
388+
.inScenario("Refresh flow")
389+
.whenScenarioStateIs("Retry Ready")
390+
.withHeader('Authorization', equalTo('Bearer new-abc-token'))
391+
.willReturn(
392+
WireMock.aResponse()
393+
.withStatus(200)
394+
.withHeader('Content-Type', 'application/json')
395+
.withBody('{"signedToken":"xyz789", "expiresAt":' + expirationDate + '}')
396+
)
397+
)
398+
399+
when:
400+
final token = provider.getLicenseToken(PRODUCT, VERSION)
401+
402+
then:
403+
token == 'xyz789'
404+
405+
and: 'the initial request was sent with the old token'
406+
wireMockServer.verify(1, WireMock.postRequestedFor(urlEqualTo("/license/token/"))
407+
.withHeader('Authorization', equalTo('Bearer abc-token'))
408+
)
409+
410+
and: 'the refresh request was sent with correct form data'
411+
wireMockServer.verify(1, WireMock.postRequestedFor(WireMock.urlEqualTo("/oauth/access_token"))
412+
.withHeader('Content-Type', equalTo('application/x-www-form-urlencoded'))
413+
.withRequestBody(containing('grant_type=refresh_token'))
414+
.withRequestBody(containing(URLEncoder.encode('xyz-refresh', 'UTF-8')))
415+
)
416+
417+
and: 'the retried request was sent with the new token'
418+
wireMockServer.verify(1, WireMock.postRequestedFor(WireMock.urlEqualTo("/license/token/"))
419+
.withHeader('Authorization', equalTo('Bearer new-abc-token'))
420+
)
421+
422+
cleanup:
423+
SysEnv.pop()
424+
}
425+
340426
def 'should throw UnauthorizedException if getting a token fails with 401'() {
341427
given: 'a TowerFusionEnv provider'
342428
Global.session = Mock(Session) {
@@ -369,7 +455,6 @@ class TowerFusionEnvTest extends Specification {
369455
thrown(UnauthorizedException)
370456
}
371457

372-
373458
def 'should deserialize response' () {
374459
given:
375460
def ts = Instant.ofEpochSecond(1738788914)

0 commit comments

Comments
 (0)