Skip to content

Commit f76d926

Browse files
authored
Merge pull request #1580 from bhandras/copy-fixes
tapfreighter: fix `OutboundParcel.Copy` and add generic `Copy` fn test
2 parents 74d38bc + 7ce1501 commit f76d926

File tree

4 files changed

+403
-15
lines changed

4 files changed

+403
-15
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ require (
3737
github.com/lightningnetwork/lnd/tlv v1.3.1
3838
github.com/lightningnetwork/lnd/tor v1.1.6
3939
github.com/ory/dockertest/v3 v3.10.0
40+
github.com/pmezard/go-difflib v1.0.0
4041
github.com/prometheus/client_golang v1.14.0
4142
github.com/stretchr/testify v1.10.0
4243
github.com/urfave/cli v1.22.14
@@ -148,7 +149,6 @@ require (
148149
github.com/opencontainers/image-spec v1.1.0 // indirect
149150
github.com/opencontainers/runc v1.2.0 // indirect
150151
github.com/pkg/errors v0.9.1 // indirect
151-
github.com/pmezard/go-difflib v1.0.0 // indirect
152152
github.com/prometheus/client_model v0.3.0 // indirect
153153
github.com/prometheus/common v0.37.0 // indirect
154154
github.com/prometheus/procfs v0.8.0 // indirect

internal/test/copy.go

Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
package test
2+
3+
import (
4+
"fmt"
5+
"reflect"
6+
"strings"
7+
"testing"
8+
9+
"github.com/davecgh/go-spew/spew"
10+
"github.com/lightninglabs/taproot-assets/fn"
11+
"github.com/pmezard/go-difflib/difflib"
12+
)
13+
14+
// FillFakeData recursively fills a struct with dummy values.
15+
func FillFakeData[T any](t *testing.T, debug bool, maxDepth int, v T) {
16+
if t != nil {
17+
t.Helper()
18+
}
19+
20+
val := reflect.ValueOf(v)
21+
name := val.Type().Elem().Name()
22+
fillFakeData(t, debug, 0, maxDepth, val, name)
23+
}
24+
25+
// fillFakeData is the recursive helper to fill a value with fake data.
26+
func fillFakeData(t *testing.T, debug bool, depth, maxDepth int,
27+
v reflect.Value, path string) {
28+
29+
if t != nil {
30+
t.Helper()
31+
}
32+
33+
if depth > maxDepth || !v.IsValid() {
34+
return
35+
}
36+
37+
indent := strings.Repeat(" ", depth)
38+
39+
log := func(format string, args ...any) {
40+
if debug {
41+
if t != nil {
42+
t.Logf(indent+format, args...)
43+
} else {
44+
fmt.Printf(indent+format+"\n", args...)
45+
}
46+
}
47+
}
48+
switch v.Kind() {
49+
case reflect.Ptr:
50+
if v.IsNil() {
51+
ptr := reflect.New(v.Type().Elem())
52+
v.Set(ptr)
53+
54+
log("ptr: %s (%s)", path, v.Type())
55+
}
56+
57+
fillFakeData(t, debug, depth+1, maxDepth, v.Elem(), path)
58+
59+
case reflect.Struct:
60+
typ := v.Type()
61+
for i := range v.NumField() {
62+
field := v.Field(i)
63+
fieldType := typ.Field(i)
64+
65+
if !field.CanSet() {
66+
continue
67+
}
68+
69+
fieldPath := fmt.Sprintf("%s.%s", path, fieldType.Name)
70+
fillFakeData(
71+
t, debug, depth+1, maxDepth, field, fieldPath,
72+
)
73+
}
74+
75+
case reflect.Slice:
76+
if v.Type().Elem().Kind() == reflect.Uint8 {
77+
// Special case: []byte.
78+
b := make([]byte, randomLen())
79+
for i := range b {
80+
b[i] = byte(rand.Intn(256))
81+
}
82+
83+
v.SetBytes(b)
84+
log("[]byte: %s = %v", path, b)
85+
86+
return
87+
}
88+
89+
elemType := v.Type().Elem()
90+
length := randomLen()
91+
slice := reflect.MakeSlice(v.Type(), length, length)
92+
93+
for i := range length {
94+
elemPath := fmt.Sprintf("%s[%d]", path, i)
95+
96+
var elem reflect.Value
97+
if elemType.Kind() == reflect.Ptr {
98+
elem = reflect.New(elemType.Elem())
99+
100+
fillFakeData(
101+
t, debug, depth+1, maxDepth,
102+
elem.Elem(), elemPath,
103+
)
104+
} else {
105+
elem = reflect.New(elemType).Elem()
106+
107+
fillFakeData(
108+
t, debug, depth+1, maxDepth, elem,
109+
elemPath,
110+
)
111+
}
112+
113+
slice.Index(i).Set(elem)
114+
}
115+
116+
v.Set(slice)
117+
log("slice: %s (len=%d)", path, length)
118+
119+
case reflect.Array:
120+
for i := range v.Len() {
121+
fillFakeData(
122+
t, debug, depth+1, maxDepth, v.Index(i),
123+
fmt.Sprintf("%s[%d]", path, i),
124+
)
125+
}
126+
127+
log("array: %s (len=%d)", path, v.Len())
128+
129+
case reflect.Map:
130+
keyType := v.Type().Key()
131+
valType := v.Type().Elem()
132+
m := reflect.MakeMap(v.Type())
133+
length := randomLen()
134+
135+
for i := range length {
136+
key := reflect.New(keyType).Elem()
137+
138+
fillFakeData(
139+
t, debug, depth+1, maxDepth, key,
140+
fmt.Sprintf("%s[key%d]", path, i),
141+
)
142+
143+
val := reflect.New(valType).Elem()
144+
145+
fillFakeData(
146+
t, debug, depth+1, maxDepth, val,
147+
fmt.Sprintf("%s[val%d]", path, i),
148+
)
149+
150+
m.SetMapIndex(key, val)
151+
}
152+
153+
v.Set(m)
154+
log("map: %s (len=%d)", path, length)
155+
156+
default:
157+
assignDummyPrimitive(t, debug, indent, v, path)
158+
}
159+
}
160+
161+
// assignDummyPrimitive assigns dummy values to primitive type values.
162+
func assignDummyPrimitive(t *testing.T, debug bool, indent string,
163+
v reflect.Value, path string) {
164+
165+
log := func(format string, args ...any) {
166+
if debug {
167+
if t != nil {
168+
t.Logf(indent+format, args...)
169+
} else {
170+
fmt.Printf(indent+format+"\n", args...)
171+
}
172+
}
173+
}
174+
175+
switch v.Kind() {
176+
case reflect.String:
177+
s := randomString()
178+
v.SetString(s)
179+
log("string: %s = %q", path, s)
180+
181+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
182+
reflect.Int64:
183+
184+
i := rand.Int63n(1_000_000)
185+
v.SetInt(i)
186+
log("int: %s = %d", path, i)
187+
188+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
189+
reflect.Uint64:
190+
191+
u := uint64(rand.Intn(1_000_000))
192+
v.SetUint(u)
193+
log("uint: %s = %d", path, u)
194+
195+
case reflect.Bool:
196+
b := rand.Intn(2) == 0
197+
v.SetBool(b)
198+
log("bool: %s = %v", path, b)
199+
200+
case reflect.Float32, reflect.Float64:
201+
f := rand.Float64() * 1_000
202+
v.SetFloat(f)
203+
log("float: %s = %f", path, f)
204+
205+
default:
206+
}
207+
}
208+
209+
func randomString() string {
210+
return fmt.Sprintf("val_%d", rand.Intn(100_000))
211+
}
212+
213+
func randomLen() int {
214+
return rand.Intn(3)
215+
}
216+
217+
// checkAliasing walks the fields and check for shared references.
218+
func checkAliasing(t *testing.T, debug, strict bool, f1, f2 reflect.Value,
219+
path string) {
220+
221+
t.Helper()
222+
223+
if !f1.IsValid() || !f2.IsValid() {
224+
return
225+
}
226+
227+
switch f1.Kind() {
228+
case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Func,
229+
reflect.Chan:
230+
231+
if f1.IsNil() || f2.IsNil() {
232+
return
233+
}
234+
235+
if f1.Pointer() == f2.Pointer() {
236+
msg := fmt.Sprintf("Aliasing detected at path: %s "+
237+
"(shared %s)", path, f1.Kind())
238+
239+
if strict {
240+
t.Fatalf(msg)
241+
}
242+
243+
if debug {
244+
t.Logf("WARNING %s", msg)
245+
}
246+
}
247+
248+
// Recurse into slice/map values.
249+
switch f1.Kind() {
250+
case reflect.Slice:
251+
for i := 0; i < f1.Len() && i < f2.Len(); i++ {
252+
checkAliasing(
253+
t, debug, strict,
254+
f1.Index(i), f2.Index(i),
255+
fmt.Sprintf("%s[%d]", path, i),
256+
)
257+
}
258+
case reflect.Map:
259+
for _, key := range f1.MapKeys() {
260+
v1 := f1.MapIndex(key)
261+
v2 := f2.MapIndex(key)
262+
checkAliasing(
263+
t, debug, strict,
264+
v1, v2, fmt.Sprintf("%s[%v]", path,
265+
key.Interface()),
266+
)
267+
}
268+
269+
default:
270+
}
271+
272+
case reflect.Struct:
273+
for i := range f1.NumField() {
274+
field := f1.Type().Field(i)
275+
276+
// Skip unexported fields.
277+
if !f1.Field(i).CanInterface() {
278+
continue
279+
}
280+
281+
childPath := fmt.Sprintf("%s.%s", path, field.Name)
282+
checkAliasing(
283+
t, debug, strict,
284+
f1.Field(i), f2.Field(i), childPath,
285+
)
286+
}
287+
288+
default:
289+
}
290+
}
291+
292+
// AssertCopyEqual checks that the Copy method returns a value that:
293+
// 1) is deeply equal
294+
// 2) does not alias mutable fields (pointers, slices, maps)
295+
func AssertCopyEqual[T fn.Copyable[T]](t *testing.T, debug, strict bool,
296+
original T) {
297+
298+
originalVal := reflect.ValueOf(original)
299+
copied := original.Copy()
300+
copiedVal := reflect.ValueOf(copied)
301+
302+
if !reflect.DeepEqual(original, copied) {
303+
diff := difflib.UnifiedDiff{
304+
A: difflib.SplitLines(
305+
spew.Sdump(original),
306+
),
307+
B: difflib.SplitLines(
308+
spew.Sdump(copied),
309+
),
310+
FromFile: "Original",
311+
FromDate: "",
312+
ToFile: "Copied",
313+
ToDate: "",
314+
Context: 3,
315+
}
316+
diffText, _ := difflib.GetUnifiedDiffString(diff)
317+
318+
t.Fatalf("Copied value is not deeply equal to the orginal:\n%v",
319+
diffText)
320+
}
321+
322+
if originalVal.Kind() == reflect.Ptr {
323+
originalVal = originalVal.Elem()
324+
copiedVal = copiedVal.Elem()
325+
}
326+
327+
for i := range originalVal.NumField() {
328+
f1 := originalVal.Field(i)
329+
f2 := copiedVal.Field(i)
330+
name := originalVal.Type().Field(i).Name
331+
332+
checkAliasing(t, debug, strict, f1, f2, name)
333+
}
334+
}

tapfreighter/copy_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package tapfreighter
2+
3+
import (
4+
"testing"
5+
6+
"github.com/lightninglabs/taproot-assets/internal/test"
7+
)
8+
9+
// TestOutboundParcelCopy tests that OutboundParcel.Copy() works as expected.
10+
func TestOutboundParcelCopy(t *testing.T) {
11+
// Set to true to debug print.
12+
debug := false
13+
14+
// Please set the depth value carefully. Sometimes our copy functions
15+
// are deeply nested in other packages and do not need changes. Often
16+
// types are recursive and too deep copy may end up in stack-overlow.
17+
const maxDepth = 5
18+
p := &OutboundParcel{}
19+
test.FillFakeData(t, debug, maxDepth, p)
20+
21+
// We allow aliasing here deep down (for now).
22+
strict := false
23+
test.AssertCopyEqual(t, debug, strict, p)
24+
}

0 commit comments

Comments
 (0)