Skip to content

Commit 81c9254

Browse files
authored
Merge pull request #14 from aclowes/master
fix race condition
2 parents 2fa078e + f1e0020 commit 81c9254

File tree

1 file changed

+28
-48
lines changed

1 file changed

+28
-48
lines changed

main.go

Lines changed: 28 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@ import (
1212
"secrets-init/pkg/secrets"
1313
"secrets-init/pkg/secrets/aws"
1414
"secrets-init/pkg/secrets/google"
15-
"sync"
1615
"syscall"
17-
"time"
1816

1917
log "github.com/sirupsen/logrus"
2018
"github.com/urfave/cli/v2"
@@ -110,11 +108,7 @@ func copyCmd(c *cli.Context) error {
110108
}
111109

112110
func mainCmd(c *cli.Context) error {
113-
// Routine to reap zombies (it's the job of init)
114-
ctx, cancel := context.WithCancel(context.Background())
115-
var wg sync.WaitGroup
116-
wg.Add(1)
117-
go removeZombies(ctx, &wg)
111+
ctx := context.Background()
118112

119113
// get provider
120114
var provider secrets.Provider
@@ -127,53 +121,56 @@ func mainCmd(c *cli.Context) error {
127121
if err != nil {
128122
log.WithField("provider", c.String("provider")).WithError(err).Error("failed to initialize secrets provider")
129123
}
124+
130125
// Launch main command
131-
var mainRC int
132-
err = run(ctx, provider, c.Args().Slice())
126+
var childPid int
127+
childPid, err = run(ctx, provider, c.Args().Slice())
133128
if err != nil {
134129
log.WithError(err).Error("failed to run")
135-
mainRC = 1
130+
os.Exit(1)
136131
}
137132

138-
// Wait removeZombies goroutine
139-
cleanQuit(cancel, &wg, mainRC)
133+
// Routine to reap zombies (it's the job of init)
134+
removeZombies(childPid)
140135
return nil
141136
}
142137

143-
func removeZombies(ctx context.Context, wg *sync.WaitGroup) {
138+
func removeZombies(childPid int) {
139+
var exitCode int
144140
for {
145141
var status syscall.WaitStatus
146142

147143
// wait for an orphaned zombie process
148-
pid, _ := syscall.Wait4(-1, &status, syscall.WNOHANG, nil)
144+
pid, err := syscall.Wait4(-1, &status, syscall.WNOHANG, nil)
149145

150-
if pid <= 0 {
151-
// PID is 0 or -1 if no child waiting, so we wait for 1 second for next check
152-
time.Sleep(1 * time.Second)
146+
if pid == -1 {
147+
// if errno == ECHILD then no children remain; exit cleanly
148+
if err == syscall.ECHILD {
149+
break
150+
}
151+
log.WithError(err).Error("unexpected wait4 error")
152+
os.Exit(1)
153153
} else {
154+
// check if pid is child, if so save
154155
// PID is > 0 if a child was reaped and we immediately check if another one is waiting
156+
if pid == childPid {
157+
exitCode = status.ExitStatus()
158+
}
155159
continue
156160
}
157-
158-
// non-blocking test if context is done
159-
select {
160-
case <-ctx.Done():
161-
// context is done, so we stop goroutine
162-
wg.Done()
163-
return
164-
default:
165-
}
166161
}
162+
// no more children, exit with the same code as the child process
163+
os.Exit(exitCode)
167164
}
168165

169166
// run passed command
170-
func run(ctx context.Context, provider secrets.Provider, commandSlice []string) error {
167+
func run(ctx context.Context, provider secrets.Provider, commandSlice []string) (childPid int, err error) {
171168
var commandStr string
172169
var argsSlice []string
173170

174171
if len(commandSlice) == 0 {
175172
log.Warn("no command specified")
176-
return nil
173+
return childPid, err
177174
}
178175

179176
// split command and arguments
@@ -185,9 +182,7 @@ func run(ctx context.Context, provider secrets.Provider, commandSlice []string)
185182

186183
// register a channel to receive system signals
187184
sigs := make(chan os.Signal, 1)
188-
defer close(sigs)
189185
signal.Notify(sigs)
190-
defer signal.Reset()
191186

192187
// define a command and rebind its stdout and stdin
193188
cmd := exec.Command(commandStr, argsSlice...)
@@ -196,7 +191,6 @@ func run(ctx context.Context, provider secrets.Provider, commandSlice []string)
196191
// create a dedicated pidgroup used to forward signals to the main process and its children
197192
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
198193

199-
var err error
200194
// set environment variables
201195
if provider != nil {
202196
cmd.Env, err = provider.ResolveSecrets(ctx, os.Environ())
@@ -217,8 +211,9 @@ func run(ctx context.Context, provider secrets.Provider, commandSlice []string)
217211
err = cmd.Start()
218212
if err != nil {
219213
log.WithError(err).Error("failed to start command")
220-
return err
214+
return childPid, err
221215
}
216+
childPid = cmd.Process.Pid
222217

223218
// Goroutine for signals forwarding
224219
go func() {
@@ -239,20 +234,5 @@ func run(ctx context.Context, provider secrets.Provider, commandSlice []string)
239234
}
240235
}()
241236

242-
// wait for the command to exit
243-
err = cmd.Wait()
244-
if err != nil {
245-
log.WithError(err).Error("failed to wait for command to complete")
246-
return err
247-
}
248-
249-
return nil
250-
}
251-
252-
func cleanQuit(cancel context.CancelFunc, wg *sync.WaitGroup, code int) {
253-
// signal zombie goroutine to stop and wait for it to release waitgroup
254-
cancel()
255-
wg.Wait()
256-
257-
os.Exit(code)
237+
return childPid, err
258238
}

0 commit comments

Comments
 (0)