Created
July 4, 2020 03:21
-
-
Save korc/e65d4e64240364c8649f33a291c4654c to your computer and use it in GitHub Desktop.
SMTP front-end daemon (w/ SPF check with -spf option)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"crypto/tls" | |
"flag" | |
"fmt" | |
"io" | |
"log" | |
"net" | |
"os" | |
"regexp" | |
"strings" | |
"syscall" | |
"time" | |
"github.com/emersion/go-smtp" | |
"github.com/mileusna/spf" | |
) | |
type frontBackend struct { | |
nextHop string | |
nextHopLMTP bool | |
rcptFilter *regexp.Regexp | |
myHello string | |
checkSpf bool | |
} | |
func (be *frontBackend) Login(st *smtp.ConnectionState, username, password string) (smtp.Session, error) { | |
log.Printf("Login attempt from %s [%s]: %#v / %#v", st.RemoteAddr, st.Hostname, username, password) | |
return nil, smtp.ErrAuthUnsupported | |
} | |
func (be *frontBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) { | |
log.Printf("[%s] Anonymous login from [%s] (tls=%t)", | |
state.RemoteAddr, state.Hostname, state.TLS.HandshakeComplete) | |
sess := &frontSession{cs: state, be: be, to: []string{}} | |
if err := sess.createSMTPClient(); err != nil { | |
sess.log("Cannot connect to next hop %#v: %s", be.nextHop, err) | |
if _, ok := err.(*smtp.SMTPError); ok { | |
return nil, err | |
} | |
return nil, &smtp.SMTPError{Code: 441, Message: "remote error"} | |
} | |
return sess, nil | |
} | |
type frontSession struct { | |
cs *smtp.ConnectionState | |
be *frontBackend | |
c *smtp.Client | |
from string | |
to []string | |
opts smtp.MailOptions | |
spfResult spf.Result | |
} | |
func (s *frontSession) createSMTPClient() error { | |
proto := "tcp" | |
if s.be.nextHopLMTP { | |
proto = "unix" | |
} | |
conn, err := net.Dial(proto, s.be.nextHop) | |
if err != nil { | |
return err | |
} | |
if s.be.nextHopLMTP { | |
s.c, err = smtp.NewClientLMTP(conn, "lmtp-server") | |
} else { | |
host, _, _ := net.SplitHostPort(s.be.nextHop) | |
s.c, err = smtp.NewClient(conn, host) | |
} | |
if err != nil { | |
return err | |
} | |
if s.be.myHello != "" { | |
if err := s.c.Hello(s.be.myHello); err != nil { | |
return err | |
} | |
} | |
return nil | |
} | |
func (s *frontSession) log(format string, args ...interface{}) { | |
log.Printf("[%s] "+format, append([]interface{}{s.cs.RemoteAddr}, args...)...) | |
} | |
func (s *frontSession) Reset() { | |
//s.log("Session reset") | |
s.from = "" | |
s.to = []string{} | |
if err := s.c.Reset(); err != nil { | |
s.log("Could not remote reset: %s", err) | |
} | |
s.opts = smtp.MailOptions{} | |
} | |
func (s *frontSession) Logout() error { | |
//s.log("Session logout") | |
if err := s.c.Close(); err != nil { | |
s.log("Cannot close remote: %s", err) | |
return &smtp.SMTPError{Code: 441, Message: "remote error"} | |
} | |
return nil | |
} | |
func (s *frontSession) Mail(from string, opts smtp.MailOptions) error { | |
s.log("MAIL FROM: %#v, Size: %d, RequireTLS: %t, UTF8: %t", from, opts.Size, opts.RequireTLS, opts.UTF8) | |
s.from = from | |
if s.be.checkSpf { | |
host, _, _ := net.SplitHostPort(s.cs.RemoteAddr.String()) | |
emlParts := strings.Split(from, "@") | |
s.spfResult = spf.CheckHost(net.ParseIP(host), emlParts[len(emlParts)-1], from, s.cs.Hostname) | |
s.log("SPF result for %s as %s: %s", host, emlParts[len(emlParts)-1], s.spfResult) | |
if s.spfResult == spf.Fail { | |
return &smtp.SMTPError{Code: 550, EnhancedCode: smtp.EnhancedCode{5, 7, 1}, Message: "SPF failed"} | |
} | |
} | |
s.opts = opts | |
if err := s.c.Mail(from, &opts); err != nil { | |
s.log("Cannot remote Mail: %s", err) | |
return err | |
} | |
return nil | |
} | |
func (s *frontSession) Rcpt(to string) error { | |
s.log("RCPT TO: %#v", to) | |
if s.be.rcptFilter != nil { | |
if !s.be.rcptFilter.MatchString(to) { | |
s.log("Recipient does not match filter: %#v", to) | |
return &smtp.SMTPError{Code: 510, Message: "Invalid recipient"} | |
} | |
} | |
if err := s.c.Rcpt(to); err != nil { | |
return err | |
} | |
s.to = append(s.to, to) | |
return nil | |
} | |
func (s *frontSession) receivedHeader() string { | |
tlsInfo := "no TLS" | |
if t := s.cs.TLS; t.HandshakeComplete { | |
tlsInfo = fmt.Sprintf("TLS v%d.%d cph=0x%x sn=%s nCerts=%d okCerts=%d", | |
t.Version>>8, t.Version&0xff, t.CipherSuite, t.ServerName, len(t.PeerCertificates), len(t.VerifiedChains)) | |
} | |
return fmt.Sprintf("from %s ([%s])\n by %s [%s]\n (%s)\n (envelope-from <%s>)\n for <%s>; %s", | |
s.cs.Hostname, s.cs.RemoteAddr, s.be.myHello, s.cs.LocalAddr, tlsInfo, s.from, | |
strings.Join(s.to, ","), time.Now().Format(time.RFC822)) | |
} | |
func (s *frontSession) Data(r io.Reader) error { | |
//s.log("Data: reading from: %#v", r) | |
wr, err := s.c.Data() | |
if err != nil { | |
s.log("Remote Data error: %s", err) | |
return err | |
} | |
var n int64 | |
spfResult := "" | |
if s.spfResult.IsSet() { | |
spfResult = fmt.Sprintf("X-SPF-Result: %s\r\n", s.spfResult) | |
} | |
_, err = fmt.Fprintf(wr, "Received: %s\r\n%s", s.receivedHeader(), spfResult) | |
if err != nil { | |
s.log("Could not send Received: %s", err) | |
} else { | |
n, err = io.Copy(wr, r) | |
} | |
if err != nil { | |
s.log("Could not copy to next hop: %s", err) | |
if _, ok := err.(*smtp.SMTPError); ok { | |
return err | |
} | |
return &smtp.SMTPError{Code: 442, Message: "remote error"} | |
} | |
s.log("DATA transferred %d bytes of data to next hop at %#v", n, s.be.nextHop) | |
return nil | |
} | |
type loggingConn struct { | |
net.Conn | |
} | |
func (c loggingConn) Close() error { | |
log.Printf("Closing connection to: %s", c.RemoteAddr()) | |
return c.Conn.Close() | |
} | |
type loggingListener struct { | |
net.Listener | |
} | |
func (l loggingListener) Accept() (net.Conn, error) { | |
conn, err := l.Listener.Accept() | |
if err != nil { | |
return nil, err | |
} | |
log.Printf("Accepted connection from: %s", conn.RemoteAddr()) | |
return loggingConn{Conn: conn}, err | |
} | |
type dbgLog struct { | |
outFile *os.File | |
lastStamp time.Time | |
} | |
func (d *dbgLog) Write(p []byte) (n int, err error) { | |
now := time.Now() | |
if len(p) > 0 && now.Sub(d.lastStamp) > time.Second { | |
_, _ = fmt.Fprintf(d.outFile, "[%s]\n", now.Format(time.RFC3339)) | |
d.lastStamp = now | |
} | |
return d.outFile.Write(p) | |
} | |
func newDbgLog(fileName string) (dl *dbgLog, err error) { | |
dl = &dbgLog{} | |
if dl.outFile, err = os.OpenFile(fileName, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644); err != nil { | |
return nil, err | |
} | |
return dl, nil | |
} | |
func main() { | |
listenAddr := flag.String("listen", ":25", "Listen address") | |
listenLMTP := flag.Bool("listen-lmtp", false, "Listen on unix socket as LMTP server") | |
domain := flag.String("domain", "localhost.localdomain", "Domain name") | |
noAuth := flag.Bool("no-auth", false, "Disable authentication") | |
allowSecureAuth := flag.Bool("allow-insecure-auth", false, "Allow insecure auth") | |
certFile := flag.String("cert", "", "TLS certificate") | |
keyFile := flag.String("key", "", "TLS certificate key (default: same as -cert)") | |
maxMessageSize := flag.Int("max-msg-size", 20*1024*1024, "Maximum message size in bytes") | |
maxRecipients := flag.Int("max-rcpt", 50, "Maximum recipients") | |
readTimeout := flag.String("read-timeout", "10s", "Read timeout") | |
writeTimeout := flag.String("write-timeout", "10s", "Write timeout") | |
nextHop := flag.String("next", "", "forwarding server address") | |
nextHopLMTP := flag.Bool("next-lmtp", false, "next server is LMTP over unix socket") | |
myHello := flag.String("hello", "", "My hello name (default -domain)") | |
rcptFilter := flag.String("rcpt-check", "", "regex for allowed RCPTs") | |
chroot := flag.String("chroot", "", "change root to directory before run (need CAP_SYS_CHROOT)") | |
debugFlag := flag.String("debug", "", "Debug output to file or 'stderr'") | |
checkSpf := flag.Bool("spf", false, "Check SPF or not") | |
flag.Parse() | |
if *myHello == "" { | |
*myHello = *domain | |
} | |
if *nextHop == "" { | |
log.Fatal("Need to set next SMTP server with -next") | |
} | |
be := &frontBackend{nextHop: *nextHop, nextHopLMTP: *nextHopLMTP, myHello: *myHello, checkSpf: *checkSpf} | |
if *rcptFilter != "" { | |
var err error | |
if be.rcptFilter, err = regexp.Compile(*rcptFilter); err != nil { | |
log.Fatal("Cannot compile rcpt-check to regexp: ", err) | |
} | |
} | |
srv := smtp.NewServer(be) | |
srv.Addr = *listenAddr | |
srv.LMTP = *listenLMTP | |
srv.Domain = *domain | |
srv.AuthDisabled = *noAuth | |
srv.AllowInsecureAuth = *allowSecureAuth | |
srv.MaxMessageBytes = *maxMessageSize | |
srv.MaxRecipients = *maxRecipients | |
if *debugFlag == "stderr" { | |
srv.Debug = os.Stderr | |
} else if *debugFlag != "" { | |
var err error | |
if srv.Debug, err = newDbgLog(*debugFlag); err != nil { | |
log.Fatal("Could not open debug output: ", err) | |
} | |
} | |
if *readTimeout != "" { | |
tm, err := time.ParseDuration(*readTimeout) | |
if err != nil { | |
log.Fatal("Cannot parse read timeout: ", err) | |
} | |
srv.ReadTimeout = tm | |
} | |
if *writeTimeout != "" { | |
tm, err := time.ParseDuration(*writeTimeout) | |
if err != nil { | |
log.Fatal("Cannot parse write timeout: ", err) | |
} | |
srv.WriteTimeout = tm | |
} | |
if *certFile != "" { | |
if *keyFile == "" { | |
*keyFile = *certFile | |
} | |
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile) | |
if err != nil { | |
log.Fatalf("Cannot load PEM keypair from %#v / %#v: %s", *certFile, *keyFile, err) | |
} | |
srv.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}} | |
} | |
if *chroot != "" { | |
if err := syscall.Chroot(*chroot); err != nil { | |
log.Fatalf("chroot to %#v failed: %s", *chroot, err) | |
} | |
log.Printf("Changed root to %#v", *chroot) | |
if err := os.Chdir("/"); err != nil { | |
log.Fatal("Could not do chdir after chroot: ", err) | |
} | |
} | |
log.Printf("Listening on %s", srv.Addr) | |
proto := "tcp" | |
if srv.LMTP { | |
proto = "unix" | |
} | |
l, err := net.Listen(proto, srv.Addr) | |
if err != nil { | |
log.Fatal("Cannot listen: ", err) | |
} | |
if err := srv.Serve(loggingListener{l}); err != nil { | |
log.Fatal("Cannot ListenAndServe: ", err) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment