|
// Package serverTiming implements a small Server-Timing middleware and helpers to measure named timings during HTTP |
|
// request handling and emit them as the `Server-Timing` response header as specified by the Performance API. |
|
package serverTiming |
|
|
|
import ( |
|
"bufio" |
|
"context" |
|
"errors" |
|
"fmt" |
|
"maps" |
|
"net" |
|
"net/http" |
|
"sort" |
|
"strings" |
|
"sync" |
|
"time" |
|
) |
|
|
|
var ErrNotImplemented = errors.New("server timings middleware: not implemented") |
|
|
|
type ( |
|
// timingsMutexCtxKey is the context key type for storing a pointer to a sync.Mutex used to protect the timingsMap. |
|
timingsMutexCtxKey struct{} |
|
|
|
// timingsMapCtxKey is the context key type for storing timingsMap. |
|
timingsMapCtxKey struct{} |
|
) |
|
|
|
type ( |
|
// timingsItem represents a single metric entry with an optional description and a measured duration. |
|
timingsItem struct { |
|
desc string |
|
duration time.Duration |
|
} |
|
|
|
// timingsMap stores metrics by name. |
|
timingsMap map[string]timingsItem |
|
) |
|
|
|
// StartMetric begins a timing with the given name and description for the provided *http.Request. |
|
// |
|
// It returns a commit() function which records the elapsed time into the request-scoped timings map. |
|
// The returned commit function is safe to call multiple times. |
|
// |
|
// Usage example: |
|
// |
|
// commit := serverTiming.StartMetric(r, "db_query", "Database query execution") |
|
// defer commit() |
|
// |
|
// // .. some database query operation .. |
|
// |
|
// commit() // record the timing |
|
func StartMetric(r *http.Request, name, desc string) (commit func()) { |
|
var ( |
|
tm timingsMap |
|
mu *sync.Mutex |
|
tmFound bool |
|
mxFound bool |
|
) |
|
|
|
if v := r.Context().Value(timingsMapCtxKey{}); v != nil { |
|
tm, tmFound = v.(timingsMap) |
|
} |
|
|
|
if v := r.Context().Value(timingsMutexCtxKey{}); v != nil { |
|
mu, mxFound = v.(*sync.Mutex) |
|
} |
|
|
|
if !tmFound || !mxFound || tm == nil || mu == nil { |
|
return func() {} // middleware not installed for this request |
|
} |
|
|
|
startedAt := time.Now() |
|
|
|
return sync.OnceFunc(func() { |
|
mu.Lock() |
|
|
|
tm[name] = timingsItem{duration: time.Since(startedAt), desc: desc} |
|
|
|
mu.Unlock() |
|
}) |
|
} |
|
|
|
// responseWriter wraps http.ResponseWriter to ensure the Server-Timing header is written exactly once (either when |
|
// the headers are first emitted, or after the handler completes). It implements common optional interfaces when the |
|
// underlying writer supports them (http.Flusher, http.Hijacker, http.Pusher). |
|
type responseWriter struct { |
|
orig http.ResponseWriter |
|
|
|
// headerWritten tracks whether the headers have already been sent through this response writer wrapper. |
|
headerWritten bool |
|
|
|
tm timingsMap |
|
mu *sync.Mutex |
|
} |
|
|
|
var ( // ensure responseWriter implements the most common interfaces |
|
_ http.ResponseWriter = (*responseWriter)(nil) |
|
_ http.Flusher = (*responseWriter)(nil) |
|
_ http.Hijacker = (*responseWriter)(nil) |
|
_ http.Pusher = (*responseWriter)(nil) |
|
) |
|
|
|
// Header delegates to the underlying ResponseWriter's Header method. |
|
func (rw *responseWriter) Header() http.Header { return rw.orig.Header() } |
|
|
|
// Write ensures the Server-Timing header is set before the first write. |
|
func (rw *responseWriter) Write(b []byte) (int, error) { |
|
if !rw.headerWritten { |
|
rw.headerWritten = true |
|
rw.setServerTimingHeader() |
|
} |
|
|
|
return rw.orig.Write(b) |
|
} |
|
|
|
// WriteHeader ensures the Server-Timing header is set before the status code is sent to the client. |
|
func (rw *responseWriter) WriteHeader(statusCode int) { |
|
if !rw.headerWritten { |
|
rw.headerWritten = true |
|
rw.setServerTimingHeader() |
|
} |
|
|
|
rw.orig.WriteHeader(statusCode) |
|
} |
|
|
|
// Flush delegates to the underlying Flusher if present. |
|
func (rw *responseWriter) Flush() { |
|
if flusher, ok := rw.orig.(http.Flusher); ok { |
|
flusher.Flush() |
|
} |
|
} |
|
|
|
// Hijack delegates to the underlying Hijacker if present; otherwise return a wrapped error indicating the |
|
// operation is not supported. |
|
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { |
|
if hj, ok := rw.orig.(http.Hijacker); ok { |
|
return hj.Hijack() |
|
} |
|
|
|
return nil, nil, fmt.Errorf("%w: underlying ResponseWriter does not implement http.Hijacker", ErrNotImplemented) |
|
} |
|
|
|
// Push delegates to the underlying Pusher if present; otherwise return an error. |
|
func (rw *responseWriter) Push(target string, opts *http.PushOptions) error { |
|
if pusher, ok := rw.orig.(http.Pusher); ok { |
|
return pusher.Push(target, opts) |
|
} |
|
|
|
return fmt.Errorf("%w: underlying ResponseWriter does not implement http.Pusher", ErrNotImplemented) |
|
} |
|
|
|
// setServerTimingHeader formats the request-local timings map according to the Server-Timing header specification |
|
// and sets the header on the underlying ResponseWriter. It takes a snapshot of the timings map while holding |
|
// the mutex and performs formatting without holding the mutex to keep the critical section short. |
|
func (rw *responseWriter) setServerTimingHeader() { |
|
// make a thread-safe copy of the map to avoid holding the mutex while formatting |
|
rw.mu.Lock() |
|
tmCopy := maps.Clone(rw.tm) |
|
rw.mu.Unlock() |
|
|
|
var buf strings.Builder |
|
|
|
// for stable output, iterate keys in sorted order |
|
names := make([]string, 0, len(tmCopy)) |
|
for name := range tmCopy { |
|
names = append(names, name) |
|
} |
|
|
|
sort.Strings(names) |
|
|
|
buf.Grow(len(tmCopy) * 50) //nolint:mnd // rough estimate |
|
|
|
// build header entries: name[;desc="..."];dur=... (descriptions are quoted) |
|
for i, name := range names { |
|
item := tmCopy[name] |
|
|
|
if i > 0 { |
|
buf.WriteString(", ") |
|
} |
|
|
|
buf.WriteString(name) |
|
|
|
if item.desc != "" { |
|
// replace double quotes with single quotes as a simple, safe sanitization |
|
desc := strings.ReplaceAll(item.desc, "\"", "'") |
|
|
|
buf.WriteString(`;desc="`) |
|
buf.WriteString(desc) |
|
buf.WriteRune('"') |
|
} |
|
|
|
buf.WriteString(";dur=") |
|
|
|
// write duration in milliseconds with 3 decimal places |
|
buf.WriteString(fmt.Sprintf("%.3f", item.duration.Seconds()*1000.0)) //nolint:mnd |
|
} |
|
|
|
if buf.Len() > 0 { |
|
rw.Header().Set("Server-Timing", buf.String()) |
|
} |
|
} |
|
|
|
// New creates a middleware for dealing with Server-Timing header. |
|
// |
|
// The skipper function may be nil. If provided, and it returns true for a request, the middleware will skip |
|
// instrumentation for that request. |
|
// |
|
// Link: https://developer.mozilla.org/en-US/docs/Web/API/Performance_API/Server_timing |
|
func New(skipper func(*http.Request) bool) func(http.Handler) http.Handler { |
|
return func(next http.Handler) http.Handler { |
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
if skipper != nil && skipper(r) { |
|
next.ServeHTTP(w, r) |
|
|
|
return |
|
} |
|
|
|
// create a fresh timings map and mutex for this request |
|
tm := make(timingsMap) |
|
|
|
var mx sync.Mutex |
|
|
|
// store them in the request context so handlers can call StartMetric |
|
r = r.WithContext(context.WithValue(r.Context(), timingsMapCtxKey{}, tm)) |
|
r = r.WithContext(context.WithValue(r.Context(), timingsMutexCtxKey{}, &mx)) //nolint:contextcheck |
|
|
|
rw := responseWriter{orig: w, tm: tm, mu: &mx} |
|
|
|
defer func() { |
|
if !rw.headerWritten { |
|
rw.setServerTimingHeader() |
|
} |
|
}() |
|
|
|
next.ServeHTTP(&rw, r) |
|
}) |
|
} |
|
} |