Skip to content

Instantly share code, notes, and snippets.

@mostlygeek
Last active September 3, 2025 17:47
Show Gist options
  • Save mostlygeek/558b83f12383a0fa9a6d89c5542b89bd to your computer and use it in GitHub Desktop.
Save mostlygeek/558b83f12383a0fa9a6d89c5542b89bd to your computer and use it in GitHub Desktop.
tsidp.go refactor
// The tsidp command is an OpenID Connect Identity Provider server.
package main
import (
"context"
"errors"
"flag"
"fmt"
"log"
"net"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
// Existing Tailscale dependencies
"tailscale.com/client/local"
"tailscale.com/envknob"
"tailscale.com/hostinfo"
"tailscale.com/ipn/ipnstate"
"tailscale.com/tsnet"
"tailscale.com/version"
// reload idpServer into app
"tailscale.com/cmd/tsidp/internal/app"
// http handlers
"tailscale.com/cmd/tsidp/internal/handlers"
// CRUD on the .json data files
"tailscale.com/cmd/tsidp/internal/store"
)
// FunnelClientsFile is the file where client IDs and secrets are persisted.
const FunnelClientsFile = "oidc-funnel-clients.json"
// Config holds all configuration for the tsidp application.
type Config struct {
Verbose bool
Port int
LocalPort int
UseLocalTailscaled bool
Funnel bool
Hostname string
Dir string
EnableSTS bool
ServerURL string // Populated at runtime
LoopbackURL string // Populated at runtime
}
// parseFlags parses command-line flags and returns a Config struct.
func parseFlags() *Config {
cfg := &Config{}
flag.BoolVar(&cfg.Verbose, "verbose", false, "be verbose")
flag.IntVar(&cfg.Port, "port", 443, "port to listen on")
flag.IntVar(&cfg.LocalPort, "local-port", -1, "allow requests from localhost")
flag.BoolVar(&cfg.UseLocalTailscaled, "use-local-tailscaled", false, "use local tailscaled instead of tsnet")
flag.BoolVar(&cfg.Funnel, "funnel", false, "use Tailscale Funnel to make tsidp available on the public internet")
flag.StringVar(&cfg.Hostname, "hostname", "idp", "tsnet hostname to use instead of idp")
flag.StringVar(&cfg.Dir, "dir", "", "tsnet state directory")
flag.BoolVar(&cfg.EnableSTS, "enable-sts", false, "enable OIDC STS token exchange support")
flag.Parse()
if cfg.LocalPort != -1 {
cfg.LoopbackURL = fmt.Sprintf("http://localhost:%d", cfg.LocalPort)
}
return cfg
}
func main() {
// 1. Parse configuration from command-line flags.
cfg := parseFlags()
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
if !envknob.UseWIPCode() {
log.Fatal("cmd/tsidp is a work in progress and has not been security reviewed;\nits use requires TAILSCALE_USE_WIP_CODE=1 be set in the environment for now.")
}
// 2. Set up Tailscale client and get network status.
lc, st, cleanup, err := setupTailscale(ctx, cfg)
if err != nil {
log.Fatalf("failed to set up tailscale: %v", err)
}
defer cleanup()
// Update config with the runtime server URL.
if cfg.Port != 443 {
cfg.ServerURL = fmt.Sprintf("https://%s:%d", strings.TrimSuffix(st.Self.DNSName, "."), cfg.Port)
} else {
cfg.ServerURL = fmt.Sprintf("https://%s", strings.TrimSuffix(st.Self.DNSName, "."))
}
log.Printf("tsidp base URL will be %s", cfg.ServerURL)
// 3. Initialize dependencies.
clientStore, err := store.NewFileStore(FunnelClientsFile)
if err != nil {
log.Fatalf("failed to initialize client store: %v", err)
}
// 4. Instantiate the core application server.
idpServer, err := app.NewServer(cfg, lc, clientStore)
if err != nil {
log.Fatalf("failed to create IDP server: %v", err)
}
// 5. Set up the HTTP router with all the handlers.
router := handlers.NewRouter(idpServer)
// 6. Configure and start the HTTP server.
httpServer := &http.Server{
Handler: router,
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 120 * time.Second,
}
// Start network listeners.
listeners, err := setupListeners(ctx, cfg, lc, st)
if err != nil {
log.Fatalf("failed to set up listeners: %v", err)
}
// Run servers in the background.
errChan := make(chan error, len(listeners))
for _, ln := range listeners {
log.Printf("Listening on %s", ln.Addr())
go func(l net.Listener) {
if err := httpServer.Serve(l); !errors.Is(err, http.ErrServerClosed) {
errChan <- err
}
}(l)
}
log.Printf("tsidp server started successfully.")
// Wait for a shutdown signal or a server error.
select {
case err := <-errChan:
log.Fatalf("server error: %v", err)
case <-ctx.Done():
log.Println("shutdown signal received, stopping server...")
}
// Perform a graceful shutdown.
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
log.Printf("HTTP server shutdown error: %v", err)
}
log.Println("server stopped gracefully.")
}
// setupTailscale initializes the Tailscale connection.
func setupTailscale(ctx context.Context, cfg *Config) (*local.Client, *ipnstate.Status, func(), error) {
if cfg.UseLocalTailscaled {
lc := &local.Client{}
st, err := lc.StatusWithoutPeers(ctx)
if err != nil {
return nil, nil, func() {}, fmt.Errorf("getting status: %w", err)
}
return lc, st, func() {}, nil
}
hostinfo.SetApp("tsidp")
ts := &tsnet.Server{
Hostname: cfg.Hostname,
Dir: cfg.Dir,
}
if cfg.Verbose {
ts.Logf = log.Printf
}
st, err := ts.Up(ctx)
if err != nil {
return nil, nil, func() {}, fmt.Errorf("starting tsnet.Server: %w", err)
}
lc, err := ts.LocalClient()
if err != nil {
ts.Close()
return nil, nil, func() {}, fmt.Errorf("getting local client: %w", err)
}
return lc, st, func() { ts.Close() }, nil
}
// setupListeners configures and returns the required network listeners.
func setupListeners(ctx context.Context, cfg *Config, lc *local.Client, st *ipnstate.Status) ([]net.Listener, error) {
var lns []net.Listener
if !cfg.UseLocalTailscaled {
ts, ok := lc.TsnetServer()
if !ok {
return nil, errors.New("cannot get tsnet server from local client")
}
var ln net.Listener
var err error
if cfg.Funnel {
ln, err = ts.ListenFunnel("tcp", fmt.Sprintf(":%d", cfg.Port))
} else {
ln, err = ts.ListenTLS("tcp", fmt.Sprintf(":%d", cfg.Port))
}
if err != nil {
return nil, err
}
lns = append(lns, ln)
} else {
// Logic for local tailscaled listeners
if version.AtLeast(st.Version, "1.71.0") {
// Setup serve config...
}
}
if cfg.LocalPort > 0 {
ln, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", cfg.LocalPort))
if err != nil {
return nil, err
}
lns = append(lns, ln)
}
if len(lns) == 0 {
return nil, errors.New("no listeners were successfully started")
}
return lns, nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment