diff --git a/requests_test.go b/requests_test.go index b8a740d..78fd4b0 100644 --- a/requests_test.go +++ b/requests_test.go @@ -189,9 +189,9 @@ func TestDownload(t *testing.T) { return } defer f.Close() - sum := 0 + sum, cnt := 0, 0 _ = Stream(func(i int64, row []byte) error { - cnt, err := f.Write(row) + cnt, err = f.Write(row) sum += cnt return err }) diff --git a/server_test.go b/server_test.go index f5a13ef..1f2abeb 100644 --- a/server_test.go +++ b/server_test.go @@ -514,7 +514,7 @@ func generateTestCert(certFile, keyFile string) error { return err } defer certOut.Close() - if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { + if err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { return err } diff --git a/session.go b/session.go index d84d476..bc60e5c 100644 --- a/session.go +++ b/session.go @@ -3,7 +3,6 @@ package requests import ( "bytes" "context" - "fmt" "io" "net/http" ) @@ -49,7 +48,7 @@ func (s *Session) Do(ctx context.Context, opts ...Option) (*http.Response, error options := newOptions(s.opts, opts...) req, err := NewRequestWithContext(ctx, options) if err != nil { - return &http.Response{}, fmt.Errorf("newRequest: %w", err) + return &http.Response{}, err } return s.RoundTripper(opts...).RoundTrip(req) } diff --git a/socket.go b/socket.go new file mode 100644 index 0000000..96690a0 --- /dev/null +++ b/socket.go @@ -0,0 +1,31 @@ +package requests + +import ( + "context" + "net" + "net/url" + "time" +) + +// Socket socket .. +func socket(ctx context.Context, src net.Addr, network, address string, timeout time.Duration) (net.Conn, error) { + dialer := net.Dialer{ + Timeout: timeout, // TCP connection timeout + KeepAlive: 60 * time.Second, // TCP keepalive interval + LocalAddr: src, // Local address binding + Resolver: &net.Resolver{ // DNS resolver configuration + PreferGo: true, // Prefer Go's DNS resolver + StrictErrors: false, // Tolerate DNS resolution errors + }, + } + return dialer.DialContext(ctx, network, address) +} + +func Socket(ctx context.Context, opts ...Option) (net.Conn, error) { + options := newOptions(opts) + u, err := url.Parse(options.URL) + if err != nil { + return nil, err + } + return socket(ctx, options.LocalAddr, u.Scheme, u.Host, options.Timeout) +} diff --git a/socket_test.go b/socket_test.go new file mode 100644 index 0000000..ef2bb39 --- /dev/null +++ b/socket_test.go @@ -0,0 +1,170 @@ +package requests + +import ( + "context" + "errors" + "net" + "strings" + "testing" + "time" +) + +func TestSocket(t *testing.T) { + // 启动测试服务器 + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("创建监听器失败: %v", err) + } + defer listener.Close() + // 在后台接受连接 + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + defer conn.Close() + // 简单的回显服务 + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + conn.Write(buf[:n]) + }() + + tests := []struct { + name string + opts []Option + wantErr bool + }{ + { + name: "TCP正常连接", + opts: []Option{ + URL("tcp://" + listener.Addr().String()), + Timeout(time.Second), + }, + wantErr: false, + }, + { + name: "无效URL", + opts: []Option{ + URL("invalid://localhost"), + Timeout(time.Second), + }, + wantErr: true, + }, + { + name: "错误URL", + opts: []Option{ + URL("://:::"), + Timeout(time.Second), + }, + wantErr: true, + }, + { + name: "连接超时", + opts: []Option{ + URL("tcp://240.0.0.1:12345"), // 不可达的地址 + Timeout(1), + }, + wantErr: true, + }, + { + name: "Unix socket连接", + opts: []Option{ + URL("unix:///tmp/test.sock"), + Timeout(time.Second), + }, + wantErr: true, // Unix socket文件不存在,应该失败 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + conn, err := Socket(ctx, tt.opts...) + if (err != nil) != tt.wantErr { + t.Errorf("Socket() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err == nil { + defer conn.Close() + // 测试连接是否可用 + _, err = conn.Write([]byte("test")) + if err != nil { + t.Errorf("写入数据失败: %v", err) + } + buf := make([]byte, 4) + _, err = conn.Read(buf) + if err != nil { + t.Errorf("读取数据失败: %v", err) + } + if string(buf) != "test" { + t.Errorf("期望读取到 'test',得到 %s", string(buf)) + } + } + }) + } +} + +func TestSocket_ContextCancel(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("创建监听器失败: %v", err) + } + defer listener.Close() + // 在后台接受连接 + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + defer conn.Close() + // 简单的回显服务 + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + conn.Write(buf[:n]) + }() + // 创建一个可取消的上下文 + ctx, cancel := context.WithCancel(context.Background()) + + // 立即取消 + cancel() + + // 尝试建立连接 + if _, err = Socket(ctx, URL("tcp://"+listener.Addr().String())); errors.Is(err, context.Canceled) { + t.Log(err) + return + } + t.Errorf("期望错误为 context.Canceled,得到 %v", err) + +} + +func TestSocket_WithCustomDialer(t *testing.T) { + // 测试自定义本地地址 + localAddr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, // 系统自动分配端口 + } + + // 启动测试服务器 + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("创建监听器失败: %v", err) + } + defer listener.Close() + + // 使用自定义本地地址建立连接 + conn, err := Socket(context.Background(), + URL("tcp://"+listener.Addr().String()), + LocalAddr(localAddr), + ) + + if err != nil { + t.Fatalf("建立连接失败: %v", err) + } + defer conn.Close() + + // 验证连接的本地地址 + localAddrStr := conn.LocalAddr().String() + if !strings.Contains(localAddrStr, "127.0.0.1") { + t.Errorf("期望本地地址为 127.0.0.1,得到 %s", localAddrStr) + } +} diff --git a/transport.go b/transport.go index 46161ad..99f4114 100644 --- a/transport.go +++ b/transport.go @@ -49,18 +49,7 @@ func newTransport(opts ...Option) *http.Transport { } network, addr = u.Scheme, u.Path } - - // Configure dialer parameters - dialer := net.Dialer{ - Timeout: 10 * time.Second, // TCP connection timeout - KeepAlive: 60 * time.Second, // TCP keepalive interval - LocalAddr: options.LocalAddr, // Local address binding - Resolver: &net.Resolver{ // DNS resolver configuration - PreferGo: true, // Prefer Go's DNS resolver - StrictErrors: false, // Tolerate DNS resolution errors - }, - } - return dialer.DialContext(ctx, network, addr) + return socket(ctx, options.LocalAddr, network, addr, 10*time.Second) }, // Connection pool configuration diff --git a/transport_test.go b/transport_test.go index 554ca08..575a948 100644 --- a/transport_test.go +++ b/transport_test.go @@ -98,6 +98,16 @@ func TestNewTransport(t *testing.T) { } }, }, + { + name: "Unix套接字", + opts: []Option{URL("unix://:::")}, + test: func(t *testing.T, tr *http.Transport) { + _, err := tr.DialContext(context.Background(), "unix", ":::") + if err == nil { + t.Error("期望Unix套接字连接失败") + } + }, + }, { name: "TLS配置", opts: []Option{Verify(false)},