Skip to content

Instantly share code, notes, and snippets.

@jan-bar
Last active December 18, 2024 22:17
Show Gist options
  • Save jan-bar/04651037175954f2e0ed1b2d02cedaf2 to your computer and use it in GitHub Desktop.
Save jan-bar/04651037175954f2e0ed1b2d02cedaf2 to your computer and use it in GitHub Desktop.
http代理的内网穿透
package main
import (
"bytes"
"context"
"crypto/tls"
"flag"
"fmt"
"io"
"log"
"net"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
)
func main() {
serAddr := flag.String("s", "", "server listen address")
cliAddr := flag.String("c", "", "client dial address")
proxy := flag.String("p", "", "client proxy address")
flag.Parse()
log.SetFlags(log.LstdFlags | log.Lshortfile)
if *serAddr != "" {
if err := server(*serAddr); err != nil {
log.Println(err)
}
} else if *cliAddr != "" {
client(*cliAddr, *proxy)
}
}
func (s *serverHandle) tunnel(w http.ResponseWriter, r *http.Request) error {
c, err := s.ws.Upgrade(w, r, nil)
if err != nil {
return err
}
//goland:noinspection GoUnhandledErrorResult
defer c.Close()
exit := make(chan struct{})
go func() {
for {
select {
case d := <-s.send:
ew := c.WriteMessage(websocket.BinaryMessage, d.buf)
if ew != nil {
log.Println(ew)
}
putData(d)
case <-exit:
return
}
}
}()
var d *data
for {
d, err = recvData(c)
if err != nil {
putData(d)
// 客户端断开连接,需要退出上面协程
exit <- struct{}{}
return err
}
s.mux.Lock()
pc, ok := s.save[d.cnt]
if ok {
if d.cmd == cmdClose {
delete(s.save, d.cnt)
close(pc.recv)
ok = false
} else {
// 只有这种情况才需要在Read方法中回收对象
pc.recv <- d
}
}
if !ok {
putData(d)
}
s.mux.Unlock()
}
}
var httpEstablished = []byte("HTTP/1.1 200 Connection Established\r\n\r\n")
func (s *serverHandle) connect(c net.Conn, r *http.Request) error {
rc, err := s.proxyDial(r.Context(), "tcp", r.Host)
if err != nil {
return err
}
_, err = c.Write(httpEstablished)
if err != nil {
return err
}
go func() {
_, ec := io.Copy(rc, c)
if ec != nil {
log.Println(ec)
}
// 客户端c关闭,向rc发送close命令
if ec = rc.Close(); ec != nil {
log.Println(ec)
}
}()
_, err = io.Copy(c, rc)
return err
}
func (s *serverHandle) general(c net.Conn, r *http.Request) error {
r.RequestURI = ""
r.Close = true
resp, err := s.ct.Do(r)
if err != nil {
return err
}
err = resp.Write(c)
if resp.Body != nil {
if ec := resp.Body.Close(); ec != nil {
log.Println(ec)
}
}
return err
}
func (s *serverHandle) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Body != nil {
defer func() {
if ec := r.Body.Close(); ec != nil {
log.Println(ec)
}
}()
}
var err error
if r.Header.Get(flagHeaderKey) == flagHeaderVal {
err = s.tunnel(w, r)
} else {
var c net.Conn
c, _, err = w.(http.Hijacker).Hijack()
if err == nil {
if r.Method == http.MethodConnect {
// 目标为 https 地址时会用 CONNECT 方法
err = s.connect(c, r)
} else {
// 目标为 http 地址时直接转发内容
err = s.general(c, r)
}
if ec := c.Close(); ec != nil {
log.Println(ec)
}
}
}
if err != nil {
log.Println(err)
}
}
func (s *serverHandle) proxyDial(_ context.Context, _, address string) (net.Conn, error) {
pc := &proxyConn{
send: s.send,
recv: make(chan *data),
}
s.mux.Lock()
s.cnt++ // 每个连接有自己的编号,确保客户端和服务器交换数据不会乱
pc.cnt = s.cnt
s.save[pc.cnt] = pc
s.mux.Unlock()
pc.send <- sendData(pc.cnt, cmdConnect, []byte(address))
recv := <-pc.recv // 通知客户端建立连接
defer putData(recv)
if len(recv.data) != 0 {
return nil, fmt.Errorf("error: %s", recv.data)
}
return pc, nil
}
type serverHandle struct {
ws *websocket.Upgrader
ct *http.Client
send chan *data
mux sync.Mutex
cnt uint8 // 自增编号,在[0,255]之间循环
save map[uint8]*proxyConn
}
func server(addr string) error {
l, err := net.Listen("tcp", addr)
if err != nil {
return err
}
addr = l.Addr().String()
sh := &serverHandle{
ws: &websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
},
ct: &http.Client{
Timeout: time.Second * 10,
},
send: make(chan *data),
save: make(map[uint8]*proxyConn),
}
sh.ct.Transport = &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
// 普通http代理用下面方法建立连接
DialContext: sh.proxyDial,
}
srv := &http.Server{
Addr: addr,
Handler: sh,
ReadTimeout: time.Second * 3,
ReadHeaderTimeout: time.Second * 3,
}
log.Println("http server started on ", addr)
return srv.Serve(l)
}
// -----------------------------------------------------------------------------
const (
flagHeaderKey = "Proxy-Ctrl"
flagHeaderVal = "tunnel"
)
func client(addr, proxy string) {
var (
send = make(chan *data)
exit = make(chan struct{})
)
RetryConn:
c, _, err := websocket.DefaultDialer.Dial(addr, http.Header{
flagHeaderKey: []string{flagHeaderVal},
})
if err != nil {
log.Println(err)
time.Sleep(time.Second * 5)
goto RetryConn
}
log.Printf("connect [%s],proxy [%s] ok", addr, proxy)
go func() {
for {
select {
case d := <-send:
ew := c.WriteMessage(websocket.BinaryMessage, d.buf)
if ew != nil {
log.Println(ew)
}
putData(d)
case <-exit:
return
}
}
}()
var (
d *data
mux sync.Mutex
save = make(map[uint8]*proxyConn)
)
for {
d, err = recvData(c)
if err != nil {
putData(d)
log.Println(err)
if err = c.Close(); err != nil {
log.Println(err)
}
exit <- struct{}{}
goto RetryConn
}
mux.Lock()
pc, ok := save[d.cnt]
switch d.cmd {
case cmdConnect:
pc = newClient(d.cnt, send, string(d.data), proxy)
if pc == nil {
continue
}
save[d.cnt] = pc
ok = false
case cmdClose:
if ok {
delete(save, d.cnt)
close(pc.recv)
}
ok = false
default:
if ok {
// 只有这种情况才需要在Read方法中回收对象
pc.recv <- d
}
}
if !ok {
putData(d)
}
mux.Unlock()
}
}
func newClient(cnt uint8, send chan *data, addr, proxy string) *proxyConn {
cc := &proxyConn{
cnt: cnt,
send: send,
recv: make(chan *data),
}
var (
rc net.Conn
err error
)
if proxy != "" {
rc, err = net.Dial("tcp", proxy)
if err == nil {
// 客户端使用代理转发,仅支持 CONNECT 方法的代理(访问https地址)
_, err = fmt.Fprintf(rc, "CONNECT %s HTTP/1.1\r\n"+
"Host: %s\r\n"+
"User-Agent: Go-http-client/1.1\r\n"+
"Proxy-Connection: Keep-Alive\r\n\r\n", addr, addr)
if err == nil {
// 丢弃返回信息,结果包含 httpEstablished 的内容
err = discardData(rc)
}
}
} else {
rc, err = net.Dial("tcp", addr)
}
if err != nil {
cc.send <- sendData(cnt, cmdConnect, []byte(err.Error()))
log.Println(err)
return nil
}
cc.send <- sendData(cnt, cmdConnect, nil)
go func() {
_, ec := io.Copy(cc, rc)
if ec != nil {
log.Println(ec)
}
if ec = cc.Close(); ec != nil {
log.Println(ec)
}
if ec = rc.Close(); ec != nil {
log.Println(ec)
}
}()
go func() {
// 客户端cc读到close,则会在Read返回io.EOF
_, ec := io.Copy(rc, cc)
if ec != nil {
log.Println(ec)
}
}()
return cc
}
const (
cmdData = iota
cmdConnect
cmdClose
)
type proxyConn struct {
cnt uint8
rb bytes.Buffer
send, recv chan *data
}
func (p *proxyConn) Read(b []byte) (int, error) {
if p.rb.Len() == 0 {
if d, ok := <-p.recv; ok {
p.rb.Write(d.data)
putData(d)
} else {
return 0, io.EOF
}
}
return p.rb.Read(b)
}
func (p *proxyConn) Write(b []byte) (int, error) {
p.send <- sendData(p.cnt, cmdData, b)
return len(b), nil
}
func (p *proxyConn) Close() error {
p.send <- sendData(p.cnt, cmdClose, nil)
return nil
}
func (p *proxyConn) LocalAddr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1)}
}
func (p *proxyConn) RemoteAddr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1)}
}
func (p *proxyConn) SetDeadline(time.Time) error {
return nil
}
func (p *proxyConn) SetReadDeadline(time.Time) error {
return nil
}
func (p *proxyConn) SetWriteDeadline(time.Time) error {
return nil
}
// 实现资源对象的复用
type data struct {
cnt, cmd uint8
buf, data []byte
}
var dataPool = &sync.Pool{New: func() any {
return &data{
buf: make([]byte, 0, 512),
}
}}
func putData(d *data) {
if d != nil {
dataPool.Put(d)
}
}
func recvData(c *websocket.Conn) (*data, error) {
_, r, err := c.NextReader()
if err != nil {
return nil, err
}
var (
n int
d = dataPool.Get().(*data)
)
// 参考 io.ReadAll 的代码,复用 d.buf
d.buf = d.buf[:0]
for {
n, err = r.Read(d.buf[len(d.buf):cap(d.buf)])
d.buf = d.buf[:len(d.buf)+n]
if err != nil {
if err == io.EOF {
break
}
return d, err
}
if len(d.buf) == cap(d.buf) {
d.buf = append(d.buf, 0)[:len(d.buf)]
}
}
d.cnt, d.cmd = d.buf[0], d.buf[1]
d.data = d.buf[2:]
return d, nil
}
func sendData(cnt, cmd uint8, b []byte) *data {
d := dataPool.Get().(*data)
d.buf = append(d.buf[:0], cnt, cmd)
d.buf = append(d.buf, b...)
return d
}
func discardData(r io.Reader) error {
d := dataPool.Get().(*data)
_, err := r.Read(d.buf[:cap(d.buf)])
dataPool.Put(d)
return err
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment