package middleware /** Keycloak Auth Middleware **/ import ( "crypto/rsa" "encoding/base64" "encoding/json" "errors" "fmt" "math/big" "net/http" "strings" "sync" "time" "api-service/internal/config" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "golang.org/x/sync/singleflight" ) var ( ErrInvalidToken = errors.New("invalid token") ) // JwksCache caches JWKS keys with expiration type JwksCache struct { mu sync.RWMutex keys map[string]*rsa.PublicKey expiresAt time.Time sfGroup singleflight.Group config *config.Config } func NewJwksCache(cfg *config.Config) *JwksCache { return &JwksCache{ keys: make(map[string]*rsa.PublicKey), config: cfg, } } func (c *JwksCache) GetKey(kid string) (*rsa.PublicKey, error) { c.mu.RLock() if key, ok := c.keys[kid]; ok && time.Now().Before(c.expiresAt) { c.mu.RUnlock() return key, nil } c.mu.RUnlock() // Fetch keys with singleflight to avoid concurrent fetches v, err, _ := c.sfGroup.Do("fetch_jwks", func() (interface{}, error) { return c.fetchKeys() }) if err != nil { return nil, err } keys := v.(map[string]*rsa.PublicKey) c.mu.Lock() c.keys = keys c.expiresAt = time.Now().Add(1 * time.Hour) // cache for 1 hour c.mu.Unlock() key, ok := keys[kid] if !ok { return nil, fmt.Errorf("key with kid %s not found", kid) } return key, nil } func (c *JwksCache) fetchKeys() (map[string]*rsa.PublicKey, error) { if !c.config.Keycloak.Enabled { return nil, fmt.Errorf("keycloak authentication is disabled") } jwksURL := c.config.Keycloak.JwksURL if jwksURL == "" { // Construct JWKS URL from issuer if not explicitly provided jwksURL = c.config.Keycloak.Issuer + "/protocol/openid-connect/certs" } resp, err := http.Get(jwksURL) if err != nil { return nil, err } defer resp.Body.Close() var jwksData struct { Keys []struct { Kid string `json:"kid"` Kty string `json:"kty"` N string `json:"n"` E string `json:"e"` } `json:"keys"` } if err := json.NewDecoder(resp.Body).Decode(&jwksData); err != nil { return nil, err } keys := make(map[string]*rsa.PublicKey) for _, key := range jwksData.Keys { if key.Kty != "RSA" { continue } pubKey, err := parseRSAPublicKey(key.N, key.E) if err != nil { continue } keys[key.Kid] = pubKey } return keys, nil } // parseRSAPublicKey parses RSA public key components from base64url strings func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) { nBytes, err := base64UrlDecode(nStr) if err != nil { return nil, err } eBytes, err := base64UrlDecode(eStr) if err != nil { return nil, err } var eInt int for _, b := range eBytes { eInt = eInt<<8 + int(b) } pubKey := &rsa.PublicKey{ N: new(big.Int).SetBytes(nBytes), E: eInt, } return pubKey, nil } func base64UrlDecode(s string) ([]byte, error) { // Add padding if missing if m := len(s) % 4; m != 0 { s += strings.Repeat("=", 4-m) } return base64.URLEncoding.DecodeString(s) } // Global config instance var appConfig *config.Config var jwksCacheInstance *JwksCache // InitializeAuth initializes the auth middleware with config func InitializeAuth(cfg *config.Config) { appConfig = cfg jwksCacheInstance = NewJwksCache(cfg) } // AuthMiddleware validates Bearer token as Keycloak JWT token func AuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { if appConfig == nil { fmt.Println("AuthMiddleware: Config not initialized") c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "authentication service not configured"}) return } if !appConfig.Keycloak.Enabled { // Skip authentication if Keycloak is disabled but log for debugging fmt.Println("AuthMiddleware: Keycloak authentication is disabled - allowing all requests") c.Next() return } fmt.Println("AuthMiddleware: Checking Authorization header") // Debug log authHeader := c.GetHeader("Authorization") if authHeader == "" { fmt.Println("AuthMiddleware: Authorization header missing") // Debug log c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header missing"}) return } parts := strings.SplitN(authHeader, " ", 2) if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { fmt.Println("AuthMiddleware: Invalid Authorization header format") // Debug log c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"}) return } tokenString := parts[1] token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { // Verify signing method if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { fmt.Printf("AuthMiddleware: Unexpected signing method: %v\n", token.Header["alg"]) // Debug log return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } kid, ok := token.Header["kid"].(string) if !ok { fmt.Println("AuthMiddleware: kid header not found") // Debug log return nil, errors.New("kid header not found") } return jwksCacheInstance.GetKey(kid) }, jwt.WithIssuer(appConfig.Keycloak.Issuer), jwt.WithAudience(appConfig.Keycloak.Audience)) if err != nil || !token.Valid { fmt.Printf("AuthMiddleware: Invalid or expired token: %v\n", err) // Debug log c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired token"}) return } fmt.Println("AuthMiddleware: Token valid, proceeding") // Debug log // Token is valid, proceed c.Next() } } /** JWT Bearer authentication middleware */ // import ( // "net/http" // "strings" // "github.com/gin-gonic/gin" // ) // AuthMiddleware validates Bearer token in Authorization header func AuthJWTMiddleware() gin.HandlerFunc { return func(c *gin.Context) { authHeader := c.GetHeader("Authorization") if authHeader == "" { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header missing"}) return } parts := strings.SplitN(authHeader, " ", 2) if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"}) return } token := parts[1] // For now, use a static token for validation. Replace with your logic. const validToken = "your-static-token" if token != validToken { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) return } c.Next() } }