package middleware import ( "api-service/internal/config" "api-service/internal/models/auth" service "api-service/internal/services/auth" "api-service/pkg/logger" "errors" "fmt" "net/http" "strings" "sync" "time" "github.com/gin-gonic/gin" "golang.org/x/time/rate" ) var ( ErrInvalidToken = errors.New("invalid token") ErrTokenExpired = errors.New("token expired") ErrInvalidSignature = errors.New("invalid token signature") ErrInvalidIssuer = errors.New("invalid token issuer") ErrInvalidAudience = errors.New("invalid token audience") ErrMissingClaims = errors.New("required claims missing") ErrInvalidAuthHeader = errors.New("invalid authorization header format") ErrMissingAuthHeader = errors.New("authorization header missing") ) // TokenCache interface for token caching type TokenCache interface { Get(tokenString string) (*auth.JWTClaims, bool) Set(tokenString string, claims *auth.JWTClaims, expiration time.Duration) Delete(tokenString string) } // InMemoryTokenCache implements TokenCache with in-memory storage type InMemoryTokenCache struct { tokens map[string]cacheEntry mu sync.RWMutex } type cacheEntry struct { claims *auth.JWTClaims expiration time.Time } func NewInMemoryTokenCache() *InMemoryTokenCache { cache := &InMemoryTokenCache{ tokens: make(map[string]cacheEntry), } // Start cleanup goroutine go cache.cleanup() return cache } func (c *InMemoryTokenCache) Get(tokenString string) (*auth.JWTClaims, bool) { c.mu.RLock() defer c.mu.RUnlock() entry, exists := c.tokens[tokenString] if !exists || time.Now().After(entry.expiration) { return nil, false } return entry.claims, true } func (c *InMemoryTokenCache) Set(tokenString string, claims *auth.JWTClaims, expiration time.Duration) { c.mu.Lock() defer c.mu.Unlock() c.tokens[tokenString] = cacheEntry{ claims: claims, expiration: time.Now().Add(expiration), } } func (c *InMemoryTokenCache) Delete(tokenString string) { c.mu.Lock() defer c.mu.Unlock() delete(c.tokens, tokenString) } func (c *InMemoryTokenCache) cleanup() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for range ticker.C { c.mu.Lock() now := time.Now() for token, entry := range c.tokens { if now.After(entry.expiration) { delete(c.tokens, token) } } c.mu.Unlock() } } // AuthMiddleware provides authentication with rate limiting and caching type AuthMiddleware struct { providers []AuthProvider tokenCache TokenCache rateLimiter *rate.Limiter config *config.Config } func NewAuthMiddleware( cfg *config.Config, authService *service.AuthService, tokenCache TokenCache, ) *AuthMiddleware { factory := NewProviderFactory(authService, cfg) providers := factory.CreateProviders() // Rate limit: 10 requests per second with burst of 20 limiter := rate.NewLimiter(10, 20) // Use default cache if none provided if tokenCache == nil { tokenCache = NewInMemoryTokenCache() } return &AuthMiddleware{ providers: providers, tokenCache: tokenCache, rateLimiter: limiter, config: cfg, } } // RequireAuth enforces authentication func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc { return m.authenticate(false) } // OptionalAuth allows both authenticated and unauthenticated requests func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc { return m.authenticate(true) } // authenticate is the core authentication logic func (m *AuthMiddleware) authenticate(optional bool) gin.HandlerFunc { return func(c *gin.Context) { reqLogger := logger.Default().WithService("auth-middleware") reqLogger.Info("Starting authentication", map[string]interface{}{ "path": c.Request.URL.Path, "optional": optional, }) // Apply rate limiting if !m.rateLimiter.Allow() { reqLogger.Warn("Rate limit exceeded") c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{ "error": "rate limit exceeded", }) return } authHeader := c.GetHeader("Authorization") if authHeader == "" { if optional { c.Next() return } reqLogger.Warn("Authorization header missing") c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": ErrMissingAuthHeader.Error(), }) return } parts := strings.SplitN(authHeader, " ", 2) if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { if optional { c.Next() return } reqLogger.Warn("Invalid authorization header format") c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": ErrInvalidAuthHeader.Error(), }) return } tokenString := parts[1] // Check cache first if claims, found := m.tokenCache.Get(tokenString); found { reqLogger.Info("Token retrieved from cache", map[string]interface{}{ "user_id": claims.UserID, }) m.setUserInfo(c, claims, "cache") c.Next() return } // Try each provider until one succeeds var validatedClaims *auth.JWTClaims var err error var providerName string var providerErrors []string for _, provider := range m.providers { providerLog := reqLogger.WithField("provider", provider.Name()) providerLog.Info("Trying provider") validatedClaims, err = provider.ValidateToken(tokenString) if err == nil { providerName = provider.Name() providerLog.Info("Authentication successful", map[string]interface{}{ "user_id": validatedClaims.UserID, }) break } providerLog.Warn("Provider validation failed", map[string]interface{}{ "error": err.Error(), }) providerErrors = append(providerErrors, fmt.Sprintf("provider %s: %v", provider.Name(), err)) } if err != nil { if optional { c.Next() return } reqLogger.Error("All providers failed", map[string]interface{}{ "errors": strings.Join(providerErrors, "; "), }) // Return specific error message based on the error type errorMessage := "Token tidak valid" if errors.Is(err, ErrTokenExpired) { errorMessage = "Token telah kadaluarsa" } else if errors.Is(err, ErrInvalidSignature) { errorMessage = "Signature token tidak valid" } else if errors.Is(err, ErrInvalidIssuer) { errorMessage = "Issuer token tidak valid" } else if errors.Is(err, ErrInvalidAudience) { errorMessage = "Audience token tidak valid" } c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": errorMessage, "details": strings.Join(providerErrors, "; "), }) return } // Cache the validated token m.tokenCache.Set(tokenString, validatedClaims, 5*time.Minute) // Set user info in context m.setUserInfo(c, validatedClaims, providerName) c.Next() } } // setUserInfo sets user information in the Gin context func (m *AuthMiddleware) setUserInfo(c *gin.Context, claims *auth.JWTClaims, providerName string) { c.Set("user_id", claims.UserID) c.Set("username", claims.Username) c.Set("email", claims.Email) c.Set("role", claims.Role) c.Set("auth_provider", providerName) } // RequireRole creates a middleware that requires a specific role func (m *AuthMiddleware) RequireRole(requiredRole string) gin.HandlerFunc { return func(c *gin.Context) { role, exists := c.Get("role") if !exists { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "user role not found", }) return } userRole, ok := role.(string) if !ok { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{ "error": "invalid role format", }) return } if userRole != requiredRole { c.AbortWithStatusJSON(http.StatusForbidden, gin.H{ "error": fmt.Sprintf("requires %s role", requiredRole), }) return } c.Next() } }