Skip to content

Instantly share code, notes, and snippets.

@yinheli
Forked from maliubiao/socks5.go
Created August 15, 2014 22:15
Show Gist options
  • Save yinheli/ade1e52edad60999ac9a to your computer and use it in GitHub Desktop.
Save yinheli/ade1e52edad60999ac9a to your computer and use it in GitHub Desktop.
package socks5
import (
"net"
"time"
"bytes"
"errors"
"bufio"
"strconv"
"strings"
"net/http"
"encoding/binary"
)
type Socks5RoundTripper struct {
Proxy string
Resp *http.Response
conn net.Conn
reqch chan *http.Response
writech chan bool
writer *bufio.Writer
reader *bufio.Reader
sendreq time.Duration
recvresp time.Duration
ModifyHeader func(*http.Header)
}
func (Socks5 *Socks5RoundTripper) generateRequest(req *http.Request) (message []byte, err error) {
host := ""
var port uint32
var msg = make([]byte, 128)
if len(req.URL.Host) > 120 {
return nil, errors.New("req.URL.Host too long")
}
copy(msg, []byte("\x05\x01\x00\x03"))
if !strings.Contains(req.URL.Host, ":") {
host = req.URL.Host
port = uint32(80)
} else {
hostAndPort := strings.Split(host, ":")
host = hostAndPort[0]
p, err := strconv.Atoi(hostAndPort[1])
if err != nil {
return nil, errors.New("Wired req.URL.Host")
}
port = uint32(p)
}
msg[4]= byte(len(host))
copy(msg[5:], []byte(host))
portBuf := make([]byte, 2)
binary.BigEndian.PutUint16(portBuf, uint16(port))
copy(msg[len(host)+5:], portBuf)
return msg[:len(host)+7], nil
}
func (Socks5 *Socks5RoundTripper) readLoop(req *http.Request) {
res, err := http.ReadResponse(Socks5.reader, req)
if err != nil {
Socks5.reqch <- nil
return
}
Socks5.reqch <- res
}
func (Socks5 *Socks5RoundTripper) writeLoop(req *http.Request) {
err := req.Write(Socks5.writer)
if err == nil {
Socks5.writer.Flush()
Socks5.writech <- true
} else {
Socks5.writech <- false
}
}
func (Socks5 *Socks5RoundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error){
Socks5.writech = make(chan bool, 1)
Socks5.reqch = make(chan *http.Response, 1)
if Socks5.ModifyHeader != nil {
Socks5.ModifyHeader(&req.Header)
}
dialc := make(chan net.Conn, 1)
var dialTimeout <-chan time.Time
go func() {
dialTimeout = time.After(1 * time.Second)
conn, err := net.Dial("tcp", Socks5.Proxy)
if err != nil {
dialc <- nil
}
dialc <- conn
}()
select {
case Socks5.conn = <-dialc:
if Socks5.conn == nil {
return nil, errors.New("socks5: net.Dial failed")
}
case <-dialTimeout:
return nil, errors.New("socks5: connect proxy timeout")
}
if tcpconn, ok := Socks5.conn.(*net.TCPConn); ok {
tcpconn.SetNoDelay(true)
}
Socks5.reader = bufio.NewReader(Socks5.conn)
Socks5.writer = bufio.NewWriter(Socks5.conn)
var sendHandleShakeChan = make(chan bool, )
var sendHandleShakeTimeout <-chan time.Time
go func() {
sendHandleShakeTimeout = time.After(1 * time.Second)
_, err := Socks5.writer.Write([]byte("\x05\x01\x00"))
if err != nil {
sendHandleShakeChan <- false
} else {
sendHandleShakeChan <- true
Socks5.writer.Flush()
}
}()
select {
case send := <-sendHandleShakeChan:
if ! send {
return nil, errors.New("socks5: handshake error")
}
case <-sendHandleShakeTimeout:
return nil, errors.New("socks5: handshake timeout")
}
var recvHandShakeTimeout <-chan time.Time
var recvHandShakeChan = make(chan []byte, 1)
go func () {
recvHandShakeTimeout = time.After(1 * time.Second)
msg := make([]byte, 12)
for {
n, err := Socks5.reader.Read(msg)
if n != 0 {
break
}
if err != nil {
recvHandShakeChan <- nil
}
}
recvHandShakeChan <- msg
}()
select {
case res := <-recvHandShakeChan:
if !bytes.HasPrefix(res, []byte("\x05\x00")) {
return nil, errors.New("sock5: handshake failed")
}
case <-recvHandShakeTimeout:
return nil, errors.New("socks5: send handshake timout")
}
var requestTimeout <-chan time.Time
var requestChan = make(chan bool, 1)
go func () {
requestTimeout = time.After(1 * time.Second)
msg, err := Socks5.generateRequest(req)
if err != nil {
requestChan <- false
return
}
_, err = Socks5.writer.Write(msg)
if err != nil {
requestChan <- false
} else {
Socks5.writer.Flush()
requestChan <- true
}
}()
select {
case res := <-requestChan:
if !res {
return nil, errors.New("socks5: send request failed")
}
case <-requestTimeout:
return nil, errors.New("send request timeout")
}
var reqResponseTimeout <-chan time.Time
var reqResponseChan = make(chan string, 1)
go func () {
reqResponseTimeout = time.After(5 * time.Second)
msg := make([]byte, 12)
_, err := Socks5.reader.Read(msg)
if err != nil {
reqResponseChan <- ""
}
reqResponseChan <- string(msg)
}()
select {
case res := <-reqResponseChan:
if !strings.HasPrefix(res, "\x05\x00") {
return nil, errors.New("socks5: request failed")
}
case <-reqResponseTimeout:
return nil, errors.New("socks5: request timeout")
}
var writeTimeout <-chan time.Time
writeTimeout = time.After(Socks5.sendreq)
go Socks5.writeLoop(req)
select {
case wc := <-Socks5.writech:
if !wc {
return nil, errors.New("socks5: write request failed")
}
case <-writeTimeout:
return nil, errors.New("socks5: write request timeout")
}
var readTimeout <-chan time.Time
readTimeout = time.After(Socks5.recvresp)
go Socks5.readLoop(req)
select {
case rc := <-Socks5.reqch:
if rc == nil {
return nil, errors.New("socks5: wried response")
}
resp = rc
case <-readTimeout:
return nil, errors.New("socks5: remote timeout")
}
return
}
func SetProxy(url string) (err error) {
_, err = net.ResolveTCPAddr("tcp", url)
if err != nil {
return err
}
http.DefaultClient = &http.Client{}
socks5 := &Socks5RoundTripper{}
socks5.Proxy = url
socks5.sendreq = 30 * time.Second
socks5.recvresp = 30 * time.Second
http.DefaultClient.Transport = socks5
return
}
func ModifyHeader(mh func(*http.Header)) {
if st, ok := http.DefaultClient.Transport.(*Socks5RoundTripper); ok {
st.ModifyHeader = mh
}
}
func SetTimeout(sendreq, recvresp int64) {
if st, ok := http.DefaultClient.Transport.(*Socks5RoundTripper); ok {
if sendreq <= 0 || recvresp <= 0 {
return
}
st.sendreq = time.Duration(sendreq)
st.recvresp = time.Duration(recvresp)
}
}
func UnsetProxy() {
http.DefaultClient = &http.Client{}
}
func testGet(url string) {
socks5.SetProxy("127.0.0.1:9988")
socks5.ModifyHeader(func(header *http.Header) {
header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64; rv:25.0) Gecko/20100101 Firefox/25.0")
})
resp, err := http.Get(url)
if err != nil {
panic(err)
}
fmt.Println(resp.Status)
io.Copy(os.Stdout, resp.Body)
fmt.Println("")
}
func main() {
testGet(os.Args[1])
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment