Created
December 17, 2021 22:50
-
-
Save dchw/cba1558c78e0875cf0b05c191d797136 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"encoding/base64" | |
"flag" | |
"fmt" | |
"log" | |
"net/http" | |
"net/url" | |
"os" | |
"os/signal" | |
"syscall" | |
"time" | |
) | |
func main() { | |
log.Println("Starting Mock OAuth Server") | |
authPathPtr := flag.String("authorizePath", "/authorize", "Authorize endpoint path") | |
tokenPathPtr := flag.String("tokenPath", "/token", "Token endpoint path") | |
tokenJSONFmt := flag.String("tokenJsonFmt", `{"access_token": "%s"}`, "Format string for if token response is JSON") | |
tokenFmt := flag.String("tokenFmt", `access_token=%s`, "Format string for if token response is JSON") | |
portPtr := flag.String("port", ":8080", "The port") | |
tlsPtr := flag.Bool("tls", false, "Should we use TLS?") | |
tlsKeyPathPtr := flag.String("tlsKeyPath", "", "TLS Key Path") | |
tlsCertPathPtr := flag.String("tlsCertPath", "", "TLS Key Path") | |
flag.Parse() | |
validCodes := make([]string, 0) | |
log.Printf("Auth handler at: %s\n", *authPathPtr) | |
http.HandleFunc(*authPathPtr, func(w http.ResponseWriter, r *http.Request) { | |
log.Println("Authorization request received") | |
clientID := r.URL.Query().Get("client_id") | |
redirectURI := r.URL.Query().Get("redirect_uri") | |
scope := r.URL.Query().Get("scope") | |
state := r.URL.Query().Get("state") | |
log.Printf("ClientID: %s, Redirect: %s, Scope: %s, State: %s\n", clientID, redirectURI, scope, state) | |
code := timeCode() | |
validCodes = append(validCodes, code) | |
log.Printf("Returning access code in query param code: %s\n", code) | |
redirect, err := url.Parse(redirectURI) | |
if err != nil { | |
msg := fmt.Sprintf("%s is an invalid URI", redirectURI) | |
log.Println(msg) | |
http.Error(w, msg, http.StatusTeapot) | |
return | |
} | |
query := redirect.Query() | |
query.Add("code", code) | |
query.Add("state", state) | |
redirect.RawQuery = query.Encode() | |
log.Printf("Redirecting to %s\n", redirect.String()) | |
http.Redirect(w, r, redirect.String(), http.StatusFound) | |
}) | |
log.Printf("Token handler at: %s\n", *tokenPathPtr) | |
http.HandleFunc(*tokenPathPtr, func(w http.ResponseWriter, r *http.Request) { | |
log.Println("Token request received") | |
clientID := r.URL.Query().Get("client_id") | |
clientSecret := r.URL.Query().Get("client_secret") | |
code := r.URL.Query().Get("code") | |
log.Printf("ClientID: %s, ClientSecret: %s, Code: %s,\n", clientID, clientSecret, code) | |
acceptCode := false | |
for _, valid := range validCodes { | |
log.Printf("Checking valid code %s\n", valid) | |
if code == valid { | |
log.Printf("Code found: %s\n", valid) | |
acceptCode = true | |
break | |
} | |
} | |
if !acceptCode { | |
msg := fmt.Sprintf("Code is not valid: %s", code) | |
log.Println(msg) | |
http.Error(w, msg, http.StatusTeapot) | |
return | |
} | |
token := timeCode() | |
var response string | |
log.Printf("Returning token: %s for code %s\n", token, code) | |
accept := r.Header.Get("Accept") | |
if accept == "application/json" { | |
log.Println("Response as JSON requested") | |
response = fmt.Sprintf(*tokenJSONFmt, token) | |
} else { | |
log.Printf("Default response method") | |
response = fmt.Sprintf(*tokenFmt, token) | |
} | |
log.Printf("Response payload: %s\n", response) | |
w.Write([]byte(response)) | |
w.WriteHeader(http.StatusOK) | |
}) | |
c := make(chan os.Signal, 1) | |
signal.Notify(c, os.Interrupt, syscall.SIGTERM, syscall.SIGKILL) | |
go func() { | |
log.Println("Waiting for exit signal") | |
<-c | |
log.Println("Got signal, exiting...") | |
os.Exit(0) | |
}() | |
if *tlsPtr { | |
log.Println("Generating self-signed certs") | |
if _, err := os.Stat(*tlsKeyPathPtr); err != nil { | |
log.Printf("Key at %s error\n", *tlsKeyPathPtr) | |
log.Println(err.Error()) | |
return | |
} | |
if _, err := os.Stat(*tlsCertPathPtr); err != nil { | |
log.Printf("Cert at %s error\n", *tlsCertPathPtr) | |
log.Println(err.Error()) | |
return | |
} | |
log.Printf("Listening on %s\n", *portPtr) | |
log.Fatal(http.ListenAndServeTLS(*portPtr, *tlsCertPathPtr, *tlsKeyPathPtr, nil)) | |
} else { | |
log.Printf("Listening on %s\n", *portPtr) | |
log.Fatal(http.ListenAndServe(*portPtr, nil).Error()) | |
} | |
} | |
func timeCode() string { | |
current := time.Now().String() | |
return base64.URLEncoding.EncodeToString([]byte(current)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment