Skip to content

Commit 27f6899

Browse files
committed
vm: func Run accepts context
It allows to use context as a single termination signal source.
1 parent 8f9cf94 commit 27f6899

File tree

20 files changed

+76
-94
lines changed

20 files changed

+76
-94
lines changed

pkg/build/netbsd.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package build
55

66
import (
7+
"context"
78
"encoding/json"
89
"fmt"
910
"os"
@@ -155,7 +156,9 @@ func (ctx netbsd) copyKernelToDisk(targetArch, vmType, outputDir, kernel string)
155156
}
156157
commands = append(commands, "mknod /dev/vhci c 355 0")
157158
commands = append(commands, "sync") // Run sync so that the copied image is stored properly.
158-
_, rep, err := inst.Run(time.Minute, reporter, strings.Join(commands, ";"))
159+
ctxTimeout, cancel := context.WithTimeout(context.Background(), time.Minute)
160+
defer cancel()
161+
_, rep, err := inst.Run(ctxTimeout, reporter, strings.Join(commands, ";"))
159162
if err != nil {
160163
return fmt.Errorf("error syncing the instance %w", err)
161164
}

pkg/instance/execprog.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package instance
55

66
import (
7+
"context"
78
"fmt"
89
"os"
910
"time"
@@ -123,7 +124,9 @@ func (inst *ExecProgInstance) runCommand(command string, duration time.Duration,
123124
if inst.BeforeContextLen != 0 {
124125
opts = append(opts, vm.OutputSize(inst.BeforeContextLen))
125126
}
126-
output, rep, err := inst.VMInstance.Run(duration, inst.reporter, command, opts...)
127+
ctxTimeout, cancel := context.WithTimeout(context.Background(), duration)
128+
defer cancel()
129+
output, rep, err := inst.VMInstance.Run(ctxTimeout, inst.reporter, command, opts...)
127130
if err != nil {
128131
return nil, fmt.Errorf("failed to run command in VM: %w", err)
129132
}

pkg/manager/diff.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,8 +535,10 @@ func (kc *kernelContext) runInstance(ctx context.Context, inst *vm.Instance,
535535
return nil, fmt.Errorf("failed to parse manager's address")
536536
}
537537
cmd := fmt.Sprintf("%v runner %v %v %v", executorBin, inst.Index(), host, port)
538-
_, rep, err := inst.Run(kc.cfg.Timeouts.VMRunningTime, kc.reporter, cmd,
539-
vm.ExitTimeout, vm.StopContext(ctx), vm.InjectExecuting(injectExec),
538+
ctxTimeout, cancel := context.WithTimeout(ctx, kc.cfg.Timeouts.VMRunningTime)
539+
defer cancel()
540+
_, rep, err := inst.Run(ctxTimeout, kc.reporter, cmd, vm.ExitTimeout,
541+
vm.InjectExecuting(injectExec),
540542
vm.EarlyFinishCb(func() {
541543
// Depending on the crash type and kernel config, fuzzing may continue
542544
// running for several seconds even after kernel has printed a crash report.

syz-manager/manager.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -652,8 +652,10 @@ func (mgr *Manager) runInstanceInner(ctx context.Context, inst *vm.Instance, inj
652652
return nil, nil, fmt.Errorf("failed to parse manager's address")
653653
}
654654
cmd := fmt.Sprintf("%v runner %v %v %v", executorBin, inst.Index(), host, port)
655-
_, rep, err := inst.Run(mgr.cfg.Timeouts.VMRunningTime, mgr.reporter, cmd,
656-
vm.ExitTimeout, vm.StopContext(ctx), vm.InjectExecuting(injectExec),
655+
ctxTimeout, cancel := context.WithTimeout(ctx, mgr.cfg.Timeouts.VMRunningTime)
656+
defer cancel()
657+
_, rep, err := inst.Run(ctxTimeout, mgr.reporter, cmd,
658+
vm.ExitTimeout, vm.InjectExecuting(injectExec),
657659
finishCb,
658660
)
659661
if err != nil {

syz-manager/snapshot.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ func (mgr *Manager) snapshotLoop(ctx context.Context, inst *vm.Instance) error {
4141
// All network connections (including ssh) will break once we start restoring snapshots.
4242
// So we start a background process and log to /dev/kmsg.
4343
cmd := fmt.Sprintf("nohup %v exec snapshot 1>/dev/null 2>/dev/kmsg </dev/null &", executor)
44-
if _, _, err := inst.Run(time.Hour, mgr.reporter, cmd); err != nil {
44+
ctxTimeout, cancel := context.WithTimeout(ctx, time.Hour)
45+
defer cancel()
46+
if _, _, err := inst.Run(ctxTimeout, mgr.reporter, cmd); err != nil {
4547
return err
4648
}
4749

vm/adb/adb.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package adb
77

88
import (
99
"bytes"
10+
"context"
1011
"encoding/json"
1112
"fmt"
1213
"io"
@@ -521,7 +522,7 @@ func isRemoteCuttlefish(dev string) (bool, string) {
521522
return true, ip
522523
}
523524

524-
func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command string) (
525+
func (inst *instance) Run(ctx context.Context, command string) (
525526
<-chan []byte, <-chan error, error) {
526527
var tty io.ReadCloser
527528
var err error
@@ -566,9 +567,8 @@ func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command strin
566567
merger.Add("console", tty)
567568
merger.Add("adb", adbRpipe)
568569

569-
return vmimpl.Multiplex(adb, merger, timeout, vmimpl.MultiplexConfig{
570+
return vmimpl.Multiplex(ctx, adb, merger, vmimpl.MultiplexConfig{
570571
Console: tty,
571-
Stop: stop,
572572
Close: inst.closed,
573573
Debug: inst.debug,
574574
Scale: inst.timeouts.Scale,

vm/bhyve/bhyve.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package bhyve
55

66
import (
7+
"context"
78
"fmt"
89
"io"
910
"os"
@@ -324,7 +325,7 @@ func (inst *instance) Copy(hostSrc string) (string, error) {
324325
return vmDst, nil
325326
}
326327

327-
func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command string) (
328+
func (inst *instance) Run(ctx context.Context, command string) (
328329
<-chan []byte, <-chan error, error) {
329330
rpipe, wpipe, err := osutil.LongPipe()
330331
if err != nil {
@@ -360,9 +361,7 @@ func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command strin
360361

361362
go func() {
362363
select {
363-
case <-time.After(timeout):
364-
signal(vmimpl.ErrTimeout)
365-
case <-stop:
364+
case <-ctx.Done():
366365
signal(vmimpl.ErrTimeout)
367366
case err := <-inst.merger.Err:
368367
cmd.Process.Kill()

vm/cuttlefish/cuttlefish.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
package cuttlefish
1212

1313
import (
14+
"context"
1415
"fmt"
1516
"os/exec"
1617
"path/filepath"
@@ -167,9 +168,9 @@ func (inst *instance) Close() error {
167168
return inst.gceInst.Close()
168169
}
169170

170-
func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command string) (
171+
func (inst *instance) Run(ctx context.Context, command string) (
171172
<-chan []byte, <-chan error, error) {
172-
return inst.gceInst.Run(timeout, stop, fmt.Sprintf("adb shell 'cd %s; %s'", deviceRoot, command))
173+
return inst.gceInst.Run(ctx, fmt.Sprintf("adb shell 'cd %s; %s'", deviceRoot, command))
173174
}
174175

175176
func (inst *instance) Diagnose(rep *report.Report) ([]byte, bool) {

vm/gce/gce.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ func (inst *instance) Copy(hostSrc string) (string, error) {
271271
return vmDst, nil
272272
}
273273

274-
func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command string) (
274+
func (inst *instance) Run(ctx context.Context, command string) (
275275
<-chan []byte, <-chan error, error) {
276276
conRpipe, conWpipe, err := osutil.LongPipe()
277277
if err != nil {
@@ -340,9 +340,8 @@ func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command strin
340340
sshWpipe.Close()
341341
merger.Add("ssh", sshRpipe)
342342

343-
return vmimpl.Multiplex(ssh, merger, timeout, vmimpl.MultiplexConfig{
343+
return vmimpl.Multiplex(ctx, ssh, merger, vmimpl.MultiplexConfig{
344344
Console: vmimpl.CmdCloser{Cmd: con},
345-
Stop: stop,
346345
Close: inst.closed,
347346
Debug: inst.debug,
348347
Scale: inst.timeouts.Scale,

vm/gvisor/gvisor.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package gvisor
77

88
import (
99
"bytes"
10+
"context"
1011
"fmt"
1112
"io"
1213
"net"
@@ -286,7 +287,7 @@ func (inst *instance) Copy(hostSrc string) (string, error) {
286287
return filepath.Join("/", fname), nil
287288
}
288289

289-
func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command string) (
290+
func (inst *instance) Run(ctx context.Context, command string) (
290291
<-chan []byte, <-chan error, error) {
291292
args := []string{"exec", "-user=0:0"}
292293
for _, c := range sandboxCaps {
@@ -327,9 +328,7 @@ func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command strin
327328

328329
go func() {
329330
select {
330-
case <-time.After(timeout):
331-
signal(vmimpl.ErrTimeout)
332-
case <-stop:
331+
case <-ctx.Done():
333332
signal(vmimpl.ErrTimeout)
334333
case err := <-inst.merger.Err:
335334
cmd.Process.Kill()

vm/isolated/isolated.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package isolated
55

66
import (
77
"bytes"
8+
"context"
89
"fmt"
910
"io"
1011
"os"
@@ -311,7 +312,7 @@ func (inst *instance) Copy(hostSrc string) (string, error) {
311312
return vmDst, nil
312313
}
313314

314-
func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command string) (
315+
func (inst *instance) Run(ctx context.Context, command string) (
315316
<-chan []byte, <-chan error, error) {
316317
args := append(vmimpl.SSHArgs(inst.debug, inst.Key, inst.Port, inst.cfg.SystemSSHCfg),
317318
inst.User+"@"+inst.Addr)
@@ -354,9 +355,8 @@ func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command strin
354355
merger.Add("dmesg", dmesg)
355356
merger.Add("ssh", rpipe)
356357

357-
return vmimpl.Multiplex(cmd, merger, timeout, vmimpl.MultiplexConfig{
358+
return vmimpl.Multiplex(ctx, cmd, merger, vmimpl.MultiplexConfig{
358359
Console: dmesg,
359-
Stop: stop,
360360
Close: inst.closed,
361361
Debug: inst.debug,
362362
Scale: inst.timeouts.Scale,

vm/proxyapp/proxyappclient.go

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -477,11 +477,7 @@ func buildMerger(names ...string) (*vmimpl.OutputMerger, []io.Writer) {
477477
return merger, wPipes
478478
}
479479

480-
func (inst *instance) Run(
481-
timeout time.Duration,
482-
stop <-chan bool,
483-
command string,
484-
) (<-chan []byte, <-chan error, error) {
480+
func (inst *instance) Run(ctx context.Context, command string) (<-chan []byte, <-chan error, error) {
485481
merger, wPipes := buildMerger("stdout", "stderr", "console")
486482
receivedStdoutChunks := wPipes[0]
487483
receivedStderrChunks := wPipes[1]
@@ -502,7 +498,6 @@ func (inst *instance) Run(
502498

503499
runID := reply.RunID
504500
terminationError := make(chan error, 1)
505-
timeoutSignal := time.After(timeout)
506501
signalClientErrorf := clientErrorf(receivedStderrChunks)
507502

508503
go func() {
@@ -531,13 +526,10 @@ func (inst *instance) Run(
531526
} else {
532527
continue
533528
}
534-
case <-timeoutSignal:
529+
case <-ctx.Done():
535530
// It is the happy path.
536531
inst.runStop(runID)
537532
terminationError <- vmimpl.ErrTimeout
538-
case <-stop:
539-
inst.runStop(runID)
540-
terminationError <- vmimpl.ErrTimeout
541533
}
542534
break
543535
}

vm/proxyapp/proxyappclient_test.go

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package proxyapp
55

66
import (
77
"bytes"
8+
"context"
89
"fmt"
910
"io"
1011
"net/rpc"
@@ -401,28 +402,13 @@ func TestInstance_Forward_Failure(t *testing.T) {
401402
assert.Empty(t, remoteAddressToUse)
402403
}
403404

404-
func TestInstance_Run_SimpleOk(t *testing.T) {
405-
mockInstance, inst := createInstanceFixture(t)
406-
mockInstance.
407-
On("RunStart", mock.Anything, mock.Anything).
408-
Return(nil).
409-
On("RunReadProgress", mock.Anything, mock.Anything).
410-
Return(nil).
411-
Maybe()
412-
413-
outc, errc, err := inst.Run(10*time.Second, make(chan bool), "command")
414-
assert.NotNil(t, outc)
415-
assert.NotNil(t, errc)
416-
assert.Nil(t, err)
417-
}
418-
419405
func TestInstance_Run_Failure(t *testing.T) {
420406
mockInstance, inst := createInstanceFixture(t)
421407
mockInstance.
422408
On("RunStart", mock.Anything, mock.Anything).
423409
Return(fmt.Errorf("run start error"))
424410

425-
outc, errc, err := inst.Run(10*time.Second, make(chan bool), "command")
411+
outc, errc, err := inst.Run(contextWithTimeout(t, 10*time.Second), "command")
426412
assert.Nil(t, outc)
427413
assert.Nil(t, errc)
428414
assert.NotEmpty(t, err)
@@ -438,7 +424,7 @@ func TestInstance_Run_OnTimeout(t *testing.T) {
438424
On("RunStop", mock.Anything, mock.Anything).
439425
Return(nil)
440426

441-
_, errc, _ := inst.Run(time.Second, make(chan bool), "command")
427+
_, errc, _ := inst.Run(contextWithTimeout(t, time.Second), "command")
442428
err := <-errc
443429

444430
assert.Equal(t, err, vmimpl.ErrTimeout)
@@ -455,9 +441,9 @@ func TestInstance_Run_OnStop(t *testing.T) {
455441
On("RunStop", mock.Anything, mock.Anything).
456442
Return(nil)
457443

458-
stop := make(chan bool)
459-
_, errc, _ := inst.Run(10*time.Second, stop, "command")
460-
stop <- true
444+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
445+
_, errc, _ := inst.Run(ctx, "command")
446+
cancel()
461447
err := <-errc
462448
assert.Equal(t, err, vmimpl.ErrTimeout)
463449
}
@@ -478,7 +464,7 @@ func TestInstance_RunReadProgress_OnErrorReceived(t *testing.T) {
478464
Return(nil).
479465
Once()
480466

481-
outc, _, _ := inst.Run(10*time.Second, make(chan bool), "command")
467+
outc, _, _ := inst.Run(contextWithTimeout(t, 10*time.Second), "command")
482468
output := string(<-outc)
483469

484470
assert.Equal(t, "mock error\nSYZFAIL: proxy app plugin error\n", output)
@@ -500,7 +486,7 @@ func TestInstance_RunReadProgress_OnFinished(t *testing.T) {
500486
Return(nil).
501487
Once()
502488

503-
_, errc, _ := inst.Run(10*time.Second, make(chan bool), "command")
489+
_, errc, _ := inst.Run(contextWithTimeout(t, 10*time.Second), "command")
504490
err := <-errc
505491

506492
assert.Equal(t, err, nil)
@@ -519,7 +505,7 @@ func TestInstance_RunReadProgress_Failed(t *testing.T) {
519505
Return(fmt.Errorf("runreadprogresserror")).
520506
Once()
521507

522-
outc, _, _ := inst.Run(10*time.Second, make(chan bool), "command")
508+
outc, _, _ := inst.Run(contextWithTimeout(t, 10*time.Second), "command")
523509
output := string(<-outc)
524510

525511
assert.Equal(t,
@@ -532,3 +518,9 @@ func TestInstance_RunReadProgress_Failed(t *testing.T) {
532518
// [option] check pool size was changed
533519

534520
// TODO: test pool.Close() calls plugin API and return error.
521+
522+
func contextWithTimeout(t *testing.T, timeout time.Duration) context.Context {
523+
ctx, cancel := context.WithTimeout(context.Background(), timeout)
524+
t.Cleanup(cancel)
525+
return ctx
526+
}

vm/qemu/qemu.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package qemu
55

66
import (
77
"bytes"
8+
"context"
89
"encoding/json"
910
"fmt"
1011
"io"
@@ -667,7 +668,7 @@ func (inst *instance) Copy(hostSrc string) (string, error) {
667668
return vmDst, nil
668669
}
669670

670-
func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command string) (
671+
func (inst *instance) Run(ctx context.Context, command string) (
671672
<-chan []byte, <-chan error, error) {
672673
rpipe, wpipe, err := osutil.LongPipe()
673674
if err != nil {
@@ -707,8 +708,7 @@ func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command strin
707708
return nil, nil, err
708709
}
709710
wpipe.Close()
710-
return vmimpl.Multiplex(cmd, inst.merger, timeout, vmimpl.MultiplexConfig{
711-
Stop: stop,
711+
return vmimpl.Multiplex(ctx, cmd, inst.merger, vmimpl.MultiplexConfig{
712712
Debug: inst.debug,
713713
Scale: inst.timeouts.Scale,
714714
})

0 commit comments

Comments
 (0)