Skip to content

Commit

Permalink
httpclient: add count_bytes_reader (#1065)
Browse files Browse the repository at this point in the history
  • Loading branch information
gabor authored Aug 28, 2024
1 parent 4cba513 commit f1cb16c
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
39 changes: 39 additions & 0 deletions backend/httpclient/count_bytes_reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package httpclient

import (
"io"
)

type CloseCallbackFunc func(bytesRead int64)

// CountBytesReader counts the total amount of bytes read from the underlying reader.
//
// The provided callback func will be called before the underlying reader is closed.
func CountBytesReader(reader io.ReadCloser, callback CloseCallbackFunc) io.ReadCloser {
if reader == nil {
panic("reader cannot be nil")
}

if callback == nil {
panic("callback cannot be nil")
}

return &countBytesReader{reader: reader, callback: callback}
}

type countBytesReader struct {
reader io.ReadCloser
callback CloseCallbackFunc
counter int64
}

func (r *countBytesReader) Read(p []byte) (int, error) {
n, err := r.reader.Read(p)
r.counter += int64(n)
return n, err
}

func (r *countBytesReader) Close() error {
r.callback(r.counter)
return r.reader.Close()
}
38 changes: 38 additions & 0 deletions backend/httpclient/count_bytes_reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package httpclient

import (
"fmt"
"io"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

func TestCountBytesReader(t *testing.T) {
tcs := []struct {
body string
expectedBytesCount int64
}{
{body: "d", expectedBytesCount: 1},
{body: "dummy", expectedBytesCount: 5},
}

for index, tc := range tcs {
t.Run(fmt.Sprintf("Test CountBytesReader %d", index), func(t *testing.T) {
body := io.NopCloser(strings.NewReader(tc.body))
var actualBytesRead int64

readCloser := CountBytesReader(body, func(bytesRead int64) {
actualBytesRead = bytesRead
})

bodyBytes, err := io.ReadAll(readCloser)
require.NoError(t, err)
err = readCloser.Close()
require.NoError(t, err)
require.Equal(t, tc.expectedBytesCount, actualBytesRead)
require.Equal(t, string(bodyBytes), tc.body)
})
}
}

0 comments on commit f1cb16c

Please sign in to comment.