Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 40 additions & 55 deletions pkg/sources/s3/s3_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,37 +82,6 @@ func TestSource_ChunksLarge(t *testing.T) {
assert.Equal(t, got, wantChunkCount)
}

func TestSourceChunksNoResumption(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()

s := Source{}
connection := &sourcespb.S3{
Credential: &sourcespb.S3_Unauthenticated{},
Buckets: []string{"trufflesec-ahrav-test-2"},
}
conn, err := anypb.New(connection)
if err != nil {
t.Fatal(err)
}

err = s.Init(ctx, "test name", 0, 0, false, conn, 1)
chunksCh := make(chan *sources.Chunk)
go func() {
defer close(chunksCh)
err = s.Chunks(ctx, chunksCh)
assert.Nil(t, err)
}()

wantChunkCount := 19787
got := 0

for range chunksCh {
got++
}
assert.Equal(t, got, wantChunkCount)
}

func TestSource_Validate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()
Expand Down Expand Up @@ -251,34 +220,50 @@ func TestSource_Validate(t *testing.T) {
func TestSourceChunksNoResumption(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()

s := Source{}
connection := &sourcespb.S3{
Credential: &sourcespb.S3_Unauthenticated{},
Buckets: []string{"integration-resumption-tests"},
}
conn, err := anypb.New(connection)
if err != nil {
t.Fatal(err)
tests := []struct {
bucket string
wantChunkCount int
}{
{
bucket: "trufflesec-ahrav-test-2",
wantChunkCount: 19787,
},
{
bucket: "integration-resumption-tests",
wantChunkCount: 19787,
},
}

err = s.Init(ctx, "test name", 0, 0, false, conn, 1)
chunksCh := make(chan *sources.Chunk)
go func() {
defer close(chunksCh)
err = s.Chunks(ctx, chunksCh)
assert.Nil(t, err)
}()

wantChunkCount := 19787
got := 0
for _, tt := range tests {
t.Run(tt.bucket, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()

s := Source{}
connection := &sourcespb.S3{
Credential: &sourcespb.S3_Unauthenticated{},
Buckets: []string{tt.bucket},
}
conn, err := anypb.New(connection)
if err != nil {
t.Fatal(err)
}

for range chunksCh {
got++
err = s.Init(ctx, "test name", 0, 0, false, conn, 1)
chunksCh := make(chan *sources.Chunk)
go func() {
defer close(chunksCh)
err = s.Chunks(ctx, chunksCh)
assert.Nil(t, err)
}()

got := 0
for range chunksCh {
got++
}
assert.Equal(t, tt.wantChunkCount, got)
})
}
assert.Equal(t, wantChunkCount, got)
}

func TestSourceChunksResumption(t *testing.T) {
Expand Down
46 changes: 32 additions & 14 deletions pkg/sources/s3/s3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/base64"
"fmt"
"os"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -99,8 +98,7 @@ func TestSource_Chunks(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
var cancelOnce sync.Once
defer cancelOnce.Do(cancel)
defer cancel()

for k, v := range tt.init.setEnv {
t.Setenv(k, v)
Expand All @@ -117,26 +115,46 @@ func TestSource_Chunks(t *testing.T) {
t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr)
return
}
chunksCh := make(chan *sources.Chunk)
var wg sync.WaitGroup
wg.Add(1)
chunksCh := make(chan *sources.Chunk, 1)
go func() {
defer wg.Done()
defer close(chunksCh)
err = s.Chunks(ctx, chunksCh)
Comment on lines +118 to 121
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/cc @rosecodym

Here is a more detailed explanation of why TestSource_Chunks may block indefinitely.

The issue arises after the call to s.Chunks(ctx, chunksCh). The relevant call stack is:

  1. (*Source).Chunks

    // Chunks emits chunks of bytes over a channel.
    func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error {
    visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error {
    s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan)
    return nil
    }
    return s.visitRoles(ctx, visitor)
    }

  2. (*Source).scanBuckets

    s.pageChunker(ctx, pageMetadata, processingState, chunksChan)

  3. (*Source).pageChunker

    if err := handlers.HandleFile(ctx, res.Body, chunkSkel, sources.ChanReporter{Ch: chunksChan}); err != nil {
    ctx.Logger().Error(err, "error handling file")
    s.metricsCollector.RecordObjectError(metadata.bucket)
    return nil
    }

  4. HandleFile

    return handleChunksWithError(processingCtx, dataOrErrChan, chunkSkel, reporter)

  5. handleChunksWithError

    if err := reporter.ChunkOk(ctx, chunk); err != nil {
    return fmt.Errorf("error reporting chunk: %w", err)
    }

  6. (ChanReporter).ChunkOk

    func (c ChanReporter) ChunkOk(ctx context.Context, chunk Chunk) error {
    return common.CancellableWrite(ctx, c.Ch, &chunk)
    }

  7. And blocks infinitely in CancellableWrite. Because chunksCh is an unbuffered channel, the ch <- item case can never proceed since we only receive from chunksCh once.

    func CancellableWrite[T any](ctx context.Context, ch chan<- T, item T) error {
    select {
    case <-ctx.Done(): // priority to context cancellation
    return ctx.Err()
    default:
    select {
    case <-ctx.Done():
    return ctx.Err()
    case ch <- item:
    return nil
    }
    }
    }

After changing chunksCh to a buffered channel, I created a public S3 bucket and can confirm that the test no longer hangs infinitely:

TestSource_Chunk_buffered_channel.mp4

if (err != nil) != tt.wantErr {
t.Errorf("Source.Chunks() error = %v, wantErr %v", err, tt.wantErr)
os.Exit(1)
}
}()
gotChunk := <-chunksCh
wantData, _ := base64.StdEncoding.DecodeString(tt.wantChunkData)

if diff := pretty.Compare(gotChunk.Data, wantData); diff != "" {
t.Errorf("%s: Source.Chunks() diff: (-got +want)\n%s", tt.name, diff)
waitFn := func() {
receivedFirstChunk := false
for {
select {
case <-ctx.Done():
t.Errorf("TestSource_Chunks timed out: %v", ctx.Err())
case gotChunk, ok := <-chunksCh:
if !ok {
t.Logf("Source.Chunks() finished, channel closed")
assert.Equal(t, "", s.GetProgress().EncodedResumeInfo)
assert.Equal(t, int64(100), s.GetProgress().PercentComplete)
return
}
if receivedFirstChunk {
// wantChunkData is the first chunk data. After the first chunk has
// been received and matched below, we want to drain chunksCh
// so Source.Chunks() can finish completely.
continue
}

receivedFirstChunk = true
wantData, _ := base64.StdEncoding.DecodeString(tt.wantChunkData)

if diff := pretty.Compare(gotChunk.Data, wantData); diff != "" {
t.Logf("%s: Source.Chunks() diff: (-got +want)\n%s", tt.name, diff)
}
}
}
}
Copy link
Contributor Author

@Juneezee Juneezee Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rosecodym Could you try running the test again with the latest commit (1ca8a95)?

I believe this can be resolved by re-adding the removed waitgroup functionality.

Actually, a WaitGroup alone is not sufficient here. For s.Chunks(ctx, chunksCh) to finish completely, the chunksCh channel must be fully drained.

If we use a WaitGroup without draining chunksCh, the test will still block indefinitely, unless chunksCh is a buffered channel with a large enough buffer to hold all the chunks.

wg.Wait()
assert.Equal(t, "", s.GetProgress().EncodedResumeInfo)
assert.Equal(t, int64(100), s.GetProgress().PercentComplete)
waitFn()
})
}
}