Update besar
This commit is contained in:
@@ -0,0 +1,305 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"api-service/internal/config"
|
||||
"api-service/internal/models/auth"
|
||||
service "api-service/internal/services/auth"
|
||||
"api-service/pkg/logger"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
ErrTokenExpired = errors.New("token expired")
|
||||
ErrInvalidSignature = errors.New("invalid token signature")
|
||||
ErrInvalidIssuer = errors.New("invalid token issuer")
|
||||
ErrInvalidAudience = errors.New("invalid token audience")
|
||||
ErrMissingClaims = errors.New("required claims missing")
|
||||
ErrInvalidAuthHeader = errors.New("invalid authorization header format")
|
||||
ErrMissingAuthHeader = errors.New("authorization header missing")
|
||||
)
|
||||
|
||||
// TokenCache interface for token caching
|
||||
type TokenCache interface {
|
||||
Get(tokenString string) (*auth.JWTClaims, bool)
|
||||
Set(tokenString string, claims *auth.JWTClaims, expiration time.Duration)
|
||||
Delete(tokenString string)
|
||||
}
|
||||
|
||||
// InMemoryTokenCache implements TokenCache with in-memory storage
|
||||
type InMemoryTokenCache struct {
|
||||
tokens map[string]cacheEntry
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
claims *auth.JWTClaims
|
||||
expiration time.Time
|
||||
}
|
||||
|
||||
func NewInMemoryTokenCache() *InMemoryTokenCache {
|
||||
cache := &InMemoryTokenCache{
|
||||
tokens: make(map[string]cacheEntry),
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
go cache.cleanup()
|
||||
|
||||
return cache
|
||||
}
|
||||
|
||||
func (c *InMemoryTokenCache) Get(tokenString string) (*auth.JWTClaims, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
entry, exists := c.tokens[tokenString]
|
||||
if !exists || time.Now().After(entry.expiration) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return entry.claims, true
|
||||
}
|
||||
|
||||
func (c *InMemoryTokenCache) Set(tokenString string, claims *auth.JWTClaims, expiration time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.tokens[tokenString] = cacheEntry{
|
||||
claims: claims,
|
||||
expiration: time.Now().Add(expiration),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *InMemoryTokenCache) Delete(tokenString string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
delete(c.tokens, tokenString)
|
||||
}
|
||||
|
||||
func (c *InMemoryTokenCache) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
c.mu.Lock()
|
||||
now := time.Now()
|
||||
for token, entry := range c.tokens {
|
||||
if now.After(entry.expiration) {
|
||||
delete(c.tokens, token)
|
||||
}
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// AuthMiddleware provides authentication with rate limiting and caching
|
||||
type AuthMiddleware struct {
|
||||
providers []AuthProvider
|
||||
tokenCache TokenCache
|
||||
rateLimiter *rate.Limiter
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func NewAuthMiddleware(
|
||||
cfg *config.Config,
|
||||
authService *service.AuthService,
|
||||
tokenCache TokenCache,
|
||||
) *AuthMiddleware {
|
||||
factory := NewProviderFactory(authService, cfg)
|
||||
providers := factory.CreateProviders()
|
||||
|
||||
// Rate limit: 10 requests per second with burst of 20
|
||||
limiter := rate.NewLimiter(10, 20)
|
||||
|
||||
// Use default cache if none provided
|
||||
if tokenCache == nil {
|
||||
tokenCache = NewInMemoryTokenCache()
|
||||
}
|
||||
|
||||
return &AuthMiddleware{
|
||||
providers: providers,
|
||||
tokenCache: tokenCache,
|
||||
rateLimiter: limiter,
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAuth enforces authentication
|
||||
func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc {
|
||||
return m.authenticate(false)
|
||||
}
|
||||
|
||||
// OptionalAuth allows both authenticated and unauthenticated requests
|
||||
func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc {
|
||||
return m.authenticate(true)
|
||||
}
|
||||
|
||||
// authenticate is the core authentication logic
|
||||
func (m *AuthMiddleware) authenticate(optional bool) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
reqLogger := logger.Default().WithService("auth-middleware")
|
||||
reqLogger.Info("Starting authentication", map[string]interface{}{
|
||||
"path": c.Request.URL.Path,
|
||||
"optional": optional,
|
||||
})
|
||||
|
||||
// Apply rate limiting
|
||||
if !m.rateLimiter.Allow() {
|
||||
reqLogger.Warn("Rate limit exceeded")
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate limit exceeded",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
if optional {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
reqLogger.Warn("Authorization header missing")
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": ErrMissingAuthHeader.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
|
||||
if optional {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
reqLogger.Warn("Invalid authorization header format")
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": ErrInvalidAuthHeader.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
|
||||
// Check cache first
|
||||
if claims, found := m.tokenCache.Get(tokenString); found {
|
||||
reqLogger.Info("Token retrieved from cache", map[string]interface{}{
|
||||
"user_id": claims.UserID,
|
||||
})
|
||||
|
||||
m.setUserInfo(c, claims, "cache")
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Try each provider until one succeeds
|
||||
var validatedClaims *auth.JWTClaims
|
||||
var err error
|
||||
var providerName string
|
||||
var providerErrors []string
|
||||
|
||||
for _, provider := range m.providers {
|
||||
providerLog := reqLogger.WithField("provider", provider.Name())
|
||||
providerLog.Info("Trying provider")
|
||||
|
||||
validatedClaims, err = provider.ValidateToken(tokenString)
|
||||
if err == nil {
|
||||
providerName = provider.Name()
|
||||
providerLog.Info("Authentication successful", map[string]interface{}{
|
||||
"user_id": validatedClaims.UserID,
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
providerLog.Warn("Provider validation failed", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
providerErrors = append(providerErrors, fmt.Sprintf("provider %s: %v", provider.Name(), err))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if optional {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
reqLogger.Error("All providers failed", map[string]interface{}{
|
||||
"errors": strings.Join(providerErrors, "; "),
|
||||
})
|
||||
|
||||
// Return specific error message based on the error type
|
||||
errorMessage := "Token tidak valid"
|
||||
if errors.Is(err, ErrTokenExpired) {
|
||||
errorMessage = "Token telah kadaluarsa"
|
||||
} else if errors.Is(err, ErrInvalidSignature) {
|
||||
errorMessage = "Signature token tidak valid"
|
||||
} else if errors.Is(err, ErrInvalidIssuer) {
|
||||
errorMessage = "Issuer token tidak valid"
|
||||
} else if errors.Is(err, ErrInvalidAudience) {
|
||||
errorMessage = "Audience token tidak valid"
|
||||
}
|
||||
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": errorMessage,
|
||||
"details": strings.Join(providerErrors, "; "),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Cache the validated token
|
||||
m.tokenCache.Set(tokenString, validatedClaims, 5*time.Minute)
|
||||
|
||||
// Set user info in context
|
||||
m.setUserInfo(c, validatedClaims, providerName)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// setUserInfo sets user information in the Gin context
|
||||
func (m *AuthMiddleware) setUserInfo(c *gin.Context, claims *auth.JWTClaims, providerName string) {
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("email", claims.Email)
|
||||
c.Set("role", claims.Role)
|
||||
c.Set("auth_provider", providerName)
|
||||
}
|
||||
|
||||
// RequireRole creates a middleware that requires a specific role
|
||||
func (m *AuthMiddleware) RequireRole(requiredRole string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
role, exists := c.Get("role")
|
||||
if !exists {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "user role not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userRole, ok := role.(string)
|
||||
if !ok {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "invalid role format",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if userRole != requiredRole {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
||||
"error": fmt.Sprintf("requires %s role", requiredRole),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user