package middleware import ( "api-service/internal/config" "api-service/internal/models/auth" models "api-service/internal/models/auth" service "api-service/internal/services/auth" "api-service/pkg/logger" "crypto/rsa" "encoding/base64" "encoding/json" "errors" "fmt" "math/big" "net/http" "strings" "sync" "time" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "golang.org/x/sync/singleflight" ) // AuthProvider interface for different authentication methods type AuthProvider interface { ValidateToken(tokenString string) (*models.JWTClaims, error) Name() string } // ProviderFactory creates authentication providers based on configuration type ProviderFactory struct { authService *service.AuthService config *config.Config } func NewProviderFactory(authService *service.AuthService, config *config.Config) *ProviderFactory { return &ProviderFactory{ authService: authService, config: config, } } func (f *ProviderFactory) CreateProviders() []AuthProvider { var providers []AuthProvider reqLogger := logger.Default().WithService("provider-factory") reqLogger.Info("Creating authentication providers", map[string]interface{}{ "auth_type": f.config.Auth.Type, "keycloak_enabled": f.config.Keycloak.Enabled, "keycloak_issuer": f.config.Keycloak.Issuer, "static_tokens_len": len(f.config.Auth.StaticTokens), "fallback_to": f.config.Auth.FallbackTo, }) switch f.config.Auth.Type { case "static": reqLogger.Info("Configuring static token provider") if len(f.config.Auth.StaticTokens) > 0 { providers = append(providers, NewStaticTokenProvider(f.config.Auth.StaticTokens)) reqLogger.Info("Static token provider added", map[string]interface{}{ "token_count": len(f.config.Auth.StaticTokens), }) } else { reqLogger.Warn("No static tokens configured for static auth type") } case "jwt": reqLogger.Info("Configuring JWT provider") providers = append(providers, NewJWTAuthProvider(f.authService)) reqLogger.Info("JWT provider added") case "keycloak": reqLogger.Info("Configuring Keycloak provider") if f.config.Keycloak.Issuer != "" { providers = append(providers, NewKeycloakAuthProvider(f.config)) reqLogger.Info("Keycloak provider added") } else { reqLogger.Warn("Keycloak issuer not configured for keycloak auth type") } case "hybrid": reqLogger.Info("Configuring hybrid providers") if f.config.Keycloak.Issuer != "" { providers = append(providers, NewKeycloakAuthProvider(f.config)) reqLogger.Info("Keycloak provider added for hybrid") } else { reqLogger.Warn("Keycloak issuer not configured for hybrid auth type") } switch f.config.Auth.FallbackTo { case "static": reqLogger.Info("Configuring static fallback for hybrid") if len(f.config.Auth.StaticTokens) > 0 { providers = append(providers, NewStaticTokenProvider(f.config.Auth.StaticTokens)) reqLogger.Info("Static fallback provider added", map[string]interface{}{ "token_count": len(f.config.Auth.StaticTokens), }) } else { reqLogger.Warn("No static tokens configured for hybrid fallback") } case "jwt": reqLogger.Info("Configuring JWT fallback for hybrid") providers = append(providers, NewJWTAuthProvider(f.authService)) reqLogger.Info("JWT fallback provider added") case "keycloak": reqLogger.Info("Configuring Keycloak fallback for hybrid") if f.config.Keycloak.Issuer != "" { providers = append(providers, NewKeycloakAuthProvider(f.config)) reqLogger.Info("Keycloak fallback provider added") } else { reqLogger.Warn("Keycloak issuer not configured for hybrid fallback") } default: reqLogger.Warn("Unknown fallback type for hybrid, using JWT", map[string]interface{}{ "fallback_to": f.config.Auth.FallbackTo, }) providers = append(providers, NewJWTAuthProvider(f.authService)) reqLogger.Info("JWT fallback provider added as default") } default: reqLogger.Warn("Unknown auth type, defaulting to JWT", map[string]interface{}{ "auth_type": f.config.Auth.Type, }) providers = append(providers, NewJWTAuthProvider(f.authService)) reqLogger.Info("JWT provider added as default") } reqLogger.Info("Provider creation completed", map[string]interface{}{ "provider_count": len(providers), }) return providers } // StaticTokenProvider handles static token authentication type StaticTokenProvider struct { tokens map[string]bool } func NewStaticTokenProvider(tokens []string) *StaticTokenProvider { tokenMap := make(map[string]bool) for _, token := range tokens { if token != "" { tokenMap[token] = true } } return &StaticTokenProvider{tokens: tokenMap} } func (s *StaticTokenProvider) ValidateToken(tokenString string) (*models.JWTClaims, error) { reqLogger := logger.Default().WithService("static-auth") if !s.tokens[tokenString] { reqLogger.Warn("Invalid static token provided") return nil, ErrInvalidToken } reqLogger.Info("Static token validation successful") return &models.JWTClaims{ UserID: "static-user", Username: "static-user", Email: "static@example.com", Role: "user", }, nil } func (s *StaticTokenProvider) Name() string { return "static" } // JWTAuthProvider handles JWT authentication using AuthService type JWTAuthProvider struct { authService *service.AuthService } func NewJWTAuthProvider(authService *service.AuthService) *JWTAuthProvider { return &JWTAuthProvider{authService: authService} } func (j *JWTAuthProvider) ValidateToken(tokenString string) (*models.JWTClaims, error) { reqLogger := logger.Default().WithService("jwt-auth") reqLogger.Info("Starting JWT token validation") claims, err := j.authService.ValidateToken(tokenString) if err != nil { reqLogger.Error("JWT validation failed", map[string]interface{}{ "error": err.Error(), }) return nil, err } reqLogger.Info("JWT validation successful", map[string]interface{}{ "user_id": claims.UserID, }) return claims, nil } func (j *JWTAuthProvider) Name() string { return "jwt" } // KeycloakAuthProvider handles Keycloak JWT authentication type KeycloakAuthProvider struct { jwksCache *JwksCache config *config.Config } func NewKeycloakAuthProvider(cfg *config.Config) *KeycloakAuthProvider { return &KeycloakAuthProvider{ jwksCache: NewJwksCache(cfg), config: cfg, } } func (k *KeycloakAuthProvider) ValidateToken(tokenString string) (*auth.JWTClaims, error) { reqLogger := logger.Default().WithService("keycloak-auth") reqLogger.Info("Starting Keycloak token validation") // Parse token without verification first to get claims for logging parsedToken, _, err := jwt.NewParser().ParseUnverified(tokenString, jwt.MapClaims{}) if err != nil { reqLogger.Error("Failed to parse token", map[string]interface{}{ "error": err.Error(), }) return nil, ErrInvalidToken } // Extract claims for logging claims, ok := parsedToken.Claims.(jwt.MapClaims) if !ok { reqLogger.Error("Invalid claims format") return nil, ErrMissingClaims } // Check if token is expired if exp, ok := claims["exp"].(float64); ok { if time.Now().Unix() > int64(exp) { reqLogger.Warn("Token expired", map[string]interface{}{ "exp": exp, "now": time.Now().Unix(), }) return nil, ErrTokenExpired } } // Now parse with verification token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { // Verify signing method if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { reqLogger.Warn("Unexpected signing method", map[string]interface{}{ "alg": token.Header["alg"], }) return nil, ErrInvalidSignature } kid, ok := token.Header["kid"].(string) if !ok { reqLogger.Warn("kid header not found in token") return nil, errors.New("kid header not found") } reqLogger.Info("Looking for key", map[string]interface{}{ "kid": kid, }) key, err := k.jwksCache.GetKey(kid) if err != nil { reqLogger.Error("Failed to get key", map[string]interface{}{ "kid": kid, "error": err.Error(), }) return nil, err } reqLogger.Info("Key retrieved successfully", map[string]interface{}{ "kid": kid, }) return key, nil }, jwt.WithIssuer(k.config.Keycloak.Issuer), jwt.WithAudience(k.config.Keycloak.Audience)) if err != nil { reqLogger.Error("JWT parse error", map[string]interface{}{ "error": err.Error(), }) // Return specific error based on the error type if strings.Contains(err.Error(), "expired") { return nil, ErrTokenExpired } else if strings.Contains(err.Error(), "signature") { return nil, ErrInvalidSignature } else if strings.Contains(err.Error(), "issuer") { return nil, ErrInvalidIssuer } else if strings.Contains(err.Error(), "audience") { return nil, ErrInvalidAudience } return nil, fmt.Errorf("invalid token: %v", err) } if !token.Valid { reqLogger.Warn("Token is not valid") return nil, ErrInvalidToken } reqLogger.Info("Token validation successful") // Extract claims claims, ok = token.Claims.(jwt.MapClaims) if !ok { reqLogger.Error("Invalid claims format") return nil, ErrMissingClaims } // Validate required claims userID := getClaimString(claims, "sub") if userID == "" { reqLogger.Error("Missing required claim: sub") return nil, ErrMissingClaims } return &auth.JWTClaims{ UserID: userID, Username: getClaimString(claims, "preferred_username"), Email: getClaimString(claims, "email"), Role: getClaimString(claims, "role"), }, nil } func (k *KeycloakAuthProvider) Name() string { return "keycloak" } // UnifiedAuthMiddleware provides flexible authentication based on configuration func UnifiedAuthMiddleware(cfg *config.Config, authService *service.AuthService) gin.HandlerFunc { factory := NewProviderFactory(authService, cfg) providers := factory.CreateProviders() // Validate that we have at least one provider if len(providers) == 0 { logger.Default().Error("No authentication providers configured", map[string]interface{}{ "auth_type": cfg.Auth.Type, }) return func(c *gin.Context) { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "authentication service not configured"}) } } logger.Default().Info("UnifiedAuthMiddleware initialized", map[string]interface{}{ "provider_count": len(providers), "auth_type": cfg.Auth.Type, }) return func(c *gin.Context) { reqLogger := logger.Default().WithService("unified-auth") reqLogger.Info("Memulai proses autentikasi", map[string]interface{}{ "auth_type": cfg.Auth.Type, "path": c.Request.URL.Path, "method": c.Request.Method, }) authHeader := c.GetHeader("Authorization") if authHeader == "" { reqLogger.Warn("Header Authorization tidak ditemukan", map[string]interface{}{ "path": c.Request.URL.Path, "method": c.Request.Method, }) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": ErrMissingAuthHeader.Error()}) return } parts := strings.SplitN(authHeader, " ", 2) if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { reqLogger.Warn("Format header Authorization tidak valid", map[string]interface{}{ "header_value": authHeader[:min(20, len(authHeader))], // Log first 20 chars for debugging "path": c.Request.URL.Path, "method": c.Request.Method, }) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": ErrInvalidAuthHeader.Error()}) return } tokenString := parts[1] reqLogger.Info("Token diterima", map[string]interface{}{ "token_length": len(tokenString), "path": c.Request.URL.Path, "method": c.Request.Method, }) // Coba setiap provider sampai salah satu berhasil var claims *auth.JWTClaims var err error var providerName string var providerErrors []string var triedProviders []string reqLogger.Info("Starting provider validation loop", map[string]interface{}{ "provider_count": len(providers), }) for _, provider := range providers { providerLog := reqLogger.WithField("provider", provider.Name()) triedProviders = append(triedProviders, provider.Name()) providerLog.Info("Mencoba validasi dengan provider", map[string]interface{}{ "path": c.Request.URL.Path, "method": c.Request.Method, }) claims, err = provider.ValidateToken(tokenString) if err == nil { providerName = provider.Name() providerLog.Info("Autentikasi berhasil", map[string]interface{}{ "user_id": claims.UserID, "username": claims.Username, "role": claims.Role, "path": c.Request.URL.Path, "method": c.Request.Method, }) break // Berhenti jika ada yang berhasil } providerLog.Warn("Validasi provider gagal", map[string]interface{}{ "error": err.Error(), "path": c.Request.URL.Path, "method": c.Request.Method, }) providerErrors = append(providerErrors, fmt.Sprintf("provider %s: %v", provider.Name(), err)) } if err != nil { reqLogger.Error("Semua provider gagal memvalidasi token", map[string]interface{}{ "errors": strings.Join(providerErrors, "; "), "tried_providers": strings.Join(triedProviders, ", "), "path": c.Request.URL.Path, "method": c.Request.Method, }) // Return specific error message based on the error type errorMessage := "Token tidak valid" if errors.Is(err, ErrTokenExpired) { errorMessage = "Token telah kadaluarsa" } else if errors.Is(err, ErrInvalidSignature) { errorMessage = "Signature token tidak valid" } else if errors.Is(err, ErrInvalidIssuer) { errorMessage = "Issuer token tidak valid" } else if errors.Is(err, ErrInvalidAudience) { errorMessage = "Audience token tidak valid" } c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": errorMessage, "details": strings.Join(providerErrors, "; "), }) return } // Set informasi pengguna di konteks if claims != nil { c.Set("user_id", claims.UserID) c.Set("username", claims.Username) c.Set("email", claims.Email) c.Set("role", claims.Role) c.Set("auth_provider", providerName) reqLogger.Info("User context set successfully", map[string]interface{}{ "user_id": claims.UserID, "username": claims.Username, "role": claims.Role, "auth_provider": providerName, "path": c.Request.URL.Path, "method": c.Request.Method, }) } else { reqLogger.Warn("Claims is nil after successful authentication", map[string]interface{}{ "provider": providerName, "path": c.Request.URL.Path, "method": c.Request.Method, }) } reqLogger.Info("Authentication completed successfully, proceeding to next handler", map[string]interface{}{ "path": c.Request.URL.Path, "method": c.Request.Method, }) c.Next() } } // InitializeAuth initializes authentication configuration func InitializeAuth(cfg *config.Config) { // This function can be used to initialize global auth settings if needed logger.Default().Info("Authentication initialized", map[string]interface{}{ "auth_type": cfg.Auth.Type, }) } // Helper functions func getClaimString(claims jwt.MapClaims, key string) string { if value, ok := claims[key]; ok && value != nil { if str, ok := value.(string); ok { return str } } return "" } // JwksCache and related functions type JwksCache struct { mu sync.RWMutex keys map[string]*rsa.PublicKey expiresAt time.Time sfGroup singleflight.Group config *config.Config } func NewJwksCache(cfg *config.Config) *JwksCache { return &JwksCache{ keys: make(map[string]*rsa.PublicKey), config: cfg, } } func (c *JwksCache) GetKey(kid string) (*rsa.PublicKey, error) { c.mu.RLock() if key, ok := c.keys[kid]; ok && time.Now().Before(c.expiresAt) { c.mu.RUnlock() return key, nil } c.mu.RUnlock() // Fetch keys with singleflight to avoid concurrent fetches v, err, _ := c.sfGroup.Do("fetch_jwks", func() (interface{}, error) { return c.fetchKeys() }) if err != nil { return nil, err } keys := v.(map[string]*rsa.PublicKey) c.mu.Lock() c.keys = keys c.expiresAt = time.Now().Add(1 * time.Hour) // cache for 1 hour c.mu.Unlock() key, ok := keys[kid] if !ok { return nil, fmt.Errorf("key with kid %s not found", kid) } return key, nil } func (c *JwksCache) fetchKeys() (map[string]*rsa.PublicKey, error) { if c.config.Keycloak.Issuer == "" { return nil, fmt.Errorf("keycloak issuer is not configured") } jwksURL := c.config.Keycloak.JwksURL if jwksURL == "" { // Construct JWKS URL from issuer if not explicitly provided jwksURL = c.config.Keycloak.Issuer + "/protocol/openid-connect/certs" } client := &http.Client{Timeout: 10 * time.Second} resp, err := client.Get(jwksURL) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to fetch JWKS: HTTP %d", resp.StatusCode) } var jwksData struct { Keys []struct { Kid string `json:"kid"` Kty string `json:"kty"` N string `json:"n"` E string `json:"e"` } `json:"keys"` } if err := json.NewDecoder(resp.Body).Decode(&jwksData); err != nil { return nil, err } keys := make(map[string]*rsa.PublicKey) for _, key := range jwksData.Keys { if key.Kty != "RSA" { continue } pubKey, err := parseRSAPublicKey(key.N, key.E) if err != nil { continue } keys[key.Kid] = pubKey } return keys, nil } // parseRSAPublicKey parses RSA public key components from base64url strings func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) { nBytes, err := base64.RawURLEncoding.DecodeString(nStr) if err != nil { return nil, err } eBytes, err := base64.RawURLEncoding.DecodeString(eStr) if err != nil { return nil, err } n := new(big.Int).SetBytes(nBytes) e := int(new(big.Int).SetBytes(eBytes).Int64()) return &rsa.PublicKey{ N: n, E: e, }, nil }