Turnpike/auth.go

140 lines
3.5 KiB
Go
Raw Permalink Normal View History

package main
import (
"context"
"encoding/json"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
)
type Claims struct {
ParticipantID int `json:"pid"`
Email string `json:"sub"`
Roles []string `json:"roles"`
DeptIDs []int `json:"dept_ids,omitempty"`
jwt.RegisteredClaims
}
func hashPassword(password string) (string, error) {
b, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
return string(b), err
}
func checkPassword(hash, password string) bool {
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) == nil
}
func (app *App) signToken(s *User) (string, error) {
expiry := time.Duration(app.tokenExpiry) * time.Hour
claims := Claims{
ParticipantID: s.ID,
Email: s.Email,
Roles: s.Roles,
DeptIDs: s.DepartmentIDs,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiry)),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(app.secret))
}
func (app *App) parseToken(tokenStr string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(app.secret), nil
})
if err != nil {
return nil, err
}
claims, ok := token.Claims.(*Claims)
if !ok || !token.Valid {
return nil, jwt.ErrTokenInvalidClaims
}
return claims, nil
}
func bearerToken(r *http.Request) string {
h := r.Header.Get("Authorization")
if strings.HasPrefix(h, "Bearer ") {
return strings.TrimPrefix(h, "Bearer ")
}
// Fallback to query param for SSE (EventSource can't set headers)
return r.URL.Query().Get("token")
}
// requireAuth wraps a handler, injects claims into context via X-Claims header trick.
// We pass claims via a request-scoped value instead.
type contextKey string
const claimsKey contextKey = "claims"
func (app *App) requireAuth(next http.HandlerFunc, roles ...string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
token := bearerToken(r)
if token == "" {
writeError(w, "unauthorized", http.StatusUnauthorized)
return
}
claims, err := app.parseToken(token)
if err != nil {
writeError(w, "unauthorized", http.StatusUnauthorized)
return
}
if len(roles) > 0 && !hasAnyRole(claims.Roles, roles) {
writeError(w, "forbidden", http.StatusForbidden)
return
}
ctx := context.WithValue(r.Context(), claimsKey, claims)
next(w, r.WithContext(ctx))
}
}
func hasAnyRole(roles []string, allowed []string) bool {
for _, r := range roles {
for _, a := range allowed {
if r == a {
return true
}
}
}
return false
}
func isCoLeadOnly(claims *Claims) bool {
return hasAnyRole(claims.Roles, []string{"colead"}) &&
!hasAnyRole(claims.Roles, []string{"admin", "staffing"})
}
func inSlice(v int, s []int) bool {
for _, x := range s {
if x == v {
return true
}
}
return false
}
func claimsFromContext(r *http.Request) *Claims {
c, _ := r.Context().Value(claimsKey).(*Claims)
return c
}
func writeJSON(w http.ResponseWriter, v any) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(v)
}
func writeError(w http.ResponseWriter, msg string, code int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
json.NewEncoder(w).Encode(map[string]string{"error": msg})
}