Skip to content

Commit 0bffdb2

Browse files
authored
feat: Moving configuration and auth logic (#37)
This is being moved from cloudquery/cli to allow it to be shared by both the `plugin-sdk` and `cloudquery/cli` components. This may not be the best place for this code in the long run, but for now we want to share the code between `cloudquery/cli` and `plugin-sdk` without having `cloudquery/cli` introduce a dependency on the `plugin-sdk`. This PR also modifies the token issuing logic, to only refresh the ID token if the existing token is close to expiry. This is to minimise the API calls needed during the sync upsert process.
1 parent 715a2ce commit 0bffdb2

File tree

10 files changed

+569
-1
lines changed

10 files changed

+569
-1
lines changed

.github/workflows/lint_golang.yml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
name: Lint
2+
on:
3+
push:
4+
branches:
5+
- main
6+
pull_request:
7+
branches:
8+
- main
9+
10+
jobs:
11+
golangci:
12+
name: Lint with GolangCI
13+
runs-on: ubuntu-latest
14+
timeout-minutes: 10
15+
steps:
16+
- uses: actions/checkout@v3
17+
- uses: actions/setup-go@v4
18+
with:
19+
go-version-file: go.mod
20+
- name: golangci-lint
21+
uses: golangci/golangci-lint-action@v3
22+
with:
23+
version: v1.54.2

.github/workflows/unittest.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
name: "Unit tests"
2+
on:
3+
push:
4+
branches:
5+
- main
6+
pull_request:
7+
branches:
8+
- main
9+
10+
jobs:
11+
unitests:
12+
timeout-minutes: 30
13+
runs-on: ${{ matrix.os }}
14+
strategy:
15+
fail-fast: false
16+
matrix:
17+
os: [ubuntu-latest, macos-latest, windows-latest]
18+
steps:
19+
- name: Check out code into the Go module directory
20+
uses: actions/checkout@v3
21+
- name: Set up Go 1.x
22+
uses: actions/setup-go@v4
23+
with:
24+
go-version-file: go.mod
25+
- run: go mod download
26+
- run: go build ./...
27+
- name: Run tests
28+
run: make test

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@
1919

2020
# Go workspace file
2121
go.work
22+
23+
# Intellij IDE file
24+
.idea

Makefile

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
.PHONY: test
2+
test:
3+
go test -tags=assert -race ./...
4+
5+
.PHONY: lint
6+
lint:
7+
golangci-lint run
8+

auth/token.go

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
package auth
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"io"
7+
"net/http"
8+
"net/url"
9+
"os"
10+
"strings"
11+
"time"
12+
13+
"github.com/adrg/xdg"
14+
)
15+
16+
const (
17+
FirebaseAPIKey = "AIzaSyCxsrwjABEF-dWLzUqmwiL-ct02cnG9GCs"
18+
TokenBaseURL = "https://securetoken.googleapis.com"
19+
EnvVarCloudQueryAPIKey = "CLOUDQUERY_API_KEY"
20+
ExpiryBuffer = 60 * time.Second
21+
)
22+
23+
type tokenResponse struct {
24+
AccessToken string `json:"access_token"`
25+
ExpiresIn string `json:"expires_in"`
26+
TokenType string `json:"token_type"`
27+
RefreshToken string `json:"refresh_token"`
28+
IDToken string `json:"id_token"`
29+
UserID string `json:"user_id"`
30+
ProjectID string `json:"project_id"`
31+
}
32+
33+
type TokenClient struct {
34+
url string
35+
apiKey string
36+
idToken string
37+
expiresAt time.Time
38+
}
39+
40+
func NewTokenClient() *TokenClient {
41+
return &TokenClient{
42+
url: TokenBaseURL,
43+
apiKey: FirebaseAPIKey,
44+
}
45+
}
46+
47+
// GetToken returns the ID token
48+
// If CLOUDQUERY_API_KEY is set, it returns that value, otherwise it returns an ID token generated from the refresh token.
49+
func (tc *TokenClient) GetToken() (string, error) {
50+
if token := os.Getenv(EnvVarCloudQueryAPIKey); token != "" {
51+
return token, nil
52+
}
53+
54+
// If the token is not expired, return it
55+
if !tc.expiresAt.IsZero() && tc.expiresAt.Sub(time.Now().UTC()) > ExpiryBuffer {
56+
return tc.idToken, nil
57+
}
58+
59+
refreshToken, err := ReadRefreshToken()
60+
if err != nil {
61+
return "", fmt.Errorf("failed to read refresh token: %w. Hint: You may need to run `cloudquery login` or set %s", err, EnvVarCloudQueryAPIKey)
62+
}
63+
if refreshToken == "" {
64+
return "", fmt.Errorf("authentication token not found. Hint: You may need to run `cloudquery login` or set %s", EnvVarCloudQueryAPIKey)
65+
}
66+
tokenResponse, err := tc.generateToken(refreshToken)
67+
if err != nil {
68+
return "", fmt.Errorf("failed to sign in with custom token: %w", err)
69+
}
70+
71+
if err := SaveRefreshToken(tokenResponse.RefreshToken); err != nil {
72+
return "", fmt.Errorf("failed to save refresh token: %w", err)
73+
}
74+
75+
if err := tc.updateIDToken(tokenResponse); err != nil {
76+
return "", fmt.Errorf("failed to update ID token: %w", err)
77+
}
78+
79+
return tc.idToken, nil
80+
}
81+
82+
func (tc *TokenClient) generateToken(refreshToken string) (*tokenResponse, error) {
83+
data := url.Values{}
84+
data.Set("grant_type", "refresh_token")
85+
data.Set("refresh_token", refreshToken)
86+
87+
resp, err := http.PostForm(fmt.Sprintf("%s/v1/token?key=%s", tc.url, tc.apiKey), data)
88+
if err != nil {
89+
return nil, err
90+
}
91+
defer resp.Body.Close()
92+
if resp.StatusCode != http.StatusOK {
93+
body, readErr := io.ReadAll(resp.Body)
94+
if readErr != nil {
95+
return nil, fmt.Errorf("failed to read response body: %w", readErr)
96+
}
97+
return nil, fmt.Errorf("failed to refresh token: %s: %s", resp.Status, body)
98+
}
99+
100+
var tr tokenResponse
101+
body, err := io.ReadAll(resp.Body)
102+
if err != nil {
103+
return nil, err
104+
}
105+
if err := parseToken(body, &tr); err != nil {
106+
return nil, err
107+
}
108+
109+
return &tr, nil
110+
}
111+
112+
func (tc *TokenClient) updateIDToken(tr *tokenResponse) error {
113+
// Convert string duration in seconds to time.Duration
114+
duration, err := time.ParseDuration(tr.ExpiresIn + "s")
115+
if err != nil {
116+
return err
117+
}
118+
119+
tc.expiresAt = time.Now().UTC().Add(duration)
120+
tc.idToken = tr.IDToken
121+
return nil
122+
}
123+
124+
func parseToken(response []byte, tr *tokenResponse) error {
125+
err := json.Unmarshal(response, tr)
126+
if err != nil {
127+
return err
128+
}
129+
return nil
130+
}
131+
132+
// SaveRefreshToken saves the refresh token to the token file
133+
func SaveRefreshToken(refreshToken string) error {
134+
tokenFilePath, err := xdg.DataFile("cloudquery/token")
135+
if err != nil {
136+
return fmt.Errorf("failed to get token file path: %w", err)
137+
}
138+
tokenFile, err := os.OpenFile(tokenFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
139+
if err != nil {
140+
return fmt.Errorf("failed to open token file %q for writing: %w", tokenFilePath, err)
141+
}
142+
defer func() {
143+
if closeErr := tokenFile.Close(); closeErr != nil {
144+
fmt.Printf("error closing token file: %v", closeErr)
145+
}
146+
}()
147+
if _, err = tokenFile.WriteString(refreshToken); err != nil {
148+
return fmt.Errorf("failed to write token to %q: %w", tokenFilePath, err)
149+
}
150+
return nil
151+
}
152+
153+
// ReadRefreshToken reads the refresh token from the token file
154+
func ReadRefreshToken() (string, error) {
155+
tokenFilePath, err := xdg.DataFile("cloudquery/token")
156+
if err != nil {
157+
return "", fmt.Errorf("failed to get token file path: %w", err)
158+
}
159+
b, err := os.ReadFile(tokenFilePath)
160+
if err != nil {
161+
return "", fmt.Errorf("failed to read token file: %w", err)
162+
}
163+
return strings.TrimSpace(string(b)), nil
164+
}
165+
166+
// RemoveRefreshToken removes the token file
167+
func RemoveRefreshToken() error {
168+
tokenFilePath, err := xdg.DataFile("cloudquery/token")
169+
if err != nil {
170+
return fmt.Errorf("failed to get token file path: %w", err)
171+
}
172+
if err := os.RemoveAll(tokenFilePath); err != nil {
173+
return fmt.Errorf("failed to remove token file %q: %w", tokenFilePath, err)
174+
}
175+
return nil
176+
}

auth/token_test.go

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
package auth
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"github.com/stretchr/testify/require"
7+
"net/http"
8+
"net/http/httptest"
9+
"os"
10+
"testing"
11+
"time"
12+
)
13+
14+
func TestRefreshToken_RoundTrip(t *testing.T) {
15+
token := "my_token"
16+
17+
err := SaveRefreshToken(token)
18+
require.NoError(t, err)
19+
20+
readToken, err := ReadRefreshToken()
21+
require.NoError(t, err)
22+
23+
require.Equal(t, token, readToken)
24+
}
25+
26+
func TestRefreshToken_Removal(t *testing.T) {
27+
token := "my_token"
28+
29+
err := SaveRefreshToken(token)
30+
require.NoError(t, err)
31+
32+
_, err = ReadRefreshToken()
33+
require.NoError(t, err)
34+
35+
err = RemoveRefreshToken()
36+
require.NoError(t, err)
37+
38+
_, err = ReadRefreshToken()
39+
require.Error(t, err)
40+
}
41+
42+
func TestTokenClient_EnvironmentVariable(t *testing.T) {
43+
reset := overrideEnvironmentVariable(t, EnvVarCloudQueryAPIKey, "my_token")
44+
defer reset()
45+
46+
token, err := NewTokenClient().GetToken()
47+
require.NoError(t, err)
48+
49+
require.Equal(t, "my_token", token)
50+
}
51+
52+
func TestTokenClient_GetToken_ShortExpiry(t *testing.T) {
53+
server, closer := fakeAuthServer(t, "0")
54+
defer closer()
55+
56+
err := SaveRefreshToken("my_refresh_token")
57+
require.NoError(t, err)
58+
59+
t0 := time.Now().UTC()
60+
61+
tc := TokenClient{
62+
url: server.URL,
63+
apiKey: "my-api-key",
64+
expiresAt: t0,
65+
}
66+
67+
token, err := tc.GetToken()
68+
require.NoError(t, err)
69+
require.Equal(t, "my_id_token_0", token, "first token")
70+
71+
tc.expiresAt = t0
72+
73+
token, err = tc.GetToken()
74+
require.NoError(t, err)
75+
require.Equal(t, "my_id_token_1", token, "expected to issue new token")
76+
}
77+
78+
func TestTokenClient_GetToken_LongExpiry(t *testing.T) {
79+
server, closer := fakeAuthServer(t, "3600")
80+
defer closer()
81+
82+
err := SaveRefreshToken("my_refresh_token")
83+
require.NoError(t, err)
84+
85+
tc := TokenClient{
86+
url: server.URL,
87+
apiKey: "my-api-key",
88+
}
89+
90+
token, err := tc.GetToken()
91+
require.NoError(t, err)
92+
require.Equal(t, "my_id_token_0", token, "first token")
93+
94+
token, err = tc.GetToken()
95+
require.NoError(t, err)
96+
require.Equal(t, "my_id_token_0", token, "expected to reuse token")
97+
}
98+
99+
func overrideEnvironmentVariable(t *testing.T, key, value string) func() {
100+
originalValue := os.Getenv(key)
101+
resetFn := func() {
102+
err := os.Setenv(key, originalValue)
103+
require.NoError(t, err)
104+
}
105+
106+
err := os.Setenv(key, value)
107+
require.NoError(t, err)
108+
109+
return resetFn
110+
}
111+
112+
func fakeAuthServer(t *testing.T, expiresIn string) (*httptest.Server, func()) {
113+
tokenCount := 0
114+
115+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
116+
require.Equal(t, http.MethodPost, r.Method)
117+
require.Equal(t, "/v1/token?key=my-api-key", r.URL.String())
118+
119+
err := r.ParseForm()
120+
require.NoError(t, err)
121+
122+
require.Equal(t, "my_refresh_token", r.Form.Get("refresh_token"))
123+
require.Equal(t, "refresh_token", r.Form.Get("grant_type"))
124+
125+
w.Header().Set("Content-Type", "application/json")
126+
response := tokenResponse{
127+
AccessToken: "my_access_token",
128+
ExpiresIn: expiresIn,
129+
TokenType: "Bearer",
130+
RefreshToken: "my_refresh_token",
131+
IDToken: fmt.Sprintf("my_id_token_%d", tokenCount),
132+
UserID: "abcd-1234",
133+
ProjectID: "project-1",
134+
}
135+
err = json.NewEncoder(w).Encode(response)
136+
require.NoError(t, err)
137+
138+
tokenCount++
139+
}))
140+
141+
return server, func() {
142+
server.Close()
143+
}
144+
}

0 commit comments

Comments
 (0)