Files
api_antrean/internal/middleware/auth.go
2025-10-31 02:30:27 +00:00

306 lines
7.5 KiB
Go

package middleware
import (
"api-service/internal/config"
"api-service/internal/models/auth"
service "api-service/internal/services/auth"
"api-service/pkg/logger"
"errors"
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
var (
ErrInvalidToken = errors.New("invalid token")
ErrTokenExpired = errors.New("token expired")
ErrInvalidSignature = errors.New("invalid token signature")
ErrInvalidIssuer = errors.New("invalid token issuer")
ErrInvalidAudience = errors.New("invalid token audience")
ErrMissingClaims = errors.New("required claims missing")
ErrInvalidAuthHeader = errors.New("invalid authorization header format")
ErrMissingAuthHeader = errors.New("authorization header missing")
)
// TokenCache interface for token caching
type TokenCache interface {
Get(tokenString string) (*auth.JWTClaims, bool)
Set(tokenString string, claims *auth.JWTClaims, expiration time.Duration)
Delete(tokenString string)
}
// InMemoryTokenCache implements TokenCache with in-memory storage
type InMemoryTokenCache struct {
tokens map[string]cacheEntry
mu sync.RWMutex
}
type cacheEntry struct {
claims *auth.JWTClaims
expiration time.Time
}
func NewInMemoryTokenCache() *InMemoryTokenCache {
cache := &InMemoryTokenCache{
tokens: make(map[string]cacheEntry),
}
// Start cleanup goroutine
go cache.cleanup()
return cache
}
func (c *InMemoryTokenCache) Get(tokenString string) (*auth.JWTClaims, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
entry, exists := c.tokens[tokenString]
if !exists || time.Now().After(entry.expiration) {
return nil, false
}
return entry.claims, true
}
func (c *InMemoryTokenCache) Set(tokenString string, claims *auth.JWTClaims, expiration time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.tokens[tokenString] = cacheEntry{
claims: claims,
expiration: time.Now().Add(expiration),
}
}
func (c *InMemoryTokenCache) Delete(tokenString string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.tokens, tokenString)
}
func (c *InMemoryTokenCache) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
c.mu.Lock()
now := time.Now()
for token, entry := range c.tokens {
if now.After(entry.expiration) {
delete(c.tokens, token)
}
}
c.mu.Unlock()
}
}
// AuthMiddleware provides authentication with rate limiting and caching
type AuthMiddleware struct {
providers []AuthProvider
tokenCache TokenCache
rateLimiter *rate.Limiter
config *config.Config
}
func NewAuthMiddleware(
cfg *config.Config,
authService *service.AuthService,
tokenCache TokenCache,
) *AuthMiddleware {
factory := NewProviderFactory(authService, cfg)
providers := factory.CreateProviders()
// Rate limit: 10 requests per second with burst of 20
limiter := rate.NewLimiter(10, 20)
// Use default cache if none provided
if tokenCache == nil {
tokenCache = NewInMemoryTokenCache()
}
return &AuthMiddleware{
providers: providers,
tokenCache: tokenCache,
rateLimiter: limiter,
config: cfg,
}
}
// RequireAuth enforces authentication
func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc {
return m.authenticate(false)
}
// OptionalAuth allows both authenticated and unauthenticated requests
func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc {
return m.authenticate(true)
}
// authenticate is the core authentication logic
func (m *AuthMiddleware) authenticate(optional bool) gin.HandlerFunc {
return func(c *gin.Context) {
reqLogger := logger.Default().WithService("auth-middleware")
reqLogger.Info("Starting authentication", map[string]interface{}{
"path": c.Request.URL.Path,
"optional": optional,
})
// Apply rate limiting
if !m.rateLimiter.Allow() {
reqLogger.Warn("Rate limit exceeded")
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "rate limit exceeded",
})
return
}
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
if optional {
c.Next()
return
}
reqLogger.Warn("Authorization header missing")
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": ErrMissingAuthHeader.Error(),
})
return
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
if optional {
c.Next()
return
}
reqLogger.Warn("Invalid authorization header format")
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": ErrInvalidAuthHeader.Error(),
})
return
}
tokenString := parts[1]
// Check cache first
if claims, found := m.tokenCache.Get(tokenString); found {
reqLogger.Info("Token retrieved from cache", map[string]interface{}{
"user_id": claims.UserID,
})
m.setUserInfo(c, claims, "cache")
c.Next()
return
}
// Try each provider until one succeeds
var validatedClaims *auth.JWTClaims
var err error
var providerName string
var providerErrors []string
for _, provider := range m.providers {
providerLog := reqLogger.WithField("provider", provider.Name())
providerLog.Info("Trying provider")
validatedClaims, err = provider.ValidateToken(tokenString)
if err == nil {
providerName = provider.Name()
providerLog.Info("Authentication successful", map[string]interface{}{
"user_id": validatedClaims.UserID,
})
break
}
providerLog.Warn("Provider validation failed", map[string]interface{}{
"error": err.Error(),
})
providerErrors = append(providerErrors, fmt.Sprintf("provider %s: %v", provider.Name(), err))
}
if err != nil {
if optional {
c.Next()
return
}
reqLogger.Error("All providers failed", map[string]interface{}{
"errors": strings.Join(providerErrors, "; "),
})
// Return specific error message based on the error type
errorMessage := "Token tidak valid"
if errors.Is(err, ErrTokenExpired) {
errorMessage = "Token telah kadaluarsa"
} else if errors.Is(err, ErrInvalidSignature) {
errorMessage = "Signature token tidak valid"
} else if errors.Is(err, ErrInvalidIssuer) {
errorMessage = "Issuer token tidak valid"
} else if errors.Is(err, ErrInvalidAudience) {
errorMessage = "Audience token tidak valid"
}
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": errorMessage,
"details": strings.Join(providerErrors, "; "),
})
return
}
// Cache the validated token
m.tokenCache.Set(tokenString, validatedClaims, 5*time.Minute)
// Set user info in context
m.setUserInfo(c, validatedClaims, providerName)
c.Next()
}
}
// setUserInfo sets user information in the Gin context
func (m *AuthMiddleware) setUserInfo(c *gin.Context, claims *auth.JWTClaims, providerName string) {
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("email", claims.Email)
c.Set("role", claims.Role)
c.Set("auth_provider", providerName)
}
// RequireRole creates a middleware that requires a specific role
func (m *AuthMiddleware) RequireRole(requiredRole string) gin.HandlerFunc {
return func(c *gin.Context) {
role, exists := c.Get("role")
if !exists {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "user role not found",
})
return
}
userRole, ok := role.(string)
if !ok {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
"error": "invalid role format",
})
return
}
if userRole != requiredRole {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": fmt.Sprintf("requires %s role", requiredRole),
})
return
}
c.Next()
}
}