diff --git a/.gitignore b/.gitignore index cad7f12..4f021ce 100644 --- a/.gitignore +++ b/.gitignore @@ -21,4 +21,5 @@ go.work vendor/ -.idea/ \ No newline at end of file +.idea/ +coverage.txt \ No newline at end of file diff --git a/response_writer_test.go b/response_writer_test.go index 3b9a636..26896e3 100644 --- a/response_writer_test.go +++ b/response_writer_test.go @@ -13,19 +13,27 @@ import ( ) // mockResponseWriter implements http.ResponseWriter for testing +// 更新 mockResponseWriter 以支持错误测试 type mockResponseWriter struct { headers http.Header statuscode int body bytes.Buffer + writeError error // 添加这个字段 +} + +func (m *mockResponseWriter) Write(b []byte) (int, error) { + if m.writeError != nil { + return 0, m.writeError + } + return m.body.Write(b) } func newMockResponseWriter() *mockResponseWriter { return &mockResponseWriter{headers: make(http.Header)} } -func (m *mockResponseWriter) Header() http.Header { return m.headers } -func (m *mockResponseWriter) Write(b []byte) (int, error) { return m.body.Write(b) } -func (m *mockResponseWriter) WriteHeader(code int) { m.statuscode = code } +func (m *mockResponseWriter) Header() http.Header { return m.headers } +func (m *mockResponseWriter) WriteHeader(code int) { m.statuscode = code } // TestResponseWriterBasic tests basic functionality of ResponseWriter func TestResponseWriterBasic(t *testing.T) { @@ -208,3 +216,94 @@ func TestResponseWriterFlush(t *testing.T) { t.Errorf("Expected chunk2, got %s, err: %v", chunk2, err) } } + +// TestResponseWriterRead 测试 Read 方法 +func TestResponseWriterRead(t *testing.T) { + mock := newMockResponseWriter() + w := newResponseWriter(mock) + + // 写入测试数据 + testData := []byte("test data for reading") + _, err := w.Write(testData) + if err != nil { + t.Fatalf("Failed to write test data: %v", err) + } + + // 测试读取 + buf := make([]byte, len(testData)) + n, err := w.Read(buf) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + if n != len(testData) { + t.Errorf("Expected to read %d bytes, got %d", len(testData), n) + } + if string(buf) != string(testData) { + t.Errorf("Expected to read '%s', got '%s'", string(testData), string(buf)) + } + + // 测试读取完后的EOF + n, err = w.Read(buf) + if err != io.EOF { + t.Errorf("Expected EOF after reading all data, got %v", err) + } +} + +// TestResponseWriterPush 测试 HTTP/2 Push 功能 +func TestResponseWriterPush(t *testing.T) { + // 创建支持 HTTP/2 的测试服务器 + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rw := newResponseWriter(w) + err := rw.Push("/style.css", &http.PushOptions{ + Method: "GET", + Header: http.Header{ + "Content-Type": []string{"text/css"}, + }, + }) + if err != nil && err != http.ErrNotSupported { + t.Errorf("Push failed: %v", err) + } + rw.Write([]byte("main content")) + }) + + server := httptest.NewUnstartedServer(handler) + server.EnableHTTP2 = true + server.StartTLS() + defer server.Close() + + // 发起请求 + client := server.Client() + resp, err := client.Get(server.URL) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response: %v", err) + } + if string(body) != "main content" { + t.Errorf("Expected 'main content', got '%s'", string(body)) + } +} + +// TestResponseWriterWriteError 测试 Write 方法的错误处理 +func TestResponseWriterWriteError(t *testing.T) { + // 创建一个会返回错误的 mock + errMock := &mockResponseWriter{ + headers: make(http.Header), + writeError: fmt.Errorf("write error"), + } + + w := newResponseWriter(errMock) + + // 测试写入错误 + n, err := w.Write([]byte("test")) + if err == nil { + t.Error("Expected write error, got nil") + } + if n != 0 { + t.Errorf("Expected 0 bytes written on error, got %d", n) + } +} diff --git a/transport.go b/transport.go index 6366744..46161ad 100644 --- a/transport.go +++ b/transport.go @@ -78,51 +78,3 @@ func newTransport(opts ...Option) *http.Transport { }, } } - -// RoundTrip implements the RoundTripper interface. -// It processes requests by calling the RoundTripper method. -// func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) { -// return t.RoundTripper().RoundTrip(r) -// } - -// RoundTripper returns a configured http.RoundTripper. -// It applies all registered middleware in reverse order. -// func (t *Transport) RoundTripper(opts ...Option) http.RoundTripper { -// return RoundTripperFunc(func(r *http.Request) (*http.Response, error) { -// options := newOptions(t.opts, opts...) -// if options.Transport == nil { -// options.Transport = t.Transport -// } -// // Apply middleware in reverse order -// for i := len(options.HttpRoundTripper) - 1; i >= 0; i-- { -// options.Transport = options.HttpRoundTripper[i](options.Transport) -// } -// return options.Transport.RoundTrip(r) -// }) -// } - -// Redirect creates a middleware for handling HTTP redirects. -// It handles 301 (Moved Permanently) and 302 (Found) status codes. -func Redirect(next http.RoundTripper) http.RoundTripper { - return RoundTripperFunc(func(req *http.Request) (*http.Response, error) { - response, err := next.RoundTrip(req) - if err != nil { - return response, err - } - // Check if redirection is needed - if response.StatusCode != http.StatusMovedPermanently && response.StatusCode != http.StatusFound { - return response, err - } - // Create redirect request - if req, err = NewRequestWithContext(req.Context(), Options{ - Method: req.Method, - URL: response.Header.Get("Location"), - Header: req.Header, - body: req.Body, - }); err != nil { - return response, err - } - // Execute redirect request - return next.RoundTrip(req) - }) -} diff --git a/transport_test.go b/transport_test.go index cbb5482..554ca08 100644 --- a/transport_test.go +++ b/transport_test.go @@ -3,10 +3,13 @@ package requests import ( "context" "io" + "net" "net/http" "net/http/httptest" + "net/url" "strings" "testing" + "time" ) func Test_Setup(t *testing.T) { @@ -50,3 +53,123 @@ func Test_Setup(t *testing.T) { } } + +func TestWarpRoundTripper(t *testing.T) { + // 测试装饰器链 + var order []string + rt1 := RoundTripperFunc(func(r *http.Request) (*http.Response, error) { + order = append(order, "rt1") + return &http.Response{StatusCode: 200}, nil + }) + + rt2 := WarpRoundTripper(rt1)(http.DefaultTransport) + _, err := rt2.RoundTrip(&http.Request{}) + if err != nil { + t.Fatal(err) + } + if len(order) != 1 || order[0] != "rt1" { + t.Error("装饰器执行顺序错误") + } +} + +func TestNewTransport(t *testing.T) { + tests := []struct { + name string + opts []Option + test func(*testing.T, *http.Transport) + }{ + { + name: "Unix套接字", + opts: []Option{URL("unix:///tmp/test.sock")}, + test: func(t *testing.T, tr *http.Transport) { + _, err := tr.DialContext(context.Background(), "unix", "/tmp/test.sock") + if err == nil { + t.Error("期望Unix套接字连接失败") + } + }, + }, + { + name: "本地地址绑定", + opts: []Option{LocalAddr(&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})}, + test: func(t *testing.T, tr *http.Transport) { + conn, err := tr.DialContext(context.Background(), "tcp", "example.com:80") + if err == nil { + conn.Close() + } + }, + }, + { + name: "TLS配置", + opts: []Option{Verify(false)}, + test: func(t *testing.T, tr *http.Transport) { + if tr.TLSClientConfig.InsecureSkipVerify != true { + t.Error("TLS验证配置错误") + } + }, + }, + { + name: "连接池配置", + opts: []Option{MaxConns(100)}, + test: func(t *testing.T, tr *http.Transport) { + if tr.MaxIdleConns != 100 || tr.MaxIdleConnsPerHost != 100 { + t.Error("连接池配置错误") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr := newTransport(tt.opts...) + tt.test(t, tr) + }) + } +} + +func TestTransportWithRealServer(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(10 * time.Millisecond) // 模拟处理延迟 + w.Write([]byte("ok")) + })) + defer server.Close() + + tr := newTransport( + Timeout(100*time.Millisecond), + MaxConns(10), + Verify(false), + ) + + client := &http.Client{Transport: tr} + + // 并发测试 + for i := 0; i < 10; i++ { + go func() { + resp, err := client.Get(server.URL) + if err != nil { + t.Error(err) + return + } + defer resp.Body.Close() + }() + } + + time.Sleep(200 * time.Millisecond) +} + +func TestTransportProxy(t *testing.T) { + proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("proxy response")) + })) + defer proxyServer.Close() + + tr := newTransport(Proxy(proxyServer.URL)) + + // 验证代理设置是否生效 + proxyURL, err := tr.Proxy(&http.Request{URL: &url.URL{Scheme: "http", Host: "example.com"}}) + if err != nil { + t.Fatal(err) + } + if proxyURL == nil { + t.Error("代理未正确设置") + } +}