Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 40 additions & 55 deletions pkg/sources/s3/s3_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,37 +82,6 @@ func TestSource_ChunksLarge(t *testing.T) {
assert.Equal(t, got, wantChunkCount)
}

func TestSourceChunksNoResumption(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()

s := Source{}
connection := &sourcespb.S3{
Credential: &sourcespb.S3_Unauthenticated{},
Buckets: []string{"trufflesec-ahrav-test-2"},
}
conn, err := anypb.New(connection)
if err != nil {
t.Fatal(err)
}

err = s.Init(ctx, "test name", 0, 0, false, conn, 1)
chunksCh := make(chan *sources.Chunk)
go func() {
defer close(chunksCh)
err = s.Chunks(ctx, chunksCh)
assert.Nil(t, err)
}()

wantChunkCount := 19787
got := 0

for range chunksCh {
got++
}
assert.Equal(t, got, wantChunkCount)
}

func TestSource_Validate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()
Expand Down Expand Up @@ -251,34 +220,50 @@ func TestSource_Validate(t *testing.T) {
func TestSourceChunksNoResumption(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()

s := Source{}
connection := &sourcespb.S3{
Credential: &sourcespb.S3_Unauthenticated{},
Buckets: []string{"integration-resumption-tests"},
}
conn, err := anypb.New(connection)
if err != nil {
t.Fatal(err)
tests := []struct {
bucket string
wantChunkCount int
}{
{
bucket: "trufflesec-ahrav-test-2",
wantChunkCount: 19787,
},
{
bucket: "integration-resumption-tests",
wantChunkCount: 19787,
},
}

err = s.Init(ctx, "test name", 0, 0, false, conn, 1)
chunksCh := make(chan *sources.Chunk)
go func() {
defer close(chunksCh)
err = s.Chunks(ctx, chunksCh)
assert.Nil(t, err)
}()

wantChunkCount := 19787
got := 0
for _, tt := range tests {
t.Run(tt.bucket, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()

s := Source{}
connection := &sourcespb.S3{
Credential: &sourcespb.S3_Unauthenticated{},
Buckets: []string{tt.bucket},
}
conn, err := anypb.New(connection)
if err != nil {
t.Fatal(err)
}

for range chunksCh {
got++
err = s.Init(ctx, "test name", 0, 0, false, conn, 1)
chunksCh := make(chan *sources.Chunk)
go func() {
defer close(chunksCh)
err = s.Chunks(ctx, chunksCh)
assert.Nil(t, err)
}()

got := 0
for range chunksCh {
got++
}
assert.Equal(t, tt.wantChunkCount, got)
})
}
assert.Equal(t, wantChunkCount, got)
}

func TestSourceChunksResumption(t *testing.T) {
Expand Down
47 changes: 33 additions & 14 deletions pkg/sources/s3/s3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/base64"
"fmt"
"os"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -99,8 +98,7 @@ func TestSource_Chunks(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
var cancelOnce sync.Once
defer cancelOnce.Do(cancel)
defer cancel()

for k, v := range tt.init.setEnv {
t.Setenv(k, v)
Expand All @@ -117,26 +115,47 @@ func TestSource_Chunks(t *testing.T) {
t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr)
return
}
chunksCh := make(chan *sources.Chunk)
var wg sync.WaitGroup
wg.Add(1)
chunksCh := make(chan *sources.Chunk, 1)
go func() {
defer wg.Done()
defer close(chunksCh)
err = s.Chunks(ctx, chunksCh)
Comment on lines +118 to 121
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/cc @rosecodym

Here is a more detailed explanation of why TestSource_Chunks may block indefinitely.

The issue arises after the call to s.Chunks(ctx, chunksCh). The relevant call stack is:

  1. (*Source).Chunks

    // Chunks emits chunks of bytes over a channel.
    func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error {
    visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error {
    s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan)
    return nil
    }
    return s.visitRoles(ctx, visitor)
    }

  2. (*Source).scanBuckets

    s.pageChunker(ctx, pageMetadata, processingState, chunksChan)

  3. (*Source).pageChunker

    if err := handlers.HandleFile(ctx, res.Body, chunkSkel, sources.ChanReporter{Ch: chunksChan}); err != nil {
    ctx.Logger().Error(err, "error handling file")
    s.metricsCollector.RecordObjectError(metadata.bucket)
    return nil
    }

  4. HandleFile

    return handleChunksWithError(processingCtx, dataOrErrChan, chunkSkel, reporter)

  5. handleChunksWithError

    if err := reporter.ChunkOk(ctx, chunk); err != nil {
    return fmt.Errorf("error reporting chunk: %w", err)
    }

  6. (ChanReporter).ChunkOk

    func (c ChanReporter) ChunkOk(ctx context.Context, chunk Chunk) error {
    return common.CancellableWrite(ctx, c.Ch, &chunk)
    }

  7. And blocks infinitely in CancellableWrite. Because chunksCh is an unbuffered channel, the ch <- item case can never proceed since we only receive from chunksCh once.

    func CancellableWrite[T any](ctx context.Context, ch chan<- T, item T) error {
    select {
    case <-ctx.Done(): // priority to context cancellation
    return ctx.Err()
    default:
    select {
    case <-ctx.Done():
    return ctx.Err()
    case ch <- item:
    return nil
    }
    }
    }

After changing chunksCh to a buffered channel, I created a public S3 bucket and can confirm that the test no longer hangs infinitely:

TestSource_Chunk_buffered_channel.mp4

if (err != nil) != tt.wantErr {
t.Errorf("Source.Chunks() error = %v, wantErr %v", err, tt.wantErr)
os.Exit(1)
}
}()
gotChunk := <-chunksCh
wantData, _ := base64.StdEncoding.DecodeString(tt.wantChunkData)

if diff := pretty.Compare(gotChunk.Data, wantData); diff != "" {
t.Errorf("%s: Source.Chunks() diff: (-got +want)\n%s", tt.name, diff)
waitFn := func() {
receivedFirstChunk := false
for {
select {
case <-ctx.Done():
t.Errorf("TestSource_Chunks timed out: %v", ctx.Err())
return
case gotChunk, ok := <-chunksCh:
if !ok {
t.Logf("Source.Chunks() finished, channel closed")
assert.Equal(t, "", s.GetProgress().EncodedResumeInfo)
assert.Equal(t, int64(100), s.GetProgress().PercentComplete)
return
}
if receivedFirstChunk {
// wantChunkData is the first chunk data. After the first chunk has
// been received and matched below, we want to drain chunksCh
// so Source.Chunks() can finish completely.
continue
}

receivedFirstChunk = true
wantData, _ := base64.StdEncoding.DecodeString(tt.wantChunkData)

if diff := pretty.Compare(gotChunk.Data, wantData); diff != "" {
t.Logf("%s: Source.Chunks() diff: (-got +want)\n%s", tt.name, diff)
}
}
}
}
wg.Wait()
assert.Equal(t, "", s.GetProgress().EncodedResumeInfo)
assert.Equal(t, int64(100), s.GetProgress().PercentComplete)
waitFn()
})
}
}
Loading