Skip to content

Instantly share code, notes, and snippets.

@dchw
Created December 17, 2021 22:50
Show Gist options
  • Save dchw/cba1558c78e0875cf0b05c191d797136 to your computer and use it in GitHub Desktop.
Save dchw/cba1558c78e0875cf0b05c191d797136 to your computer and use it in GitHub Desktop.
package main
import (
"encoding/base64"
"flag"
"fmt"
"log"
"net/http"
"net/url"
"os"
"os/signal"
"syscall"
"time"
)
func main() {
log.Println("Starting Mock OAuth Server")
authPathPtr := flag.String("authorizePath", "/authorize", "Authorize endpoint path")
tokenPathPtr := flag.String("tokenPath", "/token", "Token endpoint path")
tokenJSONFmt := flag.String("tokenJsonFmt", `{"access_token": "%s"}`, "Format string for if token response is JSON")
tokenFmt := flag.String("tokenFmt", `access_token=%s`, "Format string for if token response is JSON")
portPtr := flag.String("port", ":8080", "The port")
tlsPtr := flag.Bool("tls", false, "Should we use TLS?")
tlsKeyPathPtr := flag.String("tlsKeyPath", "", "TLS Key Path")
tlsCertPathPtr := flag.String("tlsCertPath", "", "TLS Key Path")
flag.Parse()
validCodes := make([]string, 0)
log.Printf("Auth handler at: %s\n", *authPathPtr)
http.HandleFunc(*authPathPtr, func(w http.ResponseWriter, r *http.Request) {
log.Println("Authorization request received")
clientID := r.URL.Query().Get("client_id")
redirectURI := r.URL.Query().Get("redirect_uri")
scope := r.URL.Query().Get("scope")
state := r.URL.Query().Get("state")
log.Printf("ClientID: %s, Redirect: %s, Scope: %s, State: %s\n", clientID, redirectURI, scope, state)
code := timeCode()
validCodes = append(validCodes, code)
log.Printf("Returning access code in query param code: %s\n", code)
redirect, err := url.Parse(redirectURI)
if err != nil {
msg := fmt.Sprintf("%s is an invalid URI", redirectURI)
log.Println(msg)
http.Error(w, msg, http.StatusTeapot)
return
}
query := redirect.Query()
query.Add("code", code)
query.Add("state", state)
redirect.RawQuery = query.Encode()
log.Printf("Redirecting to %s\n", redirect.String())
http.Redirect(w, r, redirect.String(), http.StatusFound)
})
log.Printf("Token handler at: %s\n", *tokenPathPtr)
http.HandleFunc(*tokenPathPtr, func(w http.ResponseWriter, r *http.Request) {
log.Println("Token request received")
clientID := r.URL.Query().Get("client_id")
clientSecret := r.URL.Query().Get("client_secret")
code := r.URL.Query().Get("code")
log.Printf("ClientID: %s, ClientSecret: %s, Code: %s,\n", clientID, clientSecret, code)
acceptCode := false
for _, valid := range validCodes {
log.Printf("Checking valid code %s\n", valid)
if code == valid {
log.Printf("Code found: %s\n", valid)
acceptCode = true
break
}
}
if !acceptCode {
msg := fmt.Sprintf("Code is not valid: %s", code)
log.Println(msg)
http.Error(w, msg, http.StatusTeapot)
return
}
token := timeCode()
var response string
log.Printf("Returning token: %s for code %s\n", token, code)
accept := r.Header.Get("Accept")
if accept == "application/json" {
log.Println("Response as JSON requested")
response = fmt.Sprintf(*tokenJSONFmt, token)
} else {
log.Printf("Default response method")
response = fmt.Sprintf(*tokenFmt, token)
}
log.Printf("Response payload: %s\n", response)
w.Write([]byte(response))
w.WriteHeader(http.StatusOK)
})
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM, syscall.SIGKILL)
go func() {
log.Println("Waiting for exit signal")
<-c
log.Println("Got signal, exiting...")
os.Exit(0)
}()
if *tlsPtr {
log.Println("Generating self-signed certs")
if _, err := os.Stat(*tlsKeyPathPtr); err != nil {
log.Printf("Key at %s error\n", *tlsKeyPathPtr)
log.Println(err.Error())
return
}
if _, err := os.Stat(*tlsCertPathPtr); err != nil {
log.Printf("Cert at %s error\n", *tlsCertPathPtr)
log.Println(err.Error())
return
}
log.Printf("Listening on %s\n", *portPtr)
log.Fatal(http.ListenAndServeTLS(*portPtr, *tlsCertPathPtr, *tlsKeyPathPtr, nil))
} else {
log.Printf("Listening on %s\n", *portPtr)
log.Fatal(http.ListenAndServe(*portPtr, nil).Error())
}
}
func timeCode() string {
current := time.Now().String()
return base64.URLEncoding.EncodeToString([]byte(current))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment