Added optional Discourse SSO.
This commit is contained in:
parent
5527c1eb91
commit
54da04763f
8 changed files with 337 additions and 8 deletions
190
handle_sso.go
Normal file
190
handle_sso.go
Normal file
|
|
@ -0,0 +1,190 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (app *App) getSSOConfig() (ssoURL, ssoSecret string) {
|
||||
app.db.QueryRow(`SELECT value FROM config WHERE key = 'discourse_sso_url'`).Scan(&ssoURL)
|
||||
app.db.QueryRow(`SELECT value FROM config WHERE key = 'discourse_sso_secret'`).Scan(&ssoSecret)
|
||||
return
|
||||
}
|
||||
|
||||
func (app *App) handleSSOEnabled(w http.ResponseWriter, r *http.Request) {
|
||||
ssoURL, ssoSecret := app.getSSOConfig()
|
||||
writeJSON(w, map[string]bool{"enabled": ssoURL != "" && ssoSecret != ""})
|
||||
}
|
||||
|
||||
func (app *App) getBaseURL() string {
|
||||
if app.baseURL != "" {
|
||||
return app.baseURL
|
||||
}
|
||||
var u string
|
||||
app.db.QueryRow(`SELECT value FROM config WHERE key = 'base_url'`).Scan(&u)
|
||||
return u
|
||||
}
|
||||
|
||||
func (app *App) handleSSOInit(w http.ResponseWriter, r *http.Request) {
|
||||
ssoURL, ssoSecret := app.getSSOConfig()
|
||||
if ssoURL == "" || ssoSecret == "" {
|
||||
writeError(w, "SSO not configured", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := app.getBaseURL()
|
||||
if baseURL == "" {
|
||||
writeError(w, "base_url must be configured for SSO", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
b := make([]byte, 32)
|
||||
rand.Read(b)
|
||||
nonce := hex.EncodeToString(b)
|
||||
|
||||
app.cleanExpiredNonces()
|
||||
if err := app.createSSONonce(nonce); err != nil {
|
||||
writeError(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
returnURL := strings.TrimRight(baseURL, "/") + "/api/sso/callback"
|
||||
|
||||
payload := fmt.Sprintf("nonce=%s&return_sso_url=%s", url.QueryEscape(nonce), url.QueryEscape(returnURL))
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(payload))
|
||||
|
||||
mac := hmac.New(sha256.New, []byte(ssoSecret))
|
||||
mac.Write([]byte(encoded))
|
||||
sig := hex.EncodeToString(mac.Sum(nil))
|
||||
|
||||
redirect := fmt.Sprintf("%s/session/sso_provider?sso=%s&sig=%s",
|
||||
strings.TrimRight(ssoURL, "/"), url.QueryEscape(encoded), url.QueryEscape(sig))
|
||||
|
||||
writeJSON(w, map[string]string{"redirect_url": redirect})
|
||||
}
|
||||
|
||||
func (app *App) handleSSOCallback(w http.ResponseWriter, r *http.Request) {
|
||||
baseURL := app.getBaseURL()
|
||||
|
||||
ssoRedirectError := func(msg string) {
|
||||
if baseURL != "" {
|
||||
http.Redirect(w, r, strings.TrimRight(baseURL, "/")+"/#sso_error="+url.QueryEscape(msg), http.StatusFound)
|
||||
} else {
|
||||
writeError(w, msg, http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
_, ssoSecret := app.getSSOConfig()
|
||||
if ssoSecret == "" {
|
||||
ssoRedirectError("SSO not configured")
|
||||
return
|
||||
}
|
||||
|
||||
ssoParam := r.URL.Query().Get("sso")
|
||||
sigParam := r.URL.Query().Get("sig")
|
||||
if ssoParam == "" || sigParam == "" {
|
||||
ssoRedirectError("Invalid SSO response")
|
||||
return
|
||||
}
|
||||
|
||||
mac := hmac.New(sha256.New, []byte(ssoSecret))
|
||||
mac.Write([]byte(ssoParam))
|
||||
expectedSig := hex.EncodeToString(mac.Sum(nil))
|
||||
if !hmac.Equal([]byte(expectedSig), []byte(sigParam)) {
|
||||
ssoRedirectError("Invalid SSO signature")
|
||||
return
|
||||
}
|
||||
|
||||
decoded, err := base64.StdEncoding.DecodeString(ssoParam)
|
||||
if err != nil {
|
||||
ssoRedirectError("Invalid SSO payload")
|
||||
return
|
||||
}
|
||||
|
||||
vals, err := url.ParseQuery(string(decoded))
|
||||
if err != nil {
|
||||
ssoRedirectError("Invalid SSO payload")
|
||||
return
|
||||
}
|
||||
|
||||
nonce := vals.Get("nonce")
|
||||
valid, err := app.consumeSSONonce(nonce)
|
||||
if err != nil || !valid {
|
||||
ssoRedirectError("SSO session expired. Please try again.")
|
||||
return
|
||||
}
|
||||
|
||||
email := strings.ToLower(vals.Get("email"))
|
||||
if email == "" {
|
||||
ssoRedirectError("No email in SSO response")
|
||||
return
|
||||
}
|
||||
|
||||
name := vals.Get("name")
|
||||
if name == "" {
|
||||
name = vals.Get("username")
|
||||
}
|
||||
|
||||
user, _, err := app.getLoginParticipant(email)
|
||||
if err != nil {
|
||||
ssoRedirectError("Login failed. Please try again.")
|
||||
return
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
p, err := app.getParticipantByEmail(email)
|
||||
if err != nil {
|
||||
ssoRedirectError("Login failed. Please try again.")
|
||||
return
|
||||
}
|
||||
if p != nil {
|
||||
if _, err := app.db.Exec(
|
||||
`UPDATE participants SET login_enabled = 1, updated_at = ? WHERE id = ?`,
|
||||
now(), p.ID,
|
||||
); err != nil {
|
||||
ssoRedirectError("Login failed. Please try again.")
|
||||
return
|
||||
}
|
||||
user, err = app.getUser(p.ID)
|
||||
if err != nil {
|
||||
ssoRedirectError("Login failed. Please try again.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
if name == "" {
|
||||
name = strings.Split(email, "@")[0]
|
||||
}
|
||||
res, err := app.db.Exec(
|
||||
`INSERT INTO participants (email, preferred_name, login_enabled, updated_at) VALUES (?, ?, 1, ?)`,
|
||||
email, name, now(),
|
||||
)
|
||||
if err != nil {
|
||||
ssoRedirectError("Login failed. Please try again.")
|
||||
return
|
||||
}
|
||||
id, _ := res.LastInsertId()
|
||||
user, err = app.getUser(int(id))
|
||||
if err != nil || user == nil {
|
||||
ssoRedirectError("Login failed. Please try again.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
token, err := app.signToken(user)
|
||||
if err != nil {
|
||||
ssoRedirectError("Login failed. Please try again.")
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, strings.TrimRight(baseURL, "/")+"/#sso_token="+url.QueryEscape(token), http.StatusFound)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue