Skip to content

Commit bfda6a9

Browse files
authored
Merge pull request #15 from golang-io/dev
feat: socket
2 parents c9e9d49 + a2cd7bb commit bfda6a9

File tree

7 files changed

+216
-17
lines changed

7 files changed

+216
-17
lines changed

requests_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,9 @@ func TestDownload(t *testing.T) {
189189
return
190190
}
191191
defer f.Close()
192-
sum := 0
192+
sum, cnt := 0, 0
193193
_ = Stream(func(i int64, row []byte) error {
194-
cnt, err := f.Write(row)
194+
cnt, err = f.Write(row)
195195
sum += cnt
196196
return err
197197
})

server_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ func generateTestCert(certFile, keyFile string) error {
514514
return err
515515
}
516516
defer certOut.Close()
517-
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
517+
if err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
518518
return err
519519
}
520520

session.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package requests
33
import (
44
"bytes"
55
"context"
6-
"fmt"
76
"io"
87
"net/http"
98
)
@@ -49,7 +48,7 @@ func (s *Session) Do(ctx context.Context, opts ...Option) (*http.Response, error
4948
options := newOptions(s.opts, opts...)
5049
req, err := NewRequestWithContext(ctx, options)
5150
if err != nil {
52-
return &http.Response{}, fmt.Errorf("newRequest: %w", err)
51+
return &http.Response{}, err
5352
}
5453
return s.RoundTripper(opts...).RoundTrip(req)
5554
}

socket.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package requests
2+
3+
import (
4+
"context"
5+
"net"
6+
"net/url"
7+
"time"
8+
)
9+
10+
// Socket socket ..
11+
func socket(ctx context.Context, src net.Addr, network, address string, timeout time.Duration) (net.Conn, error) {
12+
dialer := net.Dialer{
13+
Timeout: timeout, // TCP connection timeout
14+
KeepAlive: 60 * time.Second, // TCP keepalive interval
15+
LocalAddr: src, // Local address binding
16+
Resolver: &net.Resolver{ // DNS resolver configuration
17+
PreferGo: true, // Prefer Go's DNS resolver
18+
StrictErrors: false, // Tolerate DNS resolution errors
19+
},
20+
}
21+
return dialer.DialContext(ctx, network, address)
22+
}
23+
24+
func Socket(ctx context.Context, opts ...Option) (net.Conn, error) {
25+
options := newOptions(opts)
26+
u, err := url.Parse(options.URL)
27+
if err != nil {
28+
return nil, err
29+
}
30+
return socket(ctx, options.LocalAddr, u.Scheme, u.Host, options.Timeout)
31+
}

socket_test.go

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
package requests
2+
3+
import (
4+
"context"
5+
"errors"
6+
"net"
7+
"strings"
8+
"testing"
9+
"time"
10+
)
11+
12+
func TestSocket(t *testing.T) {
13+
// 启动测试服务器
14+
listener, err := net.Listen("tcp", "127.0.0.1:0")
15+
if err != nil {
16+
t.Fatalf("创建监听器失败: %v", err)
17+
}
18+
defer listener.Close()
19+
// 在后台接受连接
20+
go func() {
21+
conn, err := listener.Accept()
22+
if err != nil {
23+
return
24+
}
25+
defer conn.Close()
26+
// 简单的回显服务
27+
buf := make([]byte, 1024)
28+
n, _ := conn.Read(buf)
29+
conn.Write(buf[:n])
30+
}()
31+
32+
tests := []struct {
33+
name string
34+
opts []Option
35+
wantErr bool
36+
}{
37+
{
38+
name: "TCP正常连接",
39+
opts: []Option{
40+
URL("tcp://" + listener.Addr().String()),
41+
Timeout(time.Second),
42+
},
43+
wantErr: false,
44+
},
45+
{
46+
name: "无效URL",
47+
opts: []Option{
48+
URL("invalid://localhost"),
49+
Timeout(time.Second),
50+
},
51+
wantErr: true,
52+
},
53+
{
54+
name: "错误URL",
55+
opts: []Option{
56+
URL("://:::"),
57+
Timeout(time.Second),
58+
},
59+
wantErr: true,
60+
},
61+
{
62+
name: "连接超时",
63+
opts: []Option{
64+
URL("tcp://240.0.0.1:12345"), // 不可达的地址
65+
Timeout(1),
66+
},
67+
wantErr: true,
68+
},
69+
{
70+
name: "Unix socket连接",
71+
opts: []Option{
72+
URL("unix:///tmp/test.sock"),
73+
Timeout(time.Second),
74+
},
75+
wantErr: true, // Unix socket文件不存在,应该失败
76+
},
77+
}
78+
79+
for _, tt := range tests {
80+
t.Run(tt.name, func(t *testing.T) {
81+
ctx := context.Background()
82+
conn, err := Socket(ctx, tt.opts...)
83+
if (err != nil) != tt.wantErr {
84+
t.Errorf("Socket() error = %v, wantErr %v", err, tt.wantErr)
85+
return
86+
}
87+
if err == nil {
88+
defer conn.Close()
89+
// 测试连接是否可用
90+
_, err = conn.Write([]byte("test"))
91+
if err != nil {
92+
t.Errorf("写入数据失败: %v", err)
93+
}
94+
buf := make([]byte, 4)
95+
_, err = conn.Read(buf)
96+
if err != nil {
97+
t.Errorf("读取数据失败: %v", err)
98+
}
99+
if string(buf) != "test" {
100+
t.Errorf("期望读取到 'test',得到 %s", string(buf))
101+
}
102+
}
103+
})
104+
}
105+
}
106+
107+
func TestSocket_ContextCancel(t *testing.T) {
108+
listener, err := net.Listen("tcp", "127.0.0.1:0")
109+
if err != nil {
110+
t.Fatalf("创建监听器失败: %v", err)
111+
}
112+
defer listener.Close()
113+
// 在后台接受连接
114+
go func() {
115+
conn, err := listener.Accept()
116+
if err != nil {
117+
return
118+
}
119+
defer conn.Close()
120+
// 简单的回显服务
121+
buf := make([]byte, 1024)
122+
n, _ := conn.Read(buf)
123+
conn.Write(buf[:n])
124+
}()
125+
// 创建一个可取消的上下文
126+
ctx, cancel := context.WithCancel(context.Background())
127+
128+
// 立即取消
129+
cancel()
130+
131+
// 尝试建立连接
132+
if _, err = Socket(ctx, URL("tcp://"+listener.Addr().String())); errors.Is(err, context.Canceled) {
133+
t.Log(err)
134+
return
135+
}
136+
t.Errorf("期望错误为 context.Canceled,得到 %v", err)
137+
138+
}
139+
140+
func TestSocket_WithCustomDialer(t *testing.T) {
141+
// 测试自定义本地地址
142+
localAddr := &net.TCPAddr{
143+
IP: net.ParseIP("127.0.0.1"),
144+
Port: 0, // 系统自动分配端口
145+
}
146+
147+
// 启动测试服务器
148+
listener, err := net.Listen("tcp", "127.0.0.1:0")
149+
if err != nil {
150+
t.Fatalf("创建监听器失败: %v", err)
151+
}
152+
defer listener.Close()
153+
154+
// 使用自定义本地地址建立连接
155+
conn, err := Socket(context.Background(),
156+
URL("tcp://"+listener.Addr().String()),
157+
LocalAddr(localAddr),
158+
)
159+
160+
if err != nil {
161+
t.Fatalf("建立连接失败: %v", err)
162+
}
163+
defer conn.Close()
164+
165+
// 验证连接的本地地址
166+
localAddrStr := conn.LocalAddr().String()
167+
if !strings.Contains(localAddrStr, "127.0.0.1") {
168+
t.Errorf("期望本地地址为 127.0.0.1,得到 %s", localAddrStr)
169+
}
170+
}

transport.go

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,7 @@ func newTransport(opts ...Option) *http.Transport {
4949
}
5050
network, addr = u.Scheme, u.Path
5151
}
52-
53-
// Configure dialer parameters
54-
dialer := net.Dialer{
55-
Timeout: 10 * time.Second, // TCP connection timeout
56-
KeepAlive: 60 * time.Second, // TCP keepalive interval
57-
LocalAddr: options.LocalAddr, // Local address binding
58-
Resolver: &net.Resolver{ // DNS resolver configuration
59-
PreferGo: true, // Prefer Go's DNS resolver
60-
StrictErrors: false, // Tolerate DNS resolution errors
61-
},
62-
}
63-
return dialer.DialContext(ctx, network, addr)
52+
return socket(ctx, options.LocalAddr, network, addr, 10*time.Second)
6453
},
6554

6655
// Connection pool configuration

transport_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,16 @@ func TestNewTransport(t *testing.T) {
9898
}
9999
},
100100
},
101+
{
102+
name: "Unix套接字",
103+
opts: []Option{URL("unix://:::")},
104+
test: func(t *testing.T, tr *http.Transport) {
105+
_, err := tr.DialContext(context.Background(), "unix", ":::")
106+
if err == nil {
107+
t.Error("期望Unix套接字连接失败")
108+
}
109+
},
110+
},
101111
{
102112
name: "TLS配置",
103113
opts: []Option{Verify(false)},

0 commit comments

Comments
 (0)