Skip to content

update: coverage test #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@
go.work

vendor/
.idea/
.idea/
coverage.txt
105 changes: 102 additions & 3 deletions response_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,27 @@ import (
)

// mockResponseWriter implements http.ResponseWriter for testing
// 更新 mockResponseWriter 以支持错误测试
type mockResponseWriter struct {
headers http.Header
statuscode int
body bytes.Buffer
writeError error // 添加这个字段
}

func (m *mockResponseWriter) Write(b []byte) (int, error) {
if m.writeError != nil {
return 0, m.writeError
}
return m.body.Write(b)
}

func newMockResponseWriter() *mockResponseWriter {
return &mockResponseWriter{headers: make(http.Header)}
}

func (m *mockResponseWriter) Header() http.Header { return m.headers }
func (m *mockResponseWriter) Write(b []byte) (int, error) { return m.body.Write(b) }
func (m *mockResponseWriter) WriteHeader(code int) { m.statuscode = code }
func (m *mockResponseWriter) Header() http.Header { return m.headers }
func (m *mockResponseWriter) WriteHeader(code int) { m.statuscode = code }

// TestResponseWriterBasic tests basic functionality of ResponseWriter
func TestResponseWriterBasic(t *testing.T) {
Expand Down Expand Up @@ -208,3 +216,94 @@ func TestResponseWriterFlush(t *testing.T) {
t.Errorf("Expected chunk2, got %s, err: %v", chunk2, err)
}
}

// TestResponseWriterRead 测试 Read 方法
func TestResponseWriterRead(t *testing.T) {
mock := newMockResponseWriter()
w := newResponseWriter(mock)

// 写入测试数据
testData := []byte("test data for reading")
_, err := w.Write(testData)
if err != nil {
t.Fatalf("Failed to write test data: %v", err)
}

// 测试读取
buf := make([]byte, len(testData))
n, err := w.Read(buf)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if n != len(testData) {
t.Errorf("Expected to read %d bytes, got %d", len(testData), n)
}
if string(buf) != string(testData) {
t.Errorf("Expected to read '%s', got '%s'", string(testData), string(buf))
}

// 测试读取完后的EOF
n, err = w.Read(buf)
if err != io.EOF {
t.Errorf("Expected EOF after reading all data, got %v", err)
}
}

// TestResponseWriterPush 测试 HTTP/2 Push 功能
func TestResponseWriterPush(t *testing.T) {
// 创建支持 HTTP/2 的测试服务器
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rw := newResponseWriter(w)
err := rw.Push("/style.css", &http.PushOptions{
Method: "GET",
Header: http.Header{
"Content-Type": []string{"text/css"},
},
})
if err != nil && err != http.ErrNotSupported {
t.Errorf("Push failed: %v", err)
}
rw.Write([]byte("main content"))
})

server := httptest.NewUnstartedServer(handler)
server.EnableHTTP2 = true
server.StartTLS()
defer server.Close()

// 发起请求
client := server.Client()
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response: %v", err)
}
if string(body) != "main content" {
t.Errorf("Expected 'main content', got '%s'", string(body))
}
}

// TestResponseWriterWriteError 测试 Write 方法的错误处理
func TestResponseWriterWriteError(t *testing.T) {
// 创建一个会返回错误的 mock
errMock := &mockResponseWriter{
headers: make(http.Header),
writeError: fmt.Errorf("write error"),
}

w := newResponseWriter(errMock)

// 测试写入错误
n, err := w.Write([]byte("test"))
if err == nil {
t.Error("Expected write error, got nil")
}
if n != 0 {
t.Errorf("Expected 0 bytes written on error, got %d", n)
}
}
48 changes: 0 additions & 48 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,51 +78,3 @@ func newTransport(opts ...Option) *http.Transport {
},
}
}

// RoundTrip implements the RoundTripper interface.
// It processes requests by calling the RoundTripper method.
// func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) {
// return t.RoundTripper().RoundTrip(r)
// }

// RoundTripper returns a configured http.RoundTripper.
// It applies all registered middleware in reverse order.
// func (t *Transport) RoundTripper(opts ...Option) http.RoundTripper {
// return RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
// options := newOptions(t.opts, opts...)
// if options.Transport == nil {
// options.Transport = t.Transport
// }
// // Apply middleware in reverse order
// for i := len(options.HttpRoundTripper) - 1; i >= 0; i-- {
// options.Transport = options.HttpRoundTripper[i](options.Transport)
// }
// return options.Transport.RoundTrip(r)
// })
// }

// Redirect creates a middleware for handling HTTP redirects.
// It handles 301 (Moved Permanently) and 302 (Found) status codes.
func Redirect(next http.RoundTripper) http.RoundTripper {
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
response, err := next.RoundTrip(req)
if err != nil {
return response, err
}
// Check if redirection is needed
if response.StatusCode != http.StatusMovedPermanently && response.StatusCode != http.StatusFound {
return response, err
}
// Create redirect request
if req, err = NewRequestWithContext(req.Context(), Options{
Method: req.Method,
URL: response.Header.Get("Location"),
Header: req.Header,
body: req.Body,
}); err != nil {
return response, err
}
// Execute redirect request
return next.RoundTrip(req)
})
}
123 changes: 123 additions & 0 deletions transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ package requests
import (
"context"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
)

func Test_Setup(t *testing.T) {
Expand Down Expand Up @@ -50,3 +53,123 @@ func Test_Setup(t *testing.T) {
}

}

func TestWarpRoundTripper(t *testing.T) {
// 测试装饰器链
var order []string
rt1 := RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
order = append(order, "rt1")
return &http.Response{StatusCode: 200}, nil
})

rt2 := WarpRoundTripper(rt1)(http.DefaultTransport)
_, err := rt2.RoundTrip(&http.Request{})
if err != nil {
t.Fatal(err)
}
if len(order) != 1 || order[0] != "rt1" {
t.Error("装饰器执行顺序错误")
}
}

func TestNewTransport(t *testing.T) {
tests := []struct {
name string
opts []Option
test func(*testing.T, *http.Transport)
}{
{
name: "Unix套接字",
opts: []Option{URL("unix:///tmp/test.sock")},
test: func(t *testing.T, tr *http.Transport) {
_, err := tr.DialContext(context.Background(), "unix", "/tmp/test.sock")
if err == nil {
t.Error("期望Unix套接字连接失败")
}
},
},
{
name: "本地地址绑定",
opts: []Option{LocalAddr(&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})},
test: func(t *testing.T, tr *http.Transport) {
conn, err := tr.DialContext(context.Background(), "tcp", "example.com:80")
if err == nil {
conn.Close()
}
},
},
{
name: "TLS配置",
opts: []Option{Verify(false)},
test: func(t *testing.T, tr *http.Transport) {
if tr.TLSClientConfig.InsecureSkipVerify != true {
t.Error("TLS验证配置错误")
}
},
},
{
name: "连接池配置",
opts: []Option{MaxConns(100)},
test: func(t *testing.T, tr *http.Transport) {
if tr.MaxIdleConns != 100 || tr.MaxIdleConnsPerHost != 100 {
t.Error("连接池配置错误")
}
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tr := newTransport(tt.opts...)
tt.test(t, tr)
})
}
}

func TestTransportWithRealServer(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(10 * time.Millisecond) // 模拟处理延迟
w.Write([]byte("ok"))
}))
defer server.Close()

tr := newTransport(
Timeout(100*time.Millisecond),
MaxConns(10),
Verify(false),
)

client := &http.Client{Transport: tr}

// 并发测试
for i := 0; i < 10; i++ {
go func() {
resp, err := client.Get(server.URL)
if err != nil {
t.Error(err)
return
}
defer resp.Body.Close()
}()
}

time.Sleep(200 * time.Millisecond)
}

func TestTransportProxy(t *testing.T) {
proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("proxy response"))
}))
defer proxyServer.Close()

tr := newTransport(Proxy(proxyServer.URL))

// 验证代理设置是否生效
proxyURL, err := tr.Proxy(&http.Request{URL: &url.URL{Scheme: "http", Host: "example.com"}})
if err != nil {
t.Fatal(err)
}
if proxyURL == nil {
t.Error("代理未正确设置")
}
}