first commit
This commit is contained in:
254
internal/middleware/keycloak_middleware.go
Normal file
254
internal/middleware/keycloak_middleware.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package middleware
|
||||
|
||||
/** Keycloak Auth Middleware **/
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"api-service/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
)
|
||||
|
||||
// JwksCache caches JWKS keys with expiration
|
||||
type JwksCache struct {
|
||||
mu sync.RWMutex
|
||||
keys map[string]*rsa.PublicKey
|
||||
expiresAt time.Time
|
||||
sfGroup singleflight.Group
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func NewJwksCache(cfg *config.Config) *JwksCache {
|
||||
return &JwksCache{
|
||||
keys: make(map[string]*rsa.PublicKey),
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *JwksCache) GetKey(kid string) (*rsa.PublicKey, error) {
|
||||
c.mu.RLock()
|
||||
if key, ok := c.keys[kid]; ok && time.Now().Before(c.expiresAt) {
|
||||
c.mu.RUnlock()
|
||||
return key, nil
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
// Fetch keys with singleflight to avoid concurrent fetches
|
||||
v, err, _ := c.sfGroup.Do("fetch_jwks", func() (interface{}, error) {
|
||||
return c.fetchKeys()
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys := v.(map[string]*rsa.PublicKey)
|
||||
|
||||
c.mu.Lock()
|
||||
c.keys = keys
|
||||
c.expiresAt = time.Now().Add(1 * time.Hour) // cache for 1 hour
|
||||
c.mu.Unlock()
|
||||
|
||||
key, ok := keys[kid]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("key with kid %s not found", kid)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (c *JwksCache) fetchKeys() (map[string]*rsa.PublicKey, error) {
|
||||
if !c.config.Keycloak.Enabled {
|
||||
return nil, fmt.Errorf("keycloak authentication is disabled")
|
||||
}
|
||||
|
||||
jwksURL := c.config.Keycloak.JwksURL
|
||||
if jwksURL == "" {
|
||||
// Construct JWKS URL from issuer if not explicitly provided
|
||||
jwksURL = c.config.Keycloak.Issuer + "/protocol/openid-connect/certs"
|
||||
}
|
||||
|
||||
resp, err := http.Get(jwksURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var jwksData struct {
|
||||
Keys []struct {
|
||||
Kid string `json:"kid"`
|
||||
Kty string `json:"kty"`
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
} `json:"keys"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&jwksData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys := make(map[string]*rsa.PublicKey)
|
||||
for _, key := range jwksData.Keys {
|
||||
if key.Kty != "RSA" {
|
||||
continue
|
||||
}
|
||||
pubKey, err := parseRSAPublicKey(key.N, key.E)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
keys[key.Kid] = pubKey
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// parseRSAPublicKey parses RSA public key components from base64url strings
|
||||
func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
|
||||
nBytes, err := base64UrlDecode(nStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eBytes, err := base64UrlDecode(eStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var eInt int
|
||||
for _, b := range eBytes {
|
||||
eInt = eInt<<8 + int(b)
|
||||
}
|
||||
|
||||
pubKey := &rsa.PublicKey{
|
||||
N: new(big.Int).SetBytes(nBytes),
|
||||
E: eInt,
|
||||
}
|
||||
return pubKey, nil
|
||||
}
|
||||
|
||||
func base64UrlDecode(s string) ([]byte, error) {
|
||||
// Add padding if missing
|
||||
if m := len(s) % 4; m != 0 {
|
||||
s += strings.Repeat("=", 4-m)
|
||||
}
|
||||
return base64.URLEncoding.DecodeString(s)
|
||||
}
|
||||
|
||||
// Global config instance
|
||||
var appConfig *config.Config
|
||||
var jwksCacheInstance *JwksCache
|
||||
|
||||
// InitializeAuth initializes the auth middleware with config
|
||||
func InitializeAuth(cfg *config.Config) {
|
||||
appConfig = cfg
|
||||
jwksCacheInstance = NewJwksCache(cfg)
|
||||
}
|
||||
|
||||
// AuthMiddleware validates Bearer token as Keycloak JWT token
|
||||
func AuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if appConfig == nil {
|
||||
fmt.Println("AuthMiddleware: Config not initialized")
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "authentication service not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
if !appConfig.Keycloak.Enabled {
|
||||
// Skip authentication if Keycloak is disabled but log for debugging
|
||||
fmt.Println("AuthMiddleware: Keycloak authentication is disabled - allowing all requests")
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("AuthMiddleware: Checking Authorization header") // Debug log
|
||||
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
fmt.Println("AuthMiddleware: Authorization header missing") // Debug log
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header missing"})
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
|
||||
fmt.Println("AuthMiddleware: Invalid Authorization header format") // Debug log
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"})
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
// Verify signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
fmt.Printf("AuthMiddleware: Unexpected signing method: %v\n", token.Header["alg"]) // Debug log
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
||||
kid, ok := token.Header["kid"].(string)
|
||||
if !ok {
|
||||
fmt.Println("AuthMiddleware: kid header not found") // Debug log
|
||||
return nil, errors.New("kid header not found")
|
||||
}
|
||||
|
||||
return jwksCacheInstance.GetKey(kid)
|
||||
}, jwt.WithIssuer(appConfig.Keycloak.Issuer), jwt.WithAudience(appConfig.Keycloak.Audience))
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
fmt.Printf("AuthMiddleware: Invalid or expired token: %v\n", err) // Debug log
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired token"})
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("AuthMiddleware: Token valid, proceeding") // Debug log
|
||||
// Token is valid, proceed
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
/** JWT Bearer authentication middleware */
|
||||
// import (
|
||||
// "net/http"
|
||||
// "strings"
|
||||
|
||||
// "github.com/gin-gonic/gin"
|
||||
// )
|
||||
|
||||
// AuthMiddleware validates Bearer token in Authorization header
|
||||
func AuthJWTMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header missing"})
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"})
|
||||
return
|
||||
}
|
||||
|
||||
token := parts[1]
|
||||
// For now, use a static token for validation. Replace with your logic.
|
||||
const validToken = "your-static-token"
|
||||
|
||||
if token != validToken {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user