@@ -20,58 +20,50 @@ import (
20
20
"fmt"
21
21
"net"
22
22
"net/url"
23
+ "regexp"
23
24
"strconv"
24
- "strings"
25
25
26
26
"github.com/clidey/whodb/core/src/engine"
27
27
"gorm.io/driver/postgres"
28
28
"gorm.io/gorm"
29
29
)
30
30
31
- // validateHostname ensures the hostname doesn't contain URL-reserved characters
32
- // that could lead to injection attacks
33
- func validateHostname (hostname string ) error {
34
- // Check for URL-reserved characters that could enable injection
35
- invalidChars := []string {"@" , "?" , "#" , "/" , "\\ " }
36
- for _ , char := range invalidChars {
37
- if strings .Contains (hostname , char ) {
38
- return fmt .Errorf ("invalid hostname: contains URL-reserved character '%s'" , char )
39
- }
40
- }
41
- return nil
42
- }
43
-
44
- // validateDatabase ensures the database name doesn't contain URL-encoded characters
45
- // or patterns that could lead to path traversal attacks
46
- func validateDatabase (database string ) error {
47
- // Check for URL-encoded forward slashes that could enable path traversal
48
- if strings .Contains (database , "%2f" ) || strings .Contains (database , "%2F" ) {
49
- return fmt .Errorf ("invalid database name: contains URL-encoded forward slash" )
31
+ // validateInput uses allowlist validation to ensure only safe characters are used
32
+ // This prevents all forms of injection and path traversal attacks
33
+ func validateInput (input , inputType string ) error {
34
+ if len (input ) == 0 {
35
+ return fmt .Errorf ("%s cannot be empty" , inputType )
50
36
}
51
37
52
- // Check for literal path traversal patterns (both Unix and Windows)
53
- // Normalize the database name to check for all path separator variations
54
- normalizedDB := strings .ReplaceAll (database , "\\ " , "/" )
55
- normalizedDB = strings .ReplaceAll (normalizedDB , "//" , "/" ) // Handle double slashes
38
+ // Define allowlist patterns for different input types
39
+ var allowedPattern * regexp.Regexp
40
+ var maxLength int
56
41
57
- if strings .Contains (normalizedDB , "../" ) || strings .Contains (normalizedDB , "./" ) ||
58
- strings .Contains (database , "../" ) || strings .Contains (database , "..\\ " ) ||
59
- strings .Contains (database , "./" ) || strings .Contains (database , ".\\ " ) ||
60
- strings .Contains (database , "..//" ) || strings .Contains (database , ".//" ) {
61
- return fmt .Errorf ("invalid database name: contains path traversal pattern" )
42
+ switch inputType {
43
+ case "database" :
44
+ // Database names: only alphanumeric, underscore, hyphen (no dots to prevent traversal)
45
+ allowedPattern = regexp .MustCompile (`^[a-zA-Z0-9_-]+$` )
46
+ maxLength = 63 // PostgreSQL database name limit
47
+ case "hostname" :
48
+ // Hostnames: alphanumeric, dots, hyphens (standard hostname characters)
49
+ allowedPattern = regexp .MustCompile (`^[a-zA-Z0-9.-]+$` )
50
+ maxLength = 253 // RFC hostname limit
51
+ case "username" :
52
+ // Usernames: alphanumeric and underscore only
53
+ allowedPattern = regexp .MustCompile (`^[a-zA-Z0-9_]+$` )
54
+ maxLength = 63 // PostgreSQL username limit
55
+ default :
56
+ return fmt .Errorf ("unknown input type: %s" , inputType )
62
57
}
63
58
64
- // Check for backticks that could enable SQL injection
65
- if strings . Contains ( database , "`" ) {
66
- return fmt .Errorf ("invalid database name: contains backtick character" )
59
+ // Check length
60
+ if len ( input ) > maxLength {
61
+ return fmt .Errorf ("%s too long: maximum %d characters" , inputType , maxLength )
67
62
}
68
63
69
- // Check for other URL-encoded characters that could be problematic
70
- problematicEncoded := []string {"%00" , "%20" , "%22" , "%27" , "%3B" , "%3C" , "%3E" }
71
- for _ , encoded := range problematicEncoded {
72
- if strings .Contains (strings .ToLower (database ), encoded ) {
73
- return fmt .Errorf ("invalid database name: contains URL-encoded character '%s'" , encoded )
74
- }
64
+ // Check against allowlist pattern
65
+ if ! allowedPattern .MatchString (input ) {
66
+ return fmt .Errorf ("invalid %s: contains disallowed characters (only alphanumeric, underscore, hyphen allowed)" , inputType )
75
67
}
76
68
77
69
return nil
@@ -83,14 +75,17 @@ func (p *PostgresPlugin) DB(config *engine.PluginConfig) (*gorm.DB, error) {
83
75
return nil , err
84
76
}
85
77
86
- // Validate hostname to prevent injection attacks
87
- if err := validateHostname (connectionInput .Hostname ); err != nil {
88
- return nil , err
78
+ // Validate all connection parameters using allowlist validation
79
+ if err := validateInput (connectionInput .Hostname , "hostname" ); err != nil {
80
+ return nil , fmt . Errorf ( "hostname validation failed: %w" , err )
89
81
}
90
82
91
- // Validate database name to prevent path traversal attacks
92
- if err := validateDatabase (connectionInput .Database ); err != nil {
93
- return nil , err
83
+ if err := validateInput (connectionInput .Database , "database" ); err != nil {
84
+ return nil , fmt .Errorf ("database validation failed: %w" , err )
85
+ }
86
+
87
+ if err := validateInput (connectionInput .Username , "username" ); err != nil {
88
+ return nil , fmt .Errorf ("username validation failed: %w" , err )
94
89
}
95
90
96
91
// Construct PostgreSQL URL securely using url.URL struct
@@ -105,9 +100,28 @@ func (p *PostgresPlugin) DB(config *engine.PluginConfig) (*gorm.DB, error) {
105
100
q := u .Query ()
106
101
q .Set ("sslmode" , "prefer" )
107
102
108
- // Add extra options as query parameters
103
+ // Validate and add extra options as query parameters (allowlist approach)
109
104
if connectionInput .ExtraOptions != nil {
105
+ allowedOptions := map [string ]bool {
106
+ "sslmode" : true ,
107
+ "sslcert" : true ,
108
+ "sslkey" : true ,
109
+ "sslrootcert" : true ,
110
+ "connect_timeout" : true ,
111
+ "application_name" : true ,
112
+ }
113
+
110
114
for key , value := range connectionInput .ExtraOptions {
115
+ // Only allow predefined safe options
116
+ if ! allowedOptions [key ] {
117
+ return nil , fmt .Errorf ("extra option '%s' is not allowed for security reasons" , key )
118
+ }
119
+
120
+ // Validate option values using basic allowlist (no special characters)
121
+ if ! regexp .MustCompile (`^[a-zA-Z0-9._/-]+$` ).MatchString (value ) {
122
+ return nil , fmt .Errorf ("extra option value for '%s' contains invalid characters" , key )
123
+ }
124
+
111
125
q .Set (key , value )
112
126
}
113
127
}
0 commit comments