Skip to content

feat: socket #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions requests_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,9 @@ func TestDownload(t *testing.T) {
return
}
defer f.Close()
sum := 0
sum, cnt := 0, 0
_ = Stream(func(i int64, row []byte) error {
cnt, err := f.Write(row)
cnt, err = f.Write(row)
sum += cnt
return err
})
Expand Down
2 changes: 1 addition & 1 deletion server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ func generateTestCert(certFile, keyFile string) error {
return err
}
defer certOut.Close()
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
if err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
return err
}

Expand Down
3 changes: 1 addition & 2 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package requests
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
)
Expand Down Expand Up @@ -49,7 +48,7 @@ func (s *Session) Do(ctx context.Context, opts ...Option) (*http.Response, error
options := newOptions(s.opts, opts...)
req, err := NewRequestWithContext(ctx, options)
if err != nil {
return &http.Response{}, fmt.Errorf("newRequest: %w", err)
return &http.Response{}, err
}
return s.RoundTripper(opts...).RoundTrip(req)
}
Expand Down
31 changes: 31 additions & 0 deletions socket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package requests

import (
"context"
"net"
"net/url"
"time"
)

// Socket socket ..
func socket(ctx context.Context, src net.Addr, network, address string, timeout time.Duration) (net.Conn, error) {
dialer := net.Dialer{
Timeout: timeout, // TCP connection timeout
KeepAlive: 60 * time.Second, // TCP keepalive interval
LocalAddr: src, // Local address binding
Resolver: &net.Resolver{ // DNS resolver configuration
PreferGo: true, // Prefer Go's DNS resolver
StrictErrors: false, // Tolerate DNS resolution errors
},
}
return dialer.DialContext(ctx, network, address)
}

func Socket(ctx context.Context, opts ...Option) (net.Conn, error) {
options := newOptions(opts)
u, err := url.Parse(options.URL)
if err != nil {
return nil, err
}
return socket(ctx, options.LocalAddr, u.Scheme, u.Host, options.Timeout)
}
170 changes: 170 additions & 0 deletions socket_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package requests

import (
"context"
"errors"
"net"
"strings"
"testing"
"time"
)

func TestSocket(t *testing.T) {
// 启动测试服务器
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建监听器失败: %v", err)
}
defer listener.Close()
// 在后台接受连接
go func() {
conn, err := listener.Accept()
if err != nil {
return
}
defer conn.Close()
// 简单的回显服务
buf := make([]byte, 1024)
n, _ := conn.Read(buf)
conn.Write(buf[:n])
}()

tests := []struct {
name string
opts []Option
wantErr bool
}{
{
name: "TCP正常连接",
opts: []Option{
URL("tcp://" + listener.Addr().String()),
Timeout(time.Second),
},
wantErr: false,
},
{
name: "无效URL",
opts: []Option{
URL("invalid://localhost"),
Timeout(time.Second),
},
wantErr: true,
},
{
name: "错误URL",
opts: []Option{
URL("://:::"),
Timeout(time.Second),
},
wantErr: true,
},
{
name: "连接超时",
opts: []Option{
URL("tcp://240.0.0.1:12345"), // 不可达的地址
Timeout(1),
},
wantErr: true,
},
{
name: "Unix socket连接",
opts: []Option{
URL("unix:///tmp/test.sock"),
Timeout(time.Second),
},
wantErr: true, // Unix socket文件不存在,应该失败
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
conn, err := Socket(ctx, tt.opts...)
if (err != nil) != tt.wantErr {
t.Errorf("Socket() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil {
defer conn.Close()
// 测试连接是否可用
_, err = conn.Write([]byte("test"))
if err != nil {
t.Errorf("写入数据失败: %v", err)
}
buf := make([]byte, 4)
_, err = conn.Read(buf)
if err != nil {
t.Errorf("读取数据失败: %v", err)
}
if string(buf) != "test" {
t.Errorf("期望读取到 'test',得到 %s", string(buf))
}
}
})
}
}

func TestSocket_ContextCancel(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建监听器失败: %v", err)
}
defer listener.Close()
// 在后台接受连接
go func() {
conn, err := listener.Accept()
if err != nil {
return
}
defer conn.Close()
// 简单的回显服务
buf := make([]byte, 1024)
n, _ := conn.Read(buf)
conn.Write(buf[:n])
}()
// 创建一个可取消的上下文
ctx, cancel := context.WithCancel(context.Background())

// 立即取消
cancel()

// 尝试建立连接
if _, err = Socket(ctx, URL("tcp://"+listener.Addr().String())); errors.Is(err, context.Canceled) {
t.Log(err)
return
}
t.Errorf("期望错误为 context.Canceled,得到 %v", err)

}

func TestSocket_WithCustomDialer(t *testing.T) {
// 测试自定义本地地址
localAddr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0, // 系统自动分配端口
}

// 启动测试服务器
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建监听器失败: %v", err)
}
defer listener.Close()

// 使用自定义本地地址建立连接
conn, err := Socket(context.Background(),
URL("tcp://"+listener.Addr().String()),
LocalAddr(localAddr),
)

if err != nil {
t.Fatalf("建立连接失败: %v", err)
}
defer conn.Close()

// 验证连接的本地地址
localAddrStr := conn.LocalAddr().String()
if !strings.Contains(localAddrStr, "127.0.0.1") {
t.Errorf("期望本地地址为 127.0.0.1,得到 %s", localAddrStr)
}
}
13 changes: 1 addition & 12 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,7 @@ func newTransport(opts ...Option) *http.Transport {
}
network, addr = u.Scheme, u.Path
}

// Configure dialer parameters
dialer := net.Dialer{
Timeout: 10 * time.Second, // TCP connection timeout
KeepAlive: 60 * time.Second, // TCP keepalive interval
LocalAddr: options.LocalAddr, // Local address binding
Resolver: &net.Resolver{ // DNS resolver configuration
PreferGo: true, // Prefer Go's DNS resolver
StrictErrors: false, // Tolerate DNS resolution errors
},
}
return dialer.DialContext(ctx, network, addr)
return socket(ctx, options.LocalAddr, network, addr, 10*time.Second)
},

// Connection pool configuration
Expand Down
10 changes: 10 additions & 0 deletions transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ func TestNewTransport(t *testing.T) {
}
},
},
{
name: "Unix套接字",
opts: []Option{URL("unix://:::")},
test: func(t *testing.T, tr *http.Transport) {
_, err := tr.DialContext(context.Background(), "unix", ":::")
if err == nil {
t.Error("期望Unix套接字连接失败")
}
},
},
{
name: "TLS配置",
opts: []Option{Verify(false)},
Expand Down