|
package main |
|
|
|
import ( |
|
"flag" |
|
"fmt" |
|
"io" |
|
"log" |
|
"net" |
|
"os" |
|
"strings" |
|
|
|
"golang.org/x/net/websocket" |
|
) |
|
|
|
const ( |
|
protocolNameWebSocket string = "websocket" |
|
protocolNameTCP string = "tcp" |
|
protocolNameNamedPipe string = "namedpipe" |
|
protocolNameUnixDomainSocket string = "unix" |
|
) |
|
|
|
var ( |
|
targetAddress string |
|
sourceAddress string |
|
sourceProtocol string |
|
targetProtocol string |
|
websocketOrigin string |
|
) |
|
|
|
func init() { |
|
flag.StringVar(&sourceAddress, "source-address", "", "The address of the source we will listen on. This is a file path when using unix domain sockets, a url when using websockets.") |
|
flag.StringVar(&targetAddress, "target-address", "", "The address of the source we will listen on. This is a file path when using unix domain sockets, a url when using websockets.") |
|
flag.StringVar(&sourceProtocol, "source-protocol", "", "asdfadfasdfasdf") |
|
flag.StringVar(&targetProtocol, "target-protocol", "", "fdsadfsdf") |
|
flag.StringVar(&websocketOrigin, "websocket-origin", "", "rthwsrhtf") |
|
} |
|
|
|
func parseFlags() { |
|
flag.Parse() |
|
|
|
if sourceAddress == "" { |
|
log.Fatal("Source address not defined.") |
|
} |
|
|
|
if targetAddress == "" { |
|
log.Fatal("Target address not defined.") |
|
} |
|
|
|
if sourceProtocol == "" { |
|
log.Fatal("Source protocol not defined.") |
|
} |
|
|
|
if targetProtocol == "" { |
|
log.Fatal("Target protocol not defined.") |
|
} |
|
|
|
possibleProtocols := []string{protocolNameWebSocket, protocolNameTCP, protocolNameUnixDomainSocket} |
|
sourceProtocol = strings.ToLower(sourceProtocol) |
|
for index, element := range possibleProtocols { |
|
if sourceProtocol == element { |
|
|
|
// TODO: Remove this when websocket for the source is supported. |
|
if sourceProtocol == protocolNameWebSocket { |
|
log.Fatal("Web socket as the source is not currently supported.") |
|
} |
|
|
|
break |
|
} |
|
|
|
if index == len(possibleProtocols)-1 { |
|
log.Fatal("Not a correct value for source protocol. Possible values are: ", strings.Join(possibleProtocols, ",")) |
|
} |
|
} |
|
|
|
targetProtocol = strings.ToLower(targetProtocol) |
|
for index, element := range possibleProtocols { |
|
if targetProtocol == element { |
|
break |
|
} |
|
|
|
if index == len(possibleProtocols)-1 { |
|
log.Fatal("Not a correct value for target protocol. Possible values are: ", strings.Join(possibleProtocols, ",")) |
|
} |
|
} |
|
|
|
if targetProtocol == protocolNameWebSocket && websocketOrigin == "" { |
|
log.Fatal("You must specify an origin if you're using web sockets.") |
|
} |
|
} |
|
|
|
func main() { |
|
parseFlags() |
|
|
|
if sourceProtocol == protocolNameUnixDomainSocket { |
|
// Remove the file if it already exists |
|
if err := os.Remove(sourceAddress); err != nil { |
|
log.Fatal(err) |
|
} |
|
} |
|
|
|
fmt.Println("listening on the address: ", sourceAddress) |
|
listener, err := net.Listen(sourceProtocol, sourceAddress) |
|
if err != nil { |
|
log.Fatal("listen error:", err) |
|
} |
|
defer listener.Close() |
|
|
|
for { |
|
sourceConn, err := listener.Accept() |
|
if err != nil { |
|
log.Fatal("accept error:", err) |
|
} |
|
|
|
log.Printf("Client connected [%s]", sourceConn.RemoteAddr().Network()) |
|
|
|
var targetConn net.Conn |
|
|
|
if targetProtocol == protocolNameWebSocket { |
|
targetConn, err = websocket.Dial(targetAddress, "", websocketOrigin) |
|
} else { |
|
targetConn, err = net.Dial(targetProtocol, targetAddress) |
|
} |
|
|
|
if err != nil { |
|
log.Fatal("dial error:", err) |
|
} |
|
|
|
go forward(targetConn, sourceConn) |
|
go forward(sourceConn, targetConn) |
|
} |
|
} |
|
|
|
func forward(conn1 net.Conn, conn2 net.Conn) { |
|
io.Copy(conn1, conn2) |
|
} |