Skip to content

Commit dc5cf02

Browse files
authored
feat(auth): Exposing CustomToken() on auth.TenantClient (#371)
1 parent 716247d commit dc5cf02

File tree

5 files changed

+138
-47
lines changed

5 files changed

+138
-47
lines changed

auth/auth.go

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ var reservedClaims = []string{
4444
type Client struct {
4545
*baseClient
4646
TenantManager *TenantManager
47-
signer cryptoSigner
48-
clock internal.Clock
4947
}
5048

5149
// NewClient creates a new instance of the Firebase Auth Client.
@@ -116,11 +114,11 @@ func NewClient(ctx context.Context, conf *internal.AuthConfig) (*Client, error)
116114
httpClient: hc,
117115
idTokenVerifier: idTokenVerifier,
118116
cookieVerifier: cookieVerifier,
117+
signer: signer,
118+
clock: internal.SystemClock,
119119
}
120120
return &Client{
121121
baseClient: base,
122-
signer: signer,
123-
clock: internal.SystemClock,
124122
TenantManager: newTenantManager(hc, conf, base),
125123
}, nil
126124
}
@@ -144,13 +142,13 @@ func NewClient(ctx context.Context, conf *internal.AuthConfig) (*Client, error)
144142
// conjunction with the IAM service to sign tokens remotely.
145143
//
146144
// CustomToken returns an error the SDK fails to discover a viable mechanism for signing tokens.
147-
func (c *Client) CustomToken(ctx context.Context, uid string) (string, error) {
145+
func (c *baseClient) CustomToken(ctx context.Context, uid string) (string, error) {
148146
return c.CustomTokenWithClaims(ctx, uid, nil)
149147
}
150148

151149
// CustomTokenWithClaims is similar to CustomToken, but in addition to the user ID, it also encodes
152150
// all the key-value pairs in the provided map as claims in the resulting JWT.
153-
func (c *Client) CustomTokenWithClaims(ctx context.Context, uid string, devClaims map[string]interface{}) (string, error) {
151+
func (c *baseClient) CustomTokenWithClaims(ctx context.Context, uid string, devClaims map[string]interface{}) (string, error) {
154152
iss, err := c.signer.Email(ctx)
155153
if err != nil {
156154
return "", err
@@ -176,13 +174,14 @@ func (c *Client) CustomTokenWithClaims(ctx context.Context, uid string, devClaim
176174
info := &jwtInfo{
177175
header: jwtHeader{Algorithm: "RS256", Type: "JWT"},
178176
payload: &customToken{
179-
Iss: iss,
180-
Sub: iss,
181-
Aud: firebaseAudience,
182-
UID: uid,
183-
Iat: now,
184-
Exp: now + oneHourInSeconds,
185-
Claims: devClaims,
177+
Iss: iss,
178+
Sub: iss,
179+
Aud: firebaseAudience,
180+
UID: uid,
181+
Iat: now,
182+
Exp: now + oneHourInSeconds,
183+
TenantID: c.tenantID,
184+
Claims: devClaims,
186185
},
187186
}
188187
return info.Token(ctx, c.signer)
@@ -235,6 +234,8 @@ type baseClient struct {
235234
httpClient *internal.HTTPClient
236235
idTokenVerifier *tokenVerifier
237236
cookieVerifier *tokenVerifier
237+
signer cryptoSigner
238+
clock internal.Clock
238239
}
239240

240241
func (c *baseClient) withTenantID(tenantID string) *baseClient {

auth/auth_test.go

Lines changed: 67 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -282,20 +282,26 @@ func TestNewClientExplicitNoAuth(t *testing.T) {
282282

283283
func TestCustomToken(t *testing.T) {
284284
client := &Client{
285-
signer: testSigner,
286-
clock: testClock,
285+
baseClient: &baseClient{
286+
signer: testSigner,
287+
clock: testClock,
288+
},
287289
}
288290
token, err := client.CustomToken(context.Background(), "user1")
289291
if err != nil {
290292
t.Fatal(err)
291293
}
292-
verifyCustomToken(context.Background(), token, nil, t)
294+
if err := verifyCustomToken(context.Background(), token, nil, ""); err != nil {
295+
t.Fatal(err)
296+
}
293297
}
294298

295299
func TestCustomTokenWithClaims(t *testing.T) {
296300
client := &Client{
297-
signer: testSigner,
298-
clock: testClock,
301+
baseClient: &baseClient{
302+
signer: testSigner,
303+
clock: testClock,
304+
},
299305
}
300306
claims := map[string]interface{}{
301307
"foo": "bar",
@@ -306,19 +312,46 @@ func TestCustomTokenWithClaims(t *testing.T) {
306312
if err != nil {
307313
t.Fatal(err)
308314
}
309-
verifyCustomToken(context.Background(), token, claims, t)
315+
if err := verifyCustomToken(context.Background(), token, claims, ""); err != nil {
316+
t.Fatal(err)
317+
}
310318
}
311319

312320
func TestCustomTokenWithNilClaims(t *testing.T) {
313321
client := &Client{
314-
signer: testSigner,
315-
clock: testClock,
322+
baseClient: &baseClient{
323+
signer: testSigner,
324+
clock: testClock,
325+
},
316326
}
317327
token, err := client.CustomTokenWithClaims(context.Background(), "user1", nil)
318328
if err != nil {
319329
t.Fatal(err)
320330
}
321-
verifyCustomToken(context.Background(), token, nil, t)
331+
if err := verifyCustomToken(context.Background(), token, nil, ""); err != nil {
332+
t.Fatal(err)
333+
}
334+
}
335+
336+
func TestCustomTokenForTenant(t *testing.T) {
337+
client := &Client{
338+
baseClient: &baseClient{
339+
tenantID: "tenantID",
340+
signer: testSigner,
341+
clock: testClock,
342+
},
343+
}
344+
claims := map[string]interface{}{
345+
"foo": "bar",
346+
"premium": true,
347+
}
348+
token, err := client.CustomTokenWithClaims(context.Background(), "user1", claims)
349+
if err != nil {
350+
t.Fatal(err)
351+
}
352+
if err := verifyCustomToken(context.Background(), token, claims, "tenantID"); err != nil {
353+
t.Fatal(err)
354+
}
322355
}
323356

324357
func TestCustomTokenError(t *testing.T) {
@@ -333,7 +366,7 @@ func TestCustomTokenError(t *testing.T) {
333366
{"ReservedClaims", "uid", map[string]interface{}{"sub": "1234", "aud": "foo"}},
334367
}
335368

336-
client := &Client{
369+
client := &baseClient{
337370
signer: testSigner,
338371
clock: testClock,
339372
}
@@ -628,9 +661,9 @@ func TestCustomTokenVerification(t *testing.T) {
628661
client := &Client{
629662
baseClient: &baseClient{
630663
idTokenVerifier: testIDTokenVerifier,
664+
signer: testSigner,
665+
clock: testClock,
631666
},
632-
signer: testSigner,
633-
clock: testClock,
634667
}
635668
token, err := client.CustomToken(context.Background(), "user1")
636669
if err != nil {
@@ -1137,52 +1170,61 @@ func checkBaseClient(client *Client, wantProjectID string) error {
11371170
return nil
11381171
}
11391172

1140-
func verifyCustomToken(ctx context.Context, token string, expected map[string]interface{}, t *testing.T) {
1173+
func verifyCustomToken(
1174+
ctx context.Context, token string, expected map[string]interface{}, tenantID string) error {
1175+
11411176
if err := testIDTokenVerifier.verifySignature(ctx, token); err != nil {
1142-
t.Fatal(err)
1177+
return err
11431178
}
1179+
11441180
var (
11451181
header jwtHeader
11461182
payload customToken
11471183
)
11481184
segments := strings.Split(token, ".")
11491185
if err := decode(segments[0], &header); err != nil {
1150-
t.Fatal(err)
1186+
return err
11511187
}
11521188
if err := decode(segments[1], &payload); err != nil {
1153-
t.Fatal(err)
1189+
return err
11541190
}
11551191

11561192
email, err := testSigner.Email(ctx)
11571193
if err != nil {
1158-
t.Fatal(err)
1194+
return err
11591195
}
11601196

11611197
if header.Algorithm != "RS256" {
1162-
t.Errorf("Algorithm: %q; want: 'RS256'", header.Algorithm)
1198+
return fmt.Errorf("Algorithm: %q; want: 'RS256'", header.Algorithm)
11631199
} else if header.Type != "JWT" {
1164-
t.Errorf("Type: %q; want: 'JWT'", header.Type)
1200+
return fmt.Errorf("Type: %q; want: 'JWT'", header.Type)
11651201
} else if payload.Aud != firebaseAudience {
1166-
t.Errorf("Audience: %q; want: %q", payload.Aud, firebaseAudience)
1202+
return fmt.Errorf("Audience: %q; want: %q", payload.Aud, firebaseAudience)
11671203
} else if payload.Iss != email {
1168-
t.Errorf("Issuer: %q; want: %q", payload.Iss, email)
1204+
return fmt.Errorf("Issuer: %q; want: %q", payload.Iss, email)
11691205
} else if payload.Sub != email {
1170-
t.Errorf("Subject: %q; want: %q", payload.Sub, email)
1206+
return fmt.Errorf("Subject: %q; want: %q", payload.Sub, email)
11711207
}
11721208

11731209
now := testClock.Now().Unix()
11741210
if payload.Exp != now+3600 {
1175-
t.Errorf("Exp: %d; want: %d", payload.Exp, now+3600)
1211+
return fmt.Errorf("Exp: %d; want: %d", payload.Exp, now+3600)
11761212
}
11771213
if payload.Iat != now {
1178-
t.Errorf("Iat: %d; want: %d", payload.Iat, now)
1214+
return fmt.Errorf("Iat: %d; want: %d", payload.Iat, now)
11791215
}
11801216

11811217
for k, v := range expected {
11821218
if payload.Claims[k] != v {
1183-
t.Errorf("Claim[%q]: %v; want: %v", k, payload.Claims[k], v)
1219+
return fmt.Errorf("Claim[%q]: %v; want: %v", k, payload.Claims[k], v)
11841220
}
11851221
}
1222+
1223+
if payload.TenantID != tenantID {
1224+
return fmt.Errorf("Tenant ID: %q; want: %q", payload.TenantID, tenantID)
1225+
}
1226+
1227+
return nil
11861228
}
11871229

11881230
func logFatal(err error) {

auth/token_generator.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,14 @@ type jwtHeader struct {
4141
}
4242

4343
type customToken struct {
44-
Iss string `json:"iss"`
45-
Aud string `json:"aud"`
46-
Exp int64 `json:"exp"`
47-
Iat int64 `json:"iat"`
48-
Sub string `json:"sub,omitempty"`
49-
UID string `json:"uid,omitempty"`
50-
Claims map[string]interface{} `json:"claims,omitempty"`
44+
Iss string `json:"iss"`
45+
Aud string `json:"aud"`
46+
Exp int64 `json:"exp"`
47+
Iat int64 `json:"iat"`
48+
Sub string `json:"sub,omitempty"`
49+
UID string `json:"uid,omitempty"`
50+
TenantID string `json:"tenant_id,omitempty"`
51+
Claims map[string]interface{} `json:"claims,omitempty"`
5152
}
5253

5354
type jwtInfo struct {

integration/auth/auth_test.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,19 @@ func verifyCustomToken(t *testing.T, ct, uid string) *auth.Token {
207207
}
208208

209209
func signInWithCustomToken(token string) (string, error) {
210-
req, err := json.Marshal(map[string]interface{}{
210+
return signInWithCustomTokenForTenant(token, "")
211+
}
212+
213+
func signInWithCustomTokenForTenant(token string, tenantID string) (string, error) {
214+
payload := map[string]interface{}{
211215
"token": token,
212216
"returnSecureToken": true,
213-
})
217+
}
218+
if tenantID != "" {
219+
payload["tenantId"] = tenantID
220+
}
221+
222+
req, err := json.Marshal(payload)
214223
if err != nil {
215224
return "", err
216225
}

integration/auth/tenant_mgt_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ func TestTenantManager(t *testing.T) {
9797
}
9898
})
9999

100+
t.Run("CustomTokens", func(t *testing.T) {
101+
testTenantAwareCustomToken(t, id)
102+
})
103+
100104
t.Run("UserManagement", func(t *testing.T) {
101105
testTenantAwareUserManagement(t, id)
102106
})
@@ -154,6 +158,40 @@ func TestTenantManager(t *testing.T) {
154158
})
155159
}
156160

161+
func testTenantAwareCustomToken(t *testing.T, id string) {
162+
tenantClient, err := client.TenantManager.AuthForTenant(id)
163+
if err != nil {
164+
t.Fatalf("AuthForTenant() = %v", err)
165+
}
166+
167+
uid := randomUID()
168+
ct, err := tenantClient.CustomToken(context.Background(), uid)
169+
if err != nil {
170+
t.Fatal(err)
171+
}
172+
173+
idToken, err := signInWithCustomTokenForTenant(ct, id)
174+
if err != nil {
175+
t.Fatal(err)
176+
}
177+
178+
defer func() {
179+
tenantClient.DeleteUser(context.Background(), uid)
180+
}()
181+
182+
vt, err := tenantClient.VerifyIDToken(context.Background(), idToken)
183+
if err != nil {
184+
t.Fatal(err)
185+
}
186+
187+
if vt.UID != uid {
188+
t.Errorf("UID = %q; want UID = %q", vt.UID, uid)
189+
}
190+
if vt.Firebase.Tenant != id {
191+
t.Errorf("Tenant = %q; want = %q", vt.Firebase.Tenant, id)
192+
}
193+
}
194+
157195
func testTenantAwareUserManagement(t *testing.T, id string) {
158196
tenantClient, err := client.TenantManager.AuthForTenant(id)
159197
if err != nil {

0 commit comments

Comments
 (0)