diff --git a/gomock/callset.go b/gomock/callset.go index f5cc592..57768a5 100644 --- a/gomock/callset.go +++ b/gomock/callset.go @@ -18,6 +18,7 @@ import ( "bytes" "errors" "fmt" + "slices" "sync" ) @@ -31,6 +32,9 @@ type callSet struct { exhausted map[callSetKey][]*Call // when set to true, existing call expectations are overridden when new call expectations are made allowOverride bool + // when set to true, existing call expectations that match the call arguments are overridden when new call + // expectations are made + allowOverrideArgsAware bool } // callSetKey is the key in the maps in callSet @@ -56,6 +60,16 @@ func newOverridableCallSet() *callSet { } } +func newOverridableArgsAwareCallSet() *callSet { + return &callSet{ + expected: make(map[callSetKey][]*Call), + expectedMu: &sync.Mutex{}, + exhausted: make(map[callSetKey][]*Call), + allowOverride: false, + allowOverrideArgsAware: true, + } +} + // Add adds a new expected call. func (cs callSet) Add(call *Call) { key := callSetKey{call.receiver, call.method} @@ -69,6 +83,13 @@ func (cs callSet) Add(call *Call) { } if cs.allowOverride { m[key] = make([]*Call, 0) + } else if cs.allowOverrideArgsAware { + calls := cs.expected[key] + for i, c := range calls { + if slices.Equal(c.args, call.args) { + cs.expected[key] = append(calls[:i], calls[i+1:]...) + } + } } m[key] = append(m[key], call) diff --git a/gomock/callset_test.go b/gomock/callset_test.go index d8150c5..6244997 100644 --- a/gomock/callset_test.go +++ b/gomock/callset_test.go @@ -60,6 +60,31 @@ func TestCallSetAdd_WhenOverridable_ClearsPreviousExpectedAndExhausted(t *testin } } +func TestCallSetAdd_WhenOverridableArgsAware_ClearsPreviousExpectedAndExhausted(t *testing.T) { + method := "TestMethod" + var receiver any = "TestReceiver" + cs := newOverridableArgsAwareCallSet() + + cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func), "foo")) + numExpectedCalls := len(cs.expected[callSetKey{receiver, method}]) + if numExpectedCalls != 1 { + t.Fatalf("Expected 1 expected call in callset, got %d", numExpectedCalls) + } + + cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func), "bar")) + numExpectedCalls = len(cs.expected[callSetKey{receiver, method}]) + if numExpectedCalls != 2 { + t.Fatalf("Expected 2 expected call in callset, got %d", numExpectedCalls) + } + + // Only override the first call with "foo" argument. + cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func), "foo")) + newNumExpectedCalls := len(cs.expected[callSetKey{receiver, method}]) + if newNumExpectedCalls != 2 { + t.Fatalf("Expected 2 expected call in callset, got %d", newNumExpectedCalls) + } +} + func TestCallSetRemove(t *testing.T) { method := "TestMethod" var receiver any = "TestReceiver" diff --git a/gomock/controller.go b/gomock/controller.go index 674c329..2f4922c 100644 --- a/gomock/controller.go +++ b/gomock/controller.go @@ -120,6 +120,18 @@ func (o overridableExpectationsOption) apply(ctrl *Controller) { ctrl.expectedCalls = newOverridableCallSet() } +type overridableExpectationsArgsAwareOption struct{} + +// WithOverridableExpectationsArgsAware allows for overridable call expectations +// i.e., subsequent call expectations override existing call expectations when matching arguments +func WithOverridableExpectationsArgsAware() overridableExpectationsArgsAwareOption { + return overridableExpectationsArgsAwareOption{} +} + +func (o overridableExpectationsArgsAwareOption) apply(ctrl *Controller) { + ctrl.expectedCalls = newOverridableArgsAwareCallSet() +} + type cancelReporter struct { t TestHelper cancel func() diff --git a/gomock/example_test.go b/gomock/example_test.go index 6051b71..a6ab0b9 100644 --- a/gomock/example_test.go +++ b/gomock/example_test.go @@ -67,3 +67,28 @@ func ExampleCall_DoAndReturn_withOverridableExpectations() { fmt.Printf("%s %s", r, s) // Output: I'm sleepy foo } + +func ExampleCall_DoAndReturn_withOverridableExpectationsArgsAware() { + t := &testing.T{} // provided by test + ctrl := gomock.NewController(t, gomock.WithOverridableExpectationsArgsAware()) + mockIndex := NewMockFoo(ctrl) + var s string + + mockIndex.EXPECT().Bar("foo").DoAndReturn( + func(arg string) any { + s = arg + return "I'm sleepy" + }, + ) + + mockIndex.EXPECT().Bar("foo").DoAndReturn( + func(arg string) any { + s = arg + return "I'm NOT sleepy" + }, + ) + + r := mockIndex.Bar("foo") + fmt.Printf("%s %s", r, s) + // Output: I'm NOT sleepy foo +} diff --git a/gomock/overridable_controller_test.go b/gomock/overridable_controller_test.go index 3d75e6a..ac8b7a1 100644 --- a/gomock/overridable_controller_test.go +++ b/gomock/overridable_controller_test.go @@ -32,3 +32,46 @@ func TestEcho_WithOverride_BaseCase(t *testing.T) { t.Fatalf("expected response to equal 'bar', got %s", res) } } + +func TestEcho_WithOverrideArgsAware_BaseCase(t *testing.T) { + ctrl := gomock.NewController(t, gomock.WithOverridableExpectationsArgsAware()) + mockIndex := NewMockFoo(ctrl) + + // initial expectation set + mockIndex.EXPECT().Bar("first").Return("first initial") + // another expectation + mockIndex.EXPECT().Bar("second").Return("second initial") + // reset first expectation + mockIndex.EXPECT().Bar("first").Return("first changed") + + res := mockIndex.Bar("first") + + if res != "first changed" { + t.Fatalf("expected response to equal 'first changed', got %s", res) + } + + res = mockIndex.Bar("second") + if res != "second initial" { + t.Fatalf("expected response to equal 'second initial', got %s", res) + } +} + +func TestEcho_WithOverrideArgsAware_OverrideEqualMatchersOnly(t *testing.T) { + ctrl := gomock.NewController(t, gomock.WithOverridableExpectationsArgsAware()) + mockIndex := NewMockFoo(ctrl) + + // initial expectation set + mockIndex.EXPECT().Bar("foo").Return("foo").Times(1) + mockIndex.EXPECT().Bar(gomock.Any()).Return("bar").Times(1) + + res := mockIndex.Bar("foo") + + if res != "foo" { + t.Fatalf("expected response to equal 'foo', got %s", res) + } + + res = mockIndex.Bar("bar") + if res != "bar" { + t.Fatalf("expected response to equal 'bar', got %s", res) + } +}