diff --git a/.env.test b/.env.test index a1f3186..327a0ce 100644 --- a/.env.test +++ b/.env.test @@ -5,6 +5,7 @@ MONGO_DB=switcher-gitops-test GIT_TOKEN_PRIVATE_KEY=SecretSecretSecretSecretSecretSe HANDLER_WAITING_TIME=1m +SWITCHER_API_CA_CERT= SWITCHER_API_URL=https://switcherapi.com/api SWITCHER_API_JWT_SECRET=SecretSecretSecretSecretSecretSe SWITCHER_PATH_GRAPHQL=/gitops-graphql diff --git a/Makefile b/Makefile index 806300a..8050416 100644 --- a/Makefile +++ b/Makefile @@ -2,8 +2,11 @@ build: go build -o ./bin/app ./src/cmd/app/main.go run: - GOOS=windows $env:GO_ENV="test"; go run ./src/cmd/app/main.go - GOOS=linux GO_ENV=test go run ./src/cmd/app/main.go +ifeq ($(OS),Windows_NT) + $env:GO_ENV="test"; go run ./src/cmd/app/main.go +else + GO_ENV=test go run ./src/cmd/app/main.go +endif test: go test -p 1 -coverpkg=./... -v diff --git a/resources/fixtures/api/dummy.pem b/resources/fixtures/api/dummy.pem new file mode 100644 index 0000000..c412709 --- /dev/null +++ b/resources/fixtures/api/dummy.pem @@ -0,0 +1,2 @@ +-----BEGIN CERTIFICATE----- +-----END CERTIFICATE----- diff --git a/src/core/api.go b/src/core/api.go index bc8b8fa..2a94e66 100644 --- a/src/core/api.go +++ b/src/core/api.go @@ -2,16 +2,20 @@ package core import ( "bytes" + "crypto/tls" + "crypto/x509" "encoding/json" "errors" "fmt" "io" "net/http" + "os" "time" "github.com/golang-jwt/jwt" "github.com/switcherapi/switcher-gitops/src/config" "github.com/switcherapi/switcher-gitops/src/model" + "github.com/switcherapi/switcher-gitops/src/utils" ) type GraphQLRequest struct { @@ -32,14 +36,16 @@ type IAPIService interface { } type ApiService struct { - apiKey string - apiUrl string + apiKey string + apiUrl string + caCertPath string } -func NewApiService(apiKey string, apiUrl string) *ApiService { +func NewApiService(apiKey string, apiUrl string, caCertPath string) *ApiService { return &ApiService{ - apiKey: apiKey, - apiUrl: apiUrl, + apiKey: apiKey, + apiUrl: apiUrl, + caCertPath: caCertPath, } } @@ -101,8 +107,7 @@ func (a *ApiService) doGraphQLRequest(domainId string, query string) (string, er setHeaders(req, token) // Send the request - client := &http.Client{} - resp, err := client.Do(req) + resp, err := a.doRequest(req) if err != nil { return "", err } @@ -123,8 +128,7 @@ func (a *ApiService) doPostRequest(url string, domainId string, body []byte) (st setHeaders(req, token) // Send the request - client := &http.Client{} - resp, err := client.Do(req) + resp, err := a.doRequest(req) if err != nil { return "", 0, err } @@ -134,6 +138,35 @@ func (a *ApiService) doPostRequest(url string, domainId string, body []byte) (st return string(responseBody), resp.StatusCode, nil } +func (a *ApiService) doRequest(req *http.Request) (*http.Response, error) { + var client *http.Client + + if a.caCertPath != "" { + caCert, err := os.ReadFile(a.caCertPath) + + if err != nil { + utils.LogError("Error reading CA certificate file: " + err.Error()) + return nil, err + } + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM([]byte(caCert)) + + utils.LogDebug("Using CA certificate for HTTPS requests") + client = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: caCertPool, + }, + }, + } + } else { + client = &http.Client{} + } + + return client.Do(req) +} + func generateBearerToken(apiKey string, subject string) string { // Define the claims for the JWT token claims := jwt.MapClaims{ diff --git a/src/core/api_test.go b/src/core/api_test.go index 0664d68..7ae521d 100644 --- a/src/core/api_test.go +++ b/src/core/api_test.go @@ -19,7 +19,7 @@ func TestFetchSnapshotVersion(t *testing.T) { fakeApiServer := givenApiResponse(http.StatusOK, responsePayload) defer fakeApiServer.Close() - apiService := NewApiService(SWITCHER_API_JWT_SECRET, fakeApiServer.URL) + apiService := NewApiService(SWITCHER_API_JWT_SECRET, fakeApiServer.URL, "") version, _ := apiService.FetchSnapshotVersion("domainId", "default") assert.Contains(t, version, "version", "Missing version in response") @@ -30,7 +30,7 @@ func TestFetchSnapshotVersion(t *testing.T) { fakeApiServer := givenApiResponse(http.StatusUnauthorized, `{ "error": "Invalid API token" }`) defer fakeApiServer.Close() - apiService := NewApiService("INVALID_KEY", fakeApiServer.URL) + apiService := NewApiService("INVALID_KEY", fakeApiServer.URL, "") version, _ := apiService.FetchSnapshotVersion("domainId", "default") assert.Contains(t, version, "Invalid API token") @@ -41,14 +41,14 @@ func TestFetchSnapshotVersion(t *testing.T) { fakeApiServer := givenApiResponse(http.StatusUnauthorized, responsePayload) defer fakeApiServer.Close() - apiService := NewApiService(SWITCHER_API_JWT_SECRET, fakeApiServer.URL) + apiService := NewApiService(SWITCHER_API_JWT_SECRET, fakeApiServer.URL, "") version, _ := apiService.FetchSnapshotVersion("INVALID_DOMAIN", "default") assert.Contains(t, version, "errors") }) t.Run("Should return error - invalid API URL", func(t *testing.T) { - apiService := NewApiService(config.GetEnv(SWITCHER_API_JWT_SECRET), "http://localhost:8080") + apiService := NewApiService(config.GetEnv(SWITCHER_API_JWT_SECRET), "http://localhost:8080", "") _, err := apiService.FetchSnapshotVersion("domainId", "default") assert.NotNil(t, err) @@ -61,7 +61,7 @@ func TestFetchSnapshot(t *testing.T) { fakeApiServer := givenApiResponse(http.StatusOK, responsePayload) defer fakeApiServer.Close() - apiService := NewApiService(SWITCHER_API_JWT_SECRET, fakeApiServer.URL) + apiService := NewApiService(SWITCHER_API_JWT_SECRET, fakeApiServer.URL, "") snapshot, _ := apiService.FetchSnapshot("domainId", "default") assert.Contains(t, snapshot, "domain", "Missing domain in snapshot") @@ -75,7 +75,7 @@ func TestFetchSnapshot(t *testing.T) { fakeApiServer := givenApiResponse(http.StatusOK, responsePayload) defer fakeApiServer.Close() - apiService := NewApiService(SWITCHER_API_JWT_SECRET, fakeApiServer.URL) + apiService := NewApiService(SWITCHER_API_JWT_SECRET, fakeApiServer.URL, "") snapshot, _ := apiService.FetchSnapshot("domainId", "default") data := apiService.NewDataFromJson([]byte(snapshot)) @@ -88,7 +88,7 @@ func TestFetchSnapshot(t *testing.T) { fakeApiServer := givenApiResponse(http.StatusUnauthorized, `{ "error": "Invalid API token" }`) defer fakeApiServer.Close() - apiService := NewApiService("INVALID_KEY", fakeApiServer.URL) + apiService := NewApiService("INVALID_KEY", fakeApiServer.URL, "") snapshot, _ := apiService.FetchSnapshot("domainId", "default") assert.Contains(t, snapshot, "Invalid API token") @@ -99,14 +99,14 @@ func TestFetchSnapshot(t *testing.T) { fakeApiServer := givenApiResponse(http.StatusUnauthorized, responsePayload) defer fakeApiServer.Close() - apiService := NewApiService(SWITCHER_API_JWT_SECRET, fakeApiServer.URL) + apiService := NewApiService(SWITCHER_API_JWT_SECRET, fakeApiServer.URL, "") snapshot, _ := apiService.FetchSnapshot("INVALID_DOMAIN", "default") assert.Contains(t, snapshot, "errors") }) t.Run("Should return error - invalid API URL", func(t *testing.T) { - apiService := NewApiService(config.GetEnv(SWITCHER_API_JWT_SECRET), "http://localhost:8080") + apiService := NewApiService(config.GetEnv(SWITCHER_API_JWT_SECRET), "http://localhost:8080", "") _, err := apiService.FetchSnapshot("domainId", "default") assert.NotNil(t, err) @@ -123,7 +123,7 @@ func TestPushChangesToAPI(t *testing.T) { }`) defer fakeApiServer.Close() - apiService := NewApiService(SWITCHER_API_JWT_SECRET, fakeApiServer.URL) + apiService := NewApiService(SWITCHER_API_JWT_SECRET, fakeApiServer.URL, "") // Test response, _ := apiService.PushChanges("domainId", diff) @@ -140,7 +140,7 @@ func TestPushChangesToAPI(t *testing.T) { fakeApiServer := givenApiResponse(http.StatusBadRequest, `{ "error": "Config already exists" }`) defer fakeApiServer.Close() - apiService := NewApiService(SWITCHER_API_JWT_SECRET, fakeApiServer.URL) + apiService := NewApiService(SWITCHER_API_JWT_SECRET, fakeApiServer.URL, "") // Test _, err := apiService.PushChanges("domainId", diff) @@ -157,7 +157,7 @@ func TestPushChangesToAPI(t *testing.T) { fakeApiServer := givenApiResponse(http.StatusUnauthorized, `{ "error": "Invalid API token" }`) defer fakeApiServer.Close() - apiService := NewApiService("[INVALID_KEY]", fakeApiServer.URL) + apiService := NewApiService("[INVALID_KEY]", fakeApiServer.URL, "") // Test _, err := apiService.PushChanges("domainId", diff) @@ -170,7 +170,7 @@ func TestPushChangesToAPI(t *testing.T) { t.Run("Should return error - API not accessible", func(t *testing.T) { // Given diff := givenDiffResult("default") - apiService := NewApiService("[SWITCHER_API_JWT_SECRET]", "http://localhost:8080") + apiService := NewApiService("[SWITCHER_API_JWT_SECRET]", "http://localhost:8080", "") // Test _, err := apiService.PushChanges("domainId", diff) @@ -180,6 +180,33 @@ func TestPushChangesToAPI(t *testing.T) { }) } +func TestFetchSnapshotWithCaCert(t *testing.T) { + t.Run("Should return snapshot", func(t *testing.T) { + responsePayload := utils.ReadJsonFromFile("../../resources/fixtures/api/default_snapshot.json") + fakeApiServer := givenApiResponse(http.StatusOK, responsePayload) + defer fakeApiServer.Close() + + apiService := NewApiService(SWITCHER_API_JWT_SECRET, fakeApiServer.URL, "../../resources/fixtures/api/dummy.pem") + snapshot, _ := apiService.FetchSnapshot("domainId", "default") + + assert.Contains(t, snapshot, "domain", "Missing domain in snapshot") + assert.Contains(t, snapshot, "version", "Missing version in snapshot") + assert.Contains(t, snapshot, "group", "Missing groups in snapshot") + assert.Contains(t, snapshot, "config", "Missing config in snapshot") + }) + + t.Run("Should return error - certificate not found", func(t *testing.T) { + responsePayload := utils.ReadJsonFromFile("../../resources/fixtures/api/default_snapshot.json") + fakeApiServer := givenApiResponse(http.StatusOK, responsePayload) + defer fakeApiServer.Close() + + apiService := NewApiService(SWITCHER_API_JWT_SECRET, fakeApiServer.URL, "invalid.pem") + _, err := apiService.FetchSnapshot("domainId", "default") + + assert.NotNil(t, err) + }) +} + // Helpers func givenDiffResult(environment string) model.DiffResult { diff --git a/src/core/core_test.go b/src/core/core_test.go index 522f4ea..dc90871 100644 --- a/src/core/core_test.go +++ b/src/core/core_test.go @@ -28,7 +28,7 @@ func setup() { mongoDb = db.InitDb() accountRepository := repository.NewAccountRepositoryMongo(mongoDb) - apiService := NewApiService("apiKey", "") + apiService := NewApiService("apiKey", "", "") comparatorService := NewComparatorService() coreHandler = NewCoreHandler(accountRepository, apiService, comparatorService) } diff --git a/src/server/app.go b/src/server/app.go index 8dd6ba8..b9fd9c2 100644 --- a/src/server/app.go +++ b/src/server/app.go @@ -79,6 +79,7 @@ func initCoreHandler(db *mongo.Database) *core.CoreHandler { apiService := core.NewApiService( config.GetEnv("SWITCHER_API_JWT_SECRET"), config.GetEnv("SWITCHER_API_URL"), + config.GetEnv("SWITCHER_API_CA_CERT"), ) coreHandler := core.NewCoreHandler(accountRepository, apiService, comparatorService)