Skip to content

Instantly share code, notes, and snippets.

@korc
Created July 4, 2020 03:21
Show Gist options
  • Save korc/e65d4e64240364c8649f33a291c4654c to your computer and use it in GitHub Desktop.
Save korc/e65d4e64240364c8649f33a291c4654c to your computer and use it in GitHub Desktop.
SMTP front-end daemon (w/ SPF check with -spf option)
package main
import (
"crypto/tls"
"flag"
"fmt"
"io"
"log"
"net"
"os"
"regexp"
"strings"
"syscall"
"time"
"github.com/emersion/go-smtp"
"github.com/mileusna/spf"
)
type frontBackend struct {
nextHop string
nextHopLMTP bool
rcptFilter *regexp.Regexp
myHello string
checkSpf bool
}
func (be *frontBackend) Login(st *smtp.ConnectionState, username, password string) (smtp.Session, error) {
log.Printf("Login attempt from %s [%s]: %#v / %#v", st.RemoteAddr, st.Hostname, username, password)
return nil, smtp.ErrAuthUnsupported
}
func (be *frontBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
log.Printf("[%s] Anonymous login from [%s] (tls=%t)",
state.RemoteAddr, state.Hostname, state.TLS.HandshakeComplete)
sess := &frontSession{cs: state, be: be, to: []string{}}
if err := sess.createSMTPClient(); err != nil {
sess.log("Cannot connect to next hop %#v: %s", be.nextHop, err)
if _, ok := err.(*smtp.SMTPError); ok {
return nil, err
}
return nil, &smtp.SMTPError{Code: 441, Message: "remote error"}
}
return sess, nil
}
type frontSession struct {
cs *smtp.ConnectionState
be *frontBackend
c *smtp.Client
from string
to []string
opts smtp.MailOptions
spfResult spf.Result
}
func (s *frontSession) createSMTPClient() error {
proto := "tcp"
if s.be.nextHopLMTP {
proto = "unix"
}
conn, err := net.Dial(proto, s.be.nextHop)
if err != nil {
return err
}
if s.be.nextHopLMTP {
s.c, err = smtp.NewClientLMTP(conn, "lmtp-server")
} else {
host, _, _ := net.SplitHostPort(s.be.nextHop)
s.c, err = smtp.NewClient(conn, host)
}
if err != nil {
return err
}
if s.be.myHello != "" {
if err := s.c.Hello(s.be.myHello); err != nil {
return err
}
}
return nil
}
func (s *frontSession) log(format string, args ...interface{}) {
log.Printf("[%s] "+format, append([]interface{}{s.cs.RemoteAddr}, args...)...)
}
func (s *frontSession) Reset() {
//s.log("Session reset")
s.from = ""
s.to = []string{}
if err := s.c.Reset(); err != nil {
s.log("Could not remote reset: %s", err)
}
s.opts = smtp.MailOptions{}
}
func (s *frontSession) Logout() error {
//s.log("Session logout")
if err := s.c.Close(); err != nil {
s.log("Cannot close remote: %s", err)
return &smtp.SMTPError{Code: 441, Message: "remote error"}
}
return nil
}
func (s *frontSession) Mail(from string, opts smtp.MailOptions) error {
s.log("MAIL FROM: %#v, Size: %d, RequireTLS: %t, UTF8: %t", from, opts.Size, opts.RequireTLS, opts.UTF8)
s.from = from
if s.be.checkSpf {
host, _, _ := net.SplitHostPort(s.cs.RemoteAddr.String())
emlParts := strings.Split(from, "@")
s.spfResult = spf.CheckHost(net.ParseIP(host), emlParts[len(emlParts)-1], from, s.cs.Hostname)
s.log("SPF result for %s as %s: %s", host, emlParts[len(emlParts)-1], s.spfResult)
if s.spfResult == spf.Fail {
return &smtp.SMTPError{Code: 550, EnhancedCode: smtp.EnhancedCode{5, 7, 1}, Message: "SPF failed"}
}
}
s.opts = opts
if err := s.c.Mail(from, &opts); err != nil {
s.log("Cannot remote Mail: %s", err)
return err
}
return nil
}
func (s *frontSession) Rcpt(to string) error {
s.log("RCPT TO: %#v", to)
if s.be.rcptFilter != nil {
if !s.be.rcptFilter.MatchString(to) {
s.log("Recipient does not match filter: %#v", to)
return &smtp.SMTPError{Code: 510, Message: "Invalid recipient"}
}
}
if err := s.c.Rcpt(to); err != nil {
return err
}
s.to = append(s.to, to)
return nil
}
func (s *frontSession) receivedHeader() string {
tlsInfo := "no TLS"
if t := s.cs.TLS; t.HandshakeComplete {
tlsInfo = fmt.Sprintf("TLS v%d.%d cph=0x%x sn=%s nCerts=%d okCerts=%d",
t.Version>>8, t.Version&0xff, t.CipherSuite, t.ServerName, len(t.PeerCertificates), len(t.VerifiedChains))
}
return fmt.Sprintf("from %s ([%s])\n by %s [%s]\n (%s)\n (envelope-from <%s>)\n for <%s>; %s",
s.cs.Hostname, s.cs.RemoteAddr, s.be.myHello, s.cs.LocalAddr, tlsInfo, s.from,
strings.Join(s.to, ","), time.Now().Format(time.RFC822))
}
func (s *frontSession) Data(r io.Reader) error {
//s.log("Data: reading from: %#v", r)
wr, err := s.c.Data()
if err != nil {
s.log("Remote Data error: %s", err)
return err
}
var n int64
spfResult := ""
if s.spfResult.IsSet() {
spfResult = fmt.Sprintf("X-SPF-Result: %s\r\n", s.spfResult)
}
_, err = fmt.Fprintf(wr, "Received: %s\r\n%s", s.receivedHeader(), spfResult)
if err != nil {
s.log("Could not send Received: %s", err)
} else {
n, err = io.Copy(wr, r)
}
if err != nil {
s.log("Could not copy to next hop: %s", err)
if _, ok := err.(*smtp.SMTPError); ok {
return err
}
return &smtp.SMTPError{Code: 442, Message: "remote error"}
}
s.log("DATA transferred %d bytes of data to next hop at %#v", n, s.be.nextHop)
return nil
}
type loggingConn struct {
net.Conn
}
func (c loggingConn) Close() error {
log.Printf("Closing connection to: %s", c.RemoteAddr())
return c.Conn.Close()
}
type loggingListener struct {
net.Listener
}
func (l loggingListener) Accept() (net.Conn, error) {
conn, err := l.Listener.Accept()
if err != nil {
return nil, err
}
log.Printf("Accepted connection from: %s", conn.RemoteAddr())
return loggingConn{Conn: conn}, err
}
type dbgLog struct {
outFile *os.File
lastStamp time.Time
}
func (d *dbgLog) Write(p []byte) (n int, err error) {
now := time.Now()
if len(p) > 0 && now.Sub(d.lastStamp) > time.Second {
_, _ = fmt.Fprintf(d.outFile, "[%s]\n", now.Format(time.RFC3339))
d.lastStamp = now
}
return d.outFile.Write(p)
}
func newDbgLog(fileName string) (dl *dbgLog, err error) {
dl = &dbgLog{}
if dl.outFile, err = os.OpenFile(fileName, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644); err != nil {
return nil, err
}
return dl, nil
}
func main() {
listenAddr := flag.String("listen", ":25", "Listen address")
listenLMTP := flag.Bool("listen-lmtp", false, "Listen on unix socket as LMTP server")
domain := flag.String("domain", "localhost.localdomain", "Domain name")
noAuth := flag.Bool("no-auth", false, "Disable authentication")
allowSecureAuth := flag.Bool("allow-insecure-auth", false, "Allow insecure auth")
certFile := flag.String("cert", "", "TLS certificate")
keyFile := flag.String("key", "", "TLS certificate key (default: same as -cert)")
maxMessageSize := flag.Int("max-msg-size", 20*1024*1024, "Maximum message size in bytes")
maxRecipients := flag.Int("max-rcpt", 50, "Maximum recipients")
readTimeout := flag.String("read-timeout", "10s", "Read timeout")
writeTimeout := flag.String("write-timeout", "10s", "Write timeout")
nextHop := flag.String("next", "", "forwarding server address")
nextHopLMTP := flag.Bool("next-lmtp", false, "next server is LMTP over unix socket")
myHello := flag.String("hello", "", "My hello name (default -domain)")
rcptFilter := flag.String("rcpt-check", "", "regex for allowed RCPTs")
chroot := flag.String("chroot", "", "change root to directory before run (need CAP_SYS_CHROOT)")
debugFlag := flag.String("debug", "", "Debug output to file or 'stderr'")
checkSpf := flag.Bool("spf", false, "Check SPF or not")
flag.Parse()
if *myHello == "" {
*myHello = *domain
}
if *nextHop == "" {
log.Fatal("Need to set next SMTP server with -next")
}
be := &frontBackend{nextHop: *nextHop, nextHopLMTP: *nextHopLMTP, myHello: *myHello, checkSpf: *checkSpf}
if *rcptFilter != "" {
var err error
if be.rcptFilter, err = regexp.Compile(*rcptFilter); err != nil {
log.Fatal("Cannot compile rcpt-check to regexp: ", err)
}
}
srv := smtp.NewServer(be)
srv.Addr = *listenAddr
srv.LMTP = *listenLMTP
srv.Domain = *domain
srv.AuthDisabled = *noAuth
srv.AllowInsecureAuth = *allowSecureAuth
srv.MaxMessageBytes = *maxMessageSize
srv.MaxRecipients = *maxRecipients
if *debugFlag == "stderr" {
srv.Debug = os.Stderr
} else if *debugFlag != "" {
var err error
if srv.Debug, err = newDbgLog(*debugFlag); err != nil {
log.Fatal("Could not open debug output: ", err)
}
}
if *readTimeout != "" {
tm, err := time.ParseDuration(*readTimeout)
if err != nil {
log.Fatal("Cannot parse read timeout: ", err)
}
srv.ReadTimeout = tm
}
if *writeTimeout != "" {
tm, err := time.ParseDuration(*writeTimeout)
if err != nil {
log.Fatal("Cannot parse write timeout: ", err)
}
srv.WriteTimeout = tm
}
if *certFile != "" {
if *keyFile == "" {
*keyFile = *certFile
}
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
if err != nil {
log.Fatalf("Cannot load PEM keypair from %#v / %#v: %s", *certFile, *keyFile, err)
}
srv.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
}
if *chroot != "" {
if err := syscall.Chroot(*chroot); err != nil {
log.Fatalf("chroot to %#v failed: %s", *chroot, err)
}
log.Printf("Changed root to %#v", *chroot)
if err := os.Chdir("/"); err != nil {
log.Fatal("Could not do chdir after chroot: ", err)
}
}
log.Printf("Listening on %s", srv.Addr)
proto := "tcp"
if srv.LMTP {
proto = "unix"
}
l, err := net.Listen(proto, srv.Addr)
if err != nil {
log.Fatal("Cannot listen: ", err)
}
if err := srv.Serve(loggingListener{l}); err != nil {
log.Fatal("Cannot ListenAndServe: ", err)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment