Skip to content

Commit 70b714f

Browse files
authored
replace exec with execContext (#21)
2 parents 17b42d7 + b224fb7 commit 70b714f

File tree

3 files changed

+75
-12
lines changed

3 files changed

+75
-12
lines changed

postgresql/helpers.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
package postgresql
22

33
import (
4+
"context"
45
"database/sql"
56
"fmt"
67
"log"
78
"regexp"
89
"strings"
910

11+
"github.com/hashicorp/terraform-plugin-sdk/v2/diag"
1012
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
1113
"github.com/lib/pq"
1214
)
@@ -37,6 +39,23 @@ func PGResourceExistsFunc(fn func(*DBConnection, *schema.ResourceData) (bool, er
3739
}
3840
}
3941

42+
func PGResourceContextFunc(fn func(context.Context, *DBConnection, *schema.ResourceData) diag.Diagnostics) func(context.Context, *schema.ResourceData, interface{}) diag.Diagnostics {
43+
return func(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
44+
client := meta.(*Client)
45+
46+
db, err := client.Connect()
47+
if err != nil {
48+
return diag.Diagnostics{diag.Diagnostic{
49+
Severity: diag.Error,
50+
Summary: "Failled to connext",
51+
Detail: err.Error(),
52+
}}
53+
}
54+
55+
return fn(ctx, db, d)
56+
}
57+
}
58+
4059
// QueryAble is a DB connection (sql.DB/Tx)
4160
type QueryAble interface {
4261
Exec(query string, args ...interface{}) (sql.Result, error)

postgresql/resource_postgresql_script.go

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,31 @@
11
package postgresql
22

33
import (
4+
"context"
45
"crypto/sha1"
56
"encoding/hex"
67
"fmt"
78
"log"
89
"time"
910

11+
"github.com/hashicorp/terraform-plugin-sdk/v2/diag"
1012
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
1113
)
1214

1315
const (
1416
scriptCommandsAttr = "commands"
1517
scriptTriesAttr = "tries"
1618
scriptBackoffDelayAttr = "backoff_delay"
19+
scriptTimeoutAttr = "timeout"
1720
scriptShasumAttr = "shasum"
1821
)
1922

2023
func resourcePostgreSQLScript() *schema.Resource {
2124
return &schema.Resource{
22-
Create: PGResourceFunc(resourcePostgreSQLScriptCreateOrUpdate),
23-
Read: PGResourceFunc(resourcePostgreSQLScriptRead),
24-
Update: PGResourceFunc(resourcePostgreSQLScriptCreateOrUpdate),
25-
Delete: PGResourceFunc(resourcePostgreSQLScriptDelete),
25+
CreateContext: PGResourceContextFunc(resourcePostgreSQLScriptCreateOrUpdate),
26+
Read: PGResourceFunc(resourcePostgreSQLScriptRead),
27+
UpdateContext: PGResourceContextFunc(resourcePostgreSQLScriptCreateOrUpdate),
28+
Delete: PGResourceFunc(resourcePostgreSQLScriptDelete),
2629

2730
Schema: map[string]*schema.Schema{
2831
scriptCommandsAttr: {
@@ -45,6 +48,12 @@ func resourcePostgreSQLScript() *schema.Resource {
4548
Default: 1,
4649
Description: "Number of seconds between two tries of the batch of commands",
4750
},
51+
scriptTimeoutAttr: {
52+
Type: schema.TypeInt,
53+
Optional: true,
54+
Default: 5 * 60,
55+
Description: "Number of seconds for a batch of command to timeout",
56+
},
4857
scriptShasumAttr: {
4958
Type: schema.TypeString,
5059
Computed: true,
@@ -54,19 +63,28 @@ func resourcePostgreSQLScript() *schema.Resource {
5463
}
5564
}
5665

57-
func resourcePostgreSQLScriptCreateOrUpdate(db *DBConnection, d *schema.ResourceData) error {
66+
func resourcePostgreSQLScriptCreateOrUpdate(ctx context.Context, db *DBConnection, d *schema.ResourceData) diag.Diagnostics {
5867
commands, err := toStringArray(d.Get(scriptCommandsAttr).([]any))
5968
tries := d.Get(scriptTriesAttr).(int)
6069
backoffDelay := d.Get(scriptBackoffDelayAttr).(int)
70+
timeout := d.Get(scriptTimeoutAttr).(int)
6171

6272
if err != nil {
63-
return err
73+
return diag.Diagnostics{diag.Diagnostic{
74+
Severity: diag.Error,
75+
Summary: "Commands input is not valid",
76+
Detail: err.Error(),
77+
}}
6478
}
6579

6680
sum := shasumCommands(commands)
6781

68-
if err := executeCommands(db, commands, tries, backoffDelay); err != nil {
69-
return err
82+
if err := executeCommands(ctx, db, commands, tries, backoffDelay, timeout); err != nil {
83+
return diag.Diagnostics{diag.Diagnostic{
84+
Severity: diag.Error,
85+
Summary: "Commands execution failed",
86+
Detail: err.Error(),
87+
}}
7088
}
7189

7290
d.Set(scriptShasumAttr, sum)
@@ -89,9 +107,9 @@ func resourcePostgreSQLScriptDelete(db *DBConnection, d *schema.ResourceData) er
89107
return nil
90108
}
91109

92-
func executeCommands(db *DBConnection, commands []string, tries int, backoffDelay int) error {
110+
func executeCommands(ctx context.Context, db *DBConnection, commands []string, tries int, backoffDelay int, timeout int) error {
93111
for try := 1; ; try++ {
94-
err := executeBatch(db, commands)
112+
err := executeBatch(ctx, db, commands, timeout)
95113
if err == nil {
96114
return nil
97115
} else {
@@ -103,10 +121,12 @@ func executeCommands(db *DBConnection, commands []string, tries int, backoffDela
103121
}
104122
}
105123

106-
func executeBatch(db *DBConnection, commands []string) error {
124+
func executeBatch(ctx context.Context, db *DBConnection, commands []string, timeout int) error {
125+
timeoutContext, timeoutCancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
126+
defer timeoutCancel()
107127
for _, command := range commands {
108128
log.Printf("[DEBUG] Executing %s", command)
109-
_, err := db.Exec(command)
129+
_, err := db.ExecContext(timeoutContext, command)
110130
log.Printf("[DEBUG] Result %s: %v", command, err)
111131
if err != nil {
112132
log.Println("[DEBUG] Error catched:", err)

postgresql/resource_postgresql_script_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,27 @@ func TestAccPostgresqlScript_failMultiple(t *testing.T) {
203203
},
204204
})
205205
}
206+
207+
func TestAccPostgresqlScript_timeout(t *testing.T) {
208+
config := `
209+
resource "postgresql_script" "invalid" {
210+
commands = [
211+
"BEGIN",
212+
"SELECT pg_sleep(2);",
213+
"COMMIT"
214+
]
215+
timeout = 1
216+
}
217+
`
218+
219+
resource.Test(t, resource.TestCase{
220+
PreCheck: func() { testAccPreCheck(t) },
221+
Providers: testAccProviders,
222+
Steps: []resource.TestStep{
223+
{
224+
Config: config,
225+
ExpectError: regexp.MustCompile("canceling statement"),
226+
},
227+
},
228+
})
229+
}

0 commit comments

Comments
 (0)