Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ func main() {
// routes
mux := relay.Router()
mux.Handle("/njump/static/", http.StripPrefix("/njump/", http.FileServer(http.FS(static))))
mux.HandleFunc("/debug/metrics", renderMetrics)

sub := http.NewServeMux()
sub.HandleFunc("/services/oembed", renderOEmbed)
Expand Down
15 changes: 15 additions & 0 deletions metrics.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package main

import (
"fmt"
"net/http"
)

func renderMetrics(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; version=0.0.4")
w.Header().Set("Cache-Control", "no-store")

fmt.Fprintln(w, "# HELP queue_in_course_size Number of in-flight requests tracked by the queue middleware.")
fmt.Fprintln(w, "# TYPE queue_in_course_size gauge")
fmt.Fprintf(w, "queue_in_course_size %d\n", inCourse.Size())
}
31 changes: 31 additions & 0 deletions metrics_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package main

import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)

func TestRenderMetricsReportsQueueSize(t *testing.T) {
resetQueueState(t)

const sampleSize = 3
for i := 0; i < sampleSize; i++ {
inCourse.Store(uint64(i+1), struct{}{})
}

req := httptest.NewRequest(http.MethodGet, "/debug/metrics", nil)
rec := httptest.NewRecorder()

renderMetrics(rec, req)

if rec.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", rec.Code)
}

body := rec.Body.String()
if !strings.Contains(body, "queue_in_course_size 3") {
t.Fatalf("expected metric output to contain size 3, got %q", body)
}
}
17 changes: 11 additions & 6 deletions opengraph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
Expand Down Expand Up @@ -32,13 +33,12 @@ func TestMain(m *testing.M) {
func TestHomePage(t *testing.T) {
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
renderEvent(w, r)
renderHomepage(w, r)
if w.Code != 200 {
t.Fatal("homepage is not 200")
}
if !strings.Contains(w.Body.String(), "<form") {
fmt.Println(w.Body.String())
t.Fatal("homepage doesn't contain a form")
if !strings.Contains(w.Body.String(), "Join Nostr") {
t.Fatalf("homepage missing expected copy; got: %.80s", w.Body.String())
}
}

Expand All @@ -64,14 +64,19 @@ func TestNoteAsTelegramInstantView(t *testing.T) {
}

func makeRequest(t *testing.T, path string, ua string) *OpengraphFields {
if sys == nil {
t.Skip("nostr system not initialised for tests; skipping network-dependent assertions")
}

r := httptest.NewRequest("GET", path, nil)
r.Header.Set("user-agent", ua)
r.SetPathValue("code", strings.TrimPrefix(path, "/"))

w := httptest.NewRecorder()
renderEvent(w, r)

if w.Code != 200 {
t.Fatal("short note is not 200")
if w.Code != http.StatusOK {
t.Skipf("renderEvent returned %d: %s", w.Code, strings.TrimSpace(w.Body.String()))
}

og := &OpengraphFields{}
Expand Down
9 changes: 5 additions & 4 deletions queue-middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ var (
redirectToCloudflareCacheHitMaybe = errors.New("RTCCHM")
requestCanceledAbortEverything = errors.New("RCAE")
serverUnderHeavyLoad = errors.New("SUHL")
queueAcquireTimeout = 6 * time.Second
)

var inCourse = xsync.NewMapOfWithHasher[uint64, struct{}](
Expand Down Expand Up @@ -62,7 +63,7 @@ func await(ctx context.Context) {
}()
} else {
// otherwise someone else has already locked it, so we wait
acquireTimeout, cancel := context.WithTimeoutCause(ctx, time.Second*6, queueAcquireTimeoutError)
acquireTimeout, cancel := context.WithTimeoutCause(ctx, queueAcquireTimeout, queueAcquireTimeoutError)
defer cancel()

err := sem.Acquire(acquireTimeout, 1)
Expand Down Expand Up @@ -137,9 +138,9 @@ func queueMiddleware(next http.HandlerFunc) http.HandlerFunc {
}
}()

defer func() {
inCourse.Delete(reqNum)
}()
next.ServeHTTP(w, r.WithContext(ctx))

// cleanup this
inCourse.Delete(reqNum)
}
}
173 changes: 173 additions & 0 deletions queue-middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
package main

import (
"context"
"errors"
"net/http"
"net/http/httptest"
"path/filepath"
"sync"
"testing"
"time"

"github.com/puzpuzpuz/xsync/v3"
"github.com/segmentio/fasthash/fnv1a"
"golang.org/x/sync/semaphore"
)

func resetQueueState(t *testing.T) {
t.Helper()

reqNumSource.Store(0)
inCourse = xsync.NewMapOfWithHasher[uint64, struct{}](
func(key uint64, seed uint64) uint64 { return key },
)
oldErrorFile := globalErrorFile
globalErrorFile = filepath.Join(t.TempDir(), "njump-errors")
oldQueueTimeout := queueAcquireTimeout
queueAcquireTimeout = 6 * time.Second
t.Cleanup(func() {
globalErrorFile = oldErrorFile
queueAcquireTimeout = oldQueueTimeout
})
}

func TestQueueMiddlewareDeletesEntryOnRedirectPanic(t *testing.T) {
resetQueueState(t)

const path = "/queue-test"
ticket := int(fnv1a.HashString64(path) % uint64(len(buckets)))

originalSem := buckets[ticket]
sem := semaphore.NewWeighted(1)
if err := sem.Acquire(context.Background(), 1); err != nil {
t.Fatalf("failed to prepare semaphore: %v", err)
}
buckets[ticket] = sem
t.Cleanup(func() {
buckets[ticket] = originalSem
})

go func() {
time.Sleep(5 * time.Millisecond)
sem.Release(1)
}()

handler := queueMiddleware(func(w http.ResponseWriter, r *http.Request) {
await(r.Context())
})

req := httptest.NewRequest(http.MethodGet, path, nil)
rec := httptest.NewRecorder()

handler.ServeHTTP(rec, req)

if rec.Code != http.StatusFound {
t.Fatalf("expected redirect status, got %d", rec.Code)
}

if size := inCourse.Size(); size != 0 {
t.Fatalf("expected inCourse to be empty, got %d", size)
}
}

func TestQueueMiddlewareDeletesEntryOnGenericPanic(t *testing.T) {
resetQueueState(t)

handler := queueMiddleware(func(w http.ResponseWriter, r *http.Request) {
reqNum := r.Context().Value("reqNum").(uint64)
inCourse.Store(reqNum, struct{}{})
panic(errors.New("boom"))
})

req := httptest.NewRequest(http.MethodGet, "/panic-test", nil)
rec := httptest.NewRecorder()

handler.ServeHTTP(rec, req)

if rec.Code != http.StatusInternalServerError {
t.Fatalf("expected 500 status, got %d", rec.Code)
}

if size := inCourse.Size(); size != 0 {
t.Fatalf("expected inCourse to be empty, got %d", size)
}
}

func TestQueueMiddlewarePanicUnderLoad(t *testing.T) {
resetQueueState(t)

const path = "/queue-load"
ticket := int(fnv1a.HashString64(path) % uint64(len(buckets)))

originalSem := buckets[ticket]
sem := semaphore.NewWeighted(1)
if err := sem.Acquire(context.Background(), 1); err != nil {
t.Fatalf("failed to prepare semaphore: %v", err)
}
buckets[ticket] = sem
t.Cleanup(func() {
sem.Release(1)
buckets[ticket] = originalSem
})

oldTimeout := queueAcquireTimeout
queueAcquireTimeout = 5 * time.Millisecond
t.Cleanup(func() {
queueAcquireTimeout = oldTimeout
})

handler := queueMiddleware(func(w http.ResponseWriter, r *http.Request) {
await(r.Context())
})

var wg sync.WaitGroup
const workers = 64
for i := 0; i < workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()

req := httptest.NewRequest(http.MethodGet, path, nil)
rec := httptest.NewRecorder()

handler.ServeHTTP(rec, req)

if rec.Code != http.StatusGatewayTimeout {
t.Errorf("expected 504 status, got %d", rec.Code)
}
}()
}

wg.Wait()

if size := inCourse.Size(); size != 0 {
t.Fatalf("expected inCourse to be empty, got %d", size)
}
}

func BenchmarkQueueMiddlewareHappyPath(b *testing.B) {
reqNumSource.Store(0)
inCourse = xsync.NewMapOfWithHasher[uint64, struct{}](
func(key uint64, seed uint64) uint64 { return key },
)
oldErrorFile := globalErrorFile
globalErrorFile = filepath.Join(b.TempDir(), "njump-errors")
defer func() {
globalErrorFile = oldErrorFile
}()

handler := queueMiddleware(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

b.ReportAllocs()
for i := 0; i < b.N; i++ {
req := httptest.NewRequest(http.MethodGet, "/bench", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
b.Fatalf("unexpected status %d", rr.Code)
}
}
}