diff --git a/pkg/adaptation/adaptation_suite_test.go b/pkg/adaptation/adaptation_suite_test.go index 0c1cf82d..18d8e7b1 100644 --- a/pkg/adaptation/adaptation_suite_test.go +++ b/pkg/adaptation/adaptation_suite_test.go @@ -18,6 +18,7 @@ package adaptation_test import ( "context" + "errors" "fmt" "os" "path/filepath" @@ -38,6 +39,7 @@ import ( "github.com/containerd/nri/pkg/api" "github.com/containerd/nri/pkg/plugin" validator "github.com/containerd/nri/plugins/default-validator/builtin" + "github.com/containerd/ttrpc" rspec "github.com/opencontainers/runtime-spec/specs-go" ) @@ -94,6 +96,51 @@ var _ = Describe("Configuration", func() { Expect(plugin.Start(s.dir)).ToNot(Succeed()) }) }) + + When("early connection loss during plugin startup", func() { + BeforeEach(func() { + nri.SetPluginRegistrationTimeout(1 * time.Nanosecond) + + s.Prepare( + &mockRuntime{}, + &mockPlugin{ + idx: "00", + name: "test", + }, + ) + }) + + AfterEach(func() { + nri.SetPluginRegistrationTimeout(nri.DefaultPluginRegistrationTimeout) + }) + + It("should not cause a plugin to get stuck", func() { + var ( + runtime = s.runtime + plugin = s.plugins[0] + errCh = make(chan error, 1) + err error + ) + + Expect(runtime.Start(s.dir)).To(Succeed()) + + go func() { + err := plugin.Start(s.dir) + errCh <- err + }() + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + select { + case <-ctx.Done(): + err = ctx.Err() + case err = <-errCh: + } + + Expect(errors.Is(err, ttrpc.ErrClosed)).To(BeTrue()) + }) + }) }) var _ = Describe("Adaptation", func() { diff --git a/pkg/stub/stub.go b/pkg/stub/stub.go index 4dc7fbb2..1b0be902 100644 --- a/pkg/stub/stub.go +++ b/pkg/stub/stub.go @@ -584,6 +584,12 @@ func (stub *stub) register(ctx context.Context) error { // Handle a lost connection. func (stub *stub) connClosed() { + select { + // if our connection gets closed before we get Configure()'d, let Start() know + case stub.cfgErrC <- ttrpc.ErrClosed: + default: + } + stub.Lock() stub.close() stub.Unlock() @@ -628,7 +634,10 @@ func (stub *stub) Configure(ctx context.Context, req *api.ConfigureRequest) (rpl stub.requestTimeout = time.Duration(req.RequestTimeout * int64(time.Millisecond)) defer func() { - stub.cfgErrC <- retErr + select { + case stub.cfgErrC <- retErr: + default: + } }() if handler := stub.handlers.Configure; handler == nil {