Skip to content

Commit 0ef8cd4

Browse files
authored
srv tests added (#4)
1 parent 6b33b43 commit 0ef8cd4

File tree

3 files changed

+204
-12
lines changed

3 files changed

+204
-12
lines changed

pkg/handlers/parser/parser.go

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,24 +115,23 @@ func parsePart(part string) handlers.Handler {
115115
return ret
116116
}
117117

118-
modifiers := parseModifiers()
119118
switch parts[0] {
120119
case defaultip.ShortName, defaultip.Name:
121-
return defaultip.NewHandler(modifiers...)
120+
return defaultip.NewHandler(parseModifiers()...)
122121
case cname.ShortName, cname.Name:
123-
return cname.NewHandler(modifiers...)
122+
return cname.NewHandler(parseModifiers()...)
124123
case proxy.ShortName, proxy.Name:
125-
return proxy.NewHandler(modifiers...)
124+
return proxy.NewHandler(parseModifiers()...)
126125
case random.ShortName, random.Name:
127-
return random.NewHandler(modifiers...)
126+
return random.NewHandler(parseModifiers()...)
128127
case loop.ShortName, loop.Name:
129-
return loop.NewHandler(modifiers...)
128+
return loop.NewHandler(parseModifiers()...)
130129
case sticky.ShortName, sticky.Name:
131-
return sticky.NewHandler(modifiers...)
130+
return sticky.NewHandler(parseModifiers()...)
132131
case ipv4.ShortName, ipv4.Name:
133-
return ipv4.NewHandler(modifiers...)
132+
return ipv4.NewHandler(parseModifiers()...)
134133
case ipv6.ShortName, ipv6.Name:
135-
return ipv6.NewHandler(modifiers...)
134+
return ipv6.NewHandler(parseModifiers()...)
136135
default:
137136
return parseIPHandler(part)
138137
}

pkg/nssrv/server_test.go

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
package nssrv_test
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"github.com/buglloc/rip/v2/pkg/cfg"
7+
"github.com/buglloc/rip/v2/pkg/nssrv"
8+
"github.com/miekg/dns"
9+
"github.com/stretchr/testify/require"
10+
"net"
11+
"testing"
12+
"time"
13+
)
14+
15+
func getFreePort() (int, error) {
16+
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
17+
if err != nil {
18+
return 0, err
19+
}
20+
21+
l, err := net.ListenTCP("tcp", addr)
22+
if err != nil {
23+
return 0, err
24+
}
25+
defer func() { _ = l.Close() }()
26+
27+
return l.Addr().(*net.TCPAddr).Port, nil
28+
}
29+
30+
func newSRV(t *testing.T) *nssrv.NSSrv {
31+
port, err := getFreePort()
32+
require.NoError(t, err)
33+
34+
cfg.Zones = []string{"tst"}
35+
cfg.Addr = fmt.Sprintf("localhost:%d", port)
36+
srv, err := nssrv.NewSrv()
37+
require.NoError(t, err)
38+
39+
go func() {
40+
err = srv.ListenAndServe()
41+
}()
42+
43+
// TODO((buglloc): too ugly
44+
time.Sleep(1 * time.Second)
45+
if err != nil {
46+
_ = srv.Shutdown(context.Background())
47+
require.NoError(t, err)
48+
}
49+
50+
return srv
51+
}
52+
53+
func resolve(t *testing.T, client *dns.Client, msg *dns.Msg) net.IP {
54+
res, _, err := client.Exchange(msg, cfg.Addr)
55+
require.NoError(t, err)
56+
require.NotEmpty(t, res.Answer)
57+
58+
var ip net.IP
59+
switch v := res.Answer[0].(type) {
60+
case *dns.A:
61+
ip = v.A.To4()
62+
case *dns.AAAA:
63+
ip = v.AAAA.To16()
64+
}
65+
66+
return ip
67+
}
68+
69+
func TestServer_simple(t *testing.T) {
70+
cases := []struct {
71+
in string
72+
reqType uint16
73+
ip net.IP
74+
}{
75+
{
76+
in: "1-1-1-1.4.tst",
77+
reqType: dns.TypeA,
78+
ip: net.ParseIP("1.1.1.1").To4(),
79+
},
80+
{
81+
in: "1-1-1-1.v4.tst",
82+
reqType: dns.TypeA,
83+
ip: net.ParseIP("1.1.1.1").To4(),
84+
},
85+
{
86+
in: "1-1-1-1.v4.tst",
87+
reqType: dns.TypeA,
88+
ip: net.ParseIP("1.1.1.1").To4(),
89+
},
90+
{
91+
in: "fe80--fa94-c2ff-fee5-3cf6.6.tst",
92+
reqType: dns.TypeAAAA,
93+
ip: net.ParseIP("fe80::fa94:c2ff:fee5:3cf6").To16(),
94+
},
95+
{
96+
in: "fe80000000000000fa94c2fffee53cf6.v6.tst",
97+
reqType: dns.TypeAAAA,
98+
ip: net.ParseIP("fe80::fa94:c2ff:fee5:3cf6").To16(),
99+
},
100+
{
101+
in: "2-2-2-2.3-3-3-3.4.l.tst",
102+
reqType: dns.TypeA,
103+
ip: net.ParseIP("3.3.3.3").To4(),
104+
},
105+
{
106+
in: "2-2-2-2.3-3-3-3.4.s.tst",
107+
reqType: dns.TypeA,
108+
ip: net.ParseIP("3.3.3.3").To4(),
109+
},
110+
}
111+
112+
srv := newSRV(t)
113+
defer func() { _ = srv.Shutdown(context.Background()) }()
114+
115+
client := &dns.Client{
116+
Net: "tcp",
117+
ReadTimeout: time.Second * 1,
118+
WriteTimeout: time.Second * 1,
119+
}
120+
for _, tc := range cases {
121+
t.Run(tc.in, func(t *testing.T) {
122+
msg := &dns.Msg{}
123+
msg.SetQuestion(dns.Fqdn(tc.in), tc.reqType)
124+
ip := resolve(t, client, msg)
125+
require.Equal(t, tc.ip, ip)
126+
})
127+
}
128+
}
129+
130+
func TestServer_loop(t *testing.T) {
131+
srv := newSRV(t)
132+
defer func() { _ = srv.Shutdown(context.Background()) }()
133+
134+
client := &dns.Client{
135+
Net: "tcp",
136+
ReadTimeout: time.Second * 1,
137+
WriteTimeout: time.Second * 1,
138+
}
139+
140+
msg := &dns.Msg{}
141+
msg.SetQuestion(dns.Fqdn("1-1-1-1.v4.2-2-2-2.v4.loop.tst"), dns.TypeA)
142+
ip := resolve(t, client, msg)
143+
require.Equal(t, net.ParseIP("2.2.2.2").To4(), ip)
144+
ip = resolve(t, client, msg)
145+
require.Equal(t, net.ParseIP("1.1.1.1").To4(), ip)
146+
ip = resolve(t, client, msg)
147+
require.Equal(t, net.ParseIP("2.2.2.2").To4(), ip)
148+
}
149+
150+
func TestServer_multiLoop(t *testing.T) {
151+
srv := newSRV(t)
152+
defer func() { _ = srv.Shutdown(context.Background()) }()
153+
154+
client := &dns.Client{
155+
Net: "tcp",
156+
ReadTimeout: time.Second * 1,
157+
WriteTimeout: time.Second * 1,
158+
}
159+
160+
msg := &dns.Msg{}
161+
msg.SetQuestion(dns.Fqdn("1-1-1-1.v4.2-2-2-2.v4.loop-cnt-2.3-3-3-3.v4.loop.tst"), dns.TypeA)
162+
ip := resolve(t, client, msg)
163+
require.Equal(t, net.ParseIP("3.3.3.3").To4(), ip)
164+
ip = resolve(t, client, msg)
165+
require.Equal(t, net.ParseIP("2.2.2.2").To4(), ip)
166+
ip = resolve(t, client, msg)
167+
require.Equal(t, net.ParseIP("1.1.1.1").To4(), ip)
168+
ip = resolve(t, client, msg)
169+
require.Equal(t, net.ParseIP("3.3.3.3").To4(), ip)
170+
}
171+
172+
func TestServer_multiLoopWithTTL(t *testing.T) {
173+
srv := newSRV(t)
174+
defer func() { _ = srv.Shutdown(context.Background()) }()
175+
176+
client := &dns.Client{
177+
Net: "tcp",
178+
ReadTimeout: time.Second * 1,
179+
WriteTimeout: time.Second * 1,
180+
}
181+
182+
msg := &dns.Msg{}
183+
msg.SetQuestion(dns.Fqdn("1-1-1-1.v4.2-2-2-2.v4.loop-ttl-20s.3-3-3-3.v4.loop.tst"), dns.TypeA)
184+
ip := resolve(t, client, msg)
185+
require.Equal(t, net.ParseIP("3.3.3.3").To4(), ip)
186+
ip = resolve(t, client, msg)
187+
require.Equal(t, net.ParseIP("2.2.2.2").To4(), ip)
188+
ip = resolve(t, client, msg)
189+
require.Equal(t, net.ParseIP("1.1.1.1").To4(), ip)
190+
ip = resolve(t, client, msg)
191+
require.Equal(t, net.ParseIP("2.2.2.2").To4(), ip)
192+
ip = resolve(t, client, msg)
193+
require.Equal(t, net.ParseIP("1.1.1.1").To4(), ip)
194+
}

pkg/resolver/upstreams.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,11 @@ func ResolveIp(reqType uint16, name string) ([]net.IP, error) {
4242
}
4343

4444
ttl := time.Duration(res.Answer[0].(dns.RR).Header().Ttl) * time.Second
45-
dnsCache.Set(dns.TypeA, name, ttl, ipv4)
46-
dnsCache.Set(dns.TypeAAAA, name, ttl, ipv6)
47-
4845
if reqType == dns.TypeA {
46+
dnsCache.Set(dns.TypeA, name, ttl, ipv4)
4947
return ipv4, nil
5048
}
5149

50+
dnsCache.Set(dns.TypeAAAA, name, ttl, ipv6)
5251
return ipv6, nil
5352
}

0 commit comments

Comments
 (0)