Skip to content

Commit cce342d

Browse files
committed
parallel container startup with deferred values
This commit tries to be as unobtrusive as possible, attaching new behavior to existing types where possible rather than building out new infrastructure. constructorNode returns a deferred value when called. On the first call, it asks paramList to start building an arg slice, which may also be deferred. Once the arg slice is resolved, constructorNode schedules its constructor function to be called. Once it's called, it resolves its own deferral. Multiple paramSingles can observe the same constructorNode before it's ready. If there's an error, they may all see the same error, which is a change in behavior. There are two schedulers: synchronous and parallel. The synchronous scheduler returns things in the same order as before. The parallel may not (and the tests that rely on shuffle order will fail). The scheduler needs to be flushed after deferred values are created. The synchronous scheduler does nothing on when flushing, but the parallel scheduler runs a pool of goroutines to resolve constructors. Calls to dig functions always happen on the same goroutine as Scope.Invoke(). Calls to constructor functions can happen on pooled goroutines. The choice of scheduler is up to the Scope. Whether constructor functions are safe to call in parallel seems most logically to be a property of the scope, and the scope is passed down the constructor/param call chain.
1 parent f478a90 commit cce342d

12 files changed

+609
-91
lines changed

constructor.go

+57-27
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,18 @@ type constructorNode struct {
4444
// id uniquely identifies the constructor that produces a node.
4545
id dot.CtorID
4646

47+
// Whether this node is already building its paramList and calling the constructor
48+
calling bool
49+
4750
// Whether the constructor owned by this node was already called.
4851
called bool
4952

5053
// Type information about constructor parameters.
5154
paramList paramList
5255

56+
// The result of calling the constructor
57+
deferred deferred
58+
5359
// Type information about constructor results.
5460
resultList resultList
5561

@@ -121,42 +127,66 @@ func (n *constructorNode) String() string {
121127
return fmt.Sprintf("deps: %v, ctor: %v", n.paramList, n.ctype)
122128
}
123129

124-
// Call calls this constructor if it hasn't already been called and
125-
// injects any values produced by it into the provided container.
126-
func (n *constructorNode) Call(c containerStore) error {
127-
if n.called {
128-
return nil
130+
// Call calls this constructor if it hasn't already been called and injects any values produced by it into the container
131+
// passed to newConstructorNode.
132+
//
133+
// If constructorNode has a unresolved deferred already in the process of building, it will return that one. If it has
134+
// already been successfully called, it will return an already-resolved deferred. Together these mean it will try the
135+
// call again if it failed last time.
136+
//
137+
// On failure, the returned pointer is not guaranteed to stay in a failed state; another call will reset it back to its
138+
// zero value; don't store the returned pointer. (It will still call each observer only once.)
139+
func (n *constructorNode) Call(c containerStore) *deferred {
140+
if n.calling || n.called {
141+
return &n.deferred
129142
}
130143

144+
n.calling = true
145+
n.deferred = deferred{}
146+
131147
if err := shallowCheckDependencies(c, n.paramList); err != nil {
132-
return errMissingDependencies{
148+
n.deferred.resolve(errMissingDependencies{
133149
Func: n.location,
134150
Reason: err,
135-
}
151+
})
136152
}
137153

138-
args, err := n.paramList.BuildList(c)
139-
if err != nil {
140-
return errArgumentsFailed{
141-
Func: n.location,
142-
Reason: err,
154+
var args []reflect.Value
155+
d := n.paramList.BuildList(c, &args)
156+
157+
d.observe(func(err error) {
158+
if err != nil {
159+
n.calling = false
160+
n.deferred.resolve(errArgumentsFailed{
161+
Func: n.location,
162+
Reason: err,
163+
})
164+
return
143165
}
144-
}
145-
146-
receiver := newStagingContainerWriter()
147-
results := c.invoker()(reflect.ValueOf(n.ctor), args)
148-
if err := n.resultList.ExtractList(receiver, results); err != nil {
149-
return errConstructorFailed{Func: n.location, Reason: err}
150-
}
151-
152-
// Commit the result to the original container that this constructor
153-
// was supplied to. The provided constructor is only used for a view of
154-
// the rest of the graph to instantiate the dependencies of this
155-
// container.
156-
receiver.Commit(n.s)
157-
n.called = true
158166

159-
return nil
167+
var results []reflect.Value
168+
169+
c.scheduler().schedule(func() {
170+
results = c.invoker()(reflect.ValueOf(n.ctor), args)
171+
}).observe(func(_ error) {
172+
n.calling = false
173+
receiver := newStagingContainerWriter()
174+
if err := n.resultList.ExtractList(receiver, results); err != nil {
175+
n.deferred.resolve(errConstructorFailed{Func: n.location, Reason: err})
176+
return
177+
}
178+
179+
// Commit the result to the original container that this constructor
180+
// was supplied to. The provided container is only used for a view of
181+
// the rest of the graph to instantiate the dependencies of this
182+
// container.
183+
receiver.Commit(n.s)
184+
n.called = true
185+
n.deferred.resolve(nil)
186+
})
187+
})
188+
189+
return &n.deferred
160190
}
161191

162192
// stagingContainerWriter is a containerWriter that records the changes that

constructor_test.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,11 @@ func TestNodeAlreadyCalled(t *testing.T) {
5959
require.False(t, n.called, "node must not have been called")
6060

6161
c := New()
62-
require.NoError(t, n.Call(c.scope), "invoke failed")
62+
d := n.Call(c.scope)
63+
c.scope.sched.flush()
64+
require.NoError(t, d.err, "invoke failed")
6365
require.True(t, n.called, "node must be called")
64-
require.NoError(t, n.Call(c.scope), "calling again should be okay")
66+
d = n.Call(c.scope)
67+
c.scope.sched.flush()
68+
require.NoError(t, d.err, "calling again should be okay")
6569
}

container.go

+26
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ type containerStore interface {
119119

120120
// Returns invokerFn function to use when calling arguments.
121121
invoker() invokerFn
122+
123+
// Returns the scheduler to use for this scope.
124+
scheduler() scheduler
122125
}
123126

124127
// New constructs a Container.
@@ -208,6 +211,29 @@ func dryInvoker(fn reflect.Value, _ []reflect.Value) []reflect.Value {
208211
return results
209212
}
210213

214+
type maxConcurrencyOption int
215+
216+
// MaxConcurrency run constructors in this container with a fixed pool of executor
217+
// goroutines. max is the number of goroutines to start.
218+
func MaxConcurrency(max int) Option {
219+
return maxConcurrencyOption(max)
220+
}
221+
222+
func (m maxConcurrencyOption) applyOption(container *Container) {
223+
container.scope.sched = &parallelScheduler{concurrency: int(m)}
224+
}
225+
226+
type unboundedConcurrency struct{}
227+
228+
// UnboundedConcurrency run constructors in this container as concurrently as possible.
229+
// Go's resource limits like GOMAXPROCS will inherently limit how much can happen in
230+
// parallel.
231+
var UnboundedConcurrency Option = unboundedConcurrency{}
232+
233+
func (u unboundedConcurrency) applyOption(container *Container) {
234+
container.scope.sched = &unboundedScheduler{}
235+
}
236+
211237
// String representation of the entire Container
212238
func (c *Container) String() string {
213239
return c.scope.String()

deferred.go

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
package dig
2+
3+
type observer func(error)
4+
5+
// A deferred is an observable future result that may fail. Its zero value is unresolved and has no observers. It can
6+
// be resolved once, at which point every observer will be called.
7+
type deferred struct {
8+
observers []observer
9+
settled bool
10+
err error
11+
}
12+
13+
// alreadyResolved is a deferred that has already been resolved with a nil error.
14+
var alreadyResolved = deferred{settled: true}
15+
16+
// failedDeferred returns a deferred that is resolved with the given error.
17+
func failedDeferred(err error) *deferred {
18+
return &deferred{settled: true, err: err}
19+
}
20+
21+
// observe registers an observer to receive a callback when this deferred is resolved. It will be called at most one
22+
// time. If this deferred is already resolved, the observer is called immediately, before observe returns.
23+
func (d *deferred) observe(obs observer) {
24+
if d.settled {
25+
obs(d.err)
26+
return
27+
}
28+
29+
d.observers = append(d.observers, obs)
30+
}
31+
32+
// resolve sets the status of this deferred and notifies all observers if it's not already resolved.
33+
func (d *deferred) resolve(err error) {
34+
if d.settled {
35+
return
36+
}
37+
38+
d.settled = true
39+
d.err = err
40+
for _, obs := range d.observers {
41+
obs(err)
42+
}
43+
d.observers = nil
44+
}
45+
46+
// then returns a new deferred that is either resolved with the same error as this deferred, or any error returned from
47+
// the supplied function. The supplied function is only called if this deferred is resolved without error.
48+
func (d *deferred) then(res func() error) *deferred {
49+
d2 := new(deferred)
50+
d.observe(func(err error) {
51+
if err != nil {
52+
d2.resolve(err)
53+
return
54+
}
55+
d2.resolve(res())
56+
})
57+
return d2
58+
}
59+
60+
// catch maps any error from this deferred using the supplied function. The supplied function is only called if this
61+
// deferred is resolved with an error. If the supplied function returns a nil error, the new deferred will resolve
62+
// successfully.
63+
func (d *deferred) catch(rej func(error) error) *deferred {
64+
d2 := new(deferred)
65+
d.observe(func(err error) {
66+
if err != nil {
67+
err = rej(err)
68+
}
69+
d2.resolve(err)
70+
})
71+
return d2
72+
}
73+
74+
// whenAll returns a new deferred that resolves when all the supplied deferreds resolve. It resolves with the first
75+
// error reported by any deferred, or nil if they all succeed.
76+
func whenAll(others ...*deferred) *deferred {
77+
if len(others) == 0 {
78+
return &alreadyResolved
79+
}
80+
81+
d := new(deferred)
82+
count := len(others)
83+
84+
onResolved := func(err error) {
85+
if d.settled {
86+
return
87+
}
88+
89+
if err != nil {
90+
d.resolve(err)
91+
}
92+
93+
count--
94+
if count == 0 {
95+
d.resolve(nil)
96+
}
97+
}
98+
99+
for _, other := range others {
100+
other.observe(onResolved)
101+
}
102+
103+
return d
104+
}

dig_test.go

+89
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
"math/rand"
3030
"os"
3131
"reflect"
32+
"sync/atomic"
3233
"testing"
3334
"time"
3435

@@ -3566,3 +3567,91 @@ func TestEndToEndSuccessWithAliases(t *testing.T) {
35663567
})
35673568

35683569
}
3570+
3571+
func TestConcurrency(t *testing.T) {
3572+
// Ensures providers will run at the same time
3573+
t.Run("TestMaxConcurrency", func(t *testing.T) {
3574+
t.Parallel()
3575+
3576+
type (
3577+
A int
3578+
B int
3579+
C int
3580+
)
3581+
3582+
var (
3583+
timer = time.NewTimer(10 * time.Second)
3584+
max int32 = 3
3585+
done = make(chan struct{})
3586+
running int32 = 0
3587+
waitForUs = func() error {
3588+
if atomic.AddInt32(&running, 1) == max {
3589+
close(done)
3590+
}
3591+
select {
3592+
case <-timer.C:
3593+
return errors.New("timeout expired")
3594+
case <-done:
3595+
return nil
3596+
}
3597+
}
3598+
c = digtest.New(t, dig.MaxConcurrency(int(max)))
3599+
)
3600+
3601+
c.RequireProvide(func() (A, error) { return 0, waitForUs() })
3602+
c.RequireProvide(func() (B, error) { return 1, waitForUs() })
3603+
c.RequireProvide(func() (C, error) { return 2, waitForUs() })
3604+
3605+
c.RequireInvoke(func(a A, b B, c C) {
3606+
require.Equal(t, a, A(0))
3607+
require.Equal(t, b, B(1))
3608+
require.Equal(t, c, C(2))
3609+
require.Equal(t, running, int32(3))
3610+
})
3611+
})
3612+
3613+
t.Run("TestUnboundConcurrency", func(t *testing.T) {
3614+
t.Parallel()
3615+
3616+
var (
3617+
timer = time.NewTimer(10 * time.Second)
3618+
max int32 = 20
3619+
done = make(chan struct{})
3620+
running int32 = 0
3621+
waitForUs = func() error {
3622+
if atomic.AddInt32(&running, 1) >= max {
3623+
close(done)
3624+
}
3625+
select {
3626+
case <-timer.C:
3627+
return errors.New("timeout expired")
3628+
case <-done:
3629+
return nil
3630+
}
3631+
}
3632+
c = digtest.New(t, dig.UnboundedConcurrency)
3633+
expected []int
3634+
)
3635+
3636+
for i := 0; i < int(max); i++ {
3637+
i := i
3638+
expected = append(expected, i)
3639+
type out struct {
3640+
dig.Out
3641+
3642+
Value int `group:"a"`
3643+
}
3644+
c.RequireProvide(func() (out, error) { return out{Value: i}, waitForUs() })
3645+
}
3646+
3647+
type in struct {
3648+
dig.In
3649+
3650+
Values []int `group:"a"`
3651+
}
3652+
3653+
c.RequireInvoke(func(i in) {
3654+
require.ElementsMatch(t, expected, i.Values)
3655+
})
3656+
})
3657+
}

0 commit comments

Comments
 (0)