Last active
November 30, 2022 01:31
-
-
Save kirk91/ec25703848172e8f56f671e0e1c73751 to your computer and use it in GitHub Desktop.
Pass File Descriptor over Unix Domain Socket
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"log" | |
"net" | |
"net/http" | |
"os" | |
"syscall" | |
"golang.org/x/sys/unix" | |
) | |
const udsPath = "/tmp/fd-pass-example.sock" | |
func main() { | |
os.Remove(udsPath) //nolint: errcheck | |
lis, err := net.Listen("unix", udsPath) | |
if err != nil { | |
panic(err) | |
} | |
defer lis.Close() | |
log.Println("Wait receiving listener ...") | |
conn, err := lis.Accept() | |
if err != nil { | |
panic(err) | |
} | |
defer conn.Close() | |
httpLis := receiveListener(conn.(*net.UnixConn)) | |
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { | |
fmt.Fprintf(w, "[server2] Hello, world!") | |
}) | |
log.Printf("Server is listening on %s ...\n", httpLis.Addr()) | |
http.Serve(httpLis, nil) | |
} | |
func receiveListener(conn *net.UnixConn) net.Listener { | |
connFd, err := getConnFd(conn) | |
if err != nil { | |
panic(err) | |
} | |
// receive socket control message | |
b := make([]byte, unix.CmsgSpace(4)) | |
_, _, _, _, err = unix.Recvmsg(connFd, nil, b, 0) | |
if err != nil { | |
panic(err) | |
} | |
// parse socket control message | |
cmsgs, err := unix.ParseSocketControlMessage(b) | |
if err != nil { | |
panic(err) | |
} | |
fds, err := unix.ParseUnixRights(&cmsgs[0]) | |
if err != nil { | |
panic(err) | |
} | |
fd := fds[0] | |
log.Printf("Got socket fd %d\n", fd) | |
// construct net listener | |
f := os.NewFile(uintptr(fd), "listener") | |
defer f.Close() | |
l, err := net.FileListener(f) | |
if err != nil { | |
panic(err) | |
} | |
return l | |
} | |
func getConnFd(conn syscall.Conn) (connFd int, err error) { | |
var rawConn syscall.RawConn | |
rawConn, err = conn.SyscallConn() | |
if err != nil { | |
return | |
} | |
err = rawConn.Control(func(fd uintptr) { | |
connFd = int(fd) | |
}) | |
return | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"fmt" | |
"log" | |
"net" | |
"net/http" | |
"syscall" | |
"time" | |
"golang.org/x/sys/unix" | |
) | |
const serverAddr = "127.0.0.1:8080" | |
func main() { | |
lis, err := net.Listen("tcp", serverAddr) | |
if err != nil { | |
panic(err) | |
} | |
var s http.Server | |
mux := new(http.ServeMux) | |
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { | |
fmt.Fprintf(w, "[server1] Hello, world!") | |
}) | |
mux.HandleFunc("/passfd", func(w http.ResponseWriter, r *http.Request) { | |
if err := sendListener(lis.(*net.TCPListener)); err != nil { | |
fmt.Fprintf(w, "Error: %v", err) | |
return | |
} | |
fmt.Fprintf(w, "Success") | |
time.AfterFunc(time.Millisecond*50, func() { | |
log.Println("Shutdown server ...") | |
s.Shutdown(context.Background()) | |
}) | |
}) | |
s.Handler = mux | |
log.Printf("Server is listening on %s ...\n", serverAddr) | |
s.Serve(lis) | |
log.Println("Bye bye") | |
} | |
func sendListener(lis *net.TCPListener) error { | |
// connect to the unix socket | |
const udsPath = "/tmp/fd-pass-example.sock" | |
conn, err := net.Dial("unix", udsPath) | |
if err != nil { | |
return err | |
} | |
defer conn.Close() | |
connFd, err := getConnFd(conn.(*net.UnixConn)) | |
if err != nil { | |
return err | |
} | |
// pass listener fd | |
lisFd, err := getConnFd(lis) | |
if err != nil { | |
return err | |
} | |
rights := unix.UnixRights(int(lisFd)) | |
return unix.Sendmsg(connFd, nil, rights, nil, 0) | |
} | |
func getConnFd(conn syscall.Conn) (connFd int, err error) { | |
var rawConn syscall.RawConn | |
rawConn, err = conn.SyscallConn() | |
if err != nil { | |
return | |
} | |
err = rawConn.Control(func(fd uintptr) { | |
connFd = int(fd) | |
}) | |
return | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment