Last active
November 14, 2018 16:41
-
-
Save Ferada/61254f6984823d3b590924b5219ae195 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
time go run timeoutreader.go http://www.example.com/ simple 1s >/dev/null | |
time go run timeoutreader.go http://www.example.com/ anything-but-simple 1s >/dev/null | |
time curl http://www.example.com/ >/dev/null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"fmt" | |
"io" | |
"net/http" | |
"os" | |
"time" | |
) | |
type ChunkOrError struct { | |
chunk []byte | |
error error | |
} | |
type TimeoutReader struct { | |
scratch *bytes.Buffer | |
channel chan ChunkOrError | |
} | |
const bufferSize = 32 * 1024 | |
func NewTimeoutReader(r io.Reader, d time.Duration) *TimeoutReader { | |
reads := make(chan ChunkOrError, 1) | |
chunks := make(chan ChunkOrError, 1) | |
result := TimeoutReader{bytes.NewBuffer(make([]byte, 0, bufferSize)), chunks} | |
buffer := make([]byte, bufferSize) | |
timer := time.NewTimer(d) | |
done := false | |
go func() { | |
defer func() { | |
close(reads) | |
}() | |
for { | |
n, err := r.Read(buffer) | |
if !done && n != 0 { | |
// fmt.Fprintf(os.Stderr, "RAW non-zero read, copying %d bytes\n", n) | |
buffer2 := make([]byte, n) | |
copy(buffer2, buffer[0:n]) | |
reads <- ChunkOrError{buffer2, nil} | |
} | |
if err != nil { | |
// fmt.Fprintf(os.Stderr, "RAW got err, done\n") | |
if !done { | |
reads <- ChunkOrError{nil, err} | |
} | |
return | |
} | |
} | |
}() | |
go func() { | |
defer func() { | |
close(chunks) | |
}() | |
for { | |
select { | |
case chunk := <-reads: | |
if chunk.error == nil { | |
if !done { | |
// fmt.Fprintf(os.Stderr, "CHUNKS read chunk of size %d\n", len(chunk.chunk)) | |
if !timer.Stop() { | |
<-timer.C | |
} | |
timer.Reset(d) | |
chunks <- chunk | |
// fmt.Fprintf(os.Stderr, "CHUNKS sent chunk\n") | |
} | |
} else { | |
// fmt.Fprintf(os.Stderr, "CHUNKS read err\n") | |
if !done { | |
chunks <- chunk | |
} | |
done = true | |
if !timer.Stop() { | |
<-timer.C | |
} | |
// fmt.Fprintf(os.Stderr, "CHUNKS done err\n") | |
return | |
} | |
case <-timer.C: | |
// fmt.Fprintf(os.Stderr, "CHUNKS timed out\n") | |
done = true | |
chunks <- ChunkOrError{nil, fmt.Errorf("timeout after %v", d)} | |
return | |
} | |
} | |
}() | |
return &result | |
} | |
func (r *TimeoutReader) Read(p []byte) (int, error) { | |
lp := len(p) | |
// fmt.Fprintf(os.Stderr, "READ trying to read %d bytes\n", lp) | |
if lp == 0 { | |
return 0, nil | |
} | |
ip, _ := r.scratch.Read(p) | |
if ip >= lp { | |
// fmt.Fprintf(os.Stderr, "READ using scratch buffer\n") | |
return lp, nil | |
} | |
// fmt.Fprintf(os.Stderr, "READ looping for chunks, at %d\n", ip) | |
for { | |
select { | |
case chunk, more := <-r.channel: | |
if chunk.error == nil { | |
lc := len(chunk.chunk) | |
// fmt.Fprintf(os.Stderr, "READ got chunk of size %d\n", lc) | |
copied := copy(p[ip:], chunk.chunk) | |
// fmt.Fprintf(os.Stderr, "READ ip lc lp %d %d %d\n", ip, lc, lp) | |
if ip+lc >= lp { | |
r.scratch.Write(chunk.chunk[copied:]) | |
} | |
ip += copied | |
// fmt.Fprintf(os.Stderr, "READ ip more %d %v\n", ip, more) | |
if ip == lp { | |
return ip, nil | |
} else if !more { | |
return ip, io.EOF | |
} | |
} else { | |
// fmt.Fprintf(os.Stderr, "READ returning ip, error, more, %d, %v, %v\n", ip, chunk.error, more) | |
return ip, chunk.error | |
} | |
} | |
} | |
} | |
type ReadResult struct { | |
n int | |
error error | |
} | |
type TimeoutReader2 struct { | |
nested io.Reader | |
timeout time.Duration | |
} | |
func NewTimeoutReader2(r io.Reader, d time.Duration) *TimeoutReader2 { | |
return &TimeoutReader2{r, d} | |
} | |
func (r *TimeoutReader2) Read(p []byte) (int, error) { | |
reads := make(chan ReadResult, 1) | |
go func() { | |
n, err := r.nested.Read(p) | |
reads <- ReadResult{n, err} | |
}() | |
timer := time.NewTimer(r.timeout) | |
select { | |
case result := <-reads: | |
if !timer.Stop() { | |
<-timer.C | |
} | |
return result.n, result.error | |
case <-timer.C: | |
return 0, fmt.Errorf("timeout after %v", r.timeout) | |
} | |
} | |
func main() { | |
// func() { | |
request, _ := http.NewRequest("GET", os.Args[1], nil) | |
response, _ := http.DefaultClient.Do(request) | |
defer response.Body.Close() | |
timeout, _ := time.ParseDuration(os.Args[3]) | |
// written, err := | |
if os.Args[1] == "simple" { | |
io.Copy(os.Stdout, NewTimeoutReader(response.Body, timeout)) | |
} else { | |
io.Copy(os.Stdout, NewTimeoutReader2(response.Body, timeout)) | |
} | |
// written, err := io.Copy(os.Stdout, NewTimeoutReader2(response.Body, 1*time.Second)) | |
// fmt.Fprintf(os.Stderr, "written, err, %d, %v\n", written, err) | |
// }() | |
// fmt.Fprintf(os.Stderr, "done\n") | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment