616 lines
18 KiB
Go
616 lines
18 KiB
Go
package middleware
|
|
|
|
import (
|
|
"api-service/internal/config"
|
|
"api-service/internal/models/auth"
|
|
models "api-service/internal/models/auth"
|
|
service "api-service/internal/services/auth"
|
|
"api-service/pkg/logger"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"math/big"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"golang.org/x/sync/singleflight"
|
|
)
|
|
|
|
// AuthProvider interface for different authentication methods
|
|
type AuthProvider interface {
|
|
ValidateToken(tokenString string) (*models.JWTClaims, error)
|
|
Name() string
|
|
}
|
|
|
|
// ProviderFactory creates authentication providers based on configuration
|
|
type ProviderFactory struct {
|
|
authService *service.AuthService
|
|
config *config.Config
|
|
}
|
|
|
|
func NewProviderFactory(authService *service.AuthService, config *config.Config) *ProviderFactory {
|
|
return &ProviderFactory{
|
|
authService: authService,
|
|
config: config,
|
|
}
|
|
}
|
|
|
|
func (f *ProviderFactory) CreateProviders() []AuthProvider {
|
|
var providers []AuthProvider
|
|
|
|
reqLogger := logger.Default().WithService("provider-factory")
|
|
reqLogger.Info("Creating authentication providers", map[string]interface{}{
|
|
"auth_type": f.config.Auth.Type,
|
|
"keycloak_enabled": f.config.Keycloak.Enabled,
|
|
"keycloak_issuer": f.config.Keycloak.Issuer,
|
|
"static_tokens_len": len(f.config.Auth.StaticTokens),
|
|
"fallback_to": f.config.Auth.FallbackTo,
|
|
})
|
|
|
|
switch f.config.Auth.Type {
|
|
case "static":
|
|
reqLogger.Info("Configuring static token provider")
|
|
if len(f.config.Auth.StaticTokens) > 0 {
|
|
providers = append(providers, NewStaticTokenProvider(f.config.Auth.StaticTokens))
|
|
reqLogger.Info("Static token provider added", map[string]interface{}{
|
|
"token_count": len(f.config.Auth.StaticTokens),
|
|
})
|
|
} else {
|
|
reqLogger.Warn("No static tokens configured for static auth type")
|
|
}
|
|
case "jwt":
|
|
reqLogger.Info("Configuring JWT provider")
|
|
providers = append(providers, NewJWTAuthProvider(f.authService))
|
|
reqLogger.Info("JWT provider added")
|
|
case "keycloak":
|
|
reqLogger.Info("Configuring Keycloak provider")
|
|
if f.config.Keycloak.Issuer != "" {
|
|
providers = append(providers, NewKeycloakAuthProvider(f.config))
|
|
reqLogger.Info("Keycloak provider added")
|
|
} else {
|
|
reqLogger.Warn("Keycloak issuer not configured for keycloak auth type")
|
|
}
|
|
case "hybrid":
|
|
reqLogger.Info("Configuring hybrid providers")
|
|
if f.config.Keycloak.Issuer != "" {
|
|
providers = append(providers, NewKeycloakAuthProvider(f.config))
|
|
reqLogger.Info("Keycloak provider added for hybrid")
|
|
} else {
|
|
reqLogger.Warn("Keycloak issuer not configured for hybrid auth type")
|
|
}
|
|
switch f.config.Auth.FallbackTo {
|
|
case "static":
|
|
reqLogger.Info("Configuring static fallback for hybrid")
|
|
if len(f.config.Auth.StaticTokens) > 0 {
|
|
providers = append(providers, NewStaticTokenProvider(f.config.Auth.StaticTokens))
|
|
reqLogger.Info("Static fallback provider added", map[string]interface{}{
|
|
"token_count": len(f.config.Auth.StaticTokens),
|
|
})
|
|
} else {
|
|
reqLogger.Warn("No static tokens configured for hybrid fallback")
|
|
}
|
|
case "jwt":
|
|
reqLogger.Info("Configuring JWT fallback for hybrid")
|
|
providers = append(providers, NewJWTAuthProvider(f.authService))
|
|
reqLogger.Info("JWT fallback provider added")
|
|
case "keycloak":
|
|
reqLogger.Info("Configuring Keycloak fallback for hybrid")
|
|
if f.config.Keycloak.Issuer != "" {
|
|
providers = append(providers, NewKeycloakAuthProvider(f.config))
|
|
reqLogger.Info("Keycloak fallback provider added")
|
|
} else {
|
|
reqLogger.Warn("Keycloak issuer not configured for hybrid fallback")
|
|
}
|
|
default:
|
|
reqLogger.Warn("Unknown fallback type for hybrid, using JWT", map[string]interface{}{
|
|
"fallback_to": f.config.Auth.FallbackTo,
|
|
})
|
|
providers = append(providers, NewJWTAuthProvider(f.authService))
|
|
reqLogger.Info("JWT fallback provider added as default")
|
|
}
|
|
default:
|
|
reqLogger.Warn("Unknown auth type, defaulting to JWT", map[string]interface{}{
|
|
"auth_type": f.config.Auth.Type,
|
|
})
|
|
providers = append(providers, NewJWTAuthProvider(f.authService))
|
|
reqLogger.Info("JWT provider added as default")
|
|
}
|
|
|
|
reqLogger.Info("Provider creation completed", map[string]interface{}{
|
|
"provider_count": len(providers),
|
|
})
|
|
|
|
return providers
|
|
}
|
|
|
|
// StaticTokenProvider handles static token authentication
|
|
type StaticTokenProvider struct {
|
|
tokens map[string]bool
|
|
}
|
|
|
|
func NewStaticTokenProvider(tokens []string) *StaticTokenProvider {
|
|
tokenMap := make(map[string]bool)
|
|
for _, token := range tokens {
|
|
if token != "" {
|
|
tokenMap[token] = true
|
|
}
|
|
}
|
|
return &StaticTokenProvider{tokens: tokenMap}
|
|
}
|
|
|
|
func (s *StaticTokenProvider) ValidateToken(tokenString string) (*models.JWTClaims, error) {
|
|
reqLogger := logger.Default().WithService("static-auth")
|
|
|
|
if !s.tokens[tokenString] {
|
|
reqLogger.Warn("Invalid static token provided")
|
|
return nil, ErrInvalidToken
|
|
}
|
|
|
|
reqLogger.Info("Static token validation successful")
|
|
return &models.JWTClaims{
|
|
UserID: "static-user",
|
|
Username: "static-user",
|
|
Email: "static@example.com",
|
|
Role: "user",
|
|
}, nil
|
|
}
|
|
|
|
func (s *StaticTokenProvider) Name() string {
|
|
return "static"
|
|
}
|
|
|
|
// JWTAuthProvider handles JWT authentication using AuthService
|
|
type JWTAuthProvider struct {
|
|
authService *service.AuthService
|
|
}
|
|
|
|
func NewJWTAuthProvider(authService *service.AuthService) *JWTAuthProvider {
|
|
return &JWTAuthProvider{authService: authService}
|
|
}
|
|
|
|
func (j *JWTAuthProvider) ValidateToken(tokenString string) (*models.JWTClaims, error) {
|
|
reqLogger := logger.Default().WithService("jwt-auth")
|
|
reqLogger.Info("Starting JWT token validation")
|
|
|
|
claims, err := j.authService.ValidateToken(tokenString)
|
|
if err != nil {
|
|
reqLogger.Error("JWT validation failed", map[string]interface{}{
|
|
"error": err.Error(),
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
reqLogger.Info("JWT validation successful", map[string]interface{}{
|
|
"user_id": claims.UserID,
|
|
})
|
|
return claims, nil
|
|
}
|
|
|
|
func (j *JWTAuthProvider) Name() string {
|
|
return "jwt"
|
|
}
|
|
|
|
// KeycloakAuthProvider handles Keycloak JWT authentication
|
|
type KeycloakAuthProvider struct {
|
|
jwksCache *JwksCache
|
|
config *config.Config
|
|
}
|
|
|
|
func NewKeycloakAuthProvider(cfg *config.Config) *KeycloakAuthProvider {
|
|
return &KeycloakAuthProvider{
|
|
jwksCache: NewJwksCache(cfg),
|
|
config: cfg,
|
|
}
|
|
}
|
|
|
|
func (k *KeycloakAuthProvider) ValidateToken(tokenString string) (*auth.JWTClaims, error) {
|
|
reqLogger := logger.Default().WithService("keycloak-auth")
|
|
reqLogger.Info("Starting Keycloak token validation")
|
|
|
|
// Parse token without verification first to get claims for logging
|
|
parsedToken, _, err := jwt.NewParser().ParseUnverified(tokenString, jwt.MapClaims{})
|
|
if err != nil {
|
|
reqLogger.Error("Failed to parse token", map[string]interface{}{
|
|
"error": err.Error(),
|
|
})
|
|
return nil, ErrInvalidToken
|
|
}
|
|
|
|
// Extract claims for logging
|
|
claims, ok := parsedToken.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
reqLogger.Error("Invalid claims format")
|
|
return nil, ErrMissingClaims
|
|
}
|
|
|
|
// Check if token is expired
|
|
if exp, ok := claims["exp"].(float64); ok {
|
|
if time.Now().Unix() > int64(exp) {
|
|
reqLogger.Warn("Token expired", map[string]interface{}{
|
|
"exp": exp,
|
|
"now": time.Now().Unix(),
|
|
})
|
|
return nil, ErrTokenExpired
|
|
}
|
|
}
|
|
|
|
// Now parse with verification
|
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
// Verify signing method
|
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
|
reqLogger.Warn("Unexpected signing method", map[string]interface{}{
|
|
"alg": token.Header["alg"],
|
|
})
|
|
return nil, ErrInvalidSignature
|
|
}
|
|
|
|
kid, ok := token.Header["kid"].(string)
|
|
if !ok {
|
|
reqLogger.Warn("kid header not found in token")
|
|
return nil, errors.New("kid header not found")
|
|
}
|
|
|
|
reqLogger.Info("Looking for key", map[string]interface{}{
|
|
"kid": kid,
|
|
})
|
|
key, err := k.jwksCache.GetKey(kid)
|
|
if err != nil {
|
|
reqLogger.Error("Failed to get key", map[string]interface{}{
|
|
"kid": kid,
|
|
"error": err.Error(),
|
|
})
|
|
return nil, err
|
|
}
|
|
reqLogger.Info("Key retrieved successfully", map[string]interface{}{
|
|
"kid": kid,
|
|
})
|
|
return key, nil
|
|
}, jwt.WithIssuer(k.config.Keycloak.Issuer), jwt.WithAudience(k.config.Keycloak.Audience))
|
|
|
|
if err != nil {
|
|
reqLogger.Error("JWT parse error", map[string]interface{}{
|
|
"error": err.Error(),
|
|
})
|
|
|
|
// Return specific error based on the error type
|
|
if strings.Contains(err.Error(), "expired") {
|
|
return nil, ErrTokenExpired
|
|
} else if strings.Contains(err.Error(), "signature") {
|
|
return nil, ErrInvalidSignature
|
|
} else if strings.Contains(err.Error(), "issuer") {
|
|
return nil, ErrInvalidIssuer
|
|
} else if strings.Contains(err.Error(), "audience") {
|
|
return nil, ErrInvalidAudience
|
|
}
|
|
|
|
return nil, fmt.Errorf("invalid token: %v", err)
|
|
}
|
|
|
|
if !token.Valid {
|
|
reqLogger.Warn("Token is not valid")
|
|
return nil, ErrInvalidToken
|
|
}
|
|
|
|
reqLogger.Info("Token validation successful")
|
|
|
|
// Extract claims
|
|
claims, ok = token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
reqLogger.Error("Invalid claims format")
|
|
return nil, ErrMissingClaims
|
|
}
|
|
|
|
// Validate required claims
|
|
userID := getClaimString(claims, "sub")
|
|
if userID == "" {
|
|
reqLogger.Error("Missing required claim: sub")
|
|
return nil, ErrMissingClaims
|
|
}
|
|
|
|
return &auth.JWTClaims{
|
|
UserID: userID,
|
|
Username: getClaimString(claims, "preferred_username"),
|
|
Email: getClaimString(claims, "email"),
|
|
Role: getClaimString(claims, "role"),
|
|
}, nil
|
|
}
|
|
|
|
func (k *KeycloakAuthProvider) Name() string {
|
|
return "keycloak"
|
|
}
|
|
|
|
// UnifiedAuthMiddleware provides flexible authentication based on configuration
|
|
func UnifiedAuthMiddleware(cfg *config.Config, authService *service.AuthService) gin.HandlerFunc {
|
|
factory := NewProviderFactory(authService, cfg)
|
|
providers := factory.CreateProviders()
|
|
|
|
// Validate that we have at least one provider
|
|
if len(providers) == 0 {
|
|
logger.Default().Error("No authentication providers configured", map[string]interface{}{
|
|
"auth_type": cfg.Auth.Type,
|
|
})
|
|
return func(c *gin.Context) {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "authentication service not configured"})
|
|
}
|
|
}
|
|
|
|
logger.Default().Info("UnifiedAuthMiddleware initialized", map[string]interface{}{
|
|
"provider_count": len(providers),
|
|
"auth_type": cfg.Auth.Type,
|
|
})
|
|
|
|
return func(c *gin.Context) {
|
|
reqLogger := logger.Default().WithService("unified-auth")
|
|
reqLogger.Info("Memulai proses autentikasi", map[string]interface{}{
|
|
"auth_type": cfg.Auth.Type,
|
|
"path": c.Request.URL.Path,
|
|
"method": c.Request.Method,
|
|
})
|
|
|
|
authHeader := c.GetHeader("Authorization")
|
|
if authHeader == "" {
|
|
reqLogger.Warn("Header Authorization tidak ditemukan", map[string]interface{}{
|
|
"path": c.Request.URL.Path,
|
|
"method": c.Request.Method,
|
|
})
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": ErrMissingAuthHeader.Error()})
|
|
return
|
|
}
|
|
|
|
parts := strings.SplitN(authHeader, " ", 2)
|
|
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
|
|
reqLogger.Warn("Format header Authorization tidak valid", map[string]interface{}{
|
|
"header_value": authHeader[:min(20, len(authHeader))], // Log first 20 chars for debugging
|
|
"path": c.Request.URL.Path,
|
|
"method": c.Request.Method,
|
|
})
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": ErrInvalidAuthHeader.Error()})
|
|
return
|
|
}
|
|
|
|
tokenString := parts[1]
|
|
reqLogger.Info("Token diterima", map[string]interface{}{
|
|
"token_length": len(tokenString),
|
|
"path": c.Request.URL.Path,
|
|
"method": c.Request.Method,
|
|
})
|
|
|
|
// Coba setiap provider sampai salah satu berhasil
|
|
var claims *auth.JWTClaims
|
|
var err error
|
|
var providerName string
|
|
var providerErrors []string
|
|
var triedProviders []string
|
|
|
|
reqLogger.Info("Starting provider validation loop", map[string]interface{}{
|
|
"provider_count": len(providers),
|
|
})
|
|
|
|
for _, provider := range providers {
|
|
providerLog := reqLogger.WithField("provider", provider.Name())
|
|
triedProviders = append(triedProviders, provider.Name())
|
|
providerLog.Info("Mencoba validasi dengan provider", map[string]interface{}{
|
|
"path": c.Request.URL.Path,
|
|
"method": c.Request.Method,
|
|
})
|
|
|
|
claims, err = provider.ValidateToken(tokenString)
|
|
if err == nil {
|
|
providerName = provider.Name()
|
|
providerLog.Info("Autentikasi berhasil", map[string]interface{}{
|
|
"user_id": claims.UserID,
|
|
"username": claims.Username,
|
|
"role": claims.Role,
|
|
"path": c.Request.URL.Path,
|
|
"method": c.Request.Method,
|
|
})
|
|
break // Berhenti jika ada yang berhasil
|
|
}
|
|
|
|
providerLog.Warn("Validasi provider gagal", map[string]interface{}{
|
|
"error": err.Error(),
|
|
"path": c.Request.URL.Path,
|
|
"method": c.Request.Method,
|
|
})
|
|
providerErrors = append(providerErrors, fmt.Sprintf("provider %s: %v", provider.Name(), err))
|
|
}
|
|
|
|
if err != nil {
|
|
reqLogger.Error("Semua provider gagal memvalidasi token", map[string]interface{}{
|
|
"errors": strings.Join(providerErrors, "; "),
|
|
"tried_providers": strings.Join(triedProviders, ", "),
|
|
"path": c.Request.URL.Path,
|
|
"method": c.Request.Method,
|
|
})
|
|
|
|
// Return specific error message based on the error type
|
|
errorMessage := "Token tidak valid"
|
|
if errors.Is(err, ErrTokenExpired) {
|
|
errorMessage = "Token telah kadaluarsa"
|
|
} else if errors.Is(err, ErrInvalidSignature) {
|
|
errorMessage = "Signature token tidak valid"
|
|
} else if errors.Is(err, ErrInvalidIssuer) {
|
|
errorMessage = "Issuer token tidak valid"
|
|
} else if errors.Is(err, ErrInvalidAudience) {
|
|
errorMessage = "Audience token tidak valid"
|
|
}
|
|
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
|
"error": errorMessage,
|
|
"details": strings.Join(providerErrors, "; "),
|
|
})
|
|
return
|
|
}
|
|
|
|
// Set informasi pengguna di konteks
|
|
if claims != nil {
|
|
c.Set("user_id", claims.UserID)
|
|
c.Set("username", claims.Username)
|
|
c.Set("email", claims.Email)
|
|
c.Set("role", claims.Role)
|
|
c.Set("auth_provider", providerName)
|
|
|
|
reqLogger.Info("User context set successfully", map[string]interface{}{
|
|
"user_id": claims.UserID,
|
|
"username": claims.Username,
|
|
"role": claims.Role,
|
|
"auth_provider": providerName,
|
|
"path": c.Request.URL.Path,
|
|
"method": c.Request.Method,
|
|
})
|
|
} else {
|
|
reqLogger.Warn("Claims is nil after successful authentication", map[string]interface{}{
|
|
"provider": providerName,
|
|
"path": c.Request.URL.Path,
|
|
"method": c.Request.Method,
|
|
})
|
|
}
|
|
|
|
reqLogger.Info("Authentication completed successfully, proceeding to next handler", map[string]interface{}{
|
|
"path": c.Request.URL.Path,
|
|
"method": c.Request.Method,
|
|
})
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// InitializeAuth initializes authentication configuration
|
|
func InitializeAuth(cfg *config.Config) {
|
|
// This function can be used to initialize global auth settings if needed
|
|
logger.Default().Info("Authentication initialized", map[string]interface{}{
|
|
"auth_type": cfg.Auth.Type,
|
|
})
|
|
}
|
|
|
|
// Helper functions
|
|
func getClaimString(claims jwt.MapClaims, key string) string {
|
|
if value, ok := claims[key]; ok && value != nil {
|
|
if str, ok := value.(string); ok {
|
|
return str
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// JwksCache and related functions
|
|
type JwksCache struct {
|
|
mu sync.RWMutex
|
|
keys map[string]*rsa.PublicKey
|
|
expiresAt time.Time
|
|
sfGroup singleflight.Group
|
|
config *config.Config
|
|
}
|
|
|
|
func NewJwksCache(cfg *config.Config) *JwksCache {
|
|
return &JwksCache{
|
|
keys: make(map[string]*rsa.PublicKey),
|
|
config: cfg,
|
|
}
|
|
}
|
|
|
|
func (c *JwksCache) GetKey(kid string) (*rsa.PublicKey, error) {
|
|
c.mu.RLock()
|
|
if key, ok := c.keys[kid]; ok && time.Now().Before(c.expiresAt) {
|
|
c.mu.RUnlock()
|
|
return key, nil
|
|
}
|
|
c.mu.RUnlock()
|
|
|
|
// Fetch keys with singleflight to avoid concurrent fetches
|
|
v, err, _ := c.sfGroup.Do("fetch_jwks", func() (interface{}, error) {
|
|
return c.fetchKeys()
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
keys := v.(map[string]*rsa.PublicKey)
|
|
|
|
c.mu.Lock()
|
|
c.keys = keys
|
|
c.expiresAt = time.Now().Add(1 * time.Hour) // cache for 1 hour
|
|
c.mu.Unlock()
|
|
|
|
key, ok := keys[kid]
|
|
if !ok {
|
|
return nil, fmt.Errorf("key with kid %s not found", kid)
|
|
}
|
|
return key, nil
|
|
}
|
|
|
|
func (c *JwksCache) fetchKeys() (map[string]*rsa.PublicKey, error) {
|
|
if c.config.Keycloak.Issuer == "" {
|
|
return nil, fmt.Errorf("keycloak issuer is not configured")
|
|
}
|
|
|
|
jwksURL := c.config.Keycloak.JwksURL
|
|
if jwksURL == "" {
|
|
// Construct JWKS URL from issuer if not explicitly provided
|
|
jwksURL = c.config.Keycloak.Issuer + "/protocol/openid-connect/certs"
|
|
}
|
|
|
|
client := &http.Client{Timeout: 10 * time.Second}
|
|
resp, err := client.Get(jwksURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("failed to fetch JWKS: HTTP %d", resp.StatusCode)
|
|
}
|
|
|
|
var jwksData struct {
|
|
Keys []struct {
|
|
Kid string `json:"kid"`
|
|
Kty string `json:"kty"`
|
|
N string `json:"n"`
|
|
E string `json:"e"`
|
|
} `json:"keys"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&jwksData); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
keys := make(map[string]*rsa.PublicKey)
|
|
for _, key := range jwksData.Keys {
|
|
if key.Kty != "RSA" {
|
|
continue
|
|
}
|
|
pubKey, err := parseRSAPublicKey(key.N, key.E)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
keys[key.Kid] = pubKey
|
|
}
|
|
return keys, nil
|
|
}
|
|
|
|
// parseRSAPublicKey parses RSA public key components from base64url strings
|
|
func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
|
|
nBytes, err := base64.RawURLEncoding.DecodeString(nStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
eBytes, err := base64.RawURLEncoding.DecodeString(eStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
n := new(big.Int).SetBytes(nBytes)
|
|
e := int(new(big.Int).SetBytes(eBytes).Int64())
|
|
|
|
return &rsa.PublicKey{
|
|
N: n,
|
|
E: e,
|
|
}, nil
|
|
}
|