Files
antrean-anjungan/internal/middleware/auth.go
2025-08-14 09:08:34 +07:00

227 lines
5.6 KiB
Go

package middleware
/** Keylock Auth Middleware **/
import (
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/big"
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/sync/singleflight"
)
var (
ErrInvalidToken = errors.New("invalid token")
)
// Configurable Keycloak parameters - replace with your actual values or load from config/env
const (
KeycloakIssuer = "https://keycloak.example.com/auth/realms/yourrealm"
KeycloakAudience = "your-client-id"
JwksURL = KeycloakIssuer + "/protocol/openid-connect/certs"
)
// JwksCache caches JWKS keys with expiration
type JwksCache struct {
mu sync.RWMutex
keys map[string]*rsa.PublicKey
expiresAt time.Time
sfGroup singleflight.Group
}
func NewJwksCache() *JwksCache {
return &JwksCache{
keys: make(map[string]*rsa.PublicKey),
}
}
func (c *JwksCache) GetKey(kid string) (*rsa.PublicKey, error) {
c.mu.RLock()
if key, ok := c.keys[kid]; ok && time.Now().Before(c.expiresAt) {
c.mu.RUnlock()
return key, nil
}
c.mu.RUnlock()
// Fetch keys with singleflight to avoid concurrent fetches
v, err, _ := c.sfGroup.Do("fetch_jwks", func() (interface{}, error) {
return c.fetchKeys()
})
if err != nil {
return nil, err
}
keys := v.(map[string]*rsa.PublicKey)
c.mu.Lock()
c.keys = keys
c.expiresAt = time.Now().Add(1 * time.Hour) // cache for 1 hour
c.mu.Unlock()
key, ok := keys[kid]
if !ok {
return nil, fmt.Errorf("key with kid %s not found", kid)
}
return key, nil
}
func (c *JwksCache) fetchKeys() (map[string]*rsa.PublicKey, error) {
resp, err := http.Get(JwksURL)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var jwksData struct {
Keys []struct {
Kid string `json:"kid"`
Kty string `json:"kty"`
N string `json:"n"`
E string `json:"e"`
} `json:"keys"`
}
if err := json.NewDecoder(resp.Body).Decode(&jwksData); err != nil {
return nil, err
}
keys := make(map[string]*rsa.PublicKey)
for _, key := range jwksData.Keys {
if key.Kty != "RSA" {
continue
}
pubKey, err := parseRSAPublicKey(key.N, key.E)
if err != nil {
continue
}
keys[key.Kid] = pubKey
}
return keys, nil
}
// parseRSAPublicKey parses RSA public key components from base64url strings
func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
nBytes, err := base64UrlDecode(nStr)
if err != nil {
return nil, err
}
eBytes, err := base64UrlDecode(eStr)
if err != nil {
return nil, err
}
var eInt int
for _, b := range eBytes {
eInt = eInt<<8 + int(b)
}
pubKey := &rsa.PublicKey{
N: new(big.Int).SetBytes(nBytes),
E: eInt,
}
return pubKey, nil
}
func base64UrlDecode(s string) ([]byte, error) {
// Add padding if missing
if m := len(s) % 4; m != 0 {
s += strings.Repeat("=", 4-m)
}
return base64.URLEncoding.DecodeString(s)
}
var jwksCache = NewJwksCache()
// AuthMiddleware validates Bearer token as Keycloak JWT token
func AuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
fmt.Println("AuthMiddleware: Checking Authorization header") // Debug log
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
fmt.Println("AuthMiddleware: Authorization header missing") // Debug log
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header missing"})
return
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
fmt.Println("AuthMiddleware: Invalid Authorization header format") // Debug log
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"})
return
}
tokenString := parts[1]
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Verify signing method
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
fmt.Printf("AuthMiddleware: Unexpected signing method: %v\n", token.Header["alg"]) // Debug log
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
kid, ok := token.Header["kid"].(string)
if !ok {
fmt.Println("AuthMiddleware: kid header not found") // Debug log
return nil, errors.New("kid header not found")
}
return jwksCache.GetKey(kid)
}, jwt.WithIssuer(KeycloakIssuer), jwt.WithAudience(KeycloakAudience))
if err != nil || !token.Valid {
fmt.Printf("AuthMiddleware: Invalid or expired token: %v\n", err) // Debug log
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired token"})
return
}
fmt.Println("AuthMiddleware: Token valid, proceeding") // Debug log
// Token is valid, proceed
c.Next()
}
}
/** JWT Bearer authentication middleware */
// import (
// "net/http"
// "strings"
// "github.com/gin-gonic/gin"
// )
// AuthMiddleware validates Bearer token in Authorization header
func AuthJWTMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header missing"})
return
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"})
return
}
token := parts[1]
// For now, use a static token for validation. Replace with your logic.
const validToken = "your-static-token"
if token != validToken {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
return
}
c.Next()
}
}