Skip to content

Commit 7fbbedd

Browse files
committed
progress: support ReadCloser
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
1 parent 54fbd25 commit 7fbbedd

File tree

2 files changed

+44
-4
lines changed

2 files changed

+44
-4
lines changed

progress/reader.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ package progress
22

33
import "io"
44

5-
// Reader is an [io.Reader] that reports the number of bytes read from it via a
6-
// callback. The callback is called at most every "updateInterval" bytes. The
7-
// updateInterval can be set using the [Reader.WithUpdateInterval] method.
5+
// Reader is an [io.ReadCloser] that reports the number of bytes read from it
6+
// via a callback. The callback is called at most every "updateInterval" bytes.
7+
// The updateInterval can be set using the [Reader.WithUpdateInterval] method.
88
//
99
// The following is an example of how to use [Reader] to report the progress of
1010
// reading from a file:
@@ -44,4 +44,11 @@ func (r *Reader) Read(p []byte) (n int, err error) {
4444
return n, nil
4545
}
4646

47-
var _ io.Reader = (*Reader)(nil)
47+
func (r *Reader) Close() error {
48+
if closer, ok := r.inner.(io.Closer); ok {
49+
return closer.Close()
50+
}
51+
return nil
52+
}
53+
54+
var _ io.ReadCloser = (*Reader)(nil)

progress/reader_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package progress_test
22

33
import (
44
"bytes"
5+
"errors"
56
"io"
67
"testing"
78

@@ -24,3 +25,35 @@ func TestReader(t *testing.T) {
2425
assert.Greater(t, len(progressUpdates), 1)
2526
assert.IsIncreasing(t, progressUpdates)
2627
}
28+
29+
type testReader struct {
30+
*bytes.Reader
31+
closed bool
32+
}
33+
34+
func (r *testReader) Close() error {
35+
if r.closed {
36+
return errors.New("already closed")
37+
}
38+
r.closed = true
39+
return nil
40+
}
41+
42+
func TestReadCloser(t *testing.T) {
43+
readCloser := &testReader{Reader: bytes.NewReader(bytes.Repeat([]byte{42}, 1024*1024))}
44+
45+
var progressUpdates []int
46+
progressReader := progress.NewReader(readCloser, func(readBytes int) {
47+
progressUpdates = append(progressUpdates, readBytes)
48+
})
49+
50+
data, err := io.ReadAll(progressReader)
51+
assert.NoError(t, err)
52+
assert.Equal(t, data, bytes.Repeat([]byte{42}, 1024*1024))
53+
54+
assert.Greater(t, len(progressUpdates), 1)
55+
assert.IsIncreasing(t, progressUpdates)
56+
57+
assert.NoError(t, progressReader.Close())
58+
assert.ErrorContains(t, progressReader.Close(), "already closed")
59+
}

0 commit comments

Comments
 (0)