Skip to content

Instantly share code, notes, and snippets.

@Ferada
Last active November 14, 2018 16:41
Show Gist options
  • Save Ferada/61254f6984823d3b590924b5219ae195 to your computer and use it in GitHub Desktop.
Save Ferada/61254f6984823d3b590924b5219ae195 to your computer and use it in GitHub Desktop.
time go run timeoutreader.go http://www.example.com/ simple 1s >/dev/null
time go run timeoutreader.go http://www.example.com/ anything-but-simple 1s >/dev/null
time curl http://www.example.com/ >/dev/null
package main
import (
"bytes"
"fmt"
"io"
"net/http"
"os"
"time"
)
type ChunkOrError struct {
chunk []byte
error error
}
type TimeoutReader struct {
scratch *bytes.Buffer
channel chan ChunkOrError
}
const bufferSize = 32 * 1024
func NewTimeoutReader(r io.Reader, d time.Duration) *TimeoutReader {
reads := make(chan ChunkOrError, 1)
chunks := make(chan ChunkOrError, 1)
result := TimeoutReader{bytes.NewBuffer(make([]byte, 0, bufferSize)), chunks}
buffer := make([]byte, bufferSize)
timer := time.NewTimer(d)
done := false
go func() {
defer func() {
close(reads)
}()
for {
n, err := r.Read(buffer)
if !done && n != 0 {
// fmt.Fprintf(os.Stderr, "RAW non-zero read, copying %d bytes\n", n)
buffer2 := make([]byte, n)
copy(buffer2, buffer[0:n])
reads <- ChunkOrError{buffer2, nil}
}
if err != nil {
// fmt.Fprintf(os.Stderr, "RAW got err, done\n")
if !done {
reads <- ChunkOrError{nil, err}
}
return
}
}
}()
go func() {
defer func() {
close(chunks)
}()
for {
select {
case chunk := <-reads:
if chunk.error == nil {
if !done {
// fmt.Fprintf(os.Stderr, "CHUNKS read chunk of size %d\n", len(chunk.chunk))
if !timer.Stop() {
<-timer.C
}
timer.Reset(d)
chunks <- chunk
// fmt.Fprintf(os.Stderr, "CHUNKS sent chunk\n")
}
} else {
// fmt.Fprintf(os.Stderr, "CHUNKS read err\n")
if !done {
chunks <- chunk
}
done = true
if !timer.Stop() {
<-timer.C
}
// fmt.Fprintf(os.Stderr, "CHUNKS done err\n")
return
}
case <-timer.C:
// fmt.Fprintf(os.Stderr, "CHUNKS timed out\n")
done = true
chunks <- ChunkOrError{nil, fmt.Errorf("timeout after %v", d)}
return
}
}
}()
return &result
}
func (r *TimeoutReader) Read(p []byte) (int, error) {
lp := len(p)
// fmt.Fprintf(os.Stderr, "READ trying to read %d bytes\n", lp)
if lp == 0 {
return 0, nil
}
ip, _ := r.scratch.Read(p)
if ip >= lp {
// fmt.Fprintf(os.Stderr, "READ using scratch buffer\n")
return lp, nil
}
// fmt.Fprintf(os.Stderr, "READ looping for chunks, at %d\n", ip)
for {
select {
case chunk, more := <-r.channel:
if chunk.error == nil {
lc := len(chunk.chunk)
// fmt.Fprintf(os.Stderr, "READ got chunk of size %d\n", lc)
copied := copy(p[ip:], chunk.chunk)
// fmt.Fprintf(os.Stderr, "READ ip lc lp %d %d %d\n", ip, lc, lp)
if ip+lc >= lp {
r.scratch.Write(chunk.chunk[copied:])
}
ip += copied
// fmt.Fprintf(os.Stderr, "READ ip more %d %v\n", ip, more)
if ip == lp {
return ip, nil
} else if !more {
return ip, io.EOF
}
} else {
// fmt.Fprintf(os.Stderr, "READ returning ip, error, more, %d, %v, %v\n", ip, chunk.error, more)
return ip, chunk.error
}
}
}
}
type ReadResult struct {
n int
error error
}
type TimeoutReader2 struct {
nested io.Reader
timeout time.Duration
}
func NewTimeoutReader2(r io.Reader, d time.Duration) *TimeoutReader2 {
return &TimeoutReader2{r, d}
}
func (r *TimeoutReader2) Read(p []byte) (int, error) {
reads := make(chan ReadResult, 1)
go func() {
n, err := r.nested.Read(p)
reads <- ReadResult{n, err}
}()
timer := time.NewTimer(r.timeout)
select {
case result := <-reads:
if !timer.Stop() {
<-timer.C
}
return result.n, result.error
case <-timer.C:
return 0, fmt.Errorf("timeout after %v", r.timeout)
}
}
func main() {
// func() {
request, _ := http.NewRequest("GET", os.Args[1], nil)
response, _ := http.DefaultClient.Do(request)
defer response.Body.Close()
timeout, _ := time.ParseDuration(os.Args[3])
// written, err :=
if os.Args[1] == "simple" {
io.Copy(os.Stdout, NewTimeoutReader(response.Body, timeout))
} else {
io.Copy(os.Stdout, NewTimeoutReader2(response.Body, timeout))
}
// written, err := io.Copy(os.Stdout, NewTimeoutReader2(response.Body, 1*time.Second))
// fmt.Fprintf(os.Stderr, "written, err, %d, %v\n", written, err)
// }()
// fmt.Fprintf(os.Stderr, "done\n")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment