Skip to content

Commit bd682ab

Browse files
committed
modify postgres to use query string instea of parameter to remove escape logic
1 parent 8754167 commit bd682ab

File tree

1 file changed

+10
-72
lines changed
  • core/src/plugins/postgres

1 file changed

+10
-72
lines changed

core/src/plugins/postgres/db.go

Lines changed: 10 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -18,95 +18,33 @@ package postgres
1818

1919
import (
2020
"fmt"
21+
"net/url"
2122
"strings"
2223

2324
"github.com/clidey/whodb/core/src/engine"
2425
"gorm.io/driver/postgres"
2526
"gorm.io/gorm"
2627
)
2728

28-
func escapeConnectionParam(x string) string {
29-
// PostgreSQL libpq connection string escaping rules:
30-
// 1. Single quotes must be doubled: ' -> ''
31-
// 2. Backslashes must be doubled: \ -> \\
32-
// IMPORTANT: Escape single quotes first, then backslashes to avoid double-escaping
33-
x = strings.ReplaceAll(x, "'", "''")
34-
x = strings.ReplaceAll(x, "\\", "\\\\")
35-
return x
36-
}
37-
38-
func validateConnectionParam(param, paramName string) error {
39-
// Check for null bytes which can terminate the connection string
40-
if strings.Contains(param, "\x00") {
41-
return fmt.Errorf("invalid %s: contains null byte", paramName)
42-
}
43-
44-
// Check for potentially dangerous control characters
45-
for _, char := range param {
46-
if char < 32 && char != '\t' && char != '\n' && char != '\r' {
47-
return fmt.Errorf("invalid %s: contains control character", paramName)
48-
}
49-
}
50-
51-
return nil
52-
}
53-
54-
func isValidConnectionParamKey(key string) bool {
55-
// Connection parameter keys should only contain alphanumeric characters and underscores
56-
for _, char := range key {
57-
if !((char >= 'a' && char <= 'z') || (char >= 'A' && char <= 'Z') ||
58-
(char >= '0' && char <= '9') || char == '_') {
59-
return false
60-
}
61-
}
62-
return len(key) > 0
63-
}
64-
6529
func (p *PostgresPlugin) DB(config *engine.PluginConfig) (*gorm.DB, error) {
6630
connectionInput, err := p.ParseConnectionConfig(config)
6731
if err != nil {
6832
return nil, err
6933
}
7034

71-
// Validate connection parameters for security
72-
if err := validateConnectionParam(connectionInput.Hostname, "hostname"); err != nil {
73-
return nil, err
74-
}
75-
if err := validateConnectionParam(connectionInput.Username, "username"); err != nil {
76-
return nil, err
77-
}
78-
if err := validateConnectionParam(connectionInput.Password, "password"); err != nil {
79-
return nil, err
80-
}
81-
if err := validateConnectionParam(connectionInput.Database, "database"); err != nil {
82-
return nil, err
83-
}
84-
85-
host := escapeConnectionParam(connectionInput.Hostname)
86-
username := escapeConnectionParam(connectionInput.Username)
87-
password := escapeConnectionParam(connectionInput.Password)
88-
database := escapeConnectionParam(connectionInput.Database)
35+
dsn := fmt.Sprintf("postgresql://%s:%s@%s:%v/%s",
36+
url.QueryEscape(connectionInput.Username),
37+
url.QueryEscape(connectionInput.Password),
38+
url.QueryEscape(connectionInput.Hostname),
39+
connectionInput.Port,
40+
url.QueryEscape(connectionInput.Database))
8941

90-
params := strings.Builder{}
9142
if connectionInput.ExtraOptions != nil {
43+
params := url.Values{}
9244
for key, value := range connectionInput.ExtraOptions {
93-
// Validate extra option values
94-
if err := validateConnectionParam(value, fmt.Sprintf("extra option '%s'", key)); err != nil {
95-
return nil, err
96-
}
97-
// Validate key names (should only contain alphanumeric characters and underscores)
98-
if !isValidConnectionParamKey(key) {
99-
return nil, fmt.Errorf("invalid extra option key '%s': only alphanumeric characters and underscores allowed", key)
100-
}
101-
params.WriteString(fmt.Sprintf("%v='%v' ", strings.ToLower(key), escapeConnectionParam(value)))
45+
params.Add(strings.ToLower(key), value)
10246
}
103-
}
104-
105-
dsn := fmt.Sprintf("host='%v' user='%v' password='%v' dbname='%v' port='%v'",
106-
host, username, password, database, connectionInput.Port)
107-
108-
if params.Len() > 0 {
109-
dsn = fmt.Sprintf("%v %v", dsn, params.String())
47+
dsn += "?" + params.Encode()
11048
}
11149

11250
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})

0 commit comments

Comments
 (0)