Skip to content

Implement arguments aware overridable. #247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions gomock/callset.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"bytes"
"errors"
"fmt"
"slices"
"sync"
)

Expand All @@ -31,6 +32,9 @@ type callSet struct {
exhausted map[callSetKey][]*Call
// when set to true, existing call expectations are overridden when new call expectations are made
allowOverride bool
// when set to true, existing call expectations that match the call arguments are overridden when new call
// expectations are made
allowOverrideArgsAware bool
}

// callSetKey is the key in the maps in callSet
Expand All @@ -56,6 +60,16 @@ func newOverridableCallSet() *callSet {
}
}

func newOverridableArgsAwareCallSet() *callSet {
return &callSet{
expected: make(map[callSetKey][]*Call),
expectedMu: &sync.Mutex{},
exhausted: make(map[callSetKey][]*Call),
allowOverride: false,
allowOverrideArgsAware: true,
}
}

// Add adds a new expected call.
func (cs callSet) Add(call *Call) {
key := callSetKey{call.receiver, call.method}
Expand All @@ -69,6 +83,13 @@ func (cs callSet) Add(call *Call) {
}
if cs.allowOverride {
m[key] = make([]*Call, 0)
} else if cs.allowOverrideArgsAware {
calls := cs.expected[key]
for i, c := range calls {
if slices.Equal(c.args, call.args) {
cs.expected[key] = append(calls[:i], calls[i+1:]...)
}
}
}

m[key] = append(m[key], call)
Expand Down
25 changes: 25 additions & 0 deletions gomock/callset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,31 @@ func TestCallSetAdd_WhenOverridable_ClearsPreviousExpectedAndExhausted(t *testin
}
}

func TestCallSetAdd_WhenOverridableArgsAware_ClearsPreviousExpectedAndExhausted(t *testing.T) {
method := "TestMethod"
var receiver any = "TestReceiver"
cs := newOverridableArgsAwareCallSet()

cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func), "foo"))
numExpectedCalls := len(cs.expected[callSetKey{receiver, method}])
if numExpectedCalls != 1 {
t.Fatalf("Expected 1 expected call in callset, got %d", numExpectedCalls)
}

cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func), "bar"))
numExpectedCalls = len(cs.expected[callSetKey{receiver, method}])
if numExpectedCalls != 2 {
t.Fatalf("Expected 2 expected call in callset, got %d", numExpectedCalls)
}

// Only override the first call with "foo" argument.
cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func), "foo"))
newNumExpectedCalls := len(cs.expected[callSetKey{receiver, method}])
if newNumExpectedCalls != 2 {
t.Fatalf("Expected 2 expected call in callset, got %d", newNumExpectedCalls)
}
}

func TestCallSetRemove(t *testing.T) {
method := "TestMethod"
var receiver any = "TestReceiver"
Expand Down
12 changes: 12 additions & 0 deletions gomock/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,18 @@ func (o overridableExpectationsOption) apply(ctrl *Controller) {
ctrl.expectedCalls = newOverridableCallSet()
}

type overridableExpectationsArgsAwareOption struct{}

// WithOverridableExpectationsArgsAware allows for overridable call expectations
// i.e., subsequent call expectations override existing call expectations when matching arguments
func WithOverridableExpectationsArgsAware() overridableExpectationsArgsAwareOption {
return overridableExpectationsArgsAwareOption{}
}

func (o overridableExpectationsArgsAwareOption) apply(ctrl *Controller) {
ctrl.expectedCalls = newOverridableArgsAwareCallSet()
}

type cancelReporter struct {
t TestHelper
cancel func()
Expand Down
25 changes: 25 additions & 0 deletions gomock/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,28 @@ func ExampleCall_DoAndReturn_withOverridableExpectations() {
fmt.Printf("%s %s", r, s)
// Output: I'm sleepy foo
}

func ExampleCall_DoAndReturn_withOverridableExpectationsArgsAware() {
t := &testing.T{} // provided by test
ctrl := gomock.NewController(t, gomock.WithOverridableExpectationsArgsAware())
mockIndex := NewMockFoo(ctrl)
var s string

mockIndex.EXPECT().Bar("foo").DoAndReturn(
func(arg string) any {
s = arg
return "I'm sleepy"
},
)

mockIndex.EXPECT().Bar("foo").DoAndReturn(
func(arg string) any {
s = arg
return "I'm NOT sleepy"
},
)

r := mockIndex.Bar("foo")
fmt.Printf("%s %s", r, s)
// Output: I'm NOT sleepy foo
}
43 changes: 43 additions & 0 deletions gomock/overridable_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,46 @@ func TestEcho_WithOverride_BaseCase(t *testing.T) {
t.Fatalf("expected response to equal 'bar', got %s", res)
}
}

func TestEcho_WithOverrideArgsAware_BaseCase(t *testing.T) {
ctrl := gomock.NewController(t, gomock.WithOverridableExpectationsArgsAware())
mockIndex := NewMockFoo(ctrl)

// initial expectation set
mockIndex.EXPECT().Bar("first").Return("first initial")
// another expectation
mockIndex.EXPECT().Bar("second").Return("second initial")
// reset first expectation
mockIndex.EXPECT().Bar("first").Return("first changed")

res := mockIndex.Bar("first")

if res != "first changed" {
t.Fatalf("expected response to equal 'first changed', got %s", res)
}

res = mockIndex.Bar("second")
if res != "second initial" {
t.Fatalf("expected response to equal 'second initial', got %s", res)
}
}

func TestEcho_WithOverrideArgsAware_OverrideEqualMatchersOnly(t *testing.T) {
ctrl := gomock.NewController(t, gomock.WithOverridableExpectationsArgsAware())
mockIndex := NewMockFoo(ctrl)

// initial expectation set
mockIndex.EXPECT().Bar("foo").Return("foo").Times(1)
mockIndex.EXPECT().Bar(gomock.Any()).Return("bar").Times(1)

res := mockIndex.Bar("foo")

if res != "foo" {
t.Fatalf("expected response to equal 'foo', got %s", res)
}

res = mockIndex.Bar("bar")
if res != "bar" {
t.Fatalf("expected response to equal 'bar', got %s", res)
}
}