From 5f5f032435e6af6ea47e63aa2285bca151e4b335 Mon Sep 17 00:00:00 2001 From: yazansalti Date: Mon, 6 Oct 2025 13:43:33 +0400 Subject: [PATCH 1/4] feat: Add security audit logging --- cmd/start.go | 7 +- internal/config/config.go | 47 +- internal/config/config_test.go | 8 +- internal/config/types.go | 12 +- internal/server/audit_logger.go | 421 ++++++++++++++++++ internal/server/audit_middleware_test.go | 153 +++++++ internal/server/audit_options.go | 110 +++++ internal/server/authorization_test.go | 53 ++- internal/server/handlers_accounts.go | 148 ++++-- internal/server/handlers_accounts_test.go | 61 ++- .../handlers_certificate_authorities.go | 172 +++++-- .../handlers_certificate_authorities_test.go | 88 +++- .../server/handlers_certificate_requests.go | 193 +++++--- .../handlers_certificate_requests_test.go | 93 +++- internal/server/handlers_config.go | 2 +- internal/server/handlers_config_test.go | 18 +- internal/server/handlers_login.go | 21 +- internal/server/handlers_login_test.go | 77 +++- internal/server/handlers_status.go | 4 +- internal/server/handlers_status_test.go | 2 +- internal/server/middleware.go | 131 +++++- internal/server/router.go | 60 +-- internal/server/server.go | 8 +- internal/server/server_test.go | 9 +- internal/server/types.go | 3 +- internal/testutils/server_test_utils.go | 29 +- 26 files changed, 1620 insertions(+), 310 deletions(-) create mode 100644 internal/server/audit_logger.go create mode 100644 internal/server/audit_middleware_test.go create mode 100644 internal/server/audit_options.go diff --git a/cmd/start.go b/cmd/start.go index 51da7fda..104dddbd 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -28,14 +28,14 @@ https://canonical-notary.readthedocs-hosted.com/en/latest/reference/config_file/ if err != nil { log.Fatalf("couldn't create app context: %s", err) } - l := appContext.Logger + l := appContext.SystemLogger // Initialize the database connection db, err := db.NewDatabase(&db.DatabaseOpts{ DatabasePath: appContext.DBPath, ApplyMigrations: appContext.ApplyMigrations, Backend: appContext.EncryptionBackend, - Logger: appContext.Logger, + Logger: appContext.SystemLogger, }) if err != nil { l.Fatal("couldn't initialize database", zap.Error(err)) @@ -49,7 +49,8 @@ https://canonical-notary.readthedocs-hosted.com/en/latest/reference/config_file/ Database: db, ExternalHostname: appContext.ExternalHostname, EnablePebbleNotifications: appContext.PebbleNotificationsEnabled, - Logger: appContext.Logger, + SystemLogger: appContext.SystemLogger, + AuditLogger: appContext.AuditLogger, PublicConfig: appContext.PublicConfig, }) if err != nil { diff --git a/internal/config/config.go b/internal/config/config.go index 8cbac6e6..ff628f0f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -36,13 +36,28 @@ func CreateAppContext(cmdFlags *pflag.FlagSet, configFilePath string) (*NotaryAp return nil, err } - // initialize logger - logger, err := initializeLogger(cfg.Sub("logging")) + // initialize system logger + systemLogger, err := initializeLogger( + cfg.GetString("logging.system.level"), + cfg.GetString("logging.system.output"), + cfg.GetString("logging.system.path"), + ) if err != nil { - return nil, fmt.Errorf("couldn't initialize logger: %w", err) + return nil, fmt.Errorf("couldn't initialize system logger: %w", err) + } + + // initialize audit logger + // Audit logs are always at INFO level + auditLogger, err := initializeLogger( + "info", + cfg.GetString("logging.audit.output"), + cfg.GetString("logging.audit.path"), + ) + if err != nil { + return nil, fmt.Errorf("couldn't initialize audit logger: %w", err) } // initialize encryption backend - backendType, backend, err := initializeEncryptionBackend(cfg.Sub("encryption_backend"), logger) + backendType, backend, err := initializeEncryptionBackend(cfg.Sub("encryption_backend"), systemLogger) if err != nil { return nil, fmt.Errorf("couldn't initialize encryption backend: %w", err) } @@ -55,7 +70,8 @@ func CreateAppContext(cmdFlags *pflag.FlagSet, configFilePath string) (*NotaryAp appContext.TLSCertificate = cert appContext.TLSPrivateKey = key - appContext.Logger = logger + appContext.SystemLogger = systemLogger + appContext.AuditLogger = auditLogger appContext.EncryptionBackend = backend appContext.EncryptionBackendType = backendType appContext.PublicConfig = &PublicConfigData{ @@ -85,6 +101,7 @@ func initializeServerConfig(cmdFlags *pflag.FlagSet, configFilePath string) (*vi v.SetDefault("external_hostname", "localhost") v.SetDefault("logging.system.level", "debug") v.SetDefault("logging.system.output", "stdout") + v.SetDefault("logging.audit.output", "stdout") if configFilePath == "" { return nil, errors.New("config file path not provided") @@ -209,17 +226,27 @@ func initializeEncryptionBackend(encryptionCfg *viper.Viper, logger *zap.Logger) } } -// initializeLogger creates and configures a logger based on the configuration. -func initializeLogger(cfg *viper.Viper) (*zap.Logger, error) { +// initializeLogger creates and configures a logger based on the provided parameters. +// output can be "stdout", "stderr", or "file" +// path is required when output is "file" +func initializeLogger(level, output, path string) (*zap.Logger, error) { zapConfig := zap.NewProductionConfig() - logLevel, err := zapcore.ParseLevel(cfg.GetString("system.level")) + logLevel, err := zapcore.ParseLevel(level) if err != nil { return nil, fmt.Errorf("invalid log level: %w", err) } - - zapConfig.OutputPaths = []string{cfg.GetString("system.output")} zapConfig.Level.SetLevel(logLevel) + + if output == "file" { + if path == "" { + return nil, fmt.Errorf("path is required when output is 'file'") + } + zapConfig.OutputPaths = []string{path} + } else { + zapConfig.OutputPaths = []string{output} + } + zapConfig.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder logger, err := zapConfig.Build() diff --git a/internal/config/config_test.go b/internal/config/config_test.go index ebc3475d..795f6f81 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -34,7 +34,8 @@ func TestValidConfig(t *testing.T) { DBPath: "./notary.db", Port: 8000, PebbleNotificationsEnabled: false, - Logger: nil, + SystemLogger: nil, + AuditLogger: nil, EncryptionBackend: encryption_backend.NoEncryptionBackend{}, EncryptionBackendType: config.EncryptionBackendTypeNone, }}, // This case tests the expected default values for missing fields are filled correctly @@ -51,7 +52,8 @@ func TestValidConfig(t *testing.T) { DBPath: "./notary.db", Port: 8000, PebbleNotificationsEnabled: false, - Logger: nil, + SystemLogger: nil, + AuditLogger: nil, EncryptionBackend: encryption_backend.NoEncryptionBackend{}, EncryptionBackendType: config.EncryptionBackendTypeNone, }}, // This case tests that the variables from the yaml are correctly copied to the final config @@ -67,7 +69,7 @@ func TestValidConfig(t *testing.T) { t.Errorf("ValidateConfig(%q) = %v, want nil", "config.yaml", err) return } - if !cmp.Equal(gotCfg, tc.wantCfg, cmpopts.IgnoreFields(config.NotaryAppContext{}, "Logger")) { + if !cmp.Equal(gotCfg, tc.wantCfg, cmpopts.IgnoreFields(config.NotaryAppContext{}, "SystemLogger", "AuditLogger")) { t.Errorf("ValidateConfig returned unexpected diff (-want+got):\n%v", cmp.Diff(tc.wantCfg, gotCfg)) } }) diff --git a/internal/config/types.go b/internal/config/types.go index e2718ae3..f5965459 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -43,10 +43,17 @@ type EncryptionBackendConfigYaml map[string]NamedBackendConfigYaml type SystemLoggingConfigYaml struct { Level string `yaml:"level"` Output string `yaml:"output"` + Path string `yaml:"path"` +} + +type AuditLoggingConfigYaml struct { + Output string `yaml:"output"` + Path string `yaml:"path"` } type LoggingConfigYaml struct { System SystemLoggingConfigYaml `yaml:"system"` + Audit AuditLoggingConfigYaml `yaml:"audit"` } type ConfigYAML struct { @@ -111,8 +118,9 @@ type NotaryAppContext struct { // Send pebble notifications if enabled. Read more at github.com/canonical/pebble PebbleNotificationsEnabled bool - // Options for the logger - Logger *zap.Logger + // Options for the loggers + SystemLogger *zap.Logger + AuditLogger *zap.Logger // Encryption backend to be used for encrypting and decrypting sensitive data EncryptionBackendType diff --git a/internal/server/audit_logger.go b/internal/server/audit_logger.go new file mode 100644 index 00000000..ee2ac166 --- /dev/null +++ b/internal/server/audit_logger.go @@ -0,0 +1,421 @@ +package server + +import ( + "fmt" + + "go.uber.org/zap" +) + +// AuditLogger provides structured logging for security audit events. +type AuditLogger struct { + logger *zap.Logger +} + +// NewAuditLogger creates a new audit logger with a named logger for easy filtering. +func NewAuditLogger(logger *zap.Logger) *AuditLogger { + return &AuditLogger{ + logger: logger.Named("audit"), + } +} + +// Authentication Events + +// LoginSuccess logs a successful user authentication. +func (a *AuditLogger) LoginSuccess(username string, opts ...AuditOption) { + ctx := &auditContext{severity: SeverityInfo} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", fmt.Sprintf("authn_login_success:%s", username)), + zap.String("username", username), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Info(fmt.Sprintf("User %s login successfully", username), fields...) +} + +// LoginFailed logs a failed authentication attempt. +func (a *AuditLogger) LoginFailed(username string, opts ...AuditOption) { + ctx := &auditContext{severity: SeverityWarn} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", fmt.Sprintf("authn_login_fail:%s", username)), + zap.String("username", username), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Warn(fmt.Sprintf("User %s login failed", username), fields...) +} + +// TokenCreated logs when a JWT authentication token is created. +func (a *AuditLogger) TokenCreated(username string, opts ...AuditOption) { + ctx := &auditContext{severity: SeverityInfo} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", fmt.Sprintf("authn_token_created:%s", username)), + zap.String("username", username), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Info(fmt.Sprintf("A token has been created for %s", username), fields...) +} + +// PasswordChanged logs when a user's password is successfully changed. +func (a *AuditLogger) PasswordChanged(username string, opts ...AuditOption) { + ctx := &auditContext{severity: SeverityInfo} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", fmt.Sprintf("authn_password_change:%s", username)), + zap.String("username", username), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Info(fmt.Sprintf("User %s has successfully changed their password", username), fields...) +} + +// PasswordChangeFailed logs when a password change attempt fails. +func (a *AuditLogger) PasswordChangeFailed(username string, opts ...AuditOption) { + ctx := &auditContext{severity: SeverityCritical} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", fmt.Sprintf("authn_password_change_fail:%s", username)), + zap.String("username", username), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Error(fmt.Sprintf("User %s failed to change their password", username), fields...) +} + +// Certificate Events + +// CertificateRequested logs when a certificate signing request is created. +func (a *AuditLogger) CertificateRequested(csrID string, caID int, opts ...AuditOption) { + ctx := &auditContext{severity: SeverityInfo} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "cert_requested"), + zap.String("csr_id", csrID), + zap.Int("ca_id", caID), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Info("Certificate signing request created", fields...) +} + +// CertificateIssued logs when a certificate is successfully issued. +func (a *AuditLogger) CertificateIssued(csrID string, caID int, opts ...AuditOption) { + ctx := &auditContext{severity: SeverityInfo} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "cert_issued"), + zap.String("csr_id", csrID), + zap.Int("ca_id", caID), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Info("Certificate issued", fields...) +} + +// CertificateRejected logs when a certificate request is rejected. +func (a *AuditLogger) CertificateRejected(csrID string, caID int, opts ...AuditOption) { + ctx := &auditContext{severity: SeverityWarn} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "cert_rejected"), + zap.String("csr_id", csrID), + zap.Int("ca_id", caID), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Warn("Certificate request rejected", fields...) +} + +// Certificate Authority Events + +// CACreated logs when a new certificate authority is created. +func (a *AuditLogger) CACreated(caID int, commonName string, opts ...AuditOption) { + ctx := &auditContext{severity: SeverityInfo} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "ca_created"), + zap.Int("ca_id", caID), + zap.String("common_name", commonName), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Info("Certificate Authority created", fields...) +} + +// CADeleted logs when a certificate authority is deleted. +func (a *AuditLogger) CADeleted(caID int, commonName string, opts ...AuditOption) { + ctx := &auditContext{severity: SeverityWarn} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "ca_deleted"), + zap.Int("ca_id", caID), + zap.String("common_name", commonName), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Warn("Certificate Authority deleted", fields...) +} + +// CAUpdated logs when a certificate authority enabled status is changed. +func (a *AuditLogger) CAUpdated(caID string, enabled bool, opts ...AuditOption) { + ctx := &auditContext{severity: SeverityWarn} + for _, opt := range opts { + opt(ctx) + } + + status := "disabled" + if enabled { + status = "enabled" + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "ca_updated"), + zap.String("ca_id", caID), + zap.String("status", status), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Warn("Certificate Authority updated", fields...) +} + +// CACertificateUploaded logs when a CA certificate chain is uploaded. +func (a *AuditLogger) CACertificateUploaded(caID string, opts ...AuditOption) { + ctx := &auditContext{severity: SeverityInfo} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "ca_cert_uploaded"), + zap.String("ca_id", caID), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Info("Certificate uploaded to Certificate Authority", fields...) +} + +// CACertificateRevoked logs when a CA certificate is revoked. +func (a *AuditLogger) CACertificateRevoked(caID string, opts ...AuditOption) { + ctx := &auditContext{severity: SeverityWarn} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "ca_cert_revoked"), + zap.String("ca_id", caID), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Warn("Certificate Authority certificate revoked", fields...) +} + +// User Management Events + +// UserCreated logs when a new user account is created. +func (a *AuditLogger) UserCreated(username string, roleID int, opts ...AuditOption) { + ctx := &auditContext{severity: SeverityWarn} + for _, opt := range opts { + opt(ctx) + } + + roleName := fmt.Sprintf("role_%d", roleID) + if roleID == 1 { + roleName = "admin" + } else if roleID == 2 { + roleName = "user" + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", fmt.Sprintf("user_created:%s,%s", username, roleName)), + zap.String("username", username), + zap.Int("role_id", roleID), + zap.String("role_name", roleName), + } + fields = append(fields, ctx.toZapFields()...) + + description := fmt.Sprintf("User account %s created with role %s", username, roleName) + if ctx.actor != "" { + description = fmt.Sprintf("User %s created user %s with role %s", ctx.actor, username, roleName) + } + a.logger.Warn(description, fields...) +} + +// UserDeleted logs when a user account is deleted. +func (a *AuditLogger) UserDeleted(username string, opts ...AuditOption) { + ctx := &auditContext{severity: SeverityWarn} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "user_deleted"), + zap.String("username", username), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Warn("User account deleted", fields...) +} + + +// UserUpdated logs when a user account is updated (e.g., password changed). +func (a *AuditLogger) UserUpdated(username, updateType string, opts ...AuditOption) { + ctx := &auditContext{severity: SeverityWarn} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", fmt.Sprintf("user_updated:%s,%s", username, updateType)), + zap.String("username", username), + zap.String("update_type", updateType), + } + fields = append(fields, ctx.toZapFields()...) + + description := fmt.Sprintf("User %s updated with %s", username, updateType) + if ctx.actor != "" { + description = fmt.Sprintf("User %s updated user %s with %s", ctx.actor, username, updateType) + } + a.logger.Warn(description, fields...) +} + +// Access Control Events + +// AccessDenied logs when a user is denied access to a resource. +func (a *AuditLogger) AccessDenied(username, resource, action string, opts ...AuditOption) { + ctx := &auditContext{severity: SeverityCritical} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", fmt.Sprintf("authz_fail:%s,%s", username, resource)), + zap.String("username", username), + zap.String("resource", resource), + zap.String("action", action), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Error("Access denied", fields...) +} + +// UnauthorizedAccess logs when an unauthorized access attempt is detected. +func (a *AuditLogger) UnauthorizedAccess(opts ...AuditOption) { + ctx := &auditContext{severity: SeverityCritical} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "authz_fail"), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Error("Unauthorized access attempt", fields...) +} + +// API Action Events + +// APIAction logs any action performed against the API. +func (a *AuditLogger) APIAction(action string, opts ...AuditOption) { + ctx := &auditContext{severity: SeverityInfo} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "audit"), + zap.String("event", "api_action"), + zap.String("action", action), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Info("API action performed", fields...) +} + +// CSR and Certificate Request lifecycle events (deletions and revocations) + +// CertificateRequestDeleted logs when a CSR is deleted. +func (a *AuditLogger) CertificateRequestDeleted(csrID string, opts ...AuditOption) { + ctx := &auditContext{severity: SeverityWarn} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "cert_request_deleted"), + zap.String("csr_id", csrID), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Warn("Certificate request deleted", fields...) +} + +// CertificateRevoked logs when a certificate (for a CSR) is revoked. +func (a *AuditLogger) CertificateRevoked(csrID string, opts ...AuditOption) { + ctx := &auditContext{severity: SeverityWarn} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "cert_revoked"), + zap.String("csr_id", csrID), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Warn("Certificate revoked", fields...) +} diff --git a/internal/server/audit_middleware_test.go b/internal/server/audit_middleware_test.go new file mode 100644 index 00000000..54e8cc5f --- /dev/null +++ b/internal/server/audit_middleware_test.go @@ -0,0 +1,153 @@ +package server_test + +import ( + "bytes" + "encoding/json" + "net/http" + "testing" + + tu "github.com/canonical/notary/internal/testutils" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest/observer" +) + +// Helper to build a router with an observed audit logger +// Use testutils helper for observed server, keep function name for local tests +// Deprecated local helper; kept for compatibility in this file. +func newObservedRouter(t *testing.T) *observer.ObservedLogs { t.Helper(); _, logs := tu.MustPrepareServer(t); return logs } + +func findStringField(entry observer.LoggedEntry, key string) string { + for _, f := range entry.Context { + if f.Key == key { + switch f.Type { + case zapcore.StringType: + return f.String + } + } + } + return "" +} + +func TestAuditMiddleware_LogsFailureAndReason(t *testing.T) { + ts, logs := tu.MustPrepareServer(t) + // Clear any initialization noise + _ = logs.TakeAll() + + // Unauthorized GET (no token) + req, err := http.NewRequest("GET", ts.URL+"/api/v1/certificate_requests", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + res, err := ts.Client().Do(req) + if err != nil { + t.Fatalf("do request: %v", err) + } + if res.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected %d, got %d", http.StatusUnauthorized, res.StatusCode) + } + + entries := logs.TakeAll() + var haveAuthzFail, haveAPIFailed bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + switch findStringField(e, "event") { + case "authz_fail": + haveAuthzFail = true + case "api_action": + if findStringField(e, "action") == "GET certificate_requests (failed)" { + haveAPIFailed = true + } + } + } + if !haveAuthzFail { + t.Fatalf("expected UnauthorizedAccess audit entry (event=authz_fail)") + } + if !haveAPIFailed { + t.Fatalf("expected APIAction failure audit entry for GET certificate_requests") + } +} + +func TestAuditMiddleware_LogsSuccessfulRead(t *testing.T) { + ts, logs := tu.MustPrepareServer(t) + + // Create first user (open route: first user doesn't require token) + createBody := map[string]any{ + "email": "admin@example.com", + "password": "Admin123", + "role_id": 0, + } + payload, _ := json.Marshal(createBody) + req, err := http.NewRequest("POST", ts.URL+"/api/v1/accounts", bytes.NewReader(payload)) + if err != nil { + t.Fatalf("new request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + res, err := ts.Client().Do(req) + if err != nil { + t.Fatalf("do request: %v", err) + } + if res.StatusCode != http.StatusCreated { + t.Fatalf("expected %d, got %d", http.StatusCreated, res.StatusCode) + } + + // Login to obtain JWT + loginBody := map[string]any{ + "email": "admin@example.com", + "password": "Admin123", + } + loginPayload, _ := json.Marshal(loginBody) + req, err = http.NewRequest("POST", ts.URL+"/login", bytes.NewReader(loginPayload)) + if err != nil { + t.Fatalf("new request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + res, err = ts.Client().Do(req) + if err != nil { + t.Fatalf("do request: %v", err) + } + if res.StatusCode != http.StatusOK { + t.Fatalf("expected %d, got %d", http.StatusOK, res.StatusCode) + } + var loginResp struct { + Result struct{ Token string `json:"token"` } + } + if err := json.NewDecoder(res.Body).Decode(&loginResp); err != nil { + t.Fatalf("decode login response: %v", err) + } + + // Clear logs so we only capture the read success + _ = logs.TakeAll() + + // Authenticated GET (should log api_action success) + req, err = http.NewRequest("GET", ts.URL+"/api/v1/certificate_requests", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + req.Header.Set("Authorization", "Bearer "+loginResp.Result.Token) + res, err = ts.Client().Do(req) + if err != nil { + t.Fatalf("do request: %v", err) + } + if res.StatusCode != http.StatusOK { + t.Fatalf("expected %d, got %d", http.StatusOK, res.StatusCode) + } + + entries := logs.TakeAll() + var haveAPISuccess bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + if findStringField(e, "event") == "api_action" && findStringField(e, "action") == "GET certificate_requests" { + haveAPISuccess = true + break + } + } + if !haveAPISuccess { + t.Fatalf("expected APIAction success audit entry for GET certificate_requests") + } +} + + diff --git a/internal/server/audit_options.go b/internal/server/audit_options.go new file mode 100644 index 00000000..3261d06b --- /dev/null +++ b/internal/server/audit_options.go @@ -0,0 +1,110 @@ +package server + +import ( + "net/http" + + "go.uber.org/zap" +) + +// AuditOption is a functional option for adding context to audit log events. +type AuditOption func(*auditContext) + +// SecuritySeverity represents the severity level for audit events using standard log levels +type SecuritySeverity string + +const ( + SeverityDebug SecuritySeverity = "DEBUG" + SeverityInfo SecuritySeverity = "INFO" + SeverityWarn SecuritySeverity = "WARN" + SeverityError SecuritySeverity = "ERROR" + SeverityCritical SecuritySeverity = "CRITICAL" +) + +// auditContext holds optional contextual information for audit events. +type auditContext struct { + actor string + ipAddress string + reason string + userAgent string + path string + method string + resourceType string + resourceID string + severity SecuritySeverity +} + +// WithActor specifies who performed the action (e.g., username, email). +func WithActor(actor string) AuditOption { + return func(ctx *auditContext) { + ctx.actor = actor + } +} + +// WithReason specifies the reason for an action (typically used for failures). +func WithReason(reason string) AuditOption { + return func(ctx *auditContext) { + ctx.reason = reason + } +} + +// WithResourceType specifies the type of resource being acted upon (e.g., "certificate", "user", "ca"). +func WithResourceType(resourceType string) AuditOption { + return func(ctx *auditContext) { + ctx.resourceType = resourceType + } +} + +// WithResourceID specifies the ID of the resource being acted upon. +func WithResourceID(id string) AuditOption { + return func(ctx *auditContext) { + ctx.resourceID = id + } +} + +// WithRequest is a convenience function that extracts multiple fields from an HTTP request. +// It captures: remote IP, user agent, path, and method. Kept simple by design. +func WithRequest(r *http.Request) AuditOption { + return func(ctx *auditContext) { + ctx.ipAddress = r.RemoteAddr + ctx.userAgent = r.UserAgent() + ctx.path = r.URL.Path + ctx.method = r.Method + } +} + +// toZapFields converts the audit context into zap fields. +// Only non-empty fields are included in the output. +func (ctx *auditContext) toZapFields() []zap.Field { + fields := []zap.Field{} + + if ctx.actor != "" { + fields = append(fields, zap.String("actor", ctx.actor)) + } + if ctx.ipAddress != "" { + fields = append(fields, zap.String("ip_address", ctx.ipAddress)) + } + if ctx.reason != "" { + fields = append(fields, zap.String("reason", ctx.reason)) + } + if ctx.userAgent != "" { + fields = append(fields, zap.String("user_agent", ctx.userAgent)) + } + if ctx.path != "" { + fields = append(fields, zap.String("path", ctx.path)) + } + if ctx.method != "" { + fields = append(fields, zap.String("method", ctx.method)) + } + if ctx.resourceType != "" { + fields = append(fields, zap.String("resource_type", ctx.resourceType)) + } + if ctx.resourceID != "" { + fields = append(fields, zap.String("resource_id", ctx.resourceID)) + } + if ctx.severity != "" { + fields = append(fields, zap.String("severity", string(ctx.severity))) + } + + return fields +} + diff --git a/internal/server/authorization_test.go b/internal/server/authorization_test.go index 0cbef5ba..7e1ba55b 100644 --- a/internal/server/authorization_test.go +++ b/internal/server/authorization_test.go @@ -10,7 +10,7 @@ import ( ) func TestAuthorizationNoAuth(t *testing.T) { - ts := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) client := ts.Client() testCases := []struct { @@ -47,7 +47,7 @@ func TestAuthorizationNoAuth(t *testing.T) { } func TestAuthorizationAdminAuthorized(t *testing.T) { - ts := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") tu.MustPrepareAccount(t, ts, "whatever@canonical.com", tu.RoleCertificateManager, adminToken) client := ts.Client() @@ -91,7 +91,7 @@ func TestAuthorizationAdminAuthorized(t *testing.T) { } func TestAuthorizationAdminUnAuthorized(t *testing.T) { - ts := tu.MustPrepareServer(t) +ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") nonAdminToken := tu.MustPrepareAccount(t, ts, "whatever@canonical.com", tu.RoleCertificateManager, adminToken) client := ts.Client() @@ -111,7 +111,7 @@ func TestAuthorizationAdminUnAuthorized(t *testing.T) { } func TestAuthorizationCertificateManagerAuthorized(t *testing.T) { - ts := tu.MustPrepareServer(t) +ts, logs := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") certManagerToken := tu.MustPrepareAccount(t, ts, "testuser@canonical.com", tu.RoleCertificateManager, adminToken) client := ts.Client() @@ -165,8 +165,9 @@ func TestAuthorizationCertificateManagerAuthorized(t *testing.T) { status: http.StatusAccepted, }, } - for _, tC := range testCases { + for _, tC := range testCases { t.Run(tC.desc, func(t *testing.T) { + if tC.desc == "certificate manager can change self password with /me" { _ = logs.TakeAll() } req, err := http.NewRequest(tC.method, ts.URL+tC.path, strings.NewReader(tC.data)) if err != nil { t.Fatal(err) @@ -179,12 +180,27 @@ func TestAuthorizationCertificateManagerAuthorized(t *testing.T) { if res.StatusCode != tC.status { t.Errorf("expected status code %d, got %d", tC.status, res.StatusCode) } + if tC.desc == "certificate manager can change self password with /me" { + entries := logs.TakeAll() + var havePwdChanged, haveUserUpdated bool + for _, e := range entries { + if e.LoggerName != "audit" { continue } + switch findStringField(e, "event") { + case "authn_password_change:testuser@canonical.com": + havePwdChanged = true + case "user_updated:testuser@canonical.com,password_change": + haveUserUpdated = true + } + } + if !havePwdChanged { t.Errorf("expected PasswordChanged audit entry for self change") } + if !haveUserUpdated { t.Errorf("expected UserUpdated audit entry for self change") } + } }) } } func TestAuthorizationCertificateManagerUnauthorized(t *testing.T) { - ts := tu.MustPrepareServer(t) + ts, logs := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") certManagerToken := tu.MustPrepareAccount(t, ts, "whatever@canonical.com", tu.RoleCertificateManager, adminToken) client := ts.Client() @@ -218,8 +234,9 @@ func TestAuthorizationCertificateManagerUnauthorized(t *testing.T) { status: http.StatusForbidden, }, } - for _, tC := range testCases { + for _, tC := range testCases { t.Run(tC.desc, func(t *testing.T) { + _ = logs.TakeAll() req, err := http.NewRequest(tC.method, ts.URL+tC.path, strings.NewReader(tC.data)) if err != nil { t.Fatal(err) @@ -232,12 +249,26 @@ func TestAuthorizationCertificateManagerUnauthorized(t *testing.T) { if res.StatusCode != tC.status { t.Errorf("expected status code %d, got %d", tC.status, res.StatusCode) } + if tC.status == http.StatusForbidden { + entries := logs.TakeAll() + var haveAuthzFail bool + for _, e := range entries { + if e.LoggerName != "audit" { continue } + if strings.HasPrefix(findStringField(e, "event"), "authz_fail:") { + haveAuthzFail = true + break + } + } + if !haveAuthzFail { + t.Errorf("expected audit authz_fail for %s %s", tC.method, tC.path) + } + } }) } } func TestAuthorizationCertificateRequestorAuthorized(t *testing.T) { - ts := tu.MustPrepareServer(t) +ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") certRequestorToken := tu.MustPrepareAccount(t, ts, "testuser@canonical.com", tu.RoleCertificateRequestor, adminToken) client := ts.Client() @@ -315,7 +346,7 @@ func TestAuthorizationCertificateRequestorAuthorized(t *testing.T) { } func TestAuthorizationCertificateRequestorUnauthorized(t *testing.T) { - ts := tu.MustPrepareServer(t) +ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") certRequestorToken := tu.MustPrepareAccount(t, ts, "testuser@canonical.com", tu.RoleCertificateRequestor, adminToken) client := ts.Client() @@ -436,7 +467,7 @@ func TestAuthorizationCertificateRequestorUnauthorized(t *testing.T) { } func TestAuthorizationReadOnlyAuthorized(t *testing.T) { - ts := tu.MustPrepareServer(t) +ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") readOnlyToken := tu.MustPrepareAccount(t, ts, "testuser@canonical.com", tu.RoleReadOnly, adminToken) client := ts.Client() @@ -531,7 +562,7 @@ func TestAuthorizationReadOnlyAuthorized(t *testing.T) { } func TestAuthorizationReadOnlyUnauthorized(t *testing.T) { - ts := tu.MustPrepareServer(t) +ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") readOnlyToken := tu.MustPrepareAccount(t, ts, "testuser@canonical.com", tu.RoleReadOnly, adminToken) client := ts.Client() diff --git a/internal/server/handlers_accounts.go b/internal/server/handlers_accounts.go index 06dc65e8..bf409594 100644 --- a/internal/server/handlers_accounts.go +++ b/internal/server/handlers_accounts.go @@ -84,7 +84,7 @@ func ListAccounts(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { accounts, err := env.DB.ListUsers() if err != nil { - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } accountsResponse := make([]GetAccountResponse, len(accounts)) @@ -97,7 +97,7 @@ func ListAccounts(env *HandlerConfig) http.HandlerFunc { } err = writeResponse(w, accountsResponse, http.StatusOK) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -113,24 +113,24 @@ func GetAccount(env *HandlerConfig) http.HandlerFunc { if id == "me" { claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) if headerErr != nil { - writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.Logger) + writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) } account, err = env.DB.GetUser(db.ByEmail(claims.Email)) } else { var idNum int64 idNum, err = strconv.ParseInt(id, 10, 64) if err != nil { - writeError(w, http.StatusBadRequest, "Invalid ID", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid ID", err, env.SystemLogger) return } account, err = env.DB.GetUser(db.ByUserID(idNum)) } if err != nil { if errors.Is(err, db.ErrNotFound) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } accountResponse := GetAccountResponse{ @@ -140,7 +140,7 @@ func GetAccount(env *HandlerConfig) http.HandlerFunc { } err = writeResponse(w, accountResponse, http.StatusOK) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -151,27 +151,40 @@ func CreateAccount(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var createAccountParams CreateAccountParams if err := json.NewDecoder(r.Body).Decode(&createAccountParams); err != nil { - writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.SystemLogger) return } valid, err := createAccountParams.IsValid() if !valid { - writeError(w, http.StatusBadRequest, fmt.Errorf("Invalid request: %s", err).Error(), err, env.Logger) + writeError(w, http.StatusBadRequest, fmt.Errorf("Invalid request: %s", err).Error(), err, env.SystemLogger) return } newUserID, err := env.DB.CreateUser(createAccountParams.Email, createAccountParams.Password, db.RoleID(createAccountParams.RoleID)) if err != nil { if errors.Is(err, db.ErrAlreadyExists) { - writeError(w, http.StatusBadRequest, "account with given email already exists", err, env.Logger) + writeError(w, http.StatusBadRequest, "account with given email already exists", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } + + var actor string + claims, claimsErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) + if claimsErr == nil { + actor = claims.Email + } + + opts := []AuditOption{WithRequest(r)} + if actor != "" { + opts = append(opts, WithActor(actor)) + } + env.AuditLogger.UserCreated(createAccountParams.Email, int(createAccountParams.RoleID), opts...) + successResponse := CreateSuccessResponse{Message: "success", ID: newUserID} err = writeResponse(w, successResponse, http.StatusCreated) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -184,36 +197,49 @@ func DeleteAccount(env *HandlerConfig) http.HandlerFunc { id := r.PathValue("id") idInt, err := strconv.ParseInt(id, 10, 64) if err != nil { - writeError(w, http.StatusBadRequest, "Invalid ID", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid ID", err, env.SystemLogger) return } account, err := env.DB.GetUser(db.ByUserID(idInt)) if err != nil { if errors.Is(err, db.ErrNotFound) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } if account.RoleID == db.RoleID(RoleAdmin) { err = errors.New("deleting an Admin account is not allowed") - writeError(w, http.StatusBadRequest, "deleting an Admin account is not allowed.", err, env.Logger) + writeError(w, http.StatusBadRequest, "deleting an Admin account is not allowed.", err, env.SystemLogger) + return + } + + claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) + if err != nil { + writeError(w, http.StatusUnauthorized, "Unauthorized", err, env.SystemLogger) return } + err = env.DB.DeleteUser(db.ByUserID(idInt)) if err != nil { if errors.Is(err, db.ErrNotFound) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } + + env.AuditLogger.UserDeleted(account.Email, + WithActor(claims.Email), + WithRequest(r), + ) + successResponse := SuccessResponse{Message: "success"} err = writeResponse(w, successResponse, http.StatusAccepted) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -225,33 +251,75 @@ func ChangeAccountPassword(env *HandlerConfig) http.HandlerFunc { var idNum int64 idInt, err := strconv.ParseInt(id, 10, 64) if err != nil { - writeError(w, http.StatusBadRequest, "Invalid ID", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid ID", err, env.SystemLogger) return } idNum = idInt + + targetAccount, err := env.DB.GetUser(db.ByUserID(idNum)) + if err != nil { + if errors.Is(err, db.ErrNotFound) { + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) + return + } + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) + return + } + + claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) + if err != nil { + writeError(w, http.StatusUnauthorized, "Unauthorized", err, env.SystemLogger) + return + } + var changeAccountParams ChangeAccountParams if err := json.NewDecoder(r.Body).Decode(&changeAccountParams); err != nil { - writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.SystemLogger) return } valid, err := changeAccountParams.IsValid() if !valid { - writeError(w, http.StatusBadRequest, fmt.Errorf("Invalid request: %s", err).Error(), err, env.Logger) + env.AuditLogger.PasswordChangeFailed(targetAccount.Email, + WithActor(claims.Email), + WithRequest(r), + WithReason(err.Error()), + ) + writeError(w, http.StatusBadRequest, fmt.Errorf("Invalid request: %s", err).Error(), err, env.SystemLogger) return } err = env.DB.UpdateUserPassword(db.ByUserID(idNum), changeAccountParams.Password) if err != nil { if errors.Is(err, db.ErrNotFound) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + env.AuditLogger.PasswordChangeFailed(targetAccount.Email, + WithActor(claims.Email), + WithRequest(r), + WithReason("user not found"), + ) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + env.AuditLogger.PasswordChangeFailed(targetAccount.Email, + WithActor(claims.Email), + WithRequest(r), + WithReason("database error"), + ) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } + + env.AuditLogger.PasswordChanged(targetAccount.Email, + WithActor(claims.Email), + WithRequest(r), + ) + env.AuditLogger.UserUpdated(targetAccount.Email, "password_change", + WithActor(claims.Email), + WithRequest(r), + ) + successResponse := SuccessResponse{Message: "success"} err = writeResponse(w, successResponse, http.StatusCreated) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -262,38 +330,54 @@ func ChangeMyPassword(env *HandlerConfig) http.HandlerFunc { var idNum int64 claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) if err != nil { - writeError(w, http.StatusUnauthorized, "Unauthorized", err, env.Logger) + writeError(w, http.StatusUnauthorized, "Unauthorized", err, env.SystemLogger) return } account, err := env.DB.GetUser(db.ByEmail(claims.Email)) if err != nil { - writeError(w, http.StatusUnauthorized, "Unauthorized", err, env.Logger) + writeError(w, http.StatusUnauthorized, "Unauthorized", err, env.SystemLogger) return } idNum = account.ID var changeAccountParams ChangeAccountParams if err := json.NewDecoder(r.Body).Decode(&changeAccountParams); err != nil { - writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.SystemLogger) return } valid, err := changeAccountParams.IsValid() if !valid { - writeError(w, http.StatusBadRequest, fmt.Errorf("Invalid request: %s", err).Error(), err, env.Logger) + env.AuditLogger.PasswordChangeFailed(account.Email, + WithRequest(r), + WithReason(err.Error()), + ) + writeError(w, http.StatusBadRequest, fmt.Errorf("Invalid request: %s", err).Error(), err, env.SystemLogger) return } err = env.DB.UpdateUserPassword(db.ByUserID(idNum), changeAccountParams.Password) if err != nil { if errors.Is(err, db.ErrNotFound) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + env.AuditLogger.PasswordChangeFailed(account.Email, + WithRequest(r), + WithReason("user not found"), + ) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + env.AuditLogger.PasswordChangeFailed(account.Email, + WithRequest(r), + WithReason("database error"), + ) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } + + env.AuditLogger.PasswordChanged(account.Email, WithRequest(r)) + env.AuditLogger.UserUpdated(account.Email, "password_change", WithRequest(r)) + successResponse := SuccessResponse{Message: "success"} err = writeResponse(w, successResponse, http.StatusCreated) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } diff --git a/internal/server/handlers_accounts_test.go b/internal/server/handlers_accounts_test.go index 6604c035..7aeba69f 100644 --- a/internal/server/handlers_accounts_test.go +++ b/internal/server/handlers_accounts_test.go @@ -11,7 +11,7 @@ import ( // The order of the tests is important, as some tests depend on // the state of the server after previous tests. func TestAccountsEndToEnd(t *testing.T) { - ts := tu.MustPrepareServer(t) + ts, logs := tu.MustPrepareServer(t) client := ts.Client() adminToken := tu.MustPrepareAccount(t, ts, "testadmin@canonical.com", tu.RoleAdmin, "") nonAdminToken := tu.MustPrepareAccount(t, ts, "whatever@canonical.com", tu.RoleCertificateManager, adminToken) @@ -51,7 +51,8 @@ func TestAccountsEndToEnd(t *testing.T) { } }) - t.Run("3. Create account", func(t *testing.T) { + t.Run("3. Create account", func(t *testing.T) { + _ = logs.TakeAll() createAccountParams := &tu.CreateAccountParams{ Email: "nopass@canonical.com", Password: "myPassword123!", @@ -67,6 +68,21 @@ func TestAccountsEndToEnd(t *testing.T) { if response.Error != "" { t.Fatalf("unexpected error :%q", response.Error) } + + entries := logs.TakeAll() + var haveUserCreated bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + if findStringField(e, "event") == ("user_created:"+createAccountParams.Email+",admin") { + haveUserCreated = true + break + } + } + if !haveUserCreated { + t.Fatalf("expected UserCreated audit entry for %s", createAccountParams.Email) + } }) t.Run("4. Get account", func(t *testing.T) { @@ -104,7 +120,8 @@ func TestAccountsEndToEnd(t *testing.T) { } }) - t.Run("6. Change account password - success", func(t *testing.T) { + t.Run("6. Change account password - success", func(t *testing.T) { + _ = logs.TakeAll() changeAccountPasswordParams := &tu.ChangeAccountPasswordParams{ Password: "newPassword1", } @@ -118,6 +135,24 @@ func TestAccountsEndToEnd(t *testing.T) { if response.Error != "" { t.Fatalf("unexpected error :%q", response.Error) } + + entries := logs.TakeAll() + var havePwdChanged, haveUserUpdated bool + for _, e := range entries { + if e.LoggerName != "audit" { continue } + switch findStringField(e, "event") { + case "authn_password_change:testadmin@canonical.com": + havePwdChanged = true + case "user_updated:testadmin@canonical.com,password_change": + haveUserUpdated = true + } + } + if !havePwdChanged { + t.Fatalf("expected PasswordChanged audit entry") + } + if !haveUserUpdated { + t.Fatalf("expected UserUpdated audit entry for password_change") + } }) t.Run("7. Change account password - no user", func(t *testing.T) { @@ -136,7 +171,8 @@ func TestAccountsEndToEnd(t *testing.T) { } }) - t.Run("8. Delete account - success", func(t *testing.T) { + t.Run("8. Delete account - success", func(t *testing.T) { + _ = logs.TakeAll() statusCode, response, err := tu.DeleteAccount(ts.URL, client, adminToken, 2) if err != nil { t.Fatalf("couldn't delete account: %s", err) @@ -147,6 +183,19 @@ func TestAccountsEndToEnd(t *testing.T) { if response.Error != "" { t.Fatalf("expected error %q, got %q", "", response.Error) } + + entries := logs.TakeAll() + var haveUserDeleted bool + for _, e := range entries { + if e.LoggerName != "audit" { continue } + if findStringField(e, "event") == "user_deleted" && findStringField(e, "username") == "whatever@canonical.com" { + haveUserDeleted = true + break + } + } + if !haveUserDeleted { + t.Fatalf("expected UserDeleted audit entry for whatever@canonical.com") + } }) t.Run("9. Delete account - no user", func(t *testing.T) { @@ -186,7 +235,7 @@ func TestAccountsEndToEnd(t *testing.T) { } func TestCreateAccountInvalidInputs(t *testing.T) { - ts := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) client := ts.Client() adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") @@ -292,7 +341,7 @@ func TestCreateAccountInvalidInputs(t *testing.T) { } func TestChangeAccountPasswordInvalidInputs(t *testing.T) { - ts := tu.MustPrepareServer(t) +ts, _ := tu.MustPrepareServer(t) client := ts.Client() adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") diff --git a/internal/server/handlers_certificate_authorities.go b/internal/server/handlers_certificate_authorities.go index 82d0b1ad..6cf0c977 100644 --- a/internal/server/handlers_certificate_authorities.go +++ b/internal/server/handlers_certificate_authorities.go @@ -22,6 +22,19 @@ import ( const nextUpdateYears = 1 +// extractCommonName extracts the CN from a certificate PEM string, returns "unknown" if it fails +func extractCommonName(certPEM string) string { + block, _ := pem.Decode([]byte(certPEM)) + if block == nil { + return "unknown" + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return "unknown" + } + return cert.Subject.CommonName +} + type CertificateAuthority struct { ID int64 `json:"id"` Enabled bool `json:"enabled"` @@ -222,7 +235,7 @@ func ListCertificateAuthorities(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { cas, err := env.DB.ListDenormalizedCertificateAuthorities() if err != nil { - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } caResponse := make([]CertificateAuthority, len(cas)) @@ -238,7 +251,7 @@ func ListCertificateAuthorities(env *HandlerConfig) http.HandlerFunc { } err = writeResponse(w, caResponse, http.StatusOK) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -250,22 +263,22 @@ func CreateCertificateAuthority(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var params CreateCertificateAuthorityParams if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil { - writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.SystemLogger) return } valid, err := params.IsValid() if !valid { - writeError(w, http.StatusBadRequest, fmt.Errorf("Invalid request: %s", err).Error(), err, env.Logger) + writeError(w, http.StatusBadRequest, fmt.Errorf("Invalid request: %s", err).Error(), err, env.SystemLogger) return } claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) if headerErr != nil { - writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.Logger) + writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) return } csrPEM, privPEM, crlPEM, certPEM, err := createCertificateAuthority(params) if err != nil { - writeError(w, http.StatusInternalServerError, "Failed to create certificate authority", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Failed to create certificate authority", err, env.SystemLogger) return } var newCAID int64 @@ -275,14 +288,19 @@ func CreateCertificateAuthority(env *HandlerConfig) http.HandlerFunc { newCAID, err = env.DB.CreateCertificateAuthority(strings.TrimSpace(csrPEM), strings.TrimSpace(privPEM), "", "", claims.ID) } if err != nil { - writeError(w, http.StatusInternalServerError, "Failed to create certificate authority", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Failed to create certificate authority", err, env.SystemLogger) return } + env.AuditLogger.CACreated(int(newCAID), params.CommonName, + WithActor(claims.Email), + WithRequest(r), + ) + successResponse := CreateSuccessResponse{Message: "Certificate Authority created successfully", ID: newCAID} err = writeResponse(w, successResponse, http.StatusCreated) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -295,17 +313,17 @@ func GetCertificateAuthority(env *HandlerConfig) http.HandlerFunc { id := r.PathValue("id") idNum, err := strconv.ParseInt(id, 10, 64) if err != nil { - writeError(w, http.StatusBadRequest, "Invalid ID", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid ID", err, env.SystemLogger) return } ca, err := env.DB.GetDenormalizedCertificateAuthority(db.ByCertificateAuthorityDenormalizedID(idNum)) if err != nil { if errors.Is(err, db.ErrNotFound) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } caResponse := CertificateAuthority{ @@ -319,7 +337,7 @@ func GetCertificateAuthority(env *HandlerConfig) http.HandlerFunc { err = writeResponse(w, caResponse, http.StatusOK) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -332,29 +350,40 @@ func UpdateCertificateAuthority(env *HandlerConfig) http.HandlerFunc { id := r.PathValue("id") idNum, err := strconv.ParseInt(id, 10, 64) if err != nil { - writeError(w, http.StatusBadRequest, "Invalid ID", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid ID", err, env.SystemLogger) return } var params UpdateCertificateAuthorityParams if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil { - writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.SystemLogger) + return + } + + claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) + if headerErr != nil { + writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) return } err = env.DB.UpdateCertificateAuthorityEnabledStatus(db.ByCertificateAuthorityID(idNum), params.Enabled) if err != nil { if errors.Is(err, db.ErrNotFound) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } + env.AuditLogger.CAUpdated(id, params.Enabled, + WithActor(claims.Email), + WithRequest(r), + ) + successResponse := SuccessResponse{Message: "success"} err = writeResponse(w, successResponse, http.StatusOK) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -367,23 +396,45 @@ func DeleteCertificateAuthority(env *HandlerConfig) http.HandlerFunc { id := r.PathValue("id") idNum, err := strconv.ParseInt(id, 10, 64) if err != nil { - writeError(w, http.StatusBadRequest, "Invalid ID", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid ID", err, env.SystemLogger) + return + } + + claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) + if headerErr != nil { + writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) + return + } + + ca, err := env.DB.GetDenormalizedCertificateAuthority(db.ByCertificateAuthorityDenormalizedID(idNum)) + if err != nil { + if errors.Is(err, db.ErrNotFound) { + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) + return + } + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } err = env.DB.DeleteCertificateAuthority(db.ByCertificateAuthorityID(idNum)) if err != nil { if errors.Is(err, db.ErrNotFound) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } + + env.AuditLogger.CADeleted(int(idNum), extractCommonName(ca.CertificateChain), + WithActor(claims.Email), + WithRequest(r), + ) + successResponse := SuccessResponse{Message: "success"} err = writeResponse(w, successResponse, http.StatusOK) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -396,32 +447,44 @@ func PostCertificateAuthorityCertificate(env *HandlerConfig) http.HandlerFunc { id := r.PathValue("id") idNum, err := strconv.ParseInt(id, 10, 64) if err != nil { - writeError(w, http.StatusBadRequest, "Invalid ID", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid ID", err, env.SystemLogger) return } var UploadCertificateToCertificateAuthorityParams UploadCertificateToCertificateAuthorityParams if err := json.NewDecoder(r.Body).Decode(&UploadCertificateToCertificateAuthorityParams); err != nil { - writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.SystemLogger) return } valid, err := UploadCertificateToCertificateAuthorityParams.IsValid() if !valid { - writeError(w, http.StatusBadRequest, fmt.Errorf("Invalid request: %s", err).Error(), err, env.Logger) + writeError(w, http.StatusBadRequest, fmt.Errorf("Invalid request: %s", err).Error(), err, env.SystemLogger) + return + } + + claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) + if headerErr != nil { + writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) return } err = env.DB.UpdateCertificateAuthorityCertificate(db.ByCertificateAuthorityDenormalizedID(idNum), UploadCertificateToCertificateAuthorityParams.CertificateChain) if err != nil { if errors.Is(err, db.ErrNotFound) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } + + env.AuditLogger.CACertificateUploaded(id, + WithActor(claims.Email), + WithRequest(r), + ) + err = writeResponse(w, SuccessResponse{Message: "success"}, http.StatusCreated) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -435,43 +498,43 @@ func SignCertificateAuthority(env *HandlerConfig) http.HandlerFunc { id := r.PathValue("id") idNum, err := strconv.ParseInt(id, 10, 64) if err != nil { - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } var signCertificateAuthorityParams SignCertificateAuthorityParams if err := json.NewDecoder(r.Body).Decode(&signCertificateAuthorityParams); err != nil { - writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.SystemLogger) return } caIDInt, err := strconv.ParseInt(signCertificateAuthorityParams.CertificateAuthorityID, 10, 64) if err != nil { - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } caToBeSigned, err := env.DB.GetCertificateAuthority(db.ByCertificateAuthorityID(idNum)) if err != nil { - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } err = env.DB.SignCertificateRequest(db.ByCSRID(caToBeSigned.CSRID), db.ByCertificateAuthorityDenormalizedID(caIDInt), env.ExternalHostname) if err != nil { if errors.Is(err, db.ErrNotFound) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } if env.SendPebbleNotifications { err := SendPebbleNotification(CertificateUpdate, idNum) if err != nil { - env.Logger.Warn("pebble notify failed", zap.Error(err)) + env.SystemLogger.Warn("pebble notify failed", zap.Error(err)) } } successResponse := SuccessResponse{Message: "success"} err = writeResponse(w, successResponse, http.StatusAccepted) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -484,24 +547,24 @@ func GetCertificateAuthorityCRL(env *HandlerConfig) http.HandlerFunc { id := r.PathValue("id") idNum, err := strconv.ParseInt(id, 10, 64) if err != nil { - writeError(w, http.StatusBadRequest, "Invalid ID", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid ID", err, env.SystemLogger) return } ca, err := env.DB.GetDenormalizedCertificateAuthority(db.ByCertificateAuthorityDenormalizedID(idNum)) if err != nil { if errors.Is(err, db.ErrNotFound) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } crlResponse := CRL{CRL: ca.CRL} err = writeResponse(w, crlResponse, http.StatusOK) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -515,39 +578,52 @@ func RevokeCertificateAuthorityCertificate(env *HandlerConfig) http.HandlerFunc id := r.PathValue("id") idNum, err := strconv.ParseInt(id, 10, 64) if err != nil { - writeError(w, http.StatusBadRequest, "Invalid ID", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid ID", err, env.SystemLogger) + return + } + + claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) + if headerErr != nil { + writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) return } + ca, err := env.DB.GetCertificateAuthority(db.ByCertificateAuthorityID(idNum)) if err != nil { - env.Logger.Info("could not get certificate authority", zap.Error(err)) + env.SystemLogger.Info("could not get certificate authority", zap.Error(err)) if errors.Is(err, db.ErrNotFound) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } err = env.DB.RevokeCertificate(db.ByCSRID(ca.CSRID)) if err != nil { - env.Logger.Warn("could not revoke certificate", zap.Error(err)) + env.SystemLogger.Warn("could not revoke certificate", zap.Error(err)) if errors.Is(err, db.ErrNotFound) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } + + env.AuditLogger.CACertificateRevoked(id, + WithActor(claims.Email), + WithRequest(r), + ) + if env.SendPebbleNotifications { err := SendPebbleNotification(CertificateUpdate, idNum) if err != nil { - env.Logger.Warn("pebble notify failed", zap.Error(err)) + env.SystemLogger.Warn("pebble notify failed", zap.Error(err)) } } successResponse := SuccessResponse{Message: "success"} err = writeResponse(w, successResponse, http.StatusAccepted) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } diff --git a/internal/server/handlers_certificate_authorities_test.go b/internal/server/handlers_certificate_authorities_test.go index b2e6c451..21d7d0b2 100644 --- a/internal/server/handlers_certificate_authorities_test.go +++ b/internal/server/handlers_certificate_authorities_test.go @@ -15,7 +15,7 @@ import ( // The order of the tests is important, as some tests depend on the state of the server after previous tests. func TestSelfSignedCertificateAuthorityEndToEnd(t *testing.T) { - ts := tu.MustPrepareServer(t) + ts, logs := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") client := ts.Client() @@ -35,7 +35,8 @@ func TestSelfSignedCertificateAuthorityEndToEnd(t *testing.T) { } }) - t.Run("2. Create self signed certificate authority", func(t *testing.T) { + t.Run("2. Create self signed certificate authority", func(t *testing.T) { + _ = logs.TakeAll() createCertificatAuthorityParams := tu.CreateCertificateAuthorityParams{ SelfSigned: true, @@ -55,9 +56,20 @@ func TestSelfSignedCertificateAuthorityEndToEnd(t *testing.T) { if statusCode != http.StatusCreated { t.Fatalf("expected status %d, got %d", http.StatusCreated, statusCode) } - if createCAResponse.Error != "" { + if createCAResponse.Error != "" { t.Fatalf("expected success, got %s", createCAResponse.Error) } + + entries := logs.TakeAll() + var haveCACreated bool + for _, e := range entries { + if e.LoggerName != "audit" { continue } + if findStringField(e, "event") == "ca_created" { + haveCACreated = true + break + } + } + if !haveCACreated { t.Fatalf("expected CACreated audit entry") } }) t.Run("3. Get all CA's - 1 should be there and enabled", func(t *testing.T) { @@ -155,7 +167,8 @@ func TestSelfSignedCertificateAuthorityEndToEnd(t *testing.T) { IntermediateCACSR = getCAResponse.Result.CSRPEM }) - t.Run("7. Sign the intermediate CA's CSR", func(t *testing.T) { + t.Run("7. Sign the intermediate CA's CSR", func(t *testing.T) { + _ = logs.TakeAll() signedCert := tu.SignCSR(IntermediateCACSR) statusCode, uploadCertificateResponse, err := tu.UploadCertificateToCertificateAuthority(ts.URL, client, adminToken, 2, server.UploadCertificateToCertificateAuthorityParams{CertificateChain: signedCert + tu.SelfSignedCACertificate}) if err != nil { @@ -164,9 +177,17 @@ func TestSelfSignedCertificateAuthorityEndToEnd(t *testing.T) { if statusCode != http.StatusCreated { t.Fatalf("expected status %d, got %d", http.StatusCreated, statusCode) } - if uploadCertificateResponse.Error != "" { + if uploadCertificateResponse.Error != "" { t.Fatalf("expected success, got %s", uploadCertificateResponse.Error) } + + entries := logs.TakeAll() + var haveCertUploaded bool + for _, e := range entries { + if e.LoggerName != "audit" { continue } + if findStringField(e, "event") == "ca_cert_uploaded" { haveCertUploaded = true; break } + } + if !haveCertUploaded { t.Fatalf("expected CACertificateUploaded audit entry") } }) t.Run("8. Get all CA's - 2 should be there and both enabled", func(t *testing.T) { statusCode, listCAsResponse, err := tu.ListCertificateAuthorities(ts.URL, client, adminToken) @@ -189,8 +210,8 @@ func TestSelfSignedCertificateAuthorityEndToEnd(t *testing.T) { t.Fatalf("expected second CA to be enabled") } }) - t.Run("9. Make first CA legacy", func(t *testing.T) { - statusCode, makeLegacyResponse, err := tu.UpdateCertificateAuthority(ts.URL, client, adminToken, 1, tu.UpdateCertificateAuthorityParams{Status: "legacy"}) + t.Run("9. Make first CA legacy", func(t *testing.T) { + statusCode, makeLegacyResponse, err := tu.UpdateCertificateAuthority(ts.URL, client, adminToken, 1, tu.UpdateCertificateAuthorityParams{Status: "legacy"}) if err != nil { t.Fatal("expected no error, got: ", err) } @@ -222,14 +243,22 @@ func TestSelfSignedCertificateAuthorityEndToEnd(t *testing.T) { t.Fatalf("expected second CA to be enabled") } }) - t.Run("11. Delete first CA", func(t *testing.T) { - statusCode, err := tu.DeleteCertificateAuthority(ts.URL, client, adminToken, 1) + t.Run("11. Delete first CA", func(t *testing.T) { + _ = logs.TakeAll() + statusCode, err := tu.DeleteCertificateAuthority(ts.URL, client, adminToken, 1) if err != nil { t.Fatal("expected no error, got: ", err) } if statusCode != http.StatusOK { t.Fatalf("expected status %d, got %d", http.StatusOK, statusCode) } + entries := logs.TakeAll() + var haveCADeleted bool + for _, e := range entries { + if e.LoggerName != "audit" { continue } + if findStringField(e, "event") == "ca_deleted" { haveCADeleted = true; break } + } + if !haveCADeleted { t.Fatalf("expected CADeleted audit entry") } }) t.Run("12. Get all CA's - 1 enabled should be there", func(t *testing.T) { statusCode, listCAsResponse, err := tu.ListCertificateAuthorities(ts.URL, client, adminToken) @@ -252,7 +281,7 @@ func TestSelfSignedCertificateAuthorityEndToEnd(t *testing.T) { } func TestCreateCertificateAuthorityInvalidInputs(t *testing.T) { - ts := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") client := ts.Client() @@ -341,7 +370,7 @@ func TestCreateCertificateAuthorityInvalidInputs(t *testing.T) { } func TestUploadCertificateToCertificateAuthorityInvalidInputs(t *testing.T) { - ts := tu.MustPrepareServer(t) +ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") client := ts.Client() @@ -415,7 +444,7 @@ invalid } func TestSignCertificatesEndToEnd(t *testing.T) { - ts := tu.MustPrepareServer(t) +ts, logs := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") client := ts.Client() @@ -587,7 +616,7 @@ func TestSignCertificatesEndToEnd(t *testing.T) { t.Fatalf("expected success, got %s", uploadCertificateResponse.Error) } }) - t.Run("9. Get all CA's - 2 should be there and both active", func(t *testing.T) { + t.Run("9. Get all CA's - 2 should be there and both active", func(t *testing.T) { statusCode, listCAsResponse, err := tu.ListCertificateAuthorities(ts.URL, client, adminToken) if err != nil { t.Fatal("expected no error, got: ", err) @@ -611,6 +640,21 @@ func TestSignCertificatesEndToEnd(t *testing.T) { t.Fatalf("expected second CA to have a chain with 2 certificates") } }) + + t.Run("10. Update CA enabled status and assert audit", func(t *testing.T) { + _ = logs.TakeAll() + statusCode, updateResp, err := tu.UpdateCertificateAuthority(ts.URL, client, adminToken, 1, tu.UpdateCertificateAuthorityParams{Status: "active"}) + if err != nil { t.Fatal(err) } + if statusCode != http.StatusOK { t.Fatalf("expected %d, got %d", http.StatusOK, statusCode) } + if updateResp.Error != "" { t.Fatalf("expected success, got %s", updateResp.Error) } + entries := logs.TakeAll() + var haveCAUpdated bool + for _, e := range entries { + if e.LoggerName != "audit" { continue } + if findStringField(e, "event") == "ca_updated" { haveCAUpdated = true; break } + } + if !haveCAUpdated { t.Fatalf("expected CAUpdated audit entry") } + }) t.Run("10. Create 2nd CSR's", func(t *testing.T) { createCertificateRequestRequest := tu.CreateCertificateRequestParams{CSR: tu.StrawberryCSR} statusCode, createCertResponse, err := tu.CreateCertificateRequest(ts.URL, client, adminToken, createCertificateRequestRequest) @@ -672,9 +716,9 @@ func TestSignCertificatesEndToEnd(t *testing.T) { if len(listCSRsResponse.Result) != 2 { t.Fatalf("expected 2 certificates, got %d", len(listCSRsResponse.Result)) } - if listCSRsResponse.Result[0].Status != "Active" { - t.Fatalf("expected first csr to be active, got %s", listCSRsResponse.Result[3].Status) - } + if listCSRsResponse.Result[0].Status != "Active" { + t.Fatalf("expected first csr to be active, got %s", listCSRsResponse.Result[0].Status) + } if strings.Count(listCSRsResponse.Result[0].CertificateChain, "BEGIN CERTIFICATE") != 2 { t.Fatalf("expected first csr to have a chain with 2 certificates") } @@ -688,7 +732,7 @@ func TestSignCertificatesEndToEnd(t *testing.T) { } func TestUnsuccessfulRequestsMadeToCACSRs(t *testing.T) { - ts := tu.MustPrepareServer(t) +ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") client := ts.Client() @@ -819,7 +863,7 @@ func TestUnsuccessfulRequestsMadeToCACSRs(t *testing.T) { } func TestCertificateRevocationListsEndToEnd(t *testing.T) { - ts := tu.MustPrepareServer(t) + ts, logs := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") client := ts.Client() @@ -1074,6 +1118,7 @@ func TestCertificateRevocationListsEndToEnd(t *testing.T) { }) t.Run("10. Revoke Intermediate CA", func(t *testing.T) { + _ = logs.TakeAll() statusCode, response, err := tu.RevokeCertificateAuthority(ts.URL, client, adminToken, 2) if err != nil { t.Fatalf("expected no error, got: %s", err) @@ -1084,6 +1129,13 @@ func TestCertificateRevocationListsEndToEnd(t *testing.T) { if response.Error != "" { t.Fatalf("expected success, got %s", response.Error) } + entries := logs.TakeAll() + var haveCARevoked bool + for _, e := range entries { + if e.LoggerName != "audit" { continue } + if findStringField(e, "event") == "ca_cert_revoked" { haveCARevoked = true; break } + } + if !haveCARevoked { t.Fatalf("expected CACertificateRevoked audit entry") } statusCode, cas, err := tu.ListCertificateAuthorities(ts.URL, client, adminToken) if statusCode != http.StatusOK { t.Fatalf("expected status %d, got %d", http.StatusOK, statusCode) diff --git a/internal/server/handlers_certificate_requests.go b/internal/server/handlers_certificate_requests.go index 0838abb4..9af95894 100644 --- a/internal/server/handlers_certificate_requests.go +++ b/internal/server/handlers_certificate_requests.go @@ -72,7 +72,7 @@ func ListCertificateRequests(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) if headerErr != nil { - writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.Logger) + writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) return } @@ -86,7 +86,7 @@ func ListCertificateRequests(env *HandlerConfig) http.HandlerFunc { csrs, err = env.DB.ListCertificateRequestWithCertificatesWithoutCAS(filter) if err != nil { - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } @@ -96,10 +96,10 @@ func ListCertificateRequests(env *HandlerConfig) http.HandlerFunc { user, err := env.DB.GetUser(db.ByUserID(csr.UserID)) if err != nil { if errors.Is(err, db.ErrNotFound) { - env.Logger.Warn("user not found for certificate request", zap.Int64("user_id", csr.UserID)) + env.SystemLogger.Warn("user not found for certificate request", zap.Int64("user_id", csr.UserID)) email = "unknown" } else { - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } } else { @@ -115,7 +115,7 @@ func ListCertificateRequests(env *HandlerConfig) http.HandlerFunc { } err = writeResponse(w, certificateRequestsResponse, http.StatusOK) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -126,38 +126,44 @@ func CreateCertificateRequest(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var createCertificateRequestParams CreateCertificateRequestParams if err := json.NewDecoder(r.Body).Decode(&createCertificateRequestParams); err != nil { - writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.SystemLogger) return } valid, err := createCertificateRequestParams.IsValid() if !valid { - writeError(w, http.StatusBadRequest, fmt.Errorf("Invalid request: %s", err).Error(), err, env.Logger) + writeError(w, http.StatusBadRequest, fmt.Errorf("Invalid request: %s", err).Error(), err, env.SystemLogger) return } claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) if headerErr != nil { - writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.Logger) + writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) return } newCSRID, err := env.DB.CreateCertificateRequest(createCertificateRequestParams.CSR, claims.ID) if err != nil { if errors.Is(err, db.ErrAlreadyExists) { - writeError(w, http.StatusBadRequest, "given csr already recorded", err, env.Logger) + writeError(w, http.StatusBadRequest, "given csr already recorded", err, env.SystemLogger) return } if errors.Is(err, db.ErrInvalidCertificateRequest) { - writeError(w, http.StatusBadRequest, "csr validation failed", err, env.Logger) + writeError(w, http.StatusBadRequest, "csr validation failed", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } + + env.AuditLogger.CertificateRequested(strconv.FormatInt(newCSRID, 10), 0, + WithActor(claims.Email), + WithRequest(r), + ) + successResponse := CreateSuccessResponse{Message: "success", ID: newCSRID} err = writeResponse(w, successResponse, http.StatusCreated) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -169,40 +175,40 @@ func GetCertificateRequest(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) if headerErr != nil { - writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.Logger) + writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) return } id := r.PathValue("id") idNum, err := strconv.ParseInt(id, 10, 64) if err != nil { - writeError(w, http.StatusBadRequest, "Invalid ID", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid ID", err, env.SystemLogger) return } csr, err := env.DB.GetCertificateRequestAndChain(db.ByCSRID(idNum)) if err != nil { if errors.Is(err, db.ErrNotFound) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } // Restrict access to certificate requestors' own requests if claims.RoleID == RoleCertificateRequestor && claims.ID != csr.UserID { - writeError(w, http.StatusForbidden, "Access denied", fmt.Errorf("user does not have permission to access this certificate request"), env.Logger) + writeError(w, http.StatusForbidden, "Access denied", fmt.Errorf("user does not have permission to access this certificate request"), env.SystemLogger) return } _, err = env.DB.GetCertificateAuthority(db.ByCertificateAuthorityCSRID(csr.CSR_ID)) if rowFound(err) { - writeError(w, http.StatusNotFound, "Not Found", fmt.Errorf("not found"), env.Logger) + writeError(w, http.StatusNotFound, "Not Found", fmt.Errorf("not found"), env.SystemLogger) return } if realError(err) { - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } @@ -210,10 +216,10 @@ func GetCertificateRequest(env *HandlerConfig) http.HandlerFunc { user, err := env.DB.GetUser(db.ByUserID(csr.UserID)) if err != nil { if errors.Is(err, db.ErrNotFound) { - env.Logger.Warn("user not found for certificate request", zap.Int64("user_id", csr.UserID)) + env.SystemLogger.Warn("user not found for certificate request", zap.Int64("user_id", csr.UserID)) email = "unknown" } else { - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } } else { @@ -230,7 +236,7 @@ func GetCertificateRequest(env *HandlerConfig) http.HandlerFunc { err = writeResponse(w, certificateRequestResponse, http.StatusOK) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -243,31 +249,44 @@ func DeleteCertificateRequest(env *HandlerConfig) http.HandlerFunc { id := r.PathValue("id") idNum, err := strconv.ParseInt(id, 10, 64) if err != nil { - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } + + claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) + if headerErr != nil { + writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) + return + } + _, err = env.DB.GetCertificateAuthority(db.ByCertificateAuthorityCSRID(idNum)) if rowFound(err) { - writeError(w, http.StatusNotFound, "Not Found", fmt.Errorf("not found"), env.Logger) + writeError(w, http.StatusNotFound, "Not Found", fmt.Errorf("not found"), env.SystemLogger) return } if realError(err) { - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } err = env.DB.DeleteCertificateRequest(db.ByCSRID(idNum)) if err != nil { if errors.Is(err, db.ErrNotFound) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } + + env.AuditLogger.CertificateRequestDeleted(id, + WithActor(claims.Email), + WithRequest(r), + ) + successResponse := SuccessResponse{Message: "success"} err = writeResponse(w, successResponse, http.StatusAccepted) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -279,49 +298,62 @@ func PostCertificateRequestCertificate(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var createCertificateParams CreateCertificateParams if err := json.NewDecoder(r.Body).Decode(&createCertificateParams); err != nil { - writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.SystemLogger) return } valid, err := createCertificateParams.IsValid() if !valid { - writeError(w, http.StatusBadRequest, fmt.Errorf("Invalid request: %s", err).Error(), err, env.Logger) + writeError(w, http.StatusBadRequest, fmt.Errorf("Invalid request: %s", err).Error(), err, env.SystemLogger) return } id := r.PathValue("id") idNum, err := strconv.ParseInt(id, 10, 64) if err != nil { - writeError(w, http.StatusBadRequest, "Invalid ID", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid ID", err, env.SystemLogger) + return + } + + claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) + if headerErr != nil { + writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) return } + _, err = env.DB.GetCertificateAuthority(db.ByCertificateAuthorityCSRID(idNum)) if rowFound(err) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } if realError(err) { - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } newCertID, err := env.DB.AddCertificateChainToCertificateRequest(db.ByCSRID(idNum), createCertificateParams.CertificateChain) if err != nil { if errors.Is(err, db.ErrNotFound) || errors.Is(err, db.ErrInvalidCertificate) { - writeError(w, http.StatusBadRequest, "Bad Request", err, env.Logger) + writeError(w, http.StatusBadRequest, "Bad Request", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } + + env.AuditLogger.CertificateIssued(id, 0, + WithActor(claims.Email), + WithRequest(r), + ) + if env.SendPebbleNotifications { err := SendPebbleNotification(CertificateUpdate, idNum) if err != nil { - env.Logger.Warn("pebble notify failed", zap.Error(err)) + env.SystemLogger.Warn("pebble notify failed", zap.Error(err)) } } successResponse := CreateSuccessResponse{Message: "success", ID: newCertID} err = writeResponse(w, successResponse, http.StatusCreated) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -334,38 +366,52 @@ func RejectCertificateRequest(env *HandlerConfig) http.HandlerFunc { id := r.PathValue("id") idNum, err := strconv.ParseInt(id, 10, 64) if err != nil { - writeError(w, http.StatusBadRequest, "Invalid ID", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid ID", err, env.SystemLogger) return } + + claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) + if headerErr != nil { + writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) + return + } + _, err = env.DB.GetCertificateAuthority(db.ByCertificateAuthorityCSRID(idNum)) if rowFound(err) { err = fmt.Errorf("certificate request %d not found", idNum) - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } if realError(err) { - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } err = env.DB.RejectCertificateRequest(db.ByCSRID(idNum)) if err != nil { if errors.Is(err, db.ErrNotFound) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } + + env.AuditLogger.CertificateRejected(id, 0, + WithActor(claims.Email), + WithRequest(r), + WithReason("rejected by administrator"), + ) + if env.SendPebbleNotifications { err := SendPebbleNotification(CertificateUpdate, idNum) if err != nil { - env.Logger.Warn("pebble notify failed", zap.Error(err)) + env.SystemLogger.Warn("pebble notify failed", zap.Error(err)) } } successResponse := SuccessResponse{Message: "success"} err = writeResponse(w, successResponse, http.StatusAccepted) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -378,37 +424,37 @@ func DeleteCertificate(env *HandlerConfig) http.HandlerFunc { id := r.PathValue("id") idNum, err := strconv.ParseInt(id, 10, 64) if err != nil { - writeError(w, http.StatusBadRequest, "Invalid ID", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid ID", err, env.SystemLogger) return } _, err = env.DB.GetCertificateAuthority(db.ByCertificateAuthorityCSRID(idNum)) if rowFound(err) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } if realError(err) { - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } err = env.DB.DeleteCertificateRequest(db.ByCSRID(idNum)) if err != nil { if errors.Is(err, db.ErrNotFound) { - writeError(w, http.StatusBadRequest, "Bad Request", err, env.Logger) + writeError(w, http.StatusBadRequest, "Bad Request", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } if env.SendPebbleNotifications { err := SendPebbleNotification(CertificateUpdate, idNum) if err != nil { - env.Logger.Warn("pebble notify failed", zap.Error(err)) + env.SystemLogger.Warn("pebble notify failed", zap.Error(err)) } } successResponse := SuccessResponse{Message: "success"} err = writeResponse(w, successResponse, http.StatusOK) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -422,37 +468,50 @@ func RevokeCertificate(env *HandlerConfig) http.HandlerFunc { id := r.PathValue("id") idNum, err := strconv.ParseInt(id, 10, 64) if err != nil { - writeError(w, http.StatusBadRequest, "Invalid ID", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid ID", err, env.SystemLogger) + return + } + + claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) + if headerErr != nil { + writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) return } + _, err = env.DB.GetCertificateAuthority(db.ByCertificateAuthorityCSRID(idNum)) if rowFound(err) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } if realError(err) { - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } err = env.DB.RevokeCertificate(db.ByCSRID(idNum)) if err != nil { if errors.Is(err, db.ErrNotFound) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } + + env.AuditLogger.CertificateRevoked(id, + WithActor(claims.Email), + WithRequest(r), + ) + if env.SendPebbleNotifications { err := SendPebbleNotification(CertificateUpdate, idNum) if err != nil { - env.Logger.Warn("pebble notify failed", zap.Error(err)) + env.SystemLogger.Warn("pebble notify failed", zap.Error(err)) } } successResponse := SuccessResponse{Message: "success"} err = writeResponse(w, successResponse, http.StatusAccepted) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } @@ -466,48 +525,48 @@ func SignCertificateRequest(env *HandlerConfig) http.HandlerFunc { id := r.PathValue("id") idNum, err := strconv.ParseInt(id, 10, 64) if err != nil { - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } var signCertificateRequestParams SignCertificateRequestParams if err := json.NewDecoder(r.Body).Decode(&signCertificateRequestParams); err != nil { - writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.SystemLogger) return } _, err = env.DB.GetCertificateAuthority(db.ByCertificateAuthorityCSRID(idNum)) if rowFound(err) { err = fmt.Errorf("certificate authority %d not found", idNum) - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } if realError(err) { - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } caIDInt, err := strconv.ParseInt(signCertificateRequestParams.CertificateAuthorityID, 10, 64) if err != nil { - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } err = env.DB.SignCertificateRequest(db.ByCSRID(idNum), db.ByCertificateAuthorityDenormalizedID(caIDInt), env.ExternalHostname) if err != nil { if errors.Is(err, db.ErrNotFound) { - writeError(w, http.StatusNotFound, "Not Found", err, env.Logger) + writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } if env.SendPebbleNotifications { err := SendPebbleNotification(CertificateUpdate, idNum) if err != nil { - env.Logger.Warn("pebble notify failed", zap.Error(err)) + env.SystemLogger.Warn("pebble notify failed", zap.Error(err)) } } successResponse := SuccessResponse{Message: "success"} err = writeResponse(w, successResponse, http.StatusAccepted) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } diff --git a/internal/server/handlers_certificate_requests_test.go b/internal/server/handlers_certificate_requests_test.go index bf75b421..6470d07f 100644 --- a/internal/server/handlers_certificate_requests_test.go +++ b/internal/server/handlers_certificate_requests_test.go @@ -12,7 +12,7 @@ import ( // The order of the tests is important, as some tests depend on the // state of the server after previous tests. func TestCertificateRequestsEndToEnd(t *testing.T) { - ts := tu.MustPrepareServer(t) + ts, logs := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "testadmin@canonical.com", tu.RoleAdmin, "") client := ts.Client() @@ -32,7 +32,8 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { } }) - t.Run("2. Create certificate request", func(t *testing.T) { + t.Run("2. Create certificate request", func(t *testing.T) { + _ = logs.TakeAll() createCertificateRequestRequest := tu.CreateCertificateRequestParams{ CSR: tu.AppleCSR, @@ -47,7 +48,19 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { if createCertResponse.Error != "" { t.Fatalf("expected no error, got %s", createCertResponse.Error) } - }) + entries := logs.TakeAll() + var haveRequested bool + for _, e := range entries { + if e.LoggerName != "audit" { continue } + if findStringField(e, "event") == "cert_requested" { + haveRequested = true + break + } + } + if !haveRequested { + t.Fatalf("expected CertificateRequested audit entry") + } + }) t.Run("3. List certificate requests - 1 Certificate", func(t *testing.T) { statusCode, listCertRequestsResponse, err := tu.ListCertificateRequests(ts.URL, client, adminToken) @@ -71,7 +84,7 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { } }) - t.Run("4. Get certificate request", func(t *testing.T) { + t.Run("4. Get certificate request", func(t *testing.T) { statusCode, getCertRequestResponse, err := tu.GetCertificateRequest(ts.URL, client, adminToken, 1) if err != nil { t.Fatal(err) @@ -112,7 +125,7 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { } }) - t.Run("6. List certificate requests - 1 Certificate", func(t *testing.T) { + t.Run("6. List certificate requests - 1 Certificate", func(t *testing.T) { statusCode, listCertRequestsResponse, err := tu.ListCertificateRequests(ts.URL, client, adminToken) if err != nil { t.Fatal(err) @@ -128,7 +141,7 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { } }) - t.Run("7. Create another certificate request", func(t *testing.T) { + t.Run("7. Create another certificate request", func(t *testing.T) { createCertificateRequestRequest := tu.CreateCertificateRequestParams{ CSR: tu.StrawberryCSR, } @@ -144,7 +157,7 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { } }) - t.Run("8. List certificate requests - 2 Certificates", func(t *testing.T) { + t.Run("8. List certificate requests - 2 Certificates", func(t *testing.T) { statusCode, listCertRequestsResponse, err := tu.ListCertificateRequests(ts.URL, client, adminToken) if err != nil { t.Fatal(err) @@ -160,7 +173,7 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { } }) - t.Run("9. Get certificate request 2", func(t *testing.T) { + t.Run("9. Get certificate request 2", func(t *testing.T) { statusCode, getCertRequestResponse, err := tu.GetCertificateRequest(ts.URL, client, adminToken, 2) if err != nil { t.Fatal(err) @@ -185,7 +198,8 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { } }) - t.Run("10. Delete certificate request 1", func(t *testing.T) { + t.Run("10. Delete certificate request 1", func(t *testing.T) { + _ = logs.TakeAll() statusCode, err := tu.DeleteCertificateRequest(ts.URL, client, adminToken, 1) if err != nil { t.Fatal(err) @@ -193,6 +207,18 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { if statusCode != http.StatusAccepted { t.Fatalf("expected status %d, got %d", http.StatusAccepted, statusCode) } + entries := logs.TakeAll() + var haveDeleted bool + for _, e := range entries { + if e.LoggerName != "audit" { continue } + if findStringField(e, "event") == "cert_request_deleted" { + haveDeleted = true + break + } + } + if !haveDeleted { + t.Fatalf("expected CertificateRequestDeleted audit entry") + } }) t.Run("11. List certificate requests - 1 Certificate", func(t *testing.T) { @@ -211,7 +237,8 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { } }) - t.Run("12. Delete certificate request 2", func(t *testing.T) { + t.Run("12. Delete certificate request 2 and assert revoke audit", func(t *testing.T) { + _ = logs.TakeAll() statusCode, err := tu.DeleteCertificateRequest(ts.URL, client, adminToken, 2) if err != nil { t.Fatal(err) @@ -219,12 +246,19 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { if statusCode != http.StatusAccepted { t.Fatalf("expected status %d, got %d", http.StatusAccepted, statusCode) } + entries := logs.TakeAll() + var haveCertDeleted bool + for _, e := range entries { + if e.LoggerName != "audit" { continue } + if findStringField(e, "event") == "cert_request_deleted" { haveCertDeleted = true; break } + } + if !haveCertDeleted { t.Fatalf("expected CertificateRequestDeleted audit entry") } }) } // TestListCertificateRequestsRequestorRole tests that a certificate requestor can only view their own requests. func TestListCertificateRequestsRequestorRole(t *testing.T) { - ts := tu.MustPrepareServer(t) +ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "testadmin@canonical.com", tu.RoleAdmin, "") client := ts.Client() @@ -293,7 +327,7 @@ func TestListCertificateRequestsRequestorRole(t *testing.T) { // The order of the tests is important, as some tests depend on the // state of the server after previous tests. func TestCertificatesEndToEnd(t *testing.T) { - ts := tu.MustPrepareServer(t) +ts, logs := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") client := ts.Client() @@ -313,7 +347,8 @@ func TestCertificatesEndToEnd(t *testing.T) { } }) - t.Run("2. Create Certificate", func(t *testing.T) { + t.Run("2. Create Certificate", func(t *testing.T) { + _ = logs.TakeAll() createCertificateRequest := tu.CreateCertificateParams{ Certificate: fmt.Sprintf("%s\n%s", tu.ExampleCSRCertificate, tu.ExampleCSRIssuerCertificate), } @@ -327,6 +362,13 @@ func TestCertificatesEndToEnd(t *testing.T) { if createCertResponse.Error != "" { t.Fatalf("expected no error, got %s", createCertResponse.Error) } + entries := logs.TakeAll() + var haveIssued bool + for _, e := range entries { + if e.LoggerName != "audit" { continue } + if findStringField(e, "event") == "cert_issued" { haveIssued = true; break } + } + if !haveIssued { t.Fatalf("expected CertificateIssued audit entry") } }) t.Run("3. Get Certificate", func(t *testing.T) { @@ -345,7 +387,8 @@ func TestCertificatesEndToEnd(t *testing.T) { } }) - t.Run("4. Reject Certificate", func(t *testing.T) { + t.Run("4. Reject Certificate", func(t *testing.T) { + _ = logs.TakeAll() statusCode, err := tu.RejectCertificate(ts.URL, client, adminToken, 1) if err != nil { t.Fatal(err) @@ -353,6 +396,13 @@ func TestCertificatesEndToEnd(t *testing.T) { if statusCode != http.StatusAccepted { t.Fatalf("expected status %d, got %d", http.StatusAccepted, statusCode) } + entries := logs.TakeAll() + var haveRejected bool + for _, e := range entries { + if e.LoggerName != "audit" { continue } + if findStringField(e, "event") == "cert_rejected" { haveRejected = true; break } + } + if !haveRejected { t.Fatalf("expected CertificateRejected audit entry") } }) t.Run("5. Get Certificate", func(t *testing.T) { @@ -371,7 +421,8 @@ func TestCertificatesEndToEnd(t *testing.T) { } }) - t.Run("6. Delete Certificate", func(t *testing.T) { + t.Run("6. Delete Certificate (revocation)", func(t *testing.T) { + _ = logs.TakeAll() statusCode, err := tu.DeleteCertificateRequest(ts.URL, client, adminToken, 1) if err != nil { t.Fatal(err) @@ -379,6 +430,14 @@ func TestCertificatesEndToEnd(t *testing.T) { if statusCode != http.StatusAccepted { t.Fatalf("expected status %d, got %d", http.StatusAccepted, statusCode) } + entries := logs.TakeAll() + var haveDeletedOrRevoked bool + for _, e := range entries { + if e.LoggerName != "audit" { continue } + ev := findStringField(e, "event") + if ev == "cert_request_deleted" || ev == "cert_revoked" { haveDeletedOrRevoked = true; break } + } + if !haveDeletedOrRevoked { t.Fatalf("expected CertificateRequestDeleted or CertificateRevoked audit entry") } }) t.Run("7. Get Certificate", func(t *testing.T) { @@ -397,7 +456,7 @@ func TestCertificatesEndToEnd(t *testing.T) { } func TestCreateCertificateRequestInvalidInputs(t *testing.T) { - ts := tu.MustPrepareServer(t) +ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") client := ts.Client() @@ -452,7 +511,7 @@ MIIBVwIBADANBgkqhkiG9w0BAQEFAASCAT8wggE7AgEAAkEAuQ== } func TestCreateCertificateInvalidInputs(t *testing.T) { - ts := tu.MustPrepareServer(t) +ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") client := ts.Client() diff --git a/internal/server/handlers_config.go b/internal/server/handlers_config.go index 9f3c021e..b51da98a 100644 --- a/internal/server/handlers_config.go +++ b/internal/server/handlers_config.go @@ -23,7 +23,7 @@ func GetConfigContent(env *HandlerConfig) http.HandlerFunc { } err := writeResponse(w, configContent, http.StatusOK) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } diff --git a/internal/server/handlers_config_test.go b/internal/server/handlers_config_test.go index deb4e0dc..07cbf293 100644 --- a/internal/server/handlers_config_test.go +++ b/internal/server/handlers_config_test.go @@ -42,7 +42,7 @@ func getConfig(url string, client *http.Client, token string) (int, *GetConfigRe } func TestConfigEndToEnd(t *testing.T) { - ts := tu.MustPrepareServer(t) + ts, logs := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") nonAdminToken := tu.MustPrepareAccount(t, ts, "whatever@canonical.com", tu.RoleCertificateManager, adminToken) client := ts.Client() @@ -60,7 +60,8 @@ func TestConfigEndToEnd(t *testing.T) { } }) - t.Run("2. Get config - admin token", func(t *testing.T) { + t.Run("2. Get config - admin token", func(t *testing.T) { + _ = logs.TakeAll() statusCode, response, err := getConfig(ts.URL, client, adminToken) if err != nil { t.Fatalf("couldn't get config: %s", err) @@ -84,6 +85,19 @@ func TestConfigEndToEnd(t *testing.T) { if response.Result.EncryptionBackendType == "" { t.Fatalf("expected encryption backend type to be set, got %q", response.Result.EncryptionBackendType) } + + entries := logs.TakeAll() + var haveAPISuccess bool + for _, e := range entries { + if e.LoggerName != "audit" { continue } + if findStringField(e, "event") == "api_action" && findStringField(e, "action") == "GET config" { + haveAPISuccess = true + break + } + } + if !haveAPISuccess { + t.Fatalf("expected APIAction success audit entry for GET config") + } }) t.Run("3. Get config - non-admin token", func(t *testing.T) { diff --git a/internal/server/handlers_login.go b/internal/server/handlers_login.go index 8f863098..3ed8cd33 100644 --- a/internal/server/handlers_login.go +++ b/internal/server/handlers_login.go @@ -54,23 +54,23 @@ func Login(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var loginParams LoginParams if err := json.NewDecoder(r.Body).Decode(&loginParams); err != nil { - writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.Logger) + writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.SystemLogger) return } if loginParams.Email == "" { err := errors.New("email is required") - writeError(w, http.StatusBadRequest, "Email is required", err, env.Logger) + writeError(w, http.StatusBadRequest, "Email is required", err, env.SystemLogger) return } if loginParams.Password == "" { err := errors.New("password is required") - writeError(w, http.StatusBadRequest, "Password is required", err, env.Logger) + writeError(w, http.StatusBadRequest, "Password is required", err, env.SystemLogger) return } userAccount, err := env.DB.GetUser(db.ByEmail(loginParams.Email)) if err != nil { if !errors.Is(err, db.ErrNotFound) && !errors.Is(err, db.ErrInvalidFilter) { - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } } @@ -79,12 +79,16 @@ func Login(env *HandlerConfig) http.HandlerFunc { hashedPassword = userAccount.HashedPassword } if err := hashing.CompareHashAndPassword(hashedPassword, loginParams.Password); err != nil { - writeError(w, http.StatusUnauthorized, "The email or password is incorrect", err, env.Logger) + env.AuditLogger.LoginFailed(loginParams.Email, + WithRequest(r), + WithReason("invalid credentials"), + ) + writeError(w, http.StatusUnauthorized, "The email or password is incorrect", err, env.SystemLogger) return } jwt, err := generateJWT(userAccount.ID, userAccount.Email, env.JWTSecret, RoleID(userAccount.RoleID)) if err != nil { - writeError(w, http.StatusInternalServerError, "Internal Error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } loginResponse := LoginResponse{ @@ -92,8 +96,11 @@ func Login(env *HandlerConfig) http.HandlerFunc { } err = writeResponse(w, loginResponse, http.StatusOK) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } + + env.AuditLogger.LoginSuccess(userAccount.Email, WithRequest(r)) + env.AuditLogger.TokenCreated(userAccount.Email, WithRequest(r)) } } diff --git a/internal/server/handlers_login_test.go b/internal/server/handlers_login_test.go index e5a7e370..7c7e7340 100644 --- a/internal/server/handlers_login_test.go +++ b/internal/server/handlers_login_test.go @@ -9,8 +9,8 @@ import ( ) func TestLoginEndToEnd(t *testing.T) { - ts := tu.MustPrepareServer(t) - client := ts.Client() + ts, logs := tu.MustPrepareServer(t) + client := ts.Client() t.Run("Create admin user", func(t *testing.T) { adminUser := &tu.CreateAccountParams{ @@ -27,7 +27,8 @@ func TestLoginEndToEnd(t *testing.T) { } }) - t.Run("Login success", func(t *testing.T) { + t.Run("Login success", func(t *testing.T) { + _ = logs.TakeAll() adminUser := &tu.LoginParams{ Email: "testadmin@canonical.com", Password: "Admin123", @@ -46,11 +47,31 @@ func TestLoginEndToEnd(t *testing.T) { if err != nil { t.Fatalf("couldn't parse token: %s", err) } - if claims, ok := token.Claims.(jwt.MapClaims); ok { + if claims, ok := token.Claims.(jwt.MapClaims); ok { if claims["email"] != "testadmin@canonical.com" { t.Fatalf("expected email %q, got %q", "testadmin@canonical.com", claims["email"]) } } + + entries := logs.TakeAll() + var haveLoginSuccess, haveTokenCreated bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + switch findStringField(e, "event") { + case "authn_login_success:testadmin@canonical.com": + haveLoginSuccess = true + case "authn_token_created:testadmin@canonical.com": + haveTokenCreated = true + } + } + if !haveLoginSuccess { + t.Fatalf("expected LoginSuccess audit entry") + } + if !haveTokenCreated { + t.Fatalf("expected TokenCreated audit entry") + } }) t.Run("Login failure missing email", func(t *testing.T) { @@ -87,23 +108,37 @@ func TestLoginEndToEnd(t *testing.T) { } }) - t.Run("Login failure invalid password", func(t *testing.T) { - invalidUser := &tu.LoginParams{ - Email: "testadmin@canonical.com", - Password: "a-wrong-password", - } - statusCode, loginResponse, err := tu.Login(ts.URL, client, invalidUser) - if err != nil { - t.Fatalf("couldn't login admin user: %s", err) - } - if statusCode != http.StatusUnauthorized { - t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, statusCode) - } - - if loginResponse.Error != "The email or password is incorrect" { - t.Fatalf("expected error %q, got %q", "The email or password is incorrect", loginResponse.Error) - } - }) + t.Run("Login failure invalid password (with audit)", func(t *testing.T) { + _ = logs.TakeAll() + invalidUser := &tu.LoginParams{ + Email: "testadmin@canonical.com", + Password: "a-wrong-password", + } + statusCode, loginResponse, err := tu.Login(ts.URL, client, invalidUser) + if err != nil { + t.Fatalf("couldn't login admin user: %s", err) + } + if statusCode != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, statusCode) + } + if loginResponse.Error != "The email or password is incorrect" { + t.Fatalf("expected error %q, got %q", "The email or password is incorrect", loginResponse.Error) + } + entries := logs.TakeAll() + var haveLoginFailed bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + if findStringField(e, "event") == "authn_login_fail:testadmin@canonical.com" && findStringField(e, "reason") == "invalid credentials" { + haveLoginFailed = true + break + } + } + if !haveLoginFailed { + t.Fatalf("expected LoginFailed audit entry with reason 'invalid credentials'") + } + }) t.Run("Login failure invalid email", func(t *testing.T) { invalidUser := &tu.LoginParams{ diff --git a/internal/server/handlers_status.go b/internal/server/handlers_status.go index 32930668..a1dafbb5 100644 --- a/internal/server/handlers_status.go +++ b/internal/server/handlers_status.go @@ -17,7 +17,7 @@ func GetStatus(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { numUsers, err := env.DB.NumUsers() if err != nil { - writeError(w, http.StatusInternalServerError, "couldn't generate status", err, env.Logger) + writeError(w, http.StatusInternalServerError, "couldn't generate status", err, env.SystemLogger) return } statusResponse := StatusResponse{ @@ -26,7 +26,7 @@ func GetStatus(env *HandlerConfig) http.HandlerFunc { } err = writeResponse(w, statusResponse, http.StatusOK) if err != nil { - writeError(w, http.StatusInternalServerError, "internal error", err, env.Logger) + writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } } diff --git a/internal/server/handlers_status_test.go b/internal/server/handlers_status_test.go index f2308049..8e6b68c8 100644 --- a/internal/server/handlers_status_test.go +++ b/internal/server/handlers_status_test.go @@ -9,7 +9,7 @@ import ( ) func TestStatus(t *testing.T) { - ts := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) client := ts.Client() t.Run("status not initialized", func(t *testing.T) { diff --git a/internal/server/middleware.go b/internal/server/middleware.go index 8648406a..61af3a13 100644 --- a/internal/server/middleware.go +++ b/internal/server/middleware.go @@ -26,7 +26,8 @@ const ( type middlewareContext struct { responseStatusCode int jwtSecret []byte - logger *zap.Logger + systemLogger *zap.Logger + auditLogger *AuditLogger } // createMiddlewareStack chains the given middleware functions to wrap the api. @@ -89,7 +90,7 @@ func loggingMiddleware(ctx *middlewareContext) middleware { // Suppress logging for static files if !strings.HasPrefix(r.URL.Path, "/_next") { - ctx.logger.Info("Request", zap.String("method", r.Method), zap.String("path", r.URL.Path), zap.Int("status_code", clonedWriter.statusCode), zap.String("status_text", http.StatusText(clonedWriter.statusCode))) + ctx.systemLogger.Info("HTTP request completed", zap.String("method", r.Method), zap.String("path", r.URL.Path), zap.Int("status_code", clonedWriter.statusCode), zap.String("status_text", http.StatusText(clonedWriter.statusCode))) } ctx.responseStatusCode = clonedWriter.statusCode @@ -97,23 +98,119 @@ func loggingMiddleware(ctx *middlewareContext) middleware { } } -func requirePermission(permission string, jwtSecret []byte, handler http.HandlerFunc, logger *zap.Logger) http.HandlerFunc { +// auditLoggingMiddleware logs API requests to the audit log. +// It logs all failed requests, and also successful read-only (GET/HEAD) requests. +func auditLoggingMiddleware(ctx *middlewareContext) middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + var actor string + claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), ctx.jwtSecret) + if err == nil { + actor = claims.Email + } + + action := buildActionDescription(r.Method, r.URL.Path) + resourceID := extractResourceID(r.URL.Path) + resourceType := extractResourceType(r.URL.Path) + + opts := []AuditOption{WithRequest(r)} + if actor != "" { + opts = append(opts, WithActor(actor)) + } + if resourceID != "" { + opts = append(opts, WithResourceID(resourceID)) + } + if resourceType != "" { + opts = append(opts, WithResourceType(resourceType)) + } + + if ctx.responseStatusCode >= 400 { + opts = append(opts, WithReason(fmt.Sprintf("HTTP %d: %s", ctx.responseStatusCode, http.StatusText(ctx.responseStatusCode)))) + ctx.auditLogger.APIAction(action+" (failed)", opts...) + } + if ctx.responseStatusCode < 400 && (r.Method == http.MethodGet || r.Method == http.MethodHead) { + ctx.auditLogger.APIAction(action, opts...) + } + }) + } +} + +// buildActionDescription returns a minimal deterministic description from HTTP method and path. +// It returns "METHOD path" where the leading slash is trimmed from the path. +// Examples: +// - GET /certificate_requests -> "GET certificate_requests" +// - POST /users -> "POST users" +// - DELETE /certificate_authorities/5 -> "DELETE certificate_authorities/5" +func buildActionDescription(method, path string) string { + // Minimal, deterministic: "METHOD path-without-leading-slash" + cleanPath := strings.Trim(path, "/") + if cleanPath == "" { + return method + } + return method + " " + cleanPath +} + +// extractResourceID extracts the resource ID from the URL path if present. +// Examples: +// - /users/123 -> "123" +// - /certificate_authorities/5 -> "5" +func extractResourceID(path string) string { + // Expect formats like: /{resource}/{id} or /{resource}/{id}/{subresource} + cleanPath := strings.Trim(path, "/") + parts := strings.Split(cleanPath, "/") + if len(parts) > 1 { + if _, err := strconv.ParseInt(parts[1], 10, 64); err == nil { + return parts[1] + } + } + return "" +} + +// extractResourceType returns the first path segment as the resource type. +// No singularization is performed. +// Examples: +// - /users -> "users" +// - /certificate_requests/123 -> "certificate_requests" +func extractResourceType(path string) string { + cleanPath := strings.Trim(path, "/") + parts := strings.Split(cleanPath, "/") + if len(parts) > 0 && parts[0] != "" { + return parts[0] + } + return "" +} + +func requirePermission(permission string, jwtSecret []byte, handler http.HandlerFunc, systemLogger *zap.Logger, auditLogger *AuditLogger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), jwtSecret) if err != nil { - writeError(w, http.StatusUnauthorized, "Unauthorized", err, logger) + auditLogger.UnauthorizedAccess( + WithRequest(r), + WithReason("invalid or missing JWT token"), + ) + writeError(w, http.StatusUnauthorized, "Unauthorized", err, systemLogger) return } roleID := claims.RoleID permissions, ok := PermissionsByRole[roleID] if !ok { - writeError(w, http.StatusForbidden, "forbidden: unknown role", errors.New("role not found"), logger) + auditLogger.UnauthorizedAccess( + WithActor(claims.Email), + WithRequest(r), + WithReason("unknown role"), + ) + writeError(w, http.StatusForbidden, "forbidden: unknown role", errors.New("role not found"), systemLogger) return } - if !hasPermission(permissions, permission) { - writeError(w, http.StatusForbidden, "forbidden: insufficient permissions", errors.New("missing permission"), logger) + if !hasPermission(permissions, permission) { + auditLogger.AccessDenied(claims.Email, r.URL.Path, permission, + WithRequest(r), + WithReason("insufficient permissions"), + ) + writeError(w, http.StatusForbidden, "forbidden: insufficient permissions", errors.New("missing permission"), systemLogger) return } @@ -130,30 +227,36 @@ func hasPermission(userPermissions []string, required string) bool { return false } -func requirePermissionOrFirstUser(permission string, jwtSecret []byte, db *db.Database, handler http.HandlerFunc, logger *zap.Logger) http.HandlerFunc { +func requirePermissionOrFirstUser(permission string, jwtSecret []byte, db *db.Database, handler http.HandlerFunc, systemLogger *zap.Logger, auditLogger *AuditLogger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { numUsers, err := db.NumUsers() if err != nil { - writeError(w, http.StatusInternalServerError, "Internal Error", err, logger) + writeError(w, http.StatusInternalServerError, "Internal Error", err, systemLogger) return } - // If no users exist, allow the request through (initial setup case) if numUsers == 0 { handler(w, r) return } - // Otherwise validate permissions claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), jwtSecret) if err != nil { - writeError(w, http.StatusUnauthorized, "Unauthorized", err, logger) + auditLogger.UnauthorizedAccess( + WithRequest(r), + WithReason("invalid or missing JWT token"), + ) + writeError(w, http.StatusUnauthorized, "Unauthorized", err, systemLogger) return } permissions, ok := PermissionsByRole[claims.RoleID] - if !ok || !hasPermission(permissions, permission) { - writeError(w, http.StatusForbidden, "forbidden: insufficient permissions", errors.New("missing required permission"), logger) + if !ok || !hasPermission(permissions, permission) { + auditLogger.AccessDenied(claims.Email, r.URL.Path, permission, + WithRequest(r), + WithReason("insufficient permissions"), + ) + writeError(w, http.StatusForbidden, "forbidden: insufficient permissions", errors.New("missing required permission"), systemLogger) return } diff --git a/internal/server/router.go b/internal/server/router.go index f101fc14..3d0f0585 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -12,47 +12,49 @@ import ( // then builds and returns it for a server to consume func NewRouter(config *HandlerConfig) http.Handler { apiV1Router := http.NewServeMux() - apiV1Router.HandleFunc("GET /certificate_requests", requirePermission(PermListCertificateRequests, config.JWTSecret, ListCertificateRequests(config), config.Logger)) - apiV1Router.HandleFunc("POST /certificate_requests", requirePermission(PermCreateCertificateRequest, config.JWTSecret, CreateCertificateRequest(config), config.Logger)) - apiV1Router.HandleFunc("GET /certificate_requests/{id}", requirePermission(PermReadCertificateRequest, config.JWTSecret, GetCertificateRequest(config), config.Logger)) - apiV1Router.HandleFunc("DELETE /certificate_requests/{id}", requirePermission(PermDeleteCertificateRequest, config.JWTSecret, DeleteCertificateRequest(config), config.Logger)) - apiV1Router.HandleFunc("POST /certificate_requests/{id}/reject", requirePermission(PermRejectCertificateRequest, config.JWTSecret, RejectCertificateRequest(config), config.Logger)) - apiV1Router.HandleFunc("POST /certificate_requests/{id}/sign", requirePermission(PermSignCertificateRequest, config.JWTSecret, SignCertificateRequest(config), config.Logger)) - apiV1Router.HandleFunc("POST /certificate_requests/{id}/certificate", requirePermission(PermCreateCertificateRequestCertificate, config.JWTSecret, PostCertificateRequestCertificate(config), config.Logger)) - apiV1Router.HandleFunc("DELETE /certificate_requests/{id}/certificate", requirePermission(PermDeleteCertificateRequestCertificate, config.JWTSecret, DeleteCertificate(config), config.Logger)) - apiV1Router.HandleFunc("POST /certificate_requests/{id}/certificate/revoke", requirePermission(PermRevokeCertificateRequestCertificate, config.JWTSecret, RevokeCertificate(config), config.Logger)) + apiV1Router.HandleFunc("GET /certificate_requests", requirePermission(PermListCertificateRequests, config.JWTSecret, ListCertificateRequests(config), config.SystemLogger, config.AuditLogger)) + apiV1Router.HandleFunc("POST /certificate_requests", requirePermission(PermCreateCertificateRequest, config.JWTSecret, CreateCertificateRequest(config), config.SystemLogger, config.AuditLogger)) + apiV1Router.HandleFunc("GET /certificate_requests/{id}", requirePermission(PermReadCertificateRequest, config.JWTSecret, GetCertificateRequest(config), config.SystemLogger, config.AuditLogger)) + apiV1Router.HandleFunc("DELETE /certificate_requests/{id}", requirePermission(PermDeleteCertificateRequest, config.JWTSecret, DeleteCertificateRequest(config), config.SystemLogger, config.AuditLogger)) + apiV1Router.HandleFunc("POST /certificate_requests/{id}/reject", requirePermission(PermRejectCertificateRequest, config.JWTSecret, RejectCertificateRequest(config), config.SystemLogger, config.AuditLogger)) + apiV1Router.HandleFunc("POST /certificate_requests/{id}/sign", requirePermission(PermSignCertificateRequest, config.JWTSecret, SignCertificateRequest(config), config.SystemLogger, config.AuditLogger)) + apiV1Router.HandleFunc("POST /certificate_requests/{id}/certificate", requirePermission(PermCreateCertificateRequestCertificate, config.JWTSecret, PostCertificateRequestCertificate(config), config.SystemLogger, config.AuditLogger)) + apiV1Router.HandleFunc("DELETE /certificate_requests/{id}/certificate", requirePermission(PermDeleteCertificateRequestCertificate, config.JWTSecret, DeleteCertificate(config), config.SystemLogger, config.AuditLogger)) + apiV1Router.HandleFunc("POST /certificate_requests/{id}/certificate/revoke", requirePermission(PermRevokeCertificateRequestCertificate, config.JWTSecret, RevokeCertificate(config), config.SystemLogger, config.AuditLogger)) - apiV1Router.HandleFunc("GET /certificate_authorities", requirePermission(PermListCertificateAuthorities, config.JWTSecret, ListCertificateAuthorities(config), config.Logger)) - apiV1Router.HandleFunc("POST /certificate_authorities", requirePermission(PermCreateCertificateAuthority, config.JWTSecret, CreateCertificateAuthority(config), config.Logger)) - apiV1Router.HandleFunc("GET /certificate_authorities/{id}", requirePermission(PermReadCertificateAuthority, config.JWTSecret, GetCertificateAuthority(config), config.Logger)) - apiV1Router.HandleFunc("PUT /certificate_authorities/{id}", requirePermission(PermUpdateCertificateAuthority, config.JWTSecret, UpdateCertificateAuthority(config), config.Logger)) - apiV1Router.HandleFunc("DELETE /certificate_authorities/{id}", requirePermission(PermDeleteCertificateAuthority, config.JWTSecret, DeleteCertificateAuthority(config), config.Logger)) - apiV1Router.HandleFunc("POST /certificate_authorities/{id}/sign", requirePermission(PermSignCertificateAuthorityCertificate, config.JWTSecret, SignCertificateAuthority(config), config.Logger)) - apiV1Router.HandleFunc("POST /certificate_authorities/{id}/certificate", requirePermission(PermCreateCertificateAuthorityCertificate, config.JWTSecret, PostCertificateAuthorityCertificate(config), config.Logger)) + apiV1Router.HandleFunc("GET /certificate_authorities", requirePermission(PermListCertificateAuthorities, config.JWTSecret, ListCertificateAuthorities(config), config.SystemLogger, config.AuditLogger)) + apiV1Router.HandleFunc("POST /certificate_authorities", requirePermission(PermCreateCertificateAuthority, config.JWTSecret, CreateCertificateAuthority(config), config.SystemLogger, config.AuditLogger)) + apiV1Router.HandleFunc("GET /certificate_authorities/{id}", requirePermission(PermReadCertificateAuthority, config.JWTSecret, GetCertificateAuthority(config), config.SystemLogger, config.AuditLogger)) + apiV1Router.HandleFunc("PUT /certificate_authorities/{id}", requirePermission(PermUpdateCertificateAuthority, config.JWTSecret, UpdateCertificateAuthority(config), config.SystemLogger, config.AuditLogger)) + apiV1Router.HandleFunc("DELETE /certificate_authorities/{id}", requirePermission(PermDeleteCertificateAuthority, config.JWTSecret, DeleteCertificateAuthority(config), config.SystemLogger, config.AuditLogger)) + apiV1Router.HandleFunc("POST /certificate_authorities/{id}/sign", requirePermission(PermSignCertificateAuthorityCertificate, config.JWTSecret, SignCertificateAuthority(config), config.SystemLogger, config.AuditLogger)) + apiV1Router.HandleFunc("POST /certificate_authorities/{id}/certificate", requirePermission(PermCreateCertificateAuthorityCertificate, config.JWTSecret, PostCertificateAuthorityCertificate(config), config.SystemLogger, config.AuditLogger)) apiV1Router.HandleFunc("GET /certificate_authorities/{id}/crl", GetCertificateAuthorityCRL(config)) - apiV1Router.HandleFunc("POST /certificate_authorities/{id}/revoke", requirePermission(PermRevokeCertificateAuthorityCertificate, config.JWTSecret, RevokeCertificateAuthorityCertificate(config), config.Logger)) + apiV1Router.HandleFunc("POST /certificate_authorities/{id}/revoke", requirePermission(PermRevokeCertificateAuthorityCertificate, config.JWTSecret, RevokeCertificateAuthorityCertificate(config), config.SystemLogger, config.AuditLogger)) - apiV1Router.HandleFunc("GET /accounts", requirePermission(PermListUsers, config.JWTSecret, ListAccounts(config), config.Logger)) - apiV1Router.HandleFunc("POST /accounts", requirePermissionOrFirstUser(PermCreateUser, config.JWTSecret, config.DB, CreateAccount(config), config.Logger)) - apiV1Router.HandleFunc("GET /accounts/{id}", requirePermission(PermReadUser, config.JWTSecret, GetAccount(config), config.Logger)) - apiV1Router.HandleFunc("DELETE /accounts/{id}", requirePermission(PermDeleteUser, config.JWTSecret, DeleteAccount(config), config.Logger)) - apiV1Router.HandleFunc("POST /accounts/{id}/change_password", requirePermission(PermUpdateUserPassword, config.JWTSecret, ChangeAccountPassword(config), config.Logger)) - apiV1Router.HandleFunc("POST /accounts/me/change_password", requirePermission(PermUpdateMyPassword, config.JWTSecret, ChangeMyPassword(config), config.Logger)) + apiV1Router.HandleFunc("GET /accounts", requirePermission(PermListUsers, config.JWTSecret, ListAccounts(config), config.SystemLogger, config.AuditLogger)) + apiV1Router.HandleFunc("POST /accounts", requirePermissionOrFirstUser(PermCreateUser, config.JWTSecret, config.DB, CreateAccount(config), config.SystemLogger, config.AuditLogger)) + apiV1Router.HandleFunc("GET /accounts/{id}", requirePermission(PermReadUser, config.JWTSecret, GetAccount(config), config.SystemLogger, config.AuditLogger)) + apiV1Router.HandleFunc("DELETE /accounts/{id}", requirePermission(PermDeleteUser, config.JWTSecret, DeleteAccount(config), config.SystemLogger, config.AuditLogger)) + apiV1Router.HandleFunc("POST /accounts/{id}/change_password", requirePermission(PermUpdateUserPassword, config.JWTSecret, ChangeAccountPassword(config), config.SystemLogger, config.AuditLogger)) + apiV1Router.HandleFunc("POST /accounts/me/change_password", requirePermission(PermUpdateMyPassword, config.JWTSecret, ChangeMyPassword(config), config.SystemLogger, config.AuditLogger)) - apiV1Router.HandleFunc("GET /config", requirePermission(PermReadConfig, config.JWTSecret, GetConfigContent(config), config.Logger)) + apiV1Router.HandleFunc("GET /config", requirePermission(PermReadConfig, config.JWTSecret, GetConfigContent(config), config.SystemLogger, config.AuditLogger)) - m := metrics.NewMetricsSubsystem(config.DB, config.Logger) + m := metrics.NewMetricsSubsystem(config.DB, config.SystemLogger) frontendHandler, err := newFrontendFileServer() if err != nil { - config.Logger.Fatal("Failed to create frontend file server", zap.Error(err)) + config.SystemLogger.Fatal("Failed to create frontend file server", zap.Error(err)) } ctx := middlewareContext{ - jwtSecret: config.JWTSecret, - logger: config.Logger, + jwtSecret: config.JWTSecret, + systemLogger: config.SystemLogger, + auditLogger: config.AuditLogger, } apiMiddlewareStack := createMiddlewareStack( - limitRequestSize(MAX_KILOBYTES, config.Logger), + limitRequestSize(MAX_KILOBYTES, config.SystemLogger), metricsMiddleware(m), + auditLoggingMiddleware(&ctx), loggingMiddleware(&ctx), ) metricsMiddlewareStack := createMiddlewareStack( diff --git a/internal/server/server.go b/internal/server/server.go index 17fb7b03..aa0151bc 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -15,7 +15,8 @@ import ( type HandlerConfig struct { DB *db.Database - Logger *zap.Logger + SystemLogger *zap.Logger + AuditLogger *AuditLogger ExternalHostname string JWTSecret []byte SendPebbleNotifications bool @@ -28,7 +29,7 @@ func New(opts *ServerOpts) (*Server, error) { if err != nil { return nil, err } - stdErrLog, err := zap.NewStdLogAt(opts.Logger, zapcore.ErrorLevel) + stdErrLog, err := zap.NewStdLogAt(opts.SystemLogger, zapcore.ErrorLevel) if err != nil { return nil, fmt.Errorf("failed to create logger for http server: %w", err) } @@ -37,7 +38,8 @@ func New(opts *ServerOpts) (*Server, error) { cfg.SendPebbleNotifications = opts.EnablePebbleNotifications cfg.JWTSecret = opts.Database.JWTSecret cfg.ExternalHostname = opts.ExternalHostname - cfg.Logger = opts.Logger + cfg.SystemLogger = opts.SystemLogger + cfg.AuditLogger = NewAuditLogger(opts.AuditLogger) cfg.PublicConfig = *opts.PublicConfig cfg.DB = opts.Database diff --git a/internal/server/server_test.go b/internal/server/server_test.go index a9b940aa..a38d95ee 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -17,14 +17,15 @@ func TestNewSuccess(t *testing.T) { if err != nil { t.Fatalf("cannot create logger: %s", err) } - s, err := server.New(&server.ServerOpts{ + s, err := server.New(&server.ServerOpts{ Port: 8000, TLSCertificate: []byte(tu.TestServerCertificate), TLSPrivateKey: []byte(tu.TestServerKey), Database: db, ExternalHostname: "example.com", EnablePebbleNotifications: false, - Logger: l, + SystemLogger: l, + AuditLogger: l, PublicConfig: &tu.PublicConfig, }) if err != nil { @@ -48,7 +49,7 @@ func TestInvalidKeyFailure(t *testing.T) { TLSPrivateKey: []byte{}, ExternalHostname: "example.com", EnablePebbleNotifications: false, - Logger: l, + SystemLogger: l, PublicConfig: &tu.PublicConfig, }) if err == nil { @@ -57,7 +58,7 @@ func TestInvalidKeyFailure(t *testing.T) { } func TestRequestOverload(t *testing.T) { - ts := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) client := ts.Client() adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") diff --git a/internal/server/types.go b/internal/server/types.go index 694a2bcf..ef0f49b8 100644 --- a/internal/server/types.go +++ b/internal/server/types.go @@ -24,7 +24,8 @@ type ServerOpts struct { // Database object to run SQL queries on Database *db.Database - Logger *zap.Logger + SystemLogger *zap.Logger // For operational/system logs + AuditLogger *zap.Logger // For audit/compliance logs } type Server struct { diff --git a/internal/testutils/server_test_utils.go b/internal/testutils/server_test_utils.go index b6ab8aa4..15e80252 100644 --- a/internal/testutils/server_test_utils.go +++ b/internal/testutils/server_test_utils.go @@ -15,20 +15,29 @@ import ( "time" "github.com/canonical/notary/internal/server" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest/observer" ) -func MustPrepareServer(t *testing.T) *httptest.Server { +// MustPrepareServer starts a test server and returns it along with observed audit logs. +func MustPrepareServer(t *testing.T) (*httptest.Server, *observer.ObservedLogs) { t.Helper() db := MustPrepareEmptyDB(t) - srv, err := server.New(&server.ServerOpts{ + // Attach observed audit logger + core, logs := observer.New(zapcore.InfoLevel) + auditZap := zap.New(core) + + srv, err := server.New(&server.ServerOpts{ Port: 8000, TLSCertificate: []byte(TestServerCertificate), TLSPrivateKey: []byte(TestServerKey), Database: db, ExternalHostname: "example.com", EnablePebbleNotifications: false, - Logger: logger, + SystemLogger: logger, + AuditLogger: auditZap, PublicConfig: &PublicConfig, }) if err != nil { @@ -38,7 +47,7 @@ func MustPrepareServer(t *testing.T) *httptest.Server { t.Cleanup(func() { testServer.Close() }) - return testServer + return testServer, logs } func MustPrepareAccount(t *testing.T, ts *httptest.Server, email string, roleID RoleID, token string) string { @@ -527,10 +536,14 @@ type UpdateCertificateAuthorityResponse struct { } func UpdateCertificateAuthority(url string, client *http.Client, token string, id int, status UpdateCertificateAuthorityParams) (int, *UpdateCertificateAuthorityResponse, error) { - reqData, err := json.Marshal(status) - if err != nil { - return 0, nil, err - } + enabled := status.Status == "active" + payload := struct{ + Enabled bool `json:"enabled"` + }{Enabled: enabled} + reqData, err := json.Marshal(payload) + if err != nil { + return 0, nil, err + } req, err := http.NewRequest("PUT", url+"/api/v1/certificate_authorities/"+strconv.Itoa(id), bytes.NewReader(reqData)) if err != nil { return 0, nil, err From d4258f231d68234f5d775837151dfc002f9b46e2 Mon Sep 17 00:00:00 2001 From: yazansalti Date: Mon, 13 Oct 2025 17:35:03 +0200 Subject: [PATCH 2/4] Address review comments --- internal/config/config.go | 41 +++++----- internal/config/types.go | 76 ------------------- internal/{server => logging}/audit_logger.go | 3 +- internal/{server => logging}/audit_options.go | 2 +- internal/server/handlers_accounts.go | 51 +++++++------ .../handlers_certificate_authorities.go | 21 ++--- .../server/handlers_certificate_requests.go | 23 +++--- internal/server/handlers_login.go | 9 ++- internal/server/middleware.go | 39 +++++----- internal/server/server.go | 5 +- 10 files changed, 99 insertions(+), 171 deletions(-) rename internal/{server => logging}/audit_logger.go (99%) rename internal/{server => logging}/audit_options.go (99%) diff --git a/internal/config/config.go b/internal/config/config.go index ff628f0f..5c4bcd98 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -37,22 +37,14 @@ func CreateAppContext(cmdFlags *pflag.FlagSet, configFilePath string) (*NotaryAp } // initialize system logger - systemLogger, err := initializeLogger( - cfg.GetString("logging.system.level"), - cfg.GetString("logging.system.output"), - cfg.GetString("logging.system.path"), - ) + systemLogger, err := initializeLogger(cfg.Sub("logging.system"), "") if err != nil { return nil, fmt.Errorf("couldn't initialize system logger: %w", err) } // initialize audit logger // Audit logs are always at INFO level - auditLogger, err := initializeLogger( - "info", - cfg.GetString("logging.audit.output"), - cfg.GetString("logging.audit.path"), - ) + auditLogger, err := initializeLogger(cfg.Sub("logging.audit"), "info") if err != nil { return nil, fmt.Errorf("couldn't initialize audit logger: %w", err) } @@ -226,26 +218,31 @@ func initializeEncryptionBackend(encryptionCfg *viper.Viper, logger *zap.Logger) } } -// initializeLogger creates and configures a logger based on the provided parameters. -// output can be "stdout", "stderr", or "file" -// path is required when output is "file" -func initializeLogger(level, output, path string) (*zap.Logger, error) { +// initializeLogger creates and configures a logger based on the provided configuration. +// cfg is the logger configuration subsection (e.g., logging.system or logging.audit). +// levelOverride allows overriding the configured level (e.g., "info" for audit logs). +// If levelOverride is empty, the level from cfg is used. +// output can be "stdout", "stderr", or a file path. +func initializeLogger(cfg *viper.Viper, levelOverride string) (*zap.Logger, error) { + if cfg == nil { + return nil, fmt.Errorf("logger configuration is not defined") + } + zapConfig := zap.NewProductionConfig() + level := levelOverride + if level == "" { + level = cfg.GetString("level") + } + logLevel, err := zapcore.ParseLevel(level) if err != nil { return nil, fmt.Errorf("invalid log level: %w", err) } zapConfig.Level.SetLevel(logLevel) - if output == "file" { - if path == "" { - return nil, fmt.Errorf("path is required when output is 'file'") - } - zapConfig.OutputPaths = []string{path} - } else { - zapConfig.OutputPaths = []string{output} - } + output := cfg.GetString("output") + zapConfig.OutputPaths = []string{output} zapConfig.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder diff --git a/internal/config/types.go b/internal/config/types.go index f5965459..6d260113 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -13,80 +13,6 @@ const ( EncryptionBackendTypeNone = "none" ) -// VaultBackendConfigYaml BackendConfig for Vault-specific fields. -type VaultBackendConfigYaml struct { - Endpoint string `yaml:"endpoint"` - Mount string `yaml:"mount"` - KeyName string `yaml:"key_name"` - Token string `yaml:"token"` - AppRoleID string `yaml:"approle_role_id"` - AppRoleSecretID string `yaml:"approle_secret_id"` - TlsCaCertificate string `yaml:"tls_ca_cert,omitempty"` // Optional path to a CA file for Vault TLS verification - TlsSkipVerify bool `yaml:"tls_skip_verify,omitempty"` // Optional flag to skip TLS verification -} - -// PKCS11BackendConfigYaml BackendConfig for PKCS11-specific fields. -type PKCS11BackendConfigYaml struct { - LibPath string `yaml:"lib_path"` - KeyID uint16 `yaml:"aes_encryption_key_id"` - Pin string `yaml:"pin"` -} - -// NamedBackendConfigYaml represents a single named backend configuration -type NamedBackendConfigYaml struct { - PKCS11 *PKCS11BackendConfigYaml `yaml:"pkcs11,omitempty"` - Vault *VaultBackendConfigYaml `yaml:"vault,omitempty"` -} - -type EncryptionBackendConfigYaml map[string]NamedBackendConfigYaml - -type SystemLoggingConfigYaml struct { - Level string `yaml:"level"` - Output string `yaml:"output"` - Path string `yaml:"path"` -} - -type AuditLoggingConfigYaml struct { - Output string `yaml:"output"` - Path string `yaml:"path"` -} - -type LoggingConfigYaml struct { - System SystemLoggingConfigYaml `yaml:"system"` - Audit AuditLoggingConfigYaml `yaml:"audit"` -} - -type ConfigYAML struct { - KeyPath string `yaml:"key_path"` - CertPath string `yaml:"cert_path"` - ExternalHostname string `yaml:"external_hostname"` - DBPath string `yaml:"db_path"` - Port int `yaml:"port"` - PebbleNotifications bool `yaml:"pebble_notifications"` - Logging LoggingConfigYaml `yaml:"logging"` - EncryptionBackend EncryptionBackendConfigYaml `yaml:"encryption_backend"` -} - -type LoggingLevel string - -const ( - Debug LoggingLevel = "debug" - Info LoggingLevel = "info" - Warn LoggingLevel = "warn" - Error LoggingLevel = "error" - Fatal LoggingLevel = "fatal" - Panic LoggingLevel = "panic" -) - -type SystemLoggingOptions struct { - Level LoggingLevel - Output string -} - -type LoggerOptions struct { - System SystemLoggingOptions -} - // PublicConfigData contains non-sensitive configuration fields that are safe to expose type PublicConfigData struct { Port int @@ -97,8 +23,6 @@ type PublicConfigData struct { } type NotaryAppContext struct { - // The YAML configuration file content - Config *ConfigYAML PublicConfig *PublicConfigData // TLSPrivateKey and Certificate for the webserver and the listener port diff --git a/internal/server/audit_logger.go b/internal/logging/audit_logger.go similarity index 99% rename from internal/server/audit_logger.go rename to internal/logging/audit_logger.go index ee2ac166..db6188f9 100644 --- a/internal/server/audit_logger.go +++ b/internal/logging/audit_logger.go @@ -1,4 +1,4 @@ -package server +package logging import ( "fmt" @@ -419,3 +419,4 @@ func (a *AuditLogger) CertificateRevoked(csrID string, opts ...AuditOption) { a.logger.Warn("Certificate revoked", fields...) } + diff --git a/internal/server/audit_options.go b/internal/logging/audit_options.go similarity index 99% rename from internal/server/audit_options.go rename to internal/logging/audit_options.go index 3261d06b..8a1a13c5 100644 --- a/internal/server/audit_options.go +++ b/internal/logging/audit_options.go @@ -1,4 +1,4 @@ -package server +package logging import ( "net/http" diff --git a/internal/server/handlers_accounts.go b/internal/server/handlers_accounts.go index bf409594..2d89f2d9 100644 --- a/internal/server/handlers_accounts.go +++ b/internal/server/handlers_accounts.go @@ -10,6 +10,7 @@ import ( "strconv" "github.com/canonical/notary/internal/db" + "github.com/canonical/notary/internal/logging" ) type CreateAccountParams struct { @@ -175,9 +176,9 @@ func CreateAccount(env *HandlerConfig) http.HandlerFunc { actor = claims.Email } - opts := []AuditOption{WithRequest(r)} + opts := []logging.AuditOption{logging.WithRequest(r)} if actor != "" { - opts = append(opts, WithActor(actor)) + opts = append(opts, logging.WithActor(actor)) } env.AuditLogger.UserCreated(createAccountParams.Email, int(createAccountParams.RoleID), opts...) @@ -232,8 +233,8 @@ func DeleteAccount(env *HandlerConfig) http.HandlerFunc { } env.AuditLogger.UserDeleted(account.Email, - WithActor(claims.Email), - WithRequest(r), + logging.WithActor(claims.Email), + logging.WithRequest(r), ) successResponse := SuccessResponse{Message: "success"} @@ -280,9 +281,9 @@ func ChangeAccountPassword(env *HandlerConfig) http.HandlerFunc { valid, err := changeAccountParams.IsValid() if !valid { env.AuditLogger.PasswordChangeFailed(targetAccount.Email, - WithActor(claims.Email), - WithRequest(r), - WithReason(err.Error()), + logging.WithActor(claims.Email), + logging.WithRequest(r), + logging.WithReason(err.Error()), ) writeError(w, http.StatusBadRequest, fmt.Errorf("Invalid request: %s", err).Error(), err, env.SystemLogger) return @@ -291,29 +292,29 @@ func ChangeAccountPassword(env *HandlerConfig) http.HandlerFunc { if err != nil { if errors.Is(err, db.ErrNotFound) { env.AuditLogger.PasswordChangeFailed(targetAccount.Email, - WithActor(claims.Email), - WithRequest(r), - WithReason("user not found"), + logging.WithActor(claims.Email), + logging.WithRequest(r), + logging.WithReason("user not found"), ) writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } env.AuditLogger.PasswordChangeFailed(targetAccount.Email, - WithActor(claims.Email), - WithRequest(r), - WithReason("database error"), + logging.WithActor(claims.Email), + logging.WithRequest(r), + logging.WithReason("database error"), ) writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } env.AuditLogger.PasswordChanged(targetAccount.Email, - WithActor(claims.Email), - WithRequest(r), + logging.WithActor(claims.Email), + logging.WithRequest(r), ) env.AuditLogger.UserUpdated(targetAccount.Email, "password_change", - WithActor(claims.Email), - WithRequest(r), + logging.WithActor(claims.Email), + logging.WithRequest(r), ) successResponse := SuccessResponse{Message: "success"} @@ -347,8 +348,8 @@ func ChangeMyPassword(env *HandlerConfig) http.HandlerFunc { valid, err := changeAccountParams.IsValid() if !valid { env.AuditLogger.PasswordChangeFailed(account.Email, - WithRequest(r), - WithReason(err.Error()), + logging.WithRequest(r), + logging.WithReason(err.Error()), ) writeError(w, http.StatusBadRequest, fmt.Errorf("Invalid request: %s", err).Error(), err, env.SystemLogger) return @@ -357,22 +358,22 @@ func ChangeMyPassword(env *HandlerConfig) http.HandlerFunc { if err != nil { if errors.Is(err, db.ErrNotFound) { env.AuditLogger.PasswordChangeFailed(account.Email, - WithRequest(r), - WithReason("user not found"), + logging.WithRequest(r), + logging.WithReason("user not found"), ) writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) return } env.AuditLogger.PasswordChangeFailed(account.Email, - WithRequest(r), - WithReason("database error"), + logging.WithRequest(r), + logging.WithReason("database error"), ) writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } - env.AuditLogger.PasswordChanged(account.Email, WithRequest(r)) - env.AuditLogger.UserUpdated(account.Email, "password_change", WithRequest(r)) + env.AuditLogger.PasswordChanged(account.Email, logging.WithRequest(r)) + env.AuditLogger.UserUpdated(account.Email, "password_change", logging.WithRequest(r)) successResponse := SuccessResponse{Message: "success"} err = writeResponse(w, successResponse, http.StatusCreated) diff --git a/internal/server/handlers_certificate_authorities.go b/internal/server/handlers_certificate_authorities.go index 6cf0c977..52153d2f 100644 --- a/internal/server/handlers_certificate_authorities.go +++ b/internal/server/handlers_certificate_authorities.go @@ -17,6 +17,7 @@ import ( "time" "github.com/canonical/notary/internal/db" + "github.com/canonical/notary/internal/logging" "go.uber.org/zap" ) @@ -293,8 +294,8 @@ func CreateCertificateAuthority(env *HandlerConfig) http.HandlerFunc { } env.AuditLogger.CACreated(int(newCAID), params.CommonName, - WithActor(claims.Email), - WithRequest(r), + logging.WithActor(claims.Email), + logging.WithRequest(r), ) successResponse := CreateSuccessResponse{Message: "Certificate Authority created successfully", ID: newCAID} @@ -376,8 +377,8 @@ func UpdateCertificateAuthority(env *HandlerConfig) http.HandlerFunc { } env.AuditLogger.CAUpdated(id, params.Enabled, - WithActor(claims.Email), - WithRequest(r), + logging.WithActor(claims.Email), + logging.WithRequest(r), ) successResponse := SuccessResponse{Message: "success"} @@ -427,8 +428,8 @@ func DeleteCertificateAuthority(env *HandlerConfig) http.HandlerFunc { } env.AuditLogger.CADeleted(int(idNum), extractCommonName(ca.CertificateChain), - WithActor(claims.Email), - WithRequest(r), + logging.WithActor(claims.Email), + logging.WithRequest(r), ) successResponse := SuccessResponse{Message: "success"} @@ -478,8 +479,8 @@ func PostCertificateAuthorityCertificate(env *HandlerConfig) http.HandlerFunc { } env.AuditLogger.CACertificateUploaded(id, - WithActor(claims.Email), - WithRequest(r), + logging.WithActor(claims.Email), + logging.WithRequest(r), ) err = writeResponse(w, SuccessResponse{Message: "success"}, http.StatusCreated) @@ -610,8 +611,8 @@ func RevokeCertificateAuthorityCertificate(env *HandlerConfig) http.HandlerFunc } env.AuditLogger.CACertificateRevoked(id, - WithActor(claims.Email), - WithRequest(r), + logging.WithActor(claims.Email), + logging.WithRequest(r), ) if env.SendPebbleNotifications { diff --git a/internal/server/handlers_certificate_requests.go b/internal/server/handlers_certificate_requests.go index 9af95894..c9d5445c 100644 --- a/internal/server/handlers_certificate_requests.go +++ b/internal/server/handlers_certificate_requests.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/canonical/notary/internal/db" + "github.com/canonical/notary/internal/logging" "go.uber.org/zap" ) @@ -156,8 +157,8 @@ func CreateCertificateRequest(env *HandlerConfig) http.HandlerFunc { } env.AuditLogger.CertificateRequested(strconv.FormatInt(newCSRID, 10), 0, - WithActor(claims.Email), - WithRequest(r), + logging.WithActor(claims.Email), + logging.WithRequest(r), ) successResponse := CreateSuccessResponse{Message: "success", ID: newCSRID} @@ -279,8 +280,8 @@ func DeleteCertificateRequest(env *HandlerConfig) http.HandlerFunc { } env.AuditLogger.CertificateRequestDeleted(id, - WithActor(claims.Email), - WithRequest(r), + logging.WithActor(claims.Email), + logging.WithRequest(r), ) successResponse := SuccessResponse{Message: "success"} @@ -340,8 +341,8 @@ func PostCertificateRequestCertificate(env *HandlerConfig) http.HandlerFunc { } env.AuditLogger.CertificateIssued(id, 0, - WithActor(claims.Email), - WithRequest(r), + logging.WithActor(claims.Email), + logging.WithRequest(r), ) if env.SendPebbleNotifications { @@ -397,9 +398,9 @@ func RejectCertificateRequest(env *HandlerConfig) http.HandlerFunc { } env.AuditLogger.CertificateRejected(id, 0, - WithActor(claims.Email), - WithRequest(r), - WithReason("rejected by administrator"), + logging.WithActor(claims.Email), + logging.WithRequest(r), + logging.WithReason("rejected by administrator"), ) if env.SendPebbleNotifications { @@ -498,8 +499,8 @@ func RevokeCertificate(env *HandlerConfig) http.HandlerFunc { } env.AuditLogger.CertificateRevoked(id, - WithActor(claims.Email), - WithRequest(r), + logging.WithActor(claims.Email), + logging.WithRequest(r), ) if env.SendPebbleNotifications { diff --git a/internal/server/handlers_login.go b/internal/server/handlers_login.go index 3ed8cd33..6dcce4de 100644 --- a/internal/server/handlers_login.go +++ b/internal/server/handlers_login.go @@ -8,6 +8,7 @@ import ( "github.com/canonical/notary/internal/db" "github.com/canonical/notary/internal/hashing" + "github.com/canonical/notary/internal/logging" "github.com/golang-jwt/jwt/v5" ) @@ -80,8 +81,8 @@ func Login(env *HandlerConfig) http.HandlerFunc { } if err := hashing.CompareHashAndPassword(hashedPassword, loginParams.Password); err != nil { env.AuditLogger.LoginFailed(loginParams.Email, - WithRequest(r), - WithReason("invalid credentials"), + logging.WithRequest(r), + logging.WithReason("invalid credentials"), ) writeError(w, http.StatusUnauthorized, "The email or password is incorrect", err, env.SystemLogger) return @@ -100,7 +101,7 @@ func Login(env *HandlerConfig) http.HandlerFunc { return } - env.AuditLogger.LoginSuccess(userAccount.Email, WithRequest(r)) - env.AuditLogger.TokenCreated(userAccount.Email, WithRequest(r)) + env.AuditLogger.LoginSuccess(userAccount.Email, logging.WithRequest(r)) + env.AuditLogger.TokenCreated(userAccount.Email, logging.WithRequest(r)) } } diff --git a/internal/server/middleware.go b/internal/server/middleware.go index 61af3a13..ea9e05eb 100644 --- a/internal/server/middleware.go +++ b/internal/server/middleware.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/canonical/notary/internal/db" + "github.com/canonical/notary/internal/logging" "github.com/canonical/notary/internal/metrics" "github.com/golang-jwt/jwt/v5" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -27,7 +28,7 @@ type middlewareContext struct { responseStatusCode int jwtSecret []byte systemLogger *zap.Logger - auditLogger *AuditLogger + auditLogger *logging.AuditLogger } // createMiddlewareStack chains the given middleware functions to wrap the api. @@ -114,19 +115,19 @@ func auditLoggingMiddleware(ctx *middlewareContext) middleware { resourceID := extractResourceID(r.URL.Path) resourceType := extractResourceType(r.URL.Path) - opts := []AuditOption{WithRequest(r)} + opts := []logging.AuditOption{logging.WithRequest(r)} if actor != "" { - opts = append(opts, WithActor(actor)) + opts = append(opts, logging.WithActor(actor)) } if resourceID != "" { - opts = append(opts, WithResourceID(resourceID)) + opts = append(opts, logging.WithResourceID(resourceID)) } if resourceType != "" { - opts = append(opts, WithResourceType(resourceType)) + opts = append(opts, logging.WithResourceType(resourceType)) } if ctx.responseStatusCode >= 400 { - opts = append(opts, WithReason(fmt.Sprintf("HTTP %d: %s", ctx.responseStatusCode, http.StatusText(ctx.responseStatusCode)))) + opts = append(opts, logging.WithReason(fmt.Sprintf("HTTP %d: %s", ctx.responseStatusCode, http.StatusText(ctx.responseStatusCode)))) ctx.auditLogger.APIAction(action+" (failed)", opts...) } if ctx.responseStatusCode < 400 && (r.Method == http.MethodGet || r.Method == http.MethodHead) { @@ -181,13 +182,13 @@ func extractResourceType(path string) string { return "" } -func requirePermission(permission string, jwtSecret []byte, handler http.HandlerFunc, systemLogger *zap.Logger, auditLogger *AuditLogger) http.HandlerFunc { +func requirePermission(permission string, jwtSecret []byte, handler http.HandlerFunc, systemLogger *zap.Logger, auditLogger *logging.AuditLogger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), jwtSecret) if err != nil { auditLogger.UnauthorizedAccess( - WithRequest(r), - WithReason("invalid or missing JWT token"), + logging.WithRequest(r), + logging.WithReason("invalid or missing JWT token"), ) writeError(w, http.StatusUnauthorized, "Unauthorized", err, systemLogger) return @@ -197,9 +198,9 @@ func requirePermission(permission string, jwtSecret []byte, handler http.Handler permissions, ok := PermissionsByRole[roleID] if !ok { auditLogger.UnauthorizedAccess( - WithActor(claims.Email), - WithRequest(r), - WithReason("unknown role"), + logging.WithActor(claims.Email), + logging.WithRequest(r), + logging.WithReason("unknown role"), ) writeError(w, http.StatusForbidden, "forbidden: unknown role", errors.New("role not found"), systemLogger) return @@ -207,8 +208,8 @@ func requirePermission(permission string, jwtSecret []byte, handler http.Handler if !hasPermission(permissions, permission) { auditLogger.AccessDenied(claims.Email, r.URL.Path, permission, - WithRequest(r), - WithReason("insufficient permissions"), + logging.WithRequest(r), + logging.WithReason("insufficient permissions"), ) writeError(w, http.StatusForbidden, "forbidden: insufficient permissions", errors.New("missing permission"), systemLogger) return @@ -227,7 +228,7 @@ func hasPermission(userPermissions []string, required string) bool { return false } -func requirePermissionOrFirstUser(permission string, jwtSecret []byte, db *db.Database, handler http.HandlerFunc, systemLogger *zap.Logger, auditLogger *AuditLogger) http.HandlerFunc { +func requirePermissionOrFirstUser(permission string, jwtSecret []byte, db *db.Database, handler http.HandlerFunc, systemLogger *zap.Logger, auditLogger *logging.AuditLogger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { numUsers, err := db.NumUsers() if err != nil { @@ -243,8 +244,8 @@ func requirePermissionOrFirstUser(permission string, jwtSecret []byte, db *db.Da claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), jwtSecret) if err != nil { auditLogger.UnauthorizedAccess( - WithRequest(r), - WithReason("invalid or missing JWT token"), + logging.WithRequest(r), + logging.WithReason("invalid or missing JWT token"), ) writeError(w, http.StatusUnauthorized, "Unauthorized", err, systemLogger) return @@ -253,8 +254,8 @@ func requirePermissionOrFirstUser(permission string, jwtSecret []byte, db *db.Da permissions, ok := PermissionsByRole[claims.RoleID] if !ok || !hasPermission(permissions, permission) { auditLogger.AccessDenied(claims.Email, r.URL.Path, permission, - WithRequest(r), - WithReason("insufficient permissions"), + logging.WithRequest(r), + logging.WithReason("insufficient permissions"), ) writeError(w, http.StatusForbidden, "forbidden: insufficient permissions", errors.New("missing required permission"), systemLogger) return diff --git a/internal/server/server.go b/internal/server/server.go index aa0151bc..c7e9e44a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -9,6 +9,7 @@ import ( "github.com/canonical/notary/internal/config" "github.com/canonical/notary/internal/db" + "github.com/canonical/notary/internal/logging" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) @@ -16,7 +17,7 @@ import ( type HandlerConfig struct { DB *db.Database SystemLogger *zap.Logger - AuditLogger *AuditLogger + AuditLogger *logging.AuditLogger ExternalHostname string JWTSecret []byte SendPebbleNotifications bool @@ -39,7 +40,7 @@ func New(opts *ServerOpts) (*Server, error) { cfg.JWTSecret = opts.Database.JWTSecret cfg.ExternalHostname = opts.ExternalHostname cfg.SystemLogger = opts.SystemLogger - cfg.AuditLogger = NewAuditLogger(opts.AuditLogger) + cfg.AuditLogger = logging.NewAuditLogger(opts.AuditLogger) cfg.PublicConfig = *opts.PublicConfig cfg.DB = opts.Database From 3d1eb02dcc8995501ae6cf3dd9c6f30611eaf567 Mon Sep 17 00:00:00 2001 From: yazansalti Date: Mon, 13 Oct 2025 18:58:34 +0200 Subject: [PATCH 3/4] Formatting --- cmd/migrate.go | 2 +- cmd/start.go | 8 +- internal/config/config_test.go | 10 +- internal/db/db_init.go | 2 +- internal/db/types.go | 12 +- internal/logging/audit_logger.go | 342 +++++++++--------- internal/logging/audit_options.go | 3 +- internal/server/audit_middleware_test.go | 239 ++++++------ internal/server/authorization_test.go | 90 +++-- internal/server/handlers_accounts.go | 6 +- internal/server/handlers_accounts_test.go | 108 +++--- .../handlers_certificate_authorities.go | 24 +- .../handlers_certificate_authorities_test.go | 170 +++++---- .../server/handlers_certificate_requests.go | 36 +- .../handlers_certificate_requests_test.go | 184 ++++++---- internal/server/handlers_config_test.go | 32 +- internal/server/handlers_login.go | 2 +- internal/server/handlers_login_test.go | 110 +++--- internal/server/handlers_status_test.go | 2 +- internal/server/middleware.go | 35 +- internal/server/server.go | 2 +- internal/server/server_test.go | 6 +- internal/server/types.go | 2 +- internal/server/utils.go | 2 +- internal/testutils/db_test_utils.go | 2 +- internal/testutils/server_test_utils.go | 26 +- 26 files changed, 764 insertions(+), 693 deletions(-) diff --git a/cmd/migrate.go b/cmd/migrate.go index e3def59f..7ba79444 100644 --- a/cmd/migrate.go +++ b/cmd/migrate.go @@ -121,7 +121,7 @@ func init() { migrateUpCmd.Flags().StringVarP(&dsn, "database-path", "d", "./notary.db", "A DSN for connecting to the database. Also accepts a path to a file, and will assume that the database is SQLite.") migrateDownCmd.Flags().StringVarP(&dsn, "database-path", "d", "./notary.db", "A DSN for connecting to the database. Also accepts a path to a file, and will assume that the database is SQLite.") migrateStatusCmd.Flags().StringVarP(&dsn, "database-path", "d", "./notary.db", "A DSN for connecting to the database. Also accepts a path to a file, and will assume that the database is SQLite.") - + if err := migrateUpCmd.MarkFlagRequired("database-path"); err != nil { log.Fatalf("Error marking database-path flag as required: %v", err) } diff --git a/cmd/start.go b/cmd/start.go index 104dddbd..4057f72f 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -32,10 +32,10 @@ https://canonical-notary.readthedocs-hosted.com/en/latest/reference/config_file/ // Initialize the database connection db, err := db.NewDatabase(&db.DatabaseOpts{ - DatabasePath: appContext.DBPath, + DatabasePath: appContext.DBPath, ApplyMigrations: appContext.ApplyMigrations, - Backend: appContext.EncryptionBackend, - Logger: appContext.SystemLogger, + Backend: appContext.EncryptionBackend, + Logger: appContext.SystemLogger, }) if err != nil { l.Fatal("couldn't initialize database", zap.Error(err)) @@ -83,7 +83,7 @@ func init() { startCmd.Flags().StringVarP(&configFilePath, "config", "c", "", "path to the configuration file") startCmd.Flags().BoolP("migrate-database", "m", false, "automatically apply database migrations if needed (use with caution)") - + err := startCmd.MarkFlagRequired("config") if err != nil { log.Fatalf("couldn't mark flag required: %s", err) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 795f6f81..2e43db33 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -34,8 +34,8 @@ func TestValidConfig(t *testing.T) { DBPath: "./notary.db", Port: 8000, PebbleNotificationsEnabled: false, - SystemLogger: nil, - AuditLogger: nil, + SystemLogger: nil, + AuditLogger: nil, EncryptionBackend: encryption_backend.NoEncryptionBackend{}, EncryptionBackendType: config.EncryptionBackendTypeNone, }}, // This case tests the expected default values for missing fields are filled correctly @@ -52,8 +52,8 @@ func TestValidConfig(t *testing.T) { DBPath: "./notary.db", Port: 8000, PebbleNotificationsEnabled: false, - SystemLogger: nil, - AuditLogger: nil, + SystemLogger: nil, + AuditLogger: nil, EncryptionBackend: encryption_backend.NoEncryptionBackend{}, EncryptionBackendType: config.EncryptionBackendTypeNone, }}, // This case tests that the variables from the yaml are correctly copied to the final config @@ -69,7 +69,7 @@ func TestValidConfig(t *testing.T) { t.Errorf("ValidateConfig(%q) = %v, want nil", "config.yaml", err) return } - if !cmp.Equal(gotCfg, tc.wantCfg, cmpopts.IgnoreFields(config.NotaryAppContext{}, "SystemLogger", "AuditLogger")) { + if !cmp.Equal(gotCfg, tc.wantCfg, cmpopts.IgnoreFields(config.NotaryAppContext{}, "SystemLogger", "AuditLogger")) { t.Errorf("ValidateConfig returned unexpected diff (-want+got):\n%v", cmp.Diff(tc.wantCfg, gotCfg)) } }) diff --git a/internal/db/db_init.go b/internal/db/db_init.go index d1354688..86d97967 100644 --- a/internal/db/db_init.go +++ b/internal/db/db_init.go @@ -46,7 +46,7 @@ func NewDatabase(dbOpts *DatabaseOpts) (*Database, error) { if err != nil { return nil, err } - if version < 1 { + if version < 1 { if dbOpts.ApplyMigrations { goose.SetBaseFS(migrations.EmbedMigrations) if err := goose.Up(sqlConnection, ".", goose.WithNoColor(true)); err != nil { diff --git a/internal/db/types.go b/internal/db/types.go index 2b8339fb..cae9335c 100644 --- a/internal/db/types.go +++ b/internal/db/types.go @@ -7,19 +7,19 @@ import ( ) type DatabaseOpts struct { - DatabasePath string + DatabasePath string ApplyMigrations bool - Backend encryption_backend.EncryptionBackend - Logger *zap.Logger + Backend encryption_backend.EncryptionBackend + Logger *zap.Logger } // Database is the object used to communicate with the established repository. type Database struct { - Conn *sqlair.DB - stmts *Statements + Conn *sqlair.DB + stmts *Statements EncryptionKey []byte - JWTSecret []byte + JWTSecret []byte } const CAMaxExpiryYears = 1 diff --git a/internal/logging/audit_logger.go b/internal/logging/audit_logger.go index db6188f9..a116a726 100644 --- a/internal/logging/audit_logger.go +++ b/internal/logging/audit_logger.go @@ -27,11 +27,11 @@ func (a *AuditLogger) LoginSuccess(username string, opts ...AuditOption) { opt(ctx) } - fields := []zap.Field{ - zap.String("type", "security"), - zap.String("event", fmt.Sprintf("authn_login_success:%s", username)), - zap.String("username", username), - } + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", fmt.Sprintf("authn_login_success:%s", username)), + zap.String("username", username), + } fields = append(fields, ctx.toZapFields()...) a.logger.Info(fmt.Sprintf("User %s login successfully", username), fields...) @@ -44,11 +44,11 @@ func (a *AuditLogger) LoginFailed(username string, opts ...AuditOption) { opt(ctx) } - fields := []zap.Field{ - zap.String("type", "security"), - zap.String("event", fmt.Sprintf("authn_login_fail:%s", username)), - zap.String("username", username), - } + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", fmt.Sprintf("authn_login_fail:%s", username)), + zap.String("username", username), + } fields = append(fields, ctx.toZapFields()...) a.logger.Warn(fmt.Sprintf("User %s login failed", username), fields...) @@ -61,11 +61,11 @@ func (a *AuditLogger) TokenCreated(username string, opts ...AuditOption) { opt(ctx) } - fields := []zap.Field{ - zap.String("type", "security"), - zap.String("event", fmt.Sprintf("authn_token_created:%s", username)), - zap.String("username", username), - } + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", fmt.Sprintf("authn_token_created:%s", username)), + zap.String("username", username), + } fields = append(fields, ctx.toZapFields()...) a.logger.Info(fmt.Sprintf("A token has been created for %s", username), fields...) @@ -78,11 +78,11 @@ func (a *AuditLogger) PasswordChanged(username string, opts ...AuditOption) { opt(ctx) } - fields := []zap.Field{ - zap.String("type", "security"), - zap.String("event", fmt.Sprintf("authn_password_change:%s", username)), - zap.String("username", username), - } + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", fmt.Sprintf("authn_password_change:%s", username)), + zap.String("username", username), + } fields = append(fields, ctx.toZapFields()...) a.logger.Info(fmt.Sprintf("User %s has successfully changed their password", username), fields...) @@ -95,11 +95,11 @@ func (a *AuditLogger) PasswordChangeFailed(username string, opts ...AuditOption) opt(ctx) } - fields := []zap.Field{ - zap.String("type", "security"), - zap.String("event", fmt.Sprintf("authn_password_change_fail:%s", username)), - zap.String("username", username), - } + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", fmt.Sprintf("authn_password_change_fail:%s", username)), + zap.String("username", username), + } fields = append(fields, ctx.toZapFields()...) a.logger.Error(fmt.Sprintf("User %s failed to change their password", username), fields...) @@ -109,17 +109,17 @@ func (a *AuditLogger) PasswordChangeFailed(username string, opts ...AuditOption) // CertificateRequested logs when a certificate signing request is created. func (a *AuditLogger) CertificateRequested(csrID string, caID int, opts ...AuditOption) { - ctx := &auditContext{severity: SeverityInfo} + ctx := &auditContext{severity: SeverityInfo} for _, opt := range opts { opt(ctx) } - fields := []zap.Field{ - zap.String("type", "security"), - zap.String("event", "cert_requested"), - zap.String("csr_id", csrID), - zap.Int("ca_id", caID), - } + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "cert_requested"), + zap.String("csr_id", csrID), + zap.Int("ca_id", caID), + } fields = append(fields, ctx.toZapFields()...) a.logger.Info("Certificate signing request created", fields...) @@ -127,17 +127,17 @@ func (a *AuditLogger) CertificateRequested(csrID string, caID int, opts ...Audit // CertificateIssued logs when a certificate is successfully issued. func (a *AuditLogger) CertificateIssued(csrID string, caID int, opts ...AuditOption) { - ctx := &auditContext{severity: SeverityInfo} + ctx := &auditContext{severity: SeverityInfo} for _, opt := range opts { opt(ctx) } - fields := []zap.Field{ - zap.String("type", "security"), - zap.String("event", "cert_issued"), - zap.String("csr_id", csrID), - zap.Int("ca_id", caID), - } + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "cert_issued"), + zap.String("csr_id", csrID), + zap.Int("ca_id", caID), + } fields = append(fields, ctx.toZapFields()...) a.logger.Info("Certificate issued", fields...) @@ -145,17 +145,17 @@ func (a *AuditLogger) CertificateIssued(csrID string, caID int, opts ...AuditOpt // CertificateRejected logs when a certificate request is rejected. func (a *AuditLogger) CertificateRejected(csrID string, caID int, opts ...AuditOption) { - ctx := &auditContext{severity: SeverityWarn} + ctx := &auditContext{severity: SeverityWarn} for _, opt := range opts { opt(ctx) } - fields := []zap.Field{ - zap.String("type", "security"), - zap.String("event", "cert_rejected"), - zap.String("csr_id", csrID), - zap.Int("ca_id", caID), - } + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "cert_rejected"), + zap.String("csr_id", csrID), + zap.Int("ca_id", caID), + } fields = append(fields, ctx.toZapFields()...) a.logger.Warn("Certificate request rejected", fields...) @@ -165,17 +165,17 @@ func (a *AuditLogger) CertificateRejected(csrID string, caID int, opts ...AuditO // CACreated logs when a new certificate authority is created. func (a *AuditLogger) CACreated(caID int, commonName string, opts ...AuditOption) { - ctx := &auditContext{severity: SeverityInfo} + ctx := &auditContext{severity: SeverityInfo} for _, opt := range opts { opt(ctx) } - fields := []zap.Field{ - zap.String("type", "security"), - zap.String("event", "ca_created"), - zap.Int("ca_id", caID), - zap.String("common_name", commonName), - } + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "ca_created"), + zap.Int("ca_id", caID), + zap.String("common_name", commonName), + } fields = append(fields, ctx.toZapFields()...) a.logger.Info("Certificate Authority created", fields...) @@ -183,17 +183,17 @@ func (a *AuditLogger) CACreated(caID int, commonName string, opts ...AuditOption // CADeleted logs when a certificate authority is deleted. func (a *AuditLogger) CADeleted(caID int, commonName string, opts ...AuditOption) { - ctx := &auditContext{severity: SeverityWarn} + ctx := &auditContext{severity: SeverityWarn} for _, opt := range opts { opt(ctx) } - fields := []zap.Field{ - zap.String("type", "security"), - zap.String("event", "ca_deleted"), - zap.Int("ca_id", caID), - zap.String("common_name", commonName), - } + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "ca_deleted"), + zap.Int("ca_id", caID), + zap.String("common_name", commonName), + } fields = append(fields, ctx.toZapFields()...) a.logger.Warn("Certificate Authority deleted", fields...) @@ -201,59 +201,59 @@ func (a *AuditLogger) CADeleted(caID int, commonName string, opts ...AuditOption // CAUpdated logs when a certificate authority enabled status is changed. func (a *AuditLogger) CAUpdated(caID string, enabled bool, opts ...AuditOption) { - ctx := &auditContext{severity: SeverityWarn} - for _, opt := range opts { - opt(ctx) - } - - status := "disabled" - if enabled { - status = "enabled" - } - - fields := []zap.Field{ - zap.String("type", "security"), - zap.String("event", "ca_updated"), - zap.String("ca_id", caID), - zap.String("status", status), - } - fields = append(fields, ctx.toZapFields()...) - - a.logger.Warn("Certificate Authority updated", fields...) + ctx := &auditContext{severity: SeverityWarn} + for _, opt := range opts { + opt(ctx) + } + + status := "disabled" + if enabled { + status = "enabled" + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "ca_updated"), + zap.String("ca_id", caID), + zap.String("status", status), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Warn("Certificate Authority updated", fields...) } // CACertificateUploaded logs when a CA certificate chain is uploaded. func (a *AuditLogger) CACertificateUploaded(caID string, opts ...AuditOption) { - ctx := &auditContext{severity: SeverityInfo} - for _, opt := range opts { - opt(ctx) - } - - fields := []zap.Field{ - zap.String("type", "security"), - zap.String("event", "ca_cert_uploaded"), - zap.String("ca_id", caID), - } - fields = append(fields, ctx.toZapFields()...) - - a.logger.Info("Certificate uploaded to Certificate Authority", fields...) + ctx := &auditContext{severity: SeverityInfo} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "ca_cert_uploaded"), + zap.String("ca_id", caID), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Info("Certificate uploaded to Certificate Authority", fields...) } // CACertificateRevoked logs when a CA certificate is revoked. func (a *AuditLogger) CACertificateRevoked(caID string, opts ...AuditOption) { - ctx := &auditContext{severity: SeverityWarn} - for _, opt := range opts { - opt(ctx) - } - - fields := []zap.Field{ - zap.String("type", "security"), - zap.String("event", "ca_cert_revoked"), - zap.String("ca_id", caID), - } - fields = append(fields, ctx.toZapFields()...) - - a.logger.Warn("Certificate Authority certificate revoked", fields...) + ctx := &auditContext{severity: SeverityWarn} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "ca_cert_revoked"), + zap.String("ca_id", caID), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Warn("Certificate Authority certificate revoked", fields...) } // User Management Events @@ -272,13 +272,13 @@ func (a *AuditLogger) UserCreated(username string, roleID int, opts ...AuditOpti roleName = "user" } - fields := []zap.Field{ - zap.String("type", "security"), - zap.String("event", fmt.Sprintf("user_created:%s,%s", username, roleName)), - zap.String("username", username), - zap.Int("role_id", roleID), - zap.String("role_name", roleName), - } + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", fmt.Sprintf("user_created:%s,%s", username, roleName)), + zap.String("username", username), + zap.Int("role_id", roleID), + zap.String("role_name", roleName), + } fields = append(fields, ctx.toZapFields()...) description := fmt.Sprintf("User account %s created with role %s", username, roleName) @@ -290,22 +290,21 @@ func (a *AuditLogger) UserCreated(username string, roleID int, opts ...AuditOpti // UserDeleted logs when a user account is deleted. func (a *AuditLogger) UserDeleted(username string, opts ...AuditOption) { - ctx := &auditContext{severity: SeverityWarn} + ctx := &auditContext{severity: SeverityWarn} for _, opt := range opts { opt(ctx) } - fields := []zap.Field{ - zap.String("type", "security"), - zap.String("event", "user_deleted"), - zap.String("username", username), - } + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "user_deleted"), + zap.String("username", username), + } fields = append(fields, ctx.toZapFields()...) a.logger.Warn("User account deleted", fields...) } - // UserUpdated logs when a user account is updated (e.g., password changed). func (a *AuditLogger) UserUpdated(username, updateType string, opts ...AuditOption) { ctx := &auditContext{severity: SeverityWarn} @@ -313,12 +312,12 @@ func (a *AuditLogger) UserUpdated(username, updateType string, opts ...AuditOpti opt(ctx) } - fields := []zap.Field{ - zap.String("type", "security"), - zap.String("event", fmt.Sprintf("user_updated:%s,%s", username, updateType)), - zap.String("username", username), - zap.String("update_type", updateType), - } + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", fmt.Sprintf("user_updated:%s,%s", username, updateType)), + zap.String("username", username), + zap.String("update_type", updateType), + } fields = append(fields, ctx.toZapFields()...) description := fmt.Sprintf("User %s updated with %s", username, updateType) @@ -332,37 +331,37 @@ func (a *AuditLogger) UserUpdated(username, updateType string, opts ...AuditOpti // AccessDenied logs when a user is denied access to a resource. func (a *AuditLogger) AccessDenied(username, resource, action string, opts ...AuditOption) { - ctx := &auditContext{severity: SeverityCritical} + ctx := &auditContext{severity: SeverityCritical} for _, opt := range opts { opt(ctx) } - fields := []zap.Field{ - zap.String("type", "security"), - zap.String("event", fmt.Sprintf("authz_fail:%s,%s", username, resource)), - zap.String("username", username), - zap.String("resource", resource), - zap.String("action", action), - } + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", fmt.Sprintf("authz_fail:%s,%s", username, resource)), + zap.String("username", username), + zap.String("resource", resource), + zap.String("action", action), + } fields = append(fields, ctx.toZapFields()...) - a.logger.Error("Access denied", fields...) + a.logger.Error("Access denied", fields...) } // UnauthorizedAccess logs when an unauthorized access attempt is detected. func (a *AuditLogger) UnauthorizedAccess(opts ...AuditOption) { - ctx := &auditContext{severity: SeverityCritical} + ctx := &auditContext{severity: SeverityCritical} for _, opt := range opts { opt(ctx) } - fields := []zap.Field{ - zap.String("type", "security"), - zap.String("event", "authz_fail"), - } + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "authz_fail"), + } fields = append(fields, ctx.toZapFields()...) - a.logger.Error("Unauthorized access attempt", fields...) + a.logger.Error("Unauthorized access attempt", fields...) } // API Action Events @@ -374,11 +373,11 @@ func (a *AuditLogger) APIAction(action string, opts ...AuditOption) { opt(ctx) } - fields := []zap.Field{ - zap.String("type", "audit"), - zap.String("event", "api_action"), - zap.String("action", action), - } + fields := []zap.Field{ + zap.String("type", "audit"), + zap.String("event", "api_action"), + zap.String("action", action), + } fields = append(fields, ctx.toZapFields()...) a.logger.Info("API action performed", fields...) @@ -388,35 +387,34 @@ func (a *AuditLogger) APIAction(action string, opts ...AuditOption) { // CertificateRequestDeleted logs when a CSR is deleted. func (a *AuditLogger) CertificateRequestDeleted(csrID string, opts ...AuditOption) { - ctx := &auditContext{severity: SeverityWarn} - for _, opt := range opts { - opt(ctx) - } - - fields := []zap.Field{ - zap.String("type", "security"), - zap.String("event", "cert_request_deleted"), - zap.String("csr_id", csrID), - } - fields = append(fields, ctx.toZapFields()...) - - a.logger.Warn("Certificate request deleted", fields...) + ctx := &auditContext{severity: SeverityWarn} + for _, opt := range opts { + opt(ctx) + } + + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "cert_request_deleted"), + zap.String("csr_id", csrID), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Warn("Certificate request deleted", fields...) } // CertificateRevoked logs when a certificate (for a CSR) is revoked. func (a *AuditLogger) CertificateRevoked(csrID string, opts ...AuditOption) { - ctx := &auditContext{severity: SeverityWarn} - for _, opt := range opts { - opt(ctx) - } - - fields := []zap.Field{ - zap.String("type", "security"), - zap.String("event", "cert_revoked"), - zap.String("csr_id", csrID), - } - fields = append(fields, ctx.toZapFields()...) - - a.logger.Warn("Certificate revoked", fields...) -} + ctx := &auditContext{severity: SeverityWarn} + for _, opt := range opts { + opt(ctx) + } + fields := []zap.Field{ + zap.String("type", "security"), + zap.String("event", "cert_revoked"), + zap.String("csr_id", csrID), + } + fields = append(fields, ctx.toZapFields()...) + + a.logger.Warn("Certificate revoked", fields...) +} diff --git a/internal/logging/audit_options.go b/internal/logging/audit_options.go index 8a1a13c5..63970b92 100644 --- a/internal/logging/audit_options.go +++ b/internal/logging/audit_options.go @@ -65,7 +65,7 @@ func WithResourceID(id string) AuditOption { // It captures: remote IP, user agent, path, and method. Kept simple by design. func WithRequest(r *http.Request) AuditOption { return func(ctx *auditContext) { - ctx.ipAddress = r.RemoteAddr + ctx.ipAddress = r.RemoteAddr ctx.userAgent = r.UserAgent() ctx.path = r.URL.Path ctx.method = r.Method @@ -107,4 +107,3 @@ func (ctx *auditContext) toZapFields() []zap.Field { return fields } - diff --git a/internal/server/audit_middleware_test.go b/internal/server/audit_middleware_test.go index 54e8cc5f..995bd672 100644 --- a/internal/server/audit_middleware_test.go +++ b/internal/server/audit_middleware_test.go @@ -11,143 +11,132 @@ import ( "go.uber.org/zap/zaptest/observer" ) -// Helper to build a router with an observed audit logger -// Use testutils helper for observed server, keep function name for local tests -// Deprecated local helper; kept for compatibility in this file. -func newObservedRouter(t *testing.T) *observer.ObservedLogs { t.Helper(); _, logs := tu.MustPrepareServer(t); return logs } - func findStringField(entry observer.LoggedEntry, key string) string { - for _, f := range entry.Context { - if f.Key == key { - switch f.Type { - case zapcore.StringType: - return f.String - } - } - } - return "" + for _, f := range entry.Context { + if f.Key == key { + switch f.Type { + case zapcore.StringType: + return f.String + } + } + } + return "" } func TestAuditMiddleware_LogsFailureAndReason(t *testing.T) { - ts, logs := tu.MustPrepareServer(t) - // Clear any initialization noise - _ = logs.TakeAll() + ts, logs := tu.MustPrepareServer(t) + _ = logs.TakeAll() - // Unauthorized GET (no token) - req, err := http.NewRequest("GET", ts.URL+"/api/v1/certificate_requests", nil) - if err != nil { - t.Fatalf("new request: %v", err) - } - res, err := ts.Client().Do(req) - if err != nil { - t.Fatalf("do request: %v", err) - } - if res.StatusCode != http.StatusUnauthorized { - t.Fatalf("expected %d, got %d", http.StatusUnauthorized, res.StatusCode) - } + req, err := http.NewRequest("GET", ts.URL+"/api/v1/certificate_requests", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + res, err := ts.Client().Do(req) + if err != nil { + t.Fatalf("do request: %v", err) + } + if res.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected %d, got %d", http.StatusUnauthorized, res.StatusCode) + } - entries := logs.TakeAll() - var haveAuthzFail, haveAPIFailed bool - for _, e := range entries { - if e.LoggerName != "audit" { - continue - } - switch findStringField(e, "event") { - case "authz_fail": - haveAuthzFail = true - case "api_action": - if findStringField(e, "action") == "GET certificate_requests (failed)" { - haveAPIFailed = true - } - } - } - if !haveAuthzFail { - t.Fatalf("expected UnauthorizedAccess audit entry (event=authz_fail)") - } - if !haveAPIFailed { - t.Fatalf("expected APIAction failure audit entry for GET certificate_requests") - } + entries := logs.TakeAll() + var haveAuthzFail, haveAPIFailed bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + switch findStringField(e, "event") { + case "authz_fail": + haveAuthzFail = true + case "api_action": + if findStringField(e, "action") == "GET certificate_requests (failed)" { + haveAPIFailed = true + } + } + } + if !haveAuthzFail { + t.Fatalf("expected UnauthorizedAccess audit entry (event=authz_fail)") + } + if !haveAPIFailed { + t.Fatalf("expected APIAction failure audit entry for GET certificate_requests") + } } func TestAuditMiddleware_LogsSuccessfulRead(t *testing.T) { - ts, logs := tu.MustPrepareServer(t) + ts, logs := tu.MustPrepareServer(t) - // Create first user (open route: first user doesn't require token) - createBody := map[string]any{ - "email": "admin@example.com", - "password": "Admin123", - "role_id": 0, - } - payload, _ := json.Marshal(createBody) - req, err := http.NewRequest("POST", ts.URL+"/api/v1/accounts", bytes.NewReader(payload)) - if err != nil { - t.Fatalf("new request: %v", err) - } - req.Header.Set("Content-Type", "application/json") - res, err := ts.Client().Do(req) - if err != nil { - t.Fatalf("do request: %v", err) - } - if res.StatusCode != http.StatusCreated { - t.Fatalf("expected %d, got %d", http.StatusCreated, res.StatusCode) - } + createBody := map[string]any{ + "email": "admin@example.com", + "password": "Admin123", + "role_id": 0, + } + payload, _ := json.Marshal(createBody) + req, err := http.NewRequest("POST", ts.URL+"/api/v1/accounts", bytes.NewReader(payload)) + if err != nil { + t.Fatalf("new request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + res, err := ts.Client().Do(req) + if err != nil { + t.Fatalf("do request: %v", err) + } + if res.StatusCode != http.StatusCreated { + t.Fatalf("expected %d, got %d", http.StatusCreated, res.StatusCode) + } - // Login to obtain JWT - loginBody := map[string]any{ - "email": "admin@example.com", - "password": "Admin123", - } - loginPayload, _ := json.Marshal(loginBody) - req, err = http.NewRequest("POST", ts.URL+"/login", bytes.NewReader(loginPayload)) - if err != nil { - t.Fatalf("new request: %v", err) - } - req.Header.Set("Content-Type", "application/json") - res, err = ts.Client().Do(req) - if err != nil { - t.Fatalf("do request: %v", err) - } - if res.StatusCode != http.StatusOK { - t.Fatalf("expected %d, got %d", http.StatusOK, res.StatusCode) - } - var loginResp struct { - Result struct{ Token string `json:"token"` } - } - if err := json.NewDecoder(res.Body).Decode(&loginResp); err != nil { - t.Fatalf("decode login response: %v", err) - } + loginBody := map[string]any{ + "email": "admin@example.com", + "password": "Admin123", + } + loginPayload, _ := json.Marshal(loginBody) + req, err = http.NewRequest("POST", ts.URL+"/login", bytes.NewReader(loginPayload)) + if err != nil { + t.Fatalf("new request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + res, err = ts.Client().Do(req) + if err != nil { + t.Fatalf("do request: %v", err) + } + if res.StatusCode != http.StatusOK { + t.Fatalf("expected %d, got %d", http.StatusOK, res.StatusCode) + } + var loginResp struct { + Result struct { + Token string `json:"token"` + } + } + if err := json.NewDecoder(res.Body).Decode(&loginResp); err != nil { + t.Fatalf("decode login response: %v", err) + } - // Clear logs so we only capture the read success - _ = logs.TakeAll() + _ = logs.TakeAll() - // Authenticated GET (should log api_action success) - req, err = http.NewRequest("GET", ts.URL+"/api/v1/certificate_requests", nil) - if err != nil { - t.Fatalf("new request: %v", err) - } - req.Header.Set("Authorization", "Bearer "+loginResp.Result.Token) - res, err = ts.Client().Do(req) - if err != nil { - t.Fatalf("do request: %v", err) - } - if res.StatusCode != http.StatusOK { - t.Fatalf("expected %d, got %d", http.StatusOK, res.StatusCode) - } + req, err = http.NewRequest("GET", ts.URL+"/api/v1/certificate_requests", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + req.Header.Set("Authorization", "Bearer "+loginResp.Result.Token) + res, err = ts.Client().Do(req) + if err != nil { + t.Fatalf("do request: %v", err) + } + if res.StatusCode != http.StatusOK { + t.Fatalf("expected %d, got %d", http.StatusOK, res.StatusCode) + } - entries := logs.TakeAll() - var haveAPISuccess bool - for _, e := range entries { - if e.LoggerName != "audit" { - continue - } - if findStringField(e, "event") == "api_action" && findStringField(e, "action") == "GET certificate_requests" { - haveAPISuccess = true - break - } - } - if !haveAPISuccess { - t.Fatalf("expected APIAction success audit entry for GET certificate_requests") - } + entries := logs.TakeAll() + var haveAPISuccess bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + if findStringField(e, "event") == "api_action" && findStringField(e, "action") == "GET certificate_requests" { + haveAPISuccess = true + break + } + } + if !haveAPISuccess { + t.Fatalf("expected APIAction success audit entry for GET certificate_requests") + } } - - diff --git a/internal/server/authorization_test.go b/internal/server/authorization_test.go index 7e1ba55b..26d5ebad 100644 --- a/internal/server/authorization_test.go +++ b/internal/server/authorization_test.go @@ -10,7 +10,7 @@ import ( ) func TestAuthorizationNoAuth(t *testing.T) { - ts, _ := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) client := ts.Client() testCases := []struct { @@ -91,7 +91,7 @@ func TestAuthorizationAdminAuthorized(t *testing.T) { } func TestAuthorizationAdminUnAuthorized(t *testing.T) { -ts, _ := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") nonAdminToken := tu.MustPrepareAccount(t, ts, "whatever@canonical.com", tu.RoleCertificateManager, adminToken) client := ts.Client() @@ -111,7 +111,7 @@ ts, _ := tu.MustPrepareServer(t) } func TestAuthorizationCertificateManagerAuthorized(t *testing.T) { -ts, logs := tu.MustPrepareServer(t) + ts, logs := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") certManagerToken := tu.MustPrepareAccount(t, ts, "testuser@canonical.com", tu.RoleCertificateManager, adminToken) client := ts.Client() @@ -165,9 +165,11 @@ ts, logs := tu.MustPrepareServer(t) status: http.StatusAccepted, }, } - for _, tC := range testCases { + for _, tC := range testCases { t.Run(tC.desc, func(t *testing.T) { - if tC.desc == "certificate manager can change self password with /me" { _ = logs.TakeAll() } + if tC.desc == "certificate manager can change self password with /me" { + _ = logs.TakeAll() + } req, err := http.NewRequest(tC.method, ts.URL+tC.path, strings.NewReader(tC.data)) if err != nil { t.Fatal(err) @@ -180,21 +182,27 @@ ts, logs := tu.MustPrepareServer(t) if res.StatusCode != tC.status { t.Errorf("expected status code %d, got %d", tC.status, res.StatusCode) } - if tC.desc == "certificate manager can change self password with /me" { - entries := logs.TakeAll() - var havePwdChanged, haveUserUpdated bool - for _, e := range entries { - if e.LoggerName != "audit" { continue } - switch findStringField(e, "event") { - case "authn_password_change:testuser@canonical.com": - havePwdChanged = true - case "user_updated:testuser@canonical.com,password_change": - haveUserUpdated = true - } - } - if !havePwdChanged { t.Errorf("expected PasswordChanged audit entry for self change") } - if !haveUserUpdated { t.Errorf("expected UserUpdated audit entry for self change") } - } + if tC.desc == "certificate manager can change self password with /me" { + entries := logs.TakeAll() + var havePwdChanged, haveUserUpdated bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + switch findStringField(e, "event") { + case "authn_password_change:testuser@canonical.com": + havePwdChanged = true + case "user_updated:testuser@canonical.com,password_change": + haveUserUpdated = true + } + } + if !havePwdChanged { + t.Errorf("expected PasswordChanged audit entry for self change") + } + if !haveUserUpdated { + t.Errorf("expected UserUpdated audit entry for self change") + } + } }) } } @@ -234,9 +242,9 @@ func TestAuthorizationCertificateManagerUnauthorized(t *testing.T) { status: http.StatusForbidden, }, } - for _, tC := range testCases { + for _, tC := range testCases { t.Run(tC.desc, func(t *testing.T) { - _ = logs.TakeAll() + _ = logs.TakeAll() req, err := http.NewRequest(tC.method, ts.URL+tC.path, strings.NewReader(tC.data)) if err != nil { t.Fatal(err) @@ -249,26 +257,28 @@ func TestAuthorizationCertificateManagerUnauthorized(t *testing.T) { if res.StatusCode != tC.status { t.Errorf("expected status code %d, got %d", tC.status, res.StatusCode) } - if tC.status == http.StatusForbidden { - entries := logs.TakeAll() - var haveAuthzFail bool - for _, e := range entries { - if e.LoggerName != "audit" { continue } - if strings.HasPrefix(findStringField(e, "event"), "authz_fail:") { - haveAuthzFail = true - break - } - } - if !haveAuthzFail { - t.Errorf("expected audit authz_fail for %s %s", tC.method, tC.path) - } - } + if tC.status == http.StatusForbidden { + entries := logs.TakeAll() + var haveAuthzFail bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + if strings.HasPrefix(findStringField(e, "event"), "authz_fail:") { + haveAuthzFail = true + break + } + } + if !haveAuthzFail { + t.Errorf("expected audit authz_fail for %s %s", tC.method, tC.path) + } + } }) } } func TestAuthorizationCertificateRequestorAuthorized(t *testing.T) { -ts, _ := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") certRequestorToken := tu.MustPrepareAccount(t, ts, "testuser@canonical.com", tu.RoleCertificateRequestor, adminToken) client := ts.Client() @@ -346,7 +356,7 @@ ts, _ := tu.MustPrepareServer(t) } func TestAuthorizationCertificateRequestorUnauthorized(t *testing.T) { -ts, _ := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") certRequestorToken := tu.MustPrepareAccount(t, ts, "testuser@canonical.com", tu.RoleCertificateRequestor, adminToken) client := ts.Client() @@ -467,7 +477,7 @@ ts, _ := tu.MustPrepareServer(t) } func TestAuthorizationReadOnlyAuthorized(t *testing.T) { -ts, _ := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") readOnlyToken := tu.MustPrepareAccount(t, ts, "testuser@canonical.com", tu.RoleReadOnly, adminToken) client := ts.Client() @@ -562,7 +572,7 @@ ts, _ := tu.MustPrepareServer(t) } func TestAuthorizationReadOnlyUnauthorized(t *testing.T) { -ts, _ := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") readOnlyToken := tu.MustPrepareAccount(t, ts, "testuser@canonical.com", tu.RoleReadOnly, adminToken) client := ts.Client() diff --git a/internal/server/handlers_accounts.go b/internal/server/handlers_accounts.go index 2d89f2d9..0c03e6ef 100644 --- a/internal/server/handlers_accounts.go +++ b/internal/server/handlers_accounts.go @@ -175,13 +175,13 @@ func CreateAccount(env *HandlerConfig) http.HandlerFunc { if claimsErr == nil { actor = claims.Email } - + opts := []logging.AuditOption{logging.WithRequest(r)} if actor != "" { opts = append(opts, logging.WithActor(actor)) } env.AuditLogger.UserCreated(createAccountParams.Email, int(createAccountParams.RoleID), opts...) - + successResponse := CreateSuccessResponse{Message: "success", ID: newUserID} err = writeResponse(w, successResponse, http.StatusCreated) if err != nil { @@ -374,7 +374,7 @@ func ChangeMyPassword(env *HandlerConfig) http.HandlerFunc { env.AuditLogger.PasswordChanged(account.Email, logging.WithRequest(r)) env.AuditLogger.UserUpdated(account.Email, "password_change", logging.WithRequest(r)) - + successResponse := SuccessResponse{Message: "success"} err = writeResponse(w, successResponse, http.StatusCreated) if err != nil { diff --git a/internal/server/handlers_accounts_test.go b/internal/server/handlers_accounts_test.go index 7aeba69f..6e3c82e6 100644 --- a/internal/server/handlers_accounts_test.go +++ b/internal/server/handlers_accounts_test.go @@ -11,7 +11,7 @@ import ( // The order of the tests is important, as some tests depend on // the state of the server after previous tests. func TestAccountsEndToEnd(t *testing.T) { - ts, logs := tu.MustPrepareServer(t) + ts, logs := tu.MustPrepareServer(t) client := ts.Client() adminToken := tu.MustPrepareAccount(t, ts, "testadmin@canonical.com", tu.RoleAdmin, "") nonAdminToken := tu.MustPrepareAccount(t, ts, "whatever@canonical.com", tu.RoleCertificateManager, adminToken) @@ -51,8 +51,8 @@ func TestAccountsEndToEnd(t *testing.T) { } }) - t.Run("3. Create account", func(t *testing.T) { - _ = logs.TakeAll() + t.Run("3. Create account", func(t *testing.T) { + _ = logs.TakeAll() createAccountParams := &tu.CreateAccountParams{ Email: "nopass@canonical.com", Password: "myPassword123!", @@ -69,20 +69,20 @@ func TestAccountsEndToEnd(t *testing.T) { t.Fatalf("unexpected error :%q", response.Error) } - entries := logs.TakeAll() - var haveUserCreated bool - for _, e := range entries { - if e.LoggerName != "audit" { - continue - } - if findStringField(e, "event") == ("user_created:"+createAccountParams.Email+",admin") { - haveUserCreated = true - break - } - } - if !haveUserCreated { - t.Fatalf("expected UserCreated audit entry for %s", createAccountParams.Email) - } + entries := logs.TakeAll() + var haveUserCreated bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + if findStringField(e, "event") == ("user_created:" + createAccountParams.Email + ",admin") { + haveUserCreated = true + break + } + } + if !haveUserCreated { + t.Fatalf("expected UserCreated audit entry for %s", createAccountParams.Email) + } }) t.Run("4. Get account", func(t *testing.T) { @@ -120,8 +120,8 @@ func TestAccountsEndToEnd(t *testing.T) { } }) - t.Run("6. Change account password - success", func(t *testing.T) { - _ = logs.TakeAll() + t.Run("6. Change account password - success", func(t *testing.T) { + _ = logs.TakeAll() changeAccountPasswordParams := &tu.ChangeAccountPasswordParams{ Password: "newPassword1", } @@ -136,23 +136,25 @@ func TestAccountsEndToEnd(t *testing.T) { t.Fatalf("unexpected error :%q", response.Error) } - entries := logs.TakeAll() - var havePwdChanged, haveUserUpdated bool - for _, e := range entries { - if e.LoggerName != "audit" { continue } - switch findStringField(e, "event") { - case "authn_password_change:testadmin@canonical.com": - havePwdChanged = true - case "user_updated:testadmin@canonical.com,password_change": - haveUserUpdated = true - } - } - if !havePwdChanged { - t.Fatalf("expected PasswordChanged audit entry") - } - if !haveUserUpdated { - t.Fatalf("expected UserUpdated audit entry for password_change") - } + entries := logs.TakeAll() + var havePwdChanged, haveUserUpdated bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + switch findStringField(e, "event") { + case "authn_password_change:testadmin@canonical.com": + havePwdChanged = true + case "user_updated:testadmin@canonical.com,password_change": + haveUserUpdated = true + } + } + if !havePwdChanged { + t.Fatalf("expected PasswordChanged audit entry") + } + if !haveUserUpdated { + t.Fatalf("expected UserUpdated audit entry for password_change") + } }) t.Run("7. Change account password - no user", func(t *testing.T) { @@ -171,8 +173,8 @@ func TestAccountsEndToEnd(t *testing.T) { } }) - t.Run("8. Delete account - success", func(t *testing.T) { - _ = logs.TakeAll() + t.Run("8. Delete account - success", func(t *testing.T) { + _ = logs.TakeAll() statusCode, response, err := tu.DeleteAccount(ts.URL, client, adminToken, 2) if err != nil { t.Fatalf("couldn't delete account: %s", err) @@ -184,18 +186,20 @@ func TestAccountsEndToEnd(t *testing.T) { t.Fatalf("expected error %q, got %q", "", response.Error) } - entries := logs.TakeAll() - var haveUserDeleted bool - for _, e := range entries { - if e.LoggerName != "audit" { continue } - if findStringField(e, "event") == "user_deleted" && findStringField(e, "username") == "whatever@canonical.com" { - haveUserDeleted = true - break - } - } - if !haveUserDeleted { - t.Fatalf("expected UserDeleted audit entry for whatever@canonical.com") - } + entries := logs.TakeAll() + var haveUserDeleted bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + if findStringField(e, "event") == "user_deleted" && findStringField(e, "username") == "whatever@canonical.com" { + haveUserDeleted = true + break + } + } + if !haveUserDeleted { + t.Fatalf("expected UserDeleted audit entry for whatever@canonical.com") + } }) t.Run("9. Delete account - no user", func(t *testing.T) { @@ -235,7 +239,7 @@ func TestAccountsEndToEnd(t *testing.T) { } func TestCreateAccountInvalidInputs(t *testing.T) { - ts, _ := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) client := ts.Client() adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") @@ -341,7 +345,7 @@ func TestCreateAccountInvalidInputs(t *testing.T) { } func TestChangeAccountPasswordInvalidInputs(t *testing.T) { -ts, _ := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) client := ts.Client() adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") diff --git a/internal/server/handlers_certificate_authorities.go b/internal/server/handlers_certificate_authorities.go index 52153d2f..858a8bed 100644 --- a/internal/server/handlers_certificate_authorities.go +++ b/internal/server/handlers_certificate_authorities.go @@ -359,7 +359,7 @@ func UpdateCertificateAuthority(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusBadRequest, "Invalid JSON format", err, env.SystemLogger) return } - + claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) if headerErr != nil { writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) @@ -400,13 +400,13 @@ func DeleteCertificateAuthority(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusBadRequest, "Invalid ID", err, env.SystemLogger) return } - + claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) if headerErr != nil { writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) return } - + ca, err := env.DB.GetDenormalizedCertificateAuthority(db.ByCertificateAuthorityDenormalizedID(idNum)) if err != nil { if errors.Is(err, db.ErrNotFound) { @@ -426,12 +426,12 @@ func DeleteCertificateAuthority(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } - + env.AuditLogger.CADeleted(int(idNum), extractCommonName(ca.CertificateChain), logging.WithActor(claims.Email), logging.WithRequest(r), ) - + successResponse := SuccessResponse{Message: "success"} err = writeResponse(w, successResponse, http.StatusOK) if err != nil { @@ -461,7 +461,7 @@ func PostCertificateAuthorityCertificate(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusBadRequest, fmt.Errorf("Invalid request: %s", err).Error(), err, env.SystemLogger) return } - + claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) if headerErr != nil { writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) @@ -477,12 +477,12 @@ func PostCertificateAuthorityCertificate(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } - + env.AuditLogger.CACertificateUploaded(id, logging.WithActor(claims.Email), logging.WithRequest(r), ) - + err = writeResponse(w, SuccessResponse{Message: "success"}, http.StatusCreated) if err != nil { writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) @@ -582,13 +582,13 @@ func RevokeCertificateAuthorityCertificate(env *HandlerConfig) http.HandlerFunc writeError(w, http.StatusBadRequest, "Invalid ID", err, env.SystemLogger) return } - + claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) if headerErr != nil { writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) return } - + ca, err := env.DB.GetCertificateAuthority(db.ByCertificateAuthorityID(idNum)) if err != nil { env.SystemLogger.Info("could not get certificate authority", zap.Error(err)) @@ -609,12 +609,12 @@ func RevokeCertificateAuthorityCertificate(env *HandlerConfig) http.HandlerFunc writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } - + env.AuditLogger.CACertificateRevoked(id, logging.WithActor(claims.Email), logging.WithRequest(r), ) - + if env.SendPebbleNotifications { err := SendPebbleNotification(CertificateUpdate, idNum) if err != nil { diff --git a/internal/server/handlers_certificate_authorities_test.go b/internal/server/handlers_certificate_authorities_test.go index 21d7d0b2..c363bddd 100644 --- a/internal/server/handlers_certificate_authorities_test.go +++ b/internal/server/handlers_certificate_authorities_test.go @@ -15,7 +15,7 @@ import ( // The order of the tests is important, as some tests depend on the state of the server after previous tests. func TestSelfSignedCertificateAuthorityEndToEnd(t *testing.T) { - ts, logs := tu.MustPrepareServer(t) + ts, logs := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") client := ts.Client() @@ -35,8 +35,8 @@ func TestSelfSignedCertificateAuthorityEndToEnd(t *testing.T) { } }) - t.Run("2. Create self signed certificate authority", func(t *testing.T) { - _ = logs.TakeAll() + t.Run("2. Create self signed certificate authority", func(t *testing.T) { + _ = logs.TakeAll() createCertificatAuthorityParams := tu.CreateCertificateAuthorityParams{ SelfSigned: true, @@ -56,20 +56,24 @@ func TestSelfSignedCertificateAuthorityEndToEnd(t *testing.T) { if statusCode != http.StatusCreated { t.Fatalf("expected status %d, got %d", http.StatusCreated, statusCode) } - if createCAResponse.Error != "" { + if createCAResponse.Error != "" { t.Fatalf("expected success, got %s", createCAResponse.Error) } - entries := logs.TakeAll() - var haveCACreated bool - for _, e := range entries { - if e.LoggerName != "audit" { continue } - if findStringField(e, "event") == "ca_created" { - haveCACreated = true - break - } - } - if !haveCACreated { t.Fatalf("expected CACreated audit entry") } + entries := logs.TakeAll() + var haveCACreated bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + if findStringField(e, "event") == "ca_created" { + haveCACreated = true + break + } + } + if !haveCACreated { + t.Fatalf("expected CACreated audit entry") + } }) t.Run("3. Get all CA's - 1 should be there and enabled", func(t *testing.T) { @@ -167,8 +171,8 @@ func TestSelfSignedCertificateAuthorityEndToEnd(t *testing.T) { IntermediateCACSR = getCAResponse.Result.CSRPEM }) - t.Run("7. Sign the intermediate CA's CSR", func(t *testing.T) { - _ = logs.TakeAll() + t.Run("7. Sign the intermediate CA's CSR", func(t *testing.T) { + _ = logs.TakeAll() signedCert := tu.SignCSR(IntermediateCACSR) statusCode, uploadCertificateResponse, err := tu.UploadCertificateToCertificateAuthority(ts.URL, client, adminToken, 2, server.UploadCertificateToCertificateAuthorityParams{CertificateChain: signedCert + tu.SelfSignedCACertificate}) if err != nil { @@ -177,17 +181,24 @@ func TestSelfSignedCertificateAuthorityEndToEnd(t *testing.T) { if statusCode != http.StatusCreated { t.Fatalf("expected status %d, got %d", http.StatusCreated, statusCode) } - if uploadCertificateResponse.Error != "" { + if uploadCertificateResponse.Error != "" { t.Fatalf("expected success, got %s", uploadCertificateResponse.Error) } - entries := logs.TakeAll() - var haveCertUploaded bool - for _, e := range entries { - if e.LoggerName != "audit" { continue } - if findStringField(e, "event") == "ca_cert_uploaded" { haveCertUploaded = true; break } - } - if !haveCertUploaded { t.Fatalf("expected CACertificateUploaded audit entry") } + entries := logs.TakeAll() + var haveCertUploaded bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + if findStringField(e, "event") == "ca_cert_uploaded" { + haveCertUploaded = true + break + } + } + if !haveCertUploaded { + t.Fatalf("expected CACertificateUploaded audit entry") + } }) t.Run("8. Get all CA's - 2 should be there and both enabled", func(t *testing.T) { statusCode, listCAsResponse, err := tu.ListCertificateAuthorities(ts.URL, client, adminToken) @@ -210,8 +221,8 @@ func TestSelfSignedCertificateAuthorityEndToEnd(t *testing.T) { t.Fatalf("expected second CA to be enabled") } }) - t.Run("9. Make first CA legacy", func(t *testing.T) { - statusCode, makeLegacyResponse, err := tu.UpdateCertificateAuthority(ts.URL, client, adminToken, 1, tu.UpdateCertificateAuthorityParams{Status: "legacy"}) + t.Run("9. Make first CA legacy", func(t *testing.T) { + statusCode, makeLegacyResponse, err := tu.UpdateCertificateAuthority(ts.URL, client, adminToken, 1, tu.UpdateCertificateAuthorityParams{Status: "legacy"}) if err != nil { t.Fatal("expected no error, got: ", err) } @@ -243,22 +254,29 @@ func TestSelfSignedCertificateAuthorityEndToEnd(t *testing.T) { t.Fatalf("expected second CA to be enabled") } }) - t.Run("11. Delete first CA", func(t *testing.T) { - _ = logs.TakeAll() - statusCode, err := tu.DeleteCertificateAuthority(ts.URL, client, adminToken, 1) + t.Run("11. Delete first CA", func(t *testing.T) { + _ = logs.TakeAll() + statusCode, err := tu.DeleteCertificateAuthority(ts.URL, client, adminToken, 1) if err != nil { t.Fatal("expected no error, got: ", err) } if statusCode != http.StatusOK { t.Fatalf("expected status %d, got %d", http.StatusOK, statusCode) } - entries := logs.TakeAll() - var haveCADeleted bool - for _, e := range entries { - if e.LoggerName != "audit" { continue } - if findStringField(e, "event") == "ca_deleted" { haveCADeleted = true; break } - } - if !haveCADeleted { t.Fatalf("expected CADeleted audit entry") } + entries := logs.TakeAll() + var haveCADeleted bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + if findStringField(e, "event") == "ca_deleted" { + haveCADeleted = true + break + } + } + if !haveCADeleted { + t.Fatalf("expected CADeleted audit entry") + } }) t.Run("12. Get all CA's - 1 enabled should be there", func(t *testing.T) { statusCode, listCAsResponse, err := tu.ListCertificateAuthorities(ts.URL, client, adminToken) @@ -281,7 +299,7 @@ func TestSelfSignedCertificateAuthorityEndToEnd(t *testing.T) { } func TestCreateCertificateAuthorityInvalidInputs(t *testing.T) { - ts, _ := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") client := ts.Client() @@ -370,7 +388,7 @@ func TestCreateCertificateAuthorityInvalidInputs(t *testing.T) { } func TestUploadCertificateToCertificateAuthorityInvalidInputs(t *testing.T) { -ts, _ := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") client := ts.Client() @@ -444,7 +462,7 @@ invalid } func TestSignCertificatesEndToEnd(t *testing.T) { -ts, logs := tu.MustPrepareServer(t) + ts, logs := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") client := ts.Client() @@ -616,7 +634,7 @@ ts, logs := tu.MustPrepareServer(t) t.Fatalf("expected success, got %s", uploadCertificateResponse.Error) } }) - t.Run("9. Get all CA's - 2 should be there and both active", func(t *testing.T) { + t.Run("9. Get all CA's - 2 should be there and both active", func(t *testing.T) { statusCode, listCAsResponse, err := tu.ListCertificateAuthorities(ts.URL, client, adminToken) if err != nil { t.Fatal("expected no error, got: ", err) @@ -641,20 +659,33 @@ ts, logs := tu.MustPrepareServer(t) } }) - t.Run("10. Update CA enabled status and assert audit", func(t *testing.T) { - _ = logs.TakeAll() - statusCode, updateResp, err := tu.UpdateCertificateAuthority(ts.URL, client, adminToken, 1, tu.UpdateCertificateAuthorityParams{Status: "active"}) - if err != nil { t.Fatal(err) } - if statusCode != http.StatusOK { t.Fatalf("expected %d, got %d", http.StatusOK, statusCode) } - if updateResp.Error != "" { t.Fatalf("expected success, got %s", updateResp.Error) } - entries := logs.TakeAll() - var haveCAUpdated bool - for _, e := range entries { - if e.LoggerName != "audit" { continue } - if findStringField(e, "event") == "ca_updated" { haveCAUpdated = true; break } - } - if !haveCAUpdated { t.Fatalf("expected CAUpdated audit entry") } - }) + t.Run("10. Update CA enabled status and assert audit", func(t *testing.T) { + _ = logs.TakeAll() + statusCode, updateResp, err := tu.UpdateCertificateAuthority(ts.URL, client, adminToken, 1, tu.UpdateCertificateAuthorityParams{Status: "active"}) + if err != nil { + t.Fatal(err) + } + if statusCode != http.StatusOK { + t.Fatalf("expected %d, got %d", http.StatusOK, statusCode) + } + if updateResp.Error != "" { + t.Fatalf("expected success, got %s", updateResp.Error) + } + entries := logs.TakeAll() + var haveCAUpdated bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + if findStringField(e, "event") == "ca_updated" { + haveCAUpdated = true + break + } + } + if !haveCAUpdated { + t.Fatalf("expected CAUpdated audit entry") + } + }) t.Run("10. Create 2nd CSR's", func(t *testing.T) { createCertificateRequestRequest := tu.CreateCertificateRequestParams{CSR: tu.StrawberryCSR} statusCode, createCertResponse, err := tu.CreateCertificateRequest(ts.URL, client, adminToken, createCertificateRequestRequest) @@ -716,9 +747,9 @@ ts, logs := tu.MustPrepareServer(t) if len(listCSRsResponse.Result) != 2 { t.Fatalf("expected 2 certificates, got %d", len(listCSRsResponse.Result)) } - if listCSRsResponse.Result[0].Status != "Active" { - t.Fatalf("expected first csr to be active, got %s", listCSRsResponse.Result[0].Status) - } + if listCSRsResponse.Result[0].Status != "Active" { + t.Fatalf("expected first csr to be active, got %s", listCSRsResponse.Result[0].Status) + } if strings.Count(listCSRsResponse.Result[0].CertificateChain, "BEGIN CERTIFICATE") != 2 { t.Fatalf("expected first csr to have a chain with 2 certificates") } @@ -732,7 +763,7 @@ ts, logs := tu.MustPrepareServer(t) } func TestUnsuccessfulRequestsMadeToCACSRs(t *testing.T) { -ts, _ := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") client := ts.Client() @@ -1118,7 +1149,7 @@ func TestCertificateRevocationListsEndToEnd(t *testing.T) { }) t.Run("10. Revoke Intermediate CA", func(t *testing.T) { - _ = logs.TakeAll() + _ = logs.TakeAll() statusCode, response, err := tu.RevokeCertificateAuthority(ts.URL, client, adminToken, 2) if err != nil { t.Fatalf("expected no error, got: %s", err) @@ -1129,13 +1160,20 @@ func TestCertificateRevocationListsEndToEnd(t *testing.T) { if response.Error != "" { t.Fatalf("expected success, got %s", response.Error) } - entries := logs.TakeAll() - var haveCARevoked bool - for _, e := range entries { - if e.LoggerName != "audit" { continue } - if findStringField(e, "event") == "ca_cert_revoked" { haveCARevoked = true; break } - } - if !haveCARevoked { t.Fatalf("expected CACertificateRevoked audit entry") } + entries := logs.TakeAll() + var haveCARevoked bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + if findStringField(e, "event") == "ca_cert_revoked" { + haveCARevoked = true + break + } + } + if !haveCARevoked { + t.Fatalf("expected CACertificateRevoked audit entry") + } statusCode, cas, err := tu.ListCertificateAuthorities(ts.URL, client, adminToken) if statusCode != http.StatusOK { t.Fatalf("expected status %d, got %d", http.StatusOK, statusCode) diff --git a/internal/server/handlers_certificate_requests.go b/internal/server/handlers_certificate_requests.go index c9d5445c..b10a2b21 100644 --- a/internal/server/handlers_certificate_requests.go +++ b/internal/server/handlers_certificate_requests.go @@ -155,12 +155,12 @@ func CreateCertificateRequest(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } - + env.AuditLogger.CertificateRequested(strconv.FormatInt(newCSRID, 10), 0, logging.WithActor(claims.Email), logging.WithRequest(r), ) - + successResponse := CreateSuccessResponse{Message: "success", ID: newCSRID} err = writeResponse(w, successResponse, http.StatusCreated) if err != nil { @@ -253,13 +253,13 @@ func DeleteCertificateRequest(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } - + claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) if headerErr != nil { writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) return } - + _, err = env.DB.GetCertificateAuthority(db.ByCertificateAuthorityCSRID(idNum)) if rowFound(err) { writeError(w, http.StatusNotFound, "Not Found", fmt.Errorf("not found"), env.SystemLogger) @@ -278,12 +278,12 @@ func DeleteCertificateRequest(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } - + env.AuditLogger.CertificateRequestDeleted(id, logging.WithActor(claims.Email), logging.WithRequest(r), ) - + successResponse := SuccessResponse{Message: "success"} err = writeResponse(w, successResponse, http.StatusAccepted) if err != nil { @@ -313,13 +313,13 @@ func PostCertificateRequestCertificate(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusBadRequest, "Invalid ID", err, env.SystemLogger) return } - + claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) if headerErr != nil { writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) return } - + _, err = env.DB.GetCertificateAuthority(db.ByCertificateAuthorityCSRID(idNum)) if rowFound(err) { writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) @@ -339,12 +339,12 @@ func PostCertificateRequestCertificate(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } - + env.AuditLogger.CertificateIssued(id, 0, logging.WithActor(claims.Email), logging.WithRequest(r), ) - + if env.SendPebbleNotifications { err := SendPebbleNotification(CertificateUpdate, idNum) if err != nil { @@ -370,13 +370,13 @@ func RejectCertificateRequest(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusBadRequest, "Invalid ID", err, env.SystemLogger) return } - + claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) if headerErr != nil { writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) return } - + _, err = env.DB.GetCertificateAuthority(db.ByCertificateAuthorityCSRID(idNum)) if rowFound(err) { err = fmt.Errorf("certificate request %d not found", idNum) @@ -396,13 +396,13 @@ func RejectCertificateRequest(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } - + env.AuditLogger.CertificateRejected(id, 0, logging.WithActor(claims.Email), logging.WithRequest(r), logging.WithReason("rejected by administrator"), ) - + if env.SendPebbleNotifications { err := SendPebbleNotification(CertificateUpdate, idNum) if err != nil { @@ -472,13 +472,13 @@ func RevokeCertificate(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusBadRequest, "Invalid ID", err, env.SystemLogger) return } - + claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) if headerErr != nil { writeError(w, http.StatusUnauthorized, "Unauthorized", headerErr, env.SystemLogger) return } - + _, err = env.DB.GetCertificateAuthority(db.ByCertificateAuthorityCSRID(idNum)) if rowFound(err) { writeError(w, http.StatusNotFound, "Not Found", err, env.SystemLogger) @@ -497,12 +497,12 @@ func RevokeCertificate(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error", err, env.SystemLogger) return } - + env.AuditLogger.CertificateRevoked(id, logging.WithActor(claims.Email), logging.WithRequest(r), ) - + if env.SendPebbleNotifications { err := SendPebbleNotification(CertificateUpdate, idNum) if err != nil { diff --git a/internal/server/handlers_certificate_requests_test.go b/internal/server/handlers_certificate_requests_test.go index 6470d07f..b96dbafa 100644 --- a/internal/server/handlers_certificate_requests_test.go +++ b/internal/server/handlers_certificate_requests_test.go @@ -12,7 +12,7 @@ import ( // The order of the tests is important, as some tests depend on the // state of the server after previous tests. func TestCertificateRequestsEndToEnd(t *testing.T) { - ts, logs := tu.MustPrepareServer(t) + ts, logs := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "testadmin@canonical.com", tu.RoleAdmin, "") client := ts.Client() @@ -32,8 +32,8 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { } }) - t.Run("2. Create certificate request", func(t *testing.T) { - _ = logs.TakeAll() + t.Run("2. Create certificate request", func(t *testing.T) { + _ = logs.TakeAll() createCertificateRequestRequest := tu.CreateCertificateRequestParams{ CSR: tu.AppleCSR, @@ -48,19 +48,21 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { if createCertResponse.Error != "" { t.Fatalf("expected no error, got %s", createCertResponse.Error) } - entries := logs.TakeAll() - var haveRequested bool - for _, e := range entries { - if e.LoggerName != "audit" { continue } - if findStringField(e, "event") == "cert_requested" { - haveRequested = true - break - } - } - if !haveRequested { - t.Fatalf("expected CertificateRequested audit entry") - } - }) + entries := logs.TakeAll() + var haveRequested bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + if findStringField(e, "event") == "cert_requested" { + haveRequested = true + break + } + } + if !haveRequested { + t.Fatalf("expected CertificateRequested audit entry") + } + }) t.Run("3. List certificate requests - 1 Certificate", func(t *testing.T) { statusCode, listCertRequestsResponse, err := tu.ListCertificateRequests(ts.URL, client, adminToken) @@ -84,7 +86,7 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { } }) - t.Run("4. Get certificate request", func(t *testing.T) { + t.Run("4. Get certificate request", func(t *testing.T) { statusCode, getCertRequestResponse, err := tu.GetCertificateRequest(ts.URL, client, adminToken, 1) if err != nil { t.Fatal(err) @@ -125,7 +127,7 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { } }) - t.Run("6. List certificate requests - 1 Certificate", func(t *testing.T) { + t.Run("6. List certificate requests - 1 Certificate", func(t *testing.T) { statusCode, listCertRequestsResponse, err := tu.ListCertificateRequests(ts.URL, client, adminToken) if err != nil { t.Fatal(err) @@ -141,7 +143,7 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { } }) - t.Run("7. Create another certificate request", func(t *testing.T) { + t.Run("7. Create another certificate request", func(t *testing.T) { createCertificateRequestRequest := tu.CreateCertificateRequestParams{ CSR: tu.StrawberryCSR, } @@ -157,7 +159,7 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { } }) - t.Run("8. List certificate requests - 2 Certificates", func(t *testing.T) { + t.Run("8. List certificate requests - 2 Certificates", func(t *testing.T) { statusCode, listCertRequestsResponse, err := tu.ListCertificateRequests(ts.URL, client, adminToken) if err != nil { t.Fatal(err) @@ -173,7 +175,7 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { } }) - t.Run("9. Get certificate request 2", func(t *testing.T) { + t.Run("9. Get certificate request 2", func(t *testing.T) { statusCode, getCertRequestResponse, err := tu.GetCertificateRequest(ts.URL, client, adminToken, 2) if err != nil { t.Fatal(err) @@ -198,8 +200,8 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { } }) - t.Run("10. Delete certificate request 1", func(t *testing.T) { - _ = logs.TakeAll() + t.Run("10. Delete certificate request 1", func(t *testing.T) { + _ = logs.TakeAll() statusCode, err := tu.DeleteCertificateRequest(ts.URL, client, adminToken, 1) if err != nil { t.Fatal(err) @@ -207,18 +209,20 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { if statusCode != http.StatusAccepted { t.Fatalf("expected status %d, got %d", http.StatusAccepted, statusCode) } - entries := logs.TakeAll() - var haveDeleted bool - for _, e := range entries { - if e.LoggerName != "audit" { continue } - if findStringField(e, "event") == "cert_request_deleted" { - haveDeleted = true - break - } - } - if !haveDeleted { - t.Fatalf("expected CertificateRequestDeleted audit entry") - } + entries := logs.TakeAll() + var haveDeleted bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + if findStringField(e, "event") == "cert_request_deleted" { + haveDeleted = true + break + } + } + if !haveDeleted { + t.Fatalf("expected CertificateRequestDeleted audit entry") + } }) t.Run("11. List certificate requests - 1 Certificate", func(t *testing.T) { @@ -237,8 +241,8 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { } }) - t.Run("12. Delete certificate request 2 and assert revoke audit", func(t *testing.T) { - _ = logs.TakeAll() + t.Run("12. Delete certificate request 2 and assert revoke audit", func(t *testing.T) { + _ = logs.TakeAll() statusCode, err := tu.DeleteCertificateRequest(ts.URL, client, adminToken, 2) if err != nil { t.Fatal(err) @@ -246,19 +250,26 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { if statusCode != http.StatusAccepted { t.Fatalf("expected status %d, got %d", http.StatusAccepted, statusCode) } - entries := logs.TakeAll() - var haveCertDeleted bool - for _, e := range entries { - if e.LoggerName != "audit" { continue } - if findStringField(e, "event") == "cert_request_deleted" { haveCertDeleted = true; break } - } - if !haveCertDeleted { t.Fatalf("expected CertificateRequestDeleted audit entry") } + entries := logs.TakeAll() + var haveCertDeleted bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + if findStringField(e, "event") == "cert_request_deleted" { + haveCertDeleted = true + break + } + } + if !haveCertDeleted { + t.Fatalf("expected CertificateRequestDeleted audit entry") + } }) } // TestListCertificateRequestsRequestorRole tests that a certificate requestor can only view their own requests. func TestListCertificateRequestsRequestorRole(t *testing.T) { -ts, _ := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "testadmin@canonical.com", tu.RoleAdmin, "") client := ts.Client() @@ -327,7 +338,7 @@ ts, _ := tu.MustPrepareServer(t) // The order of the tests is important, as some tests depend on the // state of the server after previous tests. func TestCertificatesEndToEnd(t *testing.T) { -ts, logs := tu.MustPrepareServer(t) + ts, logs := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") client := ts.Client() @@ -347,8 +358,8 @@ ts, logs := tu.MustPrepareServer(t) } }) - t.Run("2. Create Certificate", func(t *testing.T) { - _ = logs.TakeAll() + t.Run("2. Create Certificate", func(t *testing.T) { + _ = logs.TakeAll() createCertificateRequest := tu.CreateCertificateParams{ Certificate: fmt.Sprintf("%s\n%s", tu.ExampleCSRCertificate, tu.ExampleCSRIssuerCertificate), } @@ -362,13 +373,20 @@ ts, logs := tu.MustPrepareServer(t) if createCertResponse.Error != "" { t.Fatalf("expected no error, got %s", createCertResponse.Error) } - entries := logs.TakeAll() - var haveIssued bool - for _, e := range entries { - if e.LoggerName != "audit" { continue } - if findStringField(e, "event") == "cert_issued" { haveIssued = true; break } - } - if !haveIssued { t.Fatalf("expected CertificateIssued audit entry") } + entries := logs.TakeAll() + var haveIssued bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + if findStringField(e, "event") == "cert_issued" { + haveIssued = true + break + } + } + if !haveIssued { + t.Fatalf("expected CertificateIssued audit entry") + } }) t.Run("3. Get Certificate", func(t *testing.T) { @@ -387,8 +405,8 @@ ts, logs := tu.MustPrepareServer(t) } }) - t.Run("4. Reject Certificate", func(t *testing.T) { - _ = logs.TakeAll() + t.Run("4. Reject Certificate", func(t *testing.T) { + _ = logs.TakeAll() statusCode, err := tu.RejectCertificate(ts.URL, client, adminToken, 1) if err != nil { t.Fatal(err) @@ -396,13 +414,20 @@ ts, logs := tu.MustPrepareServer(t) if statusCode != http.StatusAccepted { t.Fatalf("expected status %d, got %d", http.StatusAccepted, statusCode) } - entries := logs.TakeAll() - var haveRejected bool - for _, e := range entries { - if e.LoggerName != "audit" { continue } - if findStringField(e, "event") == "cert_rejected" { haveRejected = true; break } - } - if !haveRejected { t.Fatalf("expected CertificateRejected audit entry") } + entries := logs.TakeAll() + var haveRejected bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + if findStringField(e, "event") == "cert_rejected" { + haveRejected = true + break + } + } + if !haveRejected { + t.Fatalf("expected CertificateRejected audit entry") + } }) t.Run("5. Get Certificate", func(t *testing.T) { @@ -421,8 +446,8 @@ ts, logs := tu.MustPrepareServer(t) } }) - t.Run("6. Delete Certificate (revocation)", func(t *testing.T) { - _ = logs.TakeAll() + t.Run("6. Delete Certificate (revocation)", func(t *testing.T) { + _ = logs.TakeAll() statusCode, err := tu.DeleteCertificateRequest(ts.URL, client, adminToken, 1) if err != nil { t.Fatal(err) @@ -430,14 +455,21 @@ ts, logs := tu.MustPrepareServer(t) if statusCode != http.StatusAccepted { t.Fatalf("expected status %d, got %d", http.StatusAccepted, statusCode) } - entries := logs.TakeAll() - var haveDeletedOrRevoked bool - for _, e := range entries { - if e.LoggerName != "audit" { continue } - ev := findStringField(e, "event") - if ev == "cert_request_deleted" || ev == "cert_revoked" { haveDeletedOrRevoked = true; break } - } - if !haveDeletedOrRevoked { t.Fatalf("expected CertificateRequestDeleted or CertificateRevoked audit entry") } + entries := logs.TakeAll() + var haveDeletedOrRevoked bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + ev := findStringField(e, "event") + if ev == "cert_request_deleted" || ev == "cert_revoked" { + haveDeletedOrRevoked = true + break + } + } + if !haveDeletedOrRevoked { + t.Fatalf("expected CertificateRequestDeleted or CertificateRevoked audit entry") + } }) t.Run("7. Get Certificate", func(t *testing.T) { @@ -456,7 +488,7 @@ ts, logs := tu.MustPrepareServer(t) } func TestCreateCertificateRequestInvalidInputs(t *testing.T) { -ts, _ := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") client := ts.Client() @@ -511,7 +543,7 @@ MIIBVwIBADANBgkqhkiG9w0BAQEFAASCAT8wggE7AgEAAkEAuQ== } func TestCreateCertificateInvalidInputs(t *testing.T) { -ts, _ := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") client := ts.Client() diff --git a/internal/server/handlers_config_test.go b/internal/server/handlers_config_test.go index 07cbf293..27e299ae 100644 --- a/internal/server/handlers_config_test.go +++ b/internal/server/handlers_config_test.go @@ -42,7 +42,7 @@ func getConfig(url string, client *http.Client, token string) (int, *GetConfigRe } func TestConfigEndToEnd(t *testing.T) { - ts, logs := tu.MustPrepareServer(t) + ts, logs := tu.MustPrepareServer(t) adminToken := tu.MustPrepareAccount(t, ts, "admin@canonical.com", tu.RoleAdmin, "") nonAdminToken := tu.MustPrepareAccount(t, ts, "whatever@canonical.com", tu.RoleCertificateManager, adminToken) client := ts.Client() @@ -60,8 +60,8 @@ func TestConfigEndToEnd(t *testing.T) { } }) - t.Run("2. Get config - admin token", func(t *testing.T) { - _ = logs.TakeAll() + t.Run("2. Get config - admin token", func(t *testing.T) { + _ = logs.TakeAll() statusCode, response, err := getConfig(ts.URL, client, adminToken) if err != nil { t.Fatalf("couldn't get config: %s", err) @@ -86,18 +86,20 @@ func TestConfigEndToEnd(t *testing.T) { t.Fatalf("expected encryption backend type to be set, got %q", response.Result.EncryptionBackendType) } - entries := logs.TakeAll() - var haveAPISuccess bool - for _, e := range entries { - if e.LoggerName != "audit" { continue } - if findStringField(e, "event") == "api_action" && findStringField(e, "action") == "GET config" { - haveAPISuccess = true - break - } - } - if !haveAPISuccess { - t.Fatalf("expected APIAction success audit entry for GET config") - } + entries := logs.TakeAll() + var haveAPISuccess bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + if findStringField(e, "event") == "api_action" && findStringField(e, "action") == "GET config" { + haveAPISuccess = true + break + } + } + if !haveAPISuccess { + t.Fatalf("expected APIAction success audit entry for GET config") + } }) t.Run("3. Get config - non-admin token", func(t *testing.T) { diff --git a/internal/server/handlers_login.go b/internal/server/handlers_login.go index 6dcce4de..4715aa29 100644 --- a/internal/server/handlers_login.go +++ b/internal/server/handlers_login.go @@ -100,7 +100,7 @@ func Login(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "internal error", err, env.SystemLogger) return } - + env.AuditLogger.LoginSuccess(userAccount.Email, logging.WithRequest(r)) env.AuditLogger.TokenCreated(userAccount.Email, logging.WithRequest(r)) } diff --git a/internal/server/handlers_login_test.go b/internal/server/handlers_login_test.go index 7c7e7340..98b7cd3b 100644 --- a/internal/server/handlers_login_test.go +++ b/internal/server/handlers_login_test.go @@ -9,8 +9,8 @@ import ( ) func TestLoginEndToEnd(t *testing.T) { - ts, logs := tu.MustPrepareServer(t) - client := ts.Client() + ts, logs := tu.MustPrepareServer(t) + client := ts.Client() t.Run("Create admin user", func(t *testing.T) { adminUser := &tu.CreateAccountParams{ @@ -27,8 +27,8 @@ func TestLoginEndToEnd(t *testing.T) { } }) - t.Run("Login success", func(t *testing.T) { - _ = logs.TakeAll() + t.Run("Login success", func(t *testing.T) { + _ = logs.TakeAll() adminUser := &tu.LoginParams{ Email: "testadmin@canonical.com", Password: "Admin123", @@ -47,31 +47,31 @@ func TestLoginEndToEnd(t *testing.T) { if err != nil { t.Fatalf("couldn't parse token: %s", err) } - if claims, ok := token.Claims.(jwt.MapClaims); ok { + if claims, ok := token.Claims.(jwt.MapClaims); ok { if claims["email"] != "testadmin@canonical.com" { t.Fatalf("expected email %q, got %q", "testadmin@canonical.com", claims["email"]) } } - entries := logs.TakeAll() - var haveLoginSuccess, haveTokenCreated bool - for _, e := range entries { - if e.LoggerName != "audit" { - continue - } - switch findStringField(e, "event") { - case "authn_login_success:testadmin@canonical.com": - haveLoginSuccess = true - case "authn_token_created:testadmin@canonical.com": - haveTokenCreated = true - } - } - if !haveLoginSuccess { - t.Fatalf("expected LoginSuccess audit entry") - } - if !haveTokenCreated { - t.Fatalf("expected TokenCreated audit entry") - } + entries := logs.TakeAll() + var haveLoginSuccess, haveTokenCreated bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + switch findStringField(e, "event") { + case "authn_login_success:testadmin@canonical.com": + haveLoginSuccess = true + case "authn_token_created:testadmin@canonical.com": + haveTokenCreated = true + } + } + if !haveLoginSuccess { + t.Fatalf("expected LoginSuccess audit entry") + } + if !haveTokenCreated { + t.Fatalf("expected TokenCreated audit entry") + } }) t.Run("Login failure missing email", func(t *testing.T) { @@ -108,37 +108,37 @@ func TestLoginEndToEnd(t *testing.T) { } }) - t.Run("Login failure invalid password (with audit)", func(t *testing.T) { - _ = logs.TakeAll() - invalidUser := &tu.LoginParams{ - Email: "testadmin@canonical.com", - Password: "a-wrong-password", - } - statusCode, loginResponse, err := tu.Login(ts.URL, client, invalidUser) - if err != nil { - t.Fatalf("couldn't login admin user: %s", err) - } - if statusCode != http.StatusUnauthorized { - t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, statusCode) - } - if loginResponse.Error != "The email or password is incorrect" { - t.Fatalf("expected error %q, got %q", "The email or password is incorrect", loginResponse.Error) - } - entries := logs.TakeAll() - var haveLoginFailed bool - for _, e := range entries { - if e.LoggerName != "audit" { - continue - } - if findStringField(e, "event") == "authn_login_fail:testadmin@canonical.com" && findStringField(e, "reason") == "invalid credentials" { - haveLoginFailed = true - break - } - } - if !haveLoginFailed { - t.Fatalf("expected LoginFailed audit entry with reason 'invalid credentials'") - } - }) + t.Run("Login failure invalid password (with audit)", func(t *testing.T) { + _ = logs.TakeAll() + invalidUser := &tu.LoginParams{ + Email: "testadmin@canonical.com", + Password: "a-wrong-password", + } + statusCode, loginResponse, err := tu.Login(ts.URL, client, invalidUser) + if err != nil { + t.Fatalf("couldn't login admin user: %s", err) + } + if statusCode != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, statusCode) + } + if loginResponse.Error != "The email or password is incorrect" { + t.Fatalf("expected error %q, got %q", "The email or password is incorrect", loginResponse.Error) + } + entries := logs.TakeAll() + var haveLoginFailed bool + for _, e := range entries { + if e.LoggerName != "audit" { + continue + } + if findStringField(e, "event") == "authn_login_fail:testadmin@canonical.com" && findStringField(e, "reason") == "invalid credentials" { + haveLoginFailed = true + break + } + } + if !haveLoginFailed { + t.Fatalf("expected LoginFailed audit entry with reason 'invalid credentials'") + } + }) t.Run("Login failure invalid email", func(t *testing.T) { invalidUser := &tu.LoginParams{ diff --git a/internal/server/handlers_status_test.go b/internal/server/handlers_status_test.go index 8e6b68c8..c684f1e4 100644 --- a/internal/server/handlers_status_test.go +++ b/internal/server/handlers_status_test.go @@ -9,7 +9,7 @@ import ( ) func TestStatus(t *testing.T) { - ts, _ := tu.MustPrepareServer(t) + ts, _ := tu.MustPrepareServer(t) client := ts.Client() t.Run("status not initialized", func(t *testing.T) { diff --git a/internal/server/middleware.go b/internal/server/middleware.go index ea9e05eb..f4d355bc 100644 --- a/internal/server/middleware.go +++ b/internal/server/middleware.go @@ -22,7 +22,6 @@ const ( MAX_KILOBYTES = 100 ) - // The middlewareContext type helps middleware receive and pass along information through the middleware chain. type middlewareContext struct { responseStatusCode int @@ -104,7 +103,7 @@ func loggingMiddleware(ctx *middlewareContext) middleware { func auditLoggingMiddleware(ctx *middlewareContext) middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - next.ServeHTTP(w, r) + next.ServeHTTP(w, r) var actor string claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), ctx.jwtSecret) if err == nil { @@ -126,13 +125,13 @@ func auditLoggingMiddleware(ctx *middlewareContext) middleware { opts = append(opts, logging.WithResourceType(resourceType)) } - if ctx.responseStatusCode >= 400 { - opts = append(opts, logging.WithReason(fmt.Sprintf("HTTP %d: %s", ctx.responseStatusCode, http.StatusText(ctx.responseStatusCode)))) + if ctx.responseStatusCode >= 400 { + opts = append(opts, logging.WithReason(fmt.Sprintf("HTTP %d: %s", ctx.responseStatusCode, http.StatusText(ctx.responseStatusCode)))) ctx.auditLogger.APIAction(action+" (failed)", opts...) } - if ctx.responseStatusCode < 400 && (r.Method == http.MethodGet || r.Method == http.MethodHead) { - ctx.auditLogger.APIAction(action, opts...) - } + if ctx.responseStatusCode < 400 && (r.Method == http.MethodGet || r.Method == http.MethodHead) { + ctx.auditLogger.APIAction(action, opts...) + } }) } } @@ -177,7 +176,7 @@ func extractResourceType(path string) string { cleanPath := strings.Trim(path, "/") parts := strings.Split(cleanPath, "/") if len(parts) > 0 && parts[0] != "" { - return parts[0] + return parts[0] } return "" } @@ -206,11 +205,11 @@ func requirePermission(permission string, jwtSecret []byte, handler http.Handler return } - if !hasPermission(permissions, permission) { - auditLogger.AccessDenied(claims.Email, r.URL.Path, permission, - logging.WithRequest(r), - logging.WithReason("insufficient permissions"), - ) + if !hasPermission(permissions, permission) { + auditLogger.AccessDenied(claims.Email, r.URL.Path, permission, + logging.WithRequest(r), + logging.WithReason("insufficient permissions"), + ) writeError(w, http.StatusForbidden, "forbidden: insufficient permissions", errors.New("missing permission"), systemLogger) return } @@ -252,11 +251,11 @@ func requirePermissionOrFirstUser(permission string, jwtSecret []byte, db *db.Da } permissions, ok := PermissionsByRole[claims.RoleID] - if !ok || !hasPermission(permissions, permission) { - auditLogger.AccessDenied(claims.Email, r.URL.Path, permission, - logging.WithRequest(r), - logging.WithReason("insufficient permissions"), - ) + if !ok || !hasPermission(permissions, permission) { + auditLogger.AccessDenied(claims.Email, r.URL.Path, permission, + logging.WithRequest(r), + logging.WithReason("insufficient permissions"), + ) writeError(w, http.StatusForbidden, "forbidden: insufficient permissions", errors.New("missing required permission"), systemLogger) return } diff --git a/internal/server/server.go b/internal/server/server.go index c7e9e44a..3f36a8c2 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -34,7 +34,7 @@ func New(opts *ServerOpts) (*Server, error) { if err != nil { return nil, fmt.Errorf("failed to create logger for http server: %w", err) } - + cfg := &HandlerConfig{} cfg.SendPebbleNotifications = opts.EnablePebbleNotifications cfg.JWTSecret = opts.Database.JWTSecret diff --git a/internal/server/server_test.go b/internal/server/server_test.go index a38d95ee..a56150b5 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -17,15 +17,15 @@ func TestNewSuccess(t *testing.T) { if err != nil { t.Fatalf("cannot create logger: %s", err) } - s, err := server.New(&server.ServerOpts{ + s, err := server.New(&server.ServerOpts{ Port: 8000, TLSCertificate: []byte(tu.TestServerCertificate), TLSPrivateKey: []byte(tu.TestServerKey), Database: db, ExternalHostname: "example.com", EnablePebbleNotifications: false, - SystemLogger: l, - AuditLogger: l, + SystemLogger: l, + AuditLogger: l, PublicConfig: &tu.PublicConfig, }) if err != nil { diff --git a/internal/server/types.go b/internal/server/types.go index ef0f49b8..95d303db 100644 --- a/internal/server/types.go +++ b/internal/server/types.go @@ -34,4 +34,4 @@ type Server struct { type middleware func(http.Handler) http.Handler -type NotificationKey int \ No newline at end of file +type NotificationKey int diff --git a/internal/server/utils.go b/internal/server/utils.go index 43c5ce7c..cdf415a0 100644 --- a/internal/server/utils.go +++ b/internal/server/utils.go @@ -9,6 +9,7 @@ import ( "fmt" "os/exec" ) + const ( CertificateUpdate NotificationKey = 1 ) @@ -53,4 +54,3 @@ func generateSKI(priv *rsa.PrivateKey) []byte { hash := sha1.Sum(spki.SubjectPublicKey.Bytes) return hash[:] } - diff --git a/internal/testutils/db_test_utils.go b/internal/testutils/db_test_utils.go index 0ac44a71..f9940ba6 100644 --- a/internal/testutils/db_test_utils.go +++ b/internal/testutils/db_test_utils.go @@ -16,7 +16,7 @@ import ( func MustPrepareEmptyDB(t *testing.T) *db.Database { t.Helper() - tempDir := t.TempDir() + tempDir := t.TempDir() sqlConnection, err := sql.Open("sqlite3", filepath.Join(tempDir, "db.sqlite3")) if err != nil { t.Fatalf("Couldn't create temporary database: %s", err) diff --git a/internal/testutils/server_test_utils.go b/internal/testutils/server_test_utils.go index 15e80252..68961510 100644 --- a/internal/testutils/server_test_utils.go +++ b/internal/testutils/server_test_utils.go @@ -26,18 +26,18 @@ func MustPrepareServer(t *testing.T) (*httptest.Server, *observer.ObservedLogs) db := MustPrepareEmptyDB(t) // Attach observed audit logger - core, logs := observer.New(zapcore.InfoLevel) + core, logs := observer.New(zapcore.InfoLevel) auditZap := zap.New(core) - srv, err := server.New(&server.ServerOpts{ + srv, err := server.New(&server.ServerOpts{ Port: 8000, TLSCertificate: []byte(TestServerCertificate), TLSPrivateKey: []byte(TestServerKey), Database: db, ExternalHostname: "example.com", EnablePebbleNotifications: false, - SystemLogger: logger, - AuditLogger: auditZap, + SystemLogger: logger, + AuditLogger: auditZap, PublicConfig: &PublicConfig, }) if err != nil { @@ -47,7 +47,7 @@ func MustPrepareServer(t *testing.T) (*httptest.Server, *observer.ObservedLogs) t.Cleanup(func() { testServer.Close() }) - return testServer, logs + return testServer, logs } func MustPrepareAccount(t *testing.T, ts *httptest.Server, email string, roleID RoleID, token string) string { @@ -536,14 +536,14 @@ type UpdateCertificateAuthorityResponse struct { } func UpdateCertificateAuthority(url string, client *http.Client, token string, id int, status UpdateCertificateAuthorityParams) (int, *UpdateCertificateAuthorityResponse, error) { - enabled := status.Status == "active" - payload := struct{ - Enabled bool `json:"enabled"` - }{Enabled: enabled} - reqData, err := json.Marshal(payload) - if err != nil { - return 0, nil, err - } + enabled := status.Status == "active" + payload := struct { + Enabled bool `json:"enabled"` + }{Enabled: enabled} + reqData, err := json.Marshal(payload) + if err != nil { + return 0, nil, err + } req, err := http.NewRequest("PUT", url+"/api/v1/certificate_authorities/"+strconv.Itoa(id), bytes.NewReader(reqData)) if err != nil { return 0, nil, err From bd9e94e703e254353c7598a6fded5677e70e0f4b Mon Sep 17 00:00:00 2001 From: yazansalti Date: Wed, 22 Oct 2025 14:13:08 +0400 Subject: [PATCH 4/4] Address review comment, use a separate function to init the audit logger --- internal/config/config.go | 48 +++++++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index 5c4bcd98..733e39ed 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -36,15 +36,15 @@ func CreateAppContext(cmdFlags *pflag.FlagSet, configFilePath string) (*NotaryAp return nil, err } - // initialize system logger - systemLogger, err := initializeLogger(cfg.Sub("logging.system"), "") + // initialize system logger + systemLogger, err := initializeLogger(cfg.Sub("logging.system")) if err != nil { return nil, fmt.Errorf("couldn't initialize system logger: %w", err) } - // initialize audit logger - // Audit logs are always at INFO level - auditLogger, err := initializeLogger(cfg.Sub("logging.audit"), "info") + // initialize audit logger + // Audit logs are always at INFO level + auditLogger, err := initializeAuditLogger(cfg.Sub("logging.audit")) if err != nil { return nil, fmt.Errorf("couldn't initialize audit logger: %w", err) } @@ -219,23 +219,16 @@ func initializeEncryptionBackend(encryptionCfg *viper.Viper, logger *zap.Logger) } // initializeLogger creates and configures a logger based on the provided configuration. -// cfg is the logger configuration subsection (e.g., logging.system or logging.audit). -// levelOverride allows overriding the configured level (e.g., "info" for audit logs). -// If levelOverride is empty, the level from cfg is used. +// cfg is the logger configuration subsection (e.g., logging.system). // output can be "stdout", "stderr", or a file path. -func initializeLogger(cfg *viper.Viper, levelOverride string) (*zap.Logger, error) { +func initializeLogger(cfg *viper.Viper) (*zap.Logger, error) { if cfg == nil { return nil, fmt.Errorf("logger configuration is not defined") } zapConfig := zap.NewProductionConfig() - level := levelOverride - if level == "" { - level = cfg.GetString("level") - } - - logLevel, err := zapcore.ParseLevel(level) + logLevel, err := zapcore.ParseLevel(cfg.GetString("level")) if err != nil { return nil, fmt.Errorf("invalid log level: %w", err) } @@ -253,3 +246,28 @@ func initializeLogger(cfg *viper.Viper, levelOverride string) (*zap.Logger, erro return logger, nil } + +// initializeAuditLogger creates an audit logger that always logs at INFO level, regardless of config. +// cfg is the logger configuration subsection (e.g., logging.audit). +// output can be "stdout", "stderr", or a file path. +func initializeAuditLogger(cfg *viper.Viper) (*zap.Logger, error) { + if cfg == nil { + return nil, fmt.Errorf("logger configuration is not defined") + } + + zapConfig := zap.NewProductionConfig() + // Force INFO level for audit logs + zapConfig.Level.SetLevel(zapcore.InfoLevel) + + output := cfg.GetString("output") + zapConfig.OutputPaths = []string{output} + + zapConfig.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder + + logger, err := zapConfig.Build() + if err != nil { + return nil, err + } + + return logger, nil +}