Last active
December 18, 2024 22:17
-
-
Save jan-bar/04651037175954f2e0ed1b2d02cedaf2 to your computer and use it in GitHub Desktop.
http代理的内网穿透
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"context" | |
"crypto/tls" | |
"flag" | |
"fmt" | |
"io" | |
"log" | |
"net" | |
"net/http" | |
"sync" | |
"time" | |
"github.com/gorilla/websocket" | |
) | |
func main() { | |
serAddr := flag.String("s", "", "server listen address") | |
cliAddr := flag.String("c", "", "client dial address") | |
proxy := flag.String("p", "", "client proxy address") | |
flag.Parse() | |
log.SetFlags(log.LstdFlags | log.Lshortfile) | |
if *serAddr != "" { | |
if err := server(*serAddr); err != nil { | |
log.Println(err) | |
} | |
} else if *cliAddr != "" { | |
client(*cliAddr, *proxy) | |
} | |
} | |
func (s *serverHandle) tunnel(w http.ResponseWriter, r *http.Request) error { | |
c, err := s.ws.Upgrade(w, r, nil) | |
if err != nil { | |
return err | |
} | |
//goland:noinspection GoUnhandledErrorResult | |
defer c.Close() | |
exit := make(chan struct{}) | |
go func() { | |
for { | |
select { | |
case d := <-s.send: | |
ew := c.WriteMessage(websocket.BinaryMessage, d.buf) | |
if ew != nil { | |
log.Println(ew) | |
} | |
putData(d) | |
case <-exit: | |
return | |
} | |
} | |
}() | |
var d *data | |
for { | |
d, err = recvData(c) | |
if err != nil { | |
putData(d) | |
// 客户端断开连接,需要退出上面协程 | |
exit <- struct{}{} | |
return err | |
} | |
s.mux.Lock() | |
pc, ok := s.save[d.cnt] | |
if ok { | |
if d.cmd == cmdClose { | |
delete(s.save, d.cnt) | |
close(pc.recv) | |
ok = false | |
} else { | |
// 只有这种情况才需要在Read方法中回收对象 | |
pc.recv <- d | |
} | |
} | |
if !ok { | |
putData(d) | |
} | |
s.mux.Unlock() | |
} | |
} | |
var httpEstablished = []byte("HTTP/1.1 200 Connection Established\r\n\r\n") | |
func (s *serverHandle) connect(c net.Conn, r *http.Request) error { | |
rc, err := s.proxyDial(r.Context(), "tcp", r.Host) | |
if err != nil { | |
return err | |
} | |
_, err = c.Write(httpEstablished) | |
if err != nil { | |
return err | |
} | |
go func() { | |
_, ec := io.Copy(rc, c) | |
if ec != nil { | |
log.Println(ec) | |
} | |
// 客户端c关闭,向rc发送close命令 | |
if ec = rc.Close(); ec != nil { | |
log.Println(ec) | |
} | |
}() | |
_, err = io.Copy(c, rc) | |
return err | |
} | |
func (s *serverHandle) general(c net.Conn, r *http.Request) error { | |
r.RequestURI = "" | |
r.Close = true | |
resp, err := s.ct.Do(r) | |
if err != nil { | |
return err | |
} | |
err = resp.Write(c) | |
if resp.Body != nil { | |
if ec := resp.Body.Close(); ec != nil { | |
log.Println(ec) | |
} | |
} | |
return err | |
} | |
func (s *serverHandle) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |
if r.Body != nil { | |
defer func() { | |
if ec := r.Body.Close(); ec != nil { | |
log.Println(ec) | |
} | |
}() | |
} | |
var err error | |
if r.Header.Get(flagHeaderKey) == flagHeaderVal { | |
err = s.tunnel(w, r) | |
} else { | |
var c net.Conn | |
c, _, err = w.(http.Hijacker).Hijack() | |
if err == nil { | |
if r.Method == http.MethodConnect { | |
// 目标为 https 地址时会用 CONNECT 方法 | |
err = s.connect(c, r) | |
} else { | |
// 目标为 http 地址时直接转发内容 | |
err = s.general(c, r) | |
} | |
if ec := c.Close(); ec != nil { | |
log.Println(ec) | |
} | |
} | |
} | |
if err != nil { | |
log.Println(err) | |
} | |
} | |
func (s *serverHandle) proxyDial(_ context.Context, _, address string) (net.Conn, error) { | |
pc := &proxyConn{ | |
send: s.send, | |
recv: make(chan *data), | |
} | |
s.mux.Lock() | |
s.cnt++ // 每个连接有自己的编号,确保客户端和服务器交换数据不会乱 | |
pc.cnt = s.cnt | |
s.save[pc.cnt] = pc | |
s.mux.Unlock() | |
pc.send <- sendData(pc.cnt, cmdConnect, []byte(address)) | |
recv := <-pc.recv // 通知客户端建立连接 | |
defer putData(recv) | |
if len(recv.data) != 0 { | |
return nil, fmt.Errorf("error: %s", recv.data) | |
} | |
return pc, nil | |
} | |
type serverHandle struct { | |
ws *websocket.Upgrader | |
ct *http.Client | |
send chan *data | |
mux sync.Mutex | |
cnt uint8 // 自增编号,在[0,255]之间循环 | |
save map[uint8]*proxyConn | |
} | |
func server(addr string) error { | |
l, err := net.Listen("tcp", addr) | |
if err != nil { | |
return err | |
} | |
addr = l.Addr().String() | |
sh := &serverHandle{ | |
ws: &websocket.Upgrader{ | |
CheckOrigin: func(r *http.Request) bool { return true }, | |
}, | |
ct: &http.Client{ | |
Timeout: time.Second * 10, | |
}, | |
send: make(chan *data), | |
save: make(map[uint8]*proxyConn), | |
} | |
sh.ct.Transport = &http.Transport{ | |
TLSClientConfig: &tls.Config{ | |
InsecureSkipVerify: true, | |
}, | |
// 普通http代理用下面方法建立连接 | |
DialContext: sh.proxyDial, | |
} | |
srv := &http.Server{ | |
Addr: addr, | |
Handler: sh, | |
ReadTimeout: time.Second * 3, | |
ReadHeaderTimeout: time.Second * 3, | |
} | |
log.Println("http server started on ", addr) | |
return srv.Serve(l) | |
} | |
// ----------------------------------------------------------------------------- | |
const ( | |
flagHeaderKey = "Proxy-Ctrl" | |
flagHeaderVal = "tunnel" | |
) | |
func client(addr, proxy string) { | |
var ( | |
send = make(chan *data) | |
exit = make(chan struct{}) | |
) | |
RetryConn: | |
c, _, err := websocket.DefaultDialer.Dial(addr, http.Header{ | |
flagHeaderKey: []string{flagHeaderVal}, | |
}) | |
if err != nil { | |
log.Println(err) | |
time.Sleep(time.Second * 5) | |
goto RetryConn | |
} | |
log.Printf("connect [%s],proxy [%s] ok", addr, proxy) | |
go func() { | |
for { | |
select { | |
case d := <-send: | |
ew := c.WriteMessage(websocket.BinaryMessage, d.buf) | |
if ew != nil { | |
log.Println(ew) | |
} | |
putData(d) | |
case <-exit: | |
return | |
} | |
} | |
}() | |
var ( | |
d *data | |
mux sync.Mutex | |
save = make(map[uint8]*proxyConn) | |
) | |
for { | |
d, err = recvData(c) | |
if err != nil { | |
putData(d) | |
log.Println(err) | |
if err = c.Close(); err != nil { | |
log.Println(err) | |
} | |
exit <- struct{}{} | |
goto RetryConn | |
} | |
mux.Lock() | |
pc, ok := save[d.cnt] | |
switch d.cmd { | |
case cmdConnect: | |
pc = newClient(d.cnt, send, string(d.data), proxy) | |
if pc == nil { | |
continue | |
} | |
save[d.cnt] = pc | |
ok = false | |
case cmdClose: | |
if ok { | |
delete(save, d.cnt) | |
close(pc.recv) | |
} | |
ok = false | |
default: | |
if ok { | |
// 只有这种情况才需要在Read方法中回收对象 | |
pc.recv <- d | |
} | |
} | |
if !ok { | |
putData(d) | |
} | |
mux.Unlock() | |
} | |
} | |
func newClient(cnt uint8, send chan *data, addr, proxy string) *proxyConn { | |
cc := &proxyConn{ | |
cnt: cnt, | |
send: send, | |
recv: make(chan *data), | |
} | |
var ( | |
rc net.Conn | |
err error | |
) | |
if proxy != "" { | |
rc, err = net.Dial("tcp", proxy) | |
if err == nil { | |
// 客户端使用代理转发,仅支持 CONNECT 方法的代理(访问https地址) | |
_, err = fmt.Fprintf(rc, "CONNECT %s HTTP/1.1\r\n"+ | |
"Host: %s\r\n"+ | |
"User-Agent: Go-http-client/1.1\r\n"+ | |
"Proxy-Connection: Keep-Alive\r\n\r\n", addr, addr) | |
if err == nil { | |
// 丢弃返回信息,结果包含 httpEstablished 的内容 | |
err = discardData(rc) | |
} | |
} | |
} else { | |
rc, err = net.Dial("tcp", addr) | |
} | |
if err != nil { | |
cc.send <- sendData(cnt, cmdConnect, []byte(err.Error())) | |
log.Println(err) | |
return nil | |
} | |
cc.send <- sendData(cnt, cmdConnect, nil) | |
go func() { | |
_, ec := io.Copy(cc, rc) | |
if ec != nil { | |
log.Println(ec) | |
} | |
if ec = cc.Close(); ec != nil { | |
log.Println(ec) | |
} | |
if ec = rc.Close(); ec != nil { | |
log.Println(ec) | |
} | |
}() | |
go func() { | |
// 客户端cc读到close,则会在Read返回io.EOF | |
_, ec := io.Copy(rc, cc) | |
if ec != nil { | |
log.Println(ec) | |
} | |
}() | |
return cc | |
} | |
const ( | |
cmdData = iota | |
cmdConnect | |
cmdClose | |
) | |
type proxyConn struct { | |
cnt uint8 | |
rb bytes.Buffer | |
send, recv chan *data | |
} | |
func (p *proxyConn) Read(b []byte) (int, error) { | |
if p.rb.Len() == 0 { | |
if d, ok := <-p.recv; ok { | |
p.rb.Write(d.data) | |
putData(d) | |
} else { | |
return 0, io.EOF | |
} | |
} | |
return p.rb.Read(b) | |
} | |
func (p *proxyConn) Write(b []byte) (int, error) { | |
p.send <- sendData(p.cnt, cmdData, b) | |
return len(b), nil | |
} | |
func (p *proxyConn) Close() error { | |
p.send <- sendData(p.cnt, cmdClose, nil) | |
return nil | |
} | |
func (p *proxyConn) LocalAddr() net.Addr { | |
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1)} | |
} | |
func (p *proxyConn) RemoteAddr() net.Addr { | |
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1)} | |
} | |
func (p *proxyConn) SetDeadline(time.Time) error { | |
return nil | |
} | |
func (p *proxyConn) SetReadDeadline(time.Time) error { | |
return nil | |
} | |
func (p *proxyConn) SetWriteDeadline(time.Time) error { | |
return nil | |
} | |
// 实现资源对象的复用 | |
type data struct { | |
cnt, cmd uint8 | |
buf, data []byte | |
} | |
var dataPool = &sync.Pool{New: func() any { | |
return &data{ | |
buf: make([]byte, 0, 512), | |
} | |
}} | |
func putData(d *data) { | |
if d != nil { | |
dataPool.Put(d) | |
} | |
} | |
func recvData(c *websocket.Conn) (*data, error) { | |
_, r, err := c.NextReader() | |
if err != nil { | |
return nil, err | |
} | |
var ( | |
n int | |
d = dataPool.Get().(*data) | |
) | |
// 参考 io.ReadAll 的代码,复用 d.buf | |
d.buf = d.buf[:0] | |
for { | |
n, err = r.Read(d.buf[len(d.buf):cap(d.buf)]) | |
d.buf = d.buf[:len(d.buf)+n] | |
if err != nil { | |
if err == io.EOF { | |
break | |
} | |
return d, err | |
} | |
if len(d.buf) == cap(d.buf) { | |
d.buf = append(d.buf, 0)[:len(d.buf)] | |
} | |
} | |
d.cnt, d.cmd = d.buf[0], d.buf[1] | |
d.data = d.buf[2:] | |
return d, nil | |
} | |
func sendData(cnt, cmd uint8, b []byte) *data { | |
d := dataPool.Get().(*data) | |
d.buf = append(d.buf[:0], cnt, cmd) | |
d.buf = append(d.buf, b...) | |
return d | |
} | |
func discardData(r io.Reader) error { | |
d := dataPool.Get().(*data) | |
_, err := r.Read(d.buf[:cap(d.buf)]) | |
dataPool.Put(d) | |
return err | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment