Skip to content

Commit 988cd88

Browse files
authored
Fix Postgres login, improve testing suite (#545)
* remove validation input for now * aim to fix the mysql/mariadb bug where update doesn't properly create the query * added views back to pg query * fix e2e dropdown selection, add coverage to cypress, add setup/cleanup shell scripts for cypress * modify postgres password handling to use pgx config builder * tidy * add go coverage * remove the sqlite test db * add initial redis support for zset * add redis to e2e docker compose * update existing test + fix dropdown selection * add redis e2e tests * add elasticsearch to e2e docker compose * update elasticsearch init * elasticsearch e2e tests
1 parent 06a3e99 commit 988cd88

28 files changed

+3807
-396
lines changed

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,8 @@
44
DS_Store
55
.DS_Store
66
__debug*
7-
*.db
7+
*.db
8+
frontend/coverage
9+
frontend/.nyc_output
10+
tmp/
11+
core/tmp/

core/go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ require (
1212
github.com/go-redis/redis/v8 v8.11.5
1313
github.com/go-sql-driver/mysql v1.9.2
1414
github.com/google/uuid v1.6.0
15+
github.com/jackc/pgx/v5 v5.5.5
1516
github.com/pkg/errors v0.9.1
1617
github.com/sirupsen/logrus v1.9.3
1718
github.com/vektah/gqlparser/v2 v2.5.28
@@ -44,7 +45,6 @@ require (
4445
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
4546
github.com/jackc/pgpassfile v1.0.0 // indirect
4647
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
47-
github.com/jackc/pgx/v5 v5.5.5 // indirect
4848
github.com/jackc/puddle/v2 v2.2.2 // indirect
4949
github.com/jinzhu/inflection v1.0.0 // indirect
5050
github.com/jinzhu/now v1.1.5 // indirect

core/server_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright 2025 Clidey, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package main
18+
19+
import (
20+
"context"
21+
"embed"
22+
"fmt"
23+
"net/http"
24+
"os"
25+
"os/signal"
26+
"syscall"
27+
"testing"
28+
"time"
29+
30+
"github.com/clidey/whodb/core/src"
31+
"github.com/clidey/whodb/core/src/log"
32+
"github.com/clidey/whodb/core/src/router"
33+
"github.com/pkg/errors"
34+
)
35+
36+
//go:embed build/*
37+
var staticFilesTest embed.FS
38+
39+
const defaultPortTest = "8080"
40+
41+
var srv *http.Server
42+
43+
func TestMain(m *testing.M) {
44+
log.Logger.Info("Starting WhoDB in test mode (Ctrl+C to exit)...")
45+
46+
src.InitializeEngine()
47+
r := router.InitializeRouter(staticFilesTest)
48+
49+
port := os.Getenv("PORT")
50+
if port == "" {
51+
port = defaultPortTest
52+
}
53+
54+
srv = &http.Server{
55+
Addr: fmt.Sprintf(":%s", port),
56+
Handler: r,
57+
ReadHeaderTimeout: 5 * time.Second,
58+
ReadTimeout: 10 * time.Second,
59+
WriteTimeout: 1 * time.Minute,
60+
IdleTimeout: 30 * time.Second,
61+
}
62+
63+
go func() {
64+
log.Logger.Info("Server starting...")
65+
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
66+
log.Logger.Fatalf("listen: %s", err)
67+
os.Exit(1)
68+
}
69+
}()
70+
71+
log.Logger.Infof("🎉 WhoDB test server running at http://localhost:%s 🎉", port)
72+
73+
// Wait for SIGINT (Ctrl+C) or SIGTERM
74+
quit := make(chan os.Signal, 1)
75+
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
76+
<-quit
77+
log.Logger.Info("Received shutdown signal (Ctrl+C). Shutting down...")
78+
79+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
80+
defer cancel()
81+
if err := srv.Shutdown(ctx); err != nil {
82+
log.Logger.Errorf("Graceful shutdown failed: %v", err)
83+
}
84+
85+
log.Logger.Info("Test server shut down. Exiting and writing coverage.")
86+
os.Exit(m.Run())
87+
}

core/src/plugins/gorm/plugin.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ func (p *GormPlugin) applyWhereConditions(query *gorm.DB, condition *model.Where
180180
if err != nil {
181181
return nil, err
182182
}
183-
query = query.Where(fmt.Sprintf("%s = ?", condition.Atomic.Key), value)
183+
query = query.Where(fmt.Sprintf("%s = ?", p.EscapeIdentifier(condition.Atomic.Key)), value)
184184
}
185185

186186
case model.WhereConditionTypeAnd:

core/src/plugins/gorm/update.go

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package gorm_plugin
1919
import (
2020
"errors"
2121
"fmt"
22+
2223
"github.com/clidey/whodb/core/src/common"
2324
"github.com/clidey/whodb/core/src/engine"
2425
"github.com/clidey/whodb/core/src/plugins"
@@ -74,9 +75,17 @@ func (p *GormPlugin) UpdateStorageUnit(config *engine.PluginConfig, schema strin
7475

7576
var result *gorm.DB
7677
if len(conditions) == 0 {
77-
result = db.Table(tableName).Where(unchangedValues).Updates(convertedValues)
78+
if p.Type == engine.DatabaseType_MySQL || p.Type == engine.DatabaseType_MariaDB {
79+
result = p.executeUpdateWithWhereMap(db, tableName, unchangedValues, convertedValues)
80+
} else {
81+
result = db.Table(tableName).Where(unchangedValues).Updates(convertedValues)
82+
}
7883
} else {
79-
result = db.Table(tableName).Where(conditions, nil).Updates(convertedValues)
84+
if p.Type == engine.DatabaseType_MySQL || p.Type == engine.DatabaseType_MariaDB {
85+
result = p.executeUpdateWithWhereMap(db, tableName, conditions, convertedValues)
86+
} else {
87+
result = db.Table(tableName).Where(conditions, nil).Updates(convertedValues)
88+
}
8089
}
8190

8291
if result.Error != nil {
@@ -91,3 +100,17 @@ func (p *GormPlugin) UpdateStorageUnit(config *engine.PluginConfig, schema strin
91100
return true, nil
92101
})
93102
}
103+
104+
// weird bug for mysql/mariadb driver where the where clause is not properly escaped, so have to do it manually below
105+
// should be fine as it still uses the query builder
106+
// todo: need to investigate underlying driver to see what's going on
107+
func (p *GormPlugin) executeUpdateWithWhereMap(db *gorm.DB, tableName string, whereConditions map[string]interface{}, updateValues map[string]interface{}) *gorm.DB {
108+
query := db.Table(tableName)
109+
110+
for column, value := range whereConditions {
111+
escapedColumn := p.EscapeIdentifier(column)
112+
query = query.Where(fmt.Sprintf("%s = ?", escapedColumn), value)
113+
}
114+
115+
return query.Updates(updateValues)
116+
}

core/src/plugins/postgres/db.go

Lines changed: 15 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -17,122 +17,43 @@
1717
package postgres
1818

1919
import (
20-
"fmt"
21-
"net"
22-
"net/url"
23-
"regexp"
24-
"strconv"
25-
2620
"github.com/clidey/whodb/core/src/engine"
21+
"github.com/jackc/pgx/v5"
22+
"github.com/jackc/pgx/v5/stdlib"
2723
"gorm.io/driver/postgres"
2824
"gorm.io/gorm"
2925
)
3026

31-
// validateInput uses allowlist validation to ensure only safe characters are used
32-
// This prevents all forms of injection and path traversal attacks
33-
func validateInput(input, inputType string) error {
34-
if len(input) == 0 {
35-
return fmt.Errorf("%s cannot be empty", inputType)
36-
}
37-
38-
// Define allowlist patterns for different input types
39-
var allowedPattern *regexp.Regexp
40-
var maxLength int
41-
42-
switch inputType {
43-
case "database":
44-
// Database names: only alphanumeric, underscore, hyphen (no dots to prevent traversal)
45-
allowedPattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
46-
maxLength = 63 // PostgreSQL database name limit
47-
case "hostname":
48-
// Hostnames: alphanumeric, dots, hyphens (standard hostname characters)
49-
allowedPattern = regexp.MustCompile(`^[a-zA-Z0-9.-]+$`)
50-
maxLength = 253 // RFC hostname limit
51-
case "username":
52-
// Usernames: alphanumeric and underscore only
53-
allowedPattern = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
54-
maxLength = 63 // PostgreSQL username limit
55-
default:
56-
return fmt.Errorf("unknown input type: %s", inputType)
57-
}
58-
59-
// Check length
60-
if len(input) > maxLength {
61-
return fmt.Errorf("%s too long: maximum %d characters", inputType, maxLength)
62-
}
63-
64-
// Check against allowlist pattern
65-
if !allowedPattern.MatchString(input) {
66-
return fmt.Errorf("invalid %s: contains disallowed characters (only alphanumeric, underscore, hyphen allowed)", inputType)
67-
}
68-
69-
return nil
70-
}
71-
7227
func (p *PostgresPlugin) DB(config *engine.PluginConfig) (*gorm.DB, error) {
7328
connectionInput, err := p.ParseConnectionConfig(config)
7429
if err != nil {
7530
return nil, err
7631
}
7732

78-
// Validate all connection parameters using allowlist validation
79-
if err := validateInput(connectionInput.Hostname, "hostname"); err != nil {
80-
return nil, fmt.Errorf("hostname validation failed: %w", err)
81-
}
82-
83-
if err := validateInput(connectionInput.Database, "database"); err != nil {
84-
return nil, fmt.Errorf("database validation failed: %w", err)
85-
}
86-
87-
if err := validateInput(connectionInput.Username, "username"); err != nil {
88-
return nil, fmt.Errorf("username validation failed: %w", err)
33+
pgxConfig, err := pgx.ParseConfig("")
34+
if err != nil {
35+
return nil, err
8936
}
9037

91-
// Construct PostgreSQL URL securely using url.URL struct
92-
u := &url.URL{
93-
Scheme: "postgresql",
94-
User: url.UserPassword(connectionInput.Username, connectionInput.Password),
95-
Host: net.JoinHostPort(connectionInput.Hostname, strconv.Itoa(connectionInput.Port)),
96-
Path: "/" + connectionInput.Database,
97-
}
38+
pgxConfig.Host = connectionInput.Hostname
39+
pgxConfig.Port = uint16(connectionInput.Port)
40+
pgxConfig.User = connectionInput.Username
41+
pgxConfig.Password = connectionInput.Password
42+
pgxConfig.Database = connectionInput.Database
9843

99-
// Add query parameters securely
100-
q := u.Query()
101-
q.Set("sslmode", "prefer")
102-
103-
// Validate and add extra options as query parameters (allowlist approach)
10444
if connectionInput.ExtraOptions != nil {
105-
allowedOptions := map[string]bool{
106-
"sslmode": true,
107-
"sslcert": true,
108-
"sslkey": true,
109-
"sslrootcert": true,
110-
"connect_timeout": true,
111-
"application_name": true,
45+
if pgxConfig.RuntimeParams == nil {
46+
pgxConfig.RuntimeParams = make(map[string]string)
11247
}
113-
11448
for key, value := range connectionInput.ExtraOptions {
115-
// Only allow predefined safe options
116-
if !allowedOptions[key] {
117-
return nil, fmt.Errorf("extra option '%s' is not allowed for security reasons", key)
118-
}
119-
120-
// Validate option values using basic allowlist (no special characters)
121-
if !regexp.MustCompile(`^[a-zA-Z0-9._/-]+$`).MatchString(value) {
122-
return nil, fmt.Errorf("extra option value for '%s' contains invalid characters", key)
123-
}
124-
125-
q.Set(key, value)
49+
pgxConfig.RuntimeParams[key] = value
12650
}
12751
}
128-
129-
u.RawQuery = q.Encode()
130-
dsn := u.String()
13152

132-
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
53+
db, err := gorm.Open(postgres.New(postgres.Config{Conn: stdlib.OpenDB(*pgxConfig)}), &gorm.Config{})
54+
13355
if err != nil {
13456
return nil, err
13557
}
13658
return db, nil
13759
}
138-

core/src/plugins/postgres/postgres.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,10 @@ func (p *PostgresPlugin) GetTableInfoQuery() string {
8080
LEFT JOIN
8181
pg_stat_user_tables s ON t.table_name = s.relname
8282
WHERE
83-
t.table_schema = ?
84-
AND t.table_type = 'BASE TABLE';
83+
t.table_schema = ?;
8584
`
85+
86+
// AND t.table_type = 'BASE TABLE' this removes the view tables
8687
}
8788

8889
func (p *PostgresPlugin) GetTableNameAndAttributes(rows *sql.Rows, db *gorm.DB) (string, []engine.Record) {

core/src/plugins/redis/redis.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"context"
1919
"errors"
2020
"fmt"
21+
"sort"
2122
"strconv"
2223

2324
"github.com/clidey/whodb/core/graph/model"
@@ -149,6 +150,19 @@ func (p *RedisPlugin) GetStorageUnits(config *engine.PluginConfig, schema string
149150
{Key: "Type", Value: "set"},
150151
{Key: "Size", Value: fmt.Sprintf("%d", size)},
151152
}
153+
case "zset":
154+
sizeCmd := pipe.ZCard(ctx, key)
155+
if _, err := pipe.Exec(ctx); err != nil {
156+
return nil, err
157+
}
158+
size, err := sizeCmd.Result()
159+
if err != nil {
160+
return nil, err
161+
}
162+
attributes = []engine.Record{
163+
{Key: "Type", Value: "zset"},
164+
{Key: "Size", Value: fmt.Sprintf("%d", size)},
165+
}
152166
default:
153167
attributes = []engine.Record{
154168
{Key: "Type", Value: "unknown"},
@@ -206,6 +220,10 @@ func (p *RedisPlugin) GetRows(
206220
rows = append(rows, []string{field, value})
207221
}
208222
}
223+
// Sort rows by field name (first column) alphabetically
224+
sort.Slice(rows, func(i, j int) bool {
225+
return rows[i][0] < rows[j][0]
226+
})
209227
result = &engine.GetRowsResult{
210228
Columns: []engine.Column{{Name: "field", Type: "string"}, {Name: "value", Type: "string"}},
211229
Rows: rows,
@@ -239,6 +257,19 @@ func (p *RedisPlugin) GetRows(
239257
Rows: rows,
240258
DisableUpdate: true,
241259
}
260+
case "zset":
261+
zsetValues, err := client.ZRangeWithScores(ctx, storageUnit, 0, -1).Result()
262+
if err != nil {
263+
return nil, err
264+
}
265+
rows := [][]string{}
266+
for i, member := range zsetValues {
267+
rows = append(rows, []string{strconv.Itoa(i), member.Member.(string), fmt.Sprintf("%.2f", member.Score)})
268+
}
269+
result = &engine.GetRowsResult{
270+
Columns: []engine.Column{{Name: "index", Type: "string"}, {Name: "member", Type: "string"}, {Name: "score", Type: "string"}},
271+
Rows: rows,
272+
}
242273
default:
243274
return nil, errors.New("unsupported Redis data type")
244275
}

0 commit comments

Comments
 (0)