306 lines
7.5 KiB
Go
306 lines
7.5 KiB
Go
package middleware
|
|
|
|
import (
|
|
"api-service/internal/config"
|
|
"api-service/internal/models/auth"
|
|
service "api-service/internal/services/auth"
|
|
"api-service/pkg/logger"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
var (
|
|
ErrInvalidToken = errors.New("invalid token")
|
|
ErrTokenExpired = errors.New("token expired")
|
|
ErrInvalidSignature = errors.New("invalid token signature")
|
|
ErrInvalidIssuer = errors.New("invalid token issuer")
|
|
ErrInvalidAudience = errors.New("invalid token audience")
|
|
ErrMissingClaims = errors.New("required claims missing")
|
|
ErrInvalidAuthHeader = errors.New("invalid authorization header format")
|
|
ErrMissingAuthHeader = errors.New("authorization header missing")
|
|
)
|
|
|
|
// TokenCache interface for token caching
|
|
type TokenCache interface {
|
|
Get(tokenString string) (*auth.JWTClaims, bool)
|
|
Set(tokenString string, claims *auth.JWTClaims, expiration time.Duration)
|
|
Delete(tokenString string)
|
|
}
|
|
|
|
// InMemoryTokenCache implements TokenCache with in-memory storage
|
|
type InMemoryTokenCache struct {
|
|
tokens map[string]cacheEntry
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
type cacheEntry struct {
|
|
claims *auth.JWTClaims
|
|
expiration time.Time
|
|
}
|
|
|
|
func NewInMemoryTokenCache() *InMemoryTokenCache {
|
|
cache := &InMemoryTokenCache{
|
|
tokens: make(map[string]cacheEntry),
|
|
}
|
|
|
|
// Start cleanup goroutine
|
|
go cache.cleanup()
|
|
|
|
return cache
|
|
}
|
|
|
|
func (c *InMemoryTokenCache) Get(tokenString string) (*auth.JWTClaims, bool) {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
|
|
entry, exists := c.tokens[tokenString]
|
|
if !exists || time.Now().After(entry.expiration) {
|
|
return nil, false
|
|
}
|
|
|
|
return entry.claims, true
|
|
}
|
|
|
|
func (c *InMemoryTokenCache) Set(tokenString string, claims *auth.JWTClaims, expiration time.Duration) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
c.tokens[tokenString] = cacheEntry{
|
|
claims: claims,
|
|
expiration: time.Now().Add(expiration),
|
|
}
|
|
}
|
|
|
|
func (c *InMemoryTokenCache) Delete(tokenString string) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
delete(c.tokens, tokenString)
|
|
}
|
|
|
|
func (c *InMemoryTokenCache) cleanup() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
c.mu.Lock()
|
|
now := time.Now()
|
|
for token, entry := range c.tokens {
|
|
if now.After(entry.expiration) {
|
|
delete(c.tokens, token)
|
|
}
|
|
}
|
|
c.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// AuthMiddleware provides authentication with rate limiting and caching
|
|
type AuthMiddleware struct {
|
|
providers []AuthProvider
|
|
tokenCache TokenCache
|
|
rateLimiter *rate.Limiter
|
|
config *config.Config
|
|
}
|
|
|
|
func NewAuthMiddleware(
|
|
cfg *config.Config,
|
|
authService *service.AuthService,
|
|
tokenCache TokenCache,
|
|
) *AuthMiddleware {
|
|
factory := NewProviderFactory(authService, cfg)
|
|
providers := factory.CreateProviders()
|
|
|
|
// Rate limit: 10 requests per second with burst of 20
|
|
limiter := rate.NewLimiter(10, 20)
|
|
|
|
// Use default cache if none provided
|
|
if tokenCache == nil {
|
|
tokenCache = NewInMemoryTokenCache()
|
|
}
|
|
|
|
return &AuthMiddleware{
|
|
providers: providers,
|
|
tokenCache: tokenCache,
|
|
rateLimiter: limiter,
|
|
config: cfg,
|
|
}
|
|
}
|
|
|
|
// RequireAuth enforces authentication
|
|
func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc {
|
|
return m.authenticate(false)
|
|
}
|
|
|
|
// OptionalAuth allows both authenticated and unauthenticated requests
|
|
func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc {
|
|
return m.authenticate(true)
|
|
}
|
|
|
|
// authenticate is the core authentication logic
|
|
func (m *AuthMiddleware) authenticate(optional bool) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
reqLogger := logger.Default().WithService("auth-middleware")
|
|
reqLogger.Info("Starting authentication", map[string]interface{}{
|
|
"path": c.Request.URL.Path,
|
|
"optional": optional,
|
|
})
|
|
|
|
// Apply rate limiting
|
|
if !m.rateLimiter.Allow() {
|
|
reqLogger.Warn("Rate limit exceeded")
|
|
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
|
"error": "rate limit exceeded",
|
|
})
|
|
return
|
|
}
|
|
|
|
authHeader := c.GetHeader("Authorization")
|
|
if authHeader == "" {
|
|
if optional {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
reqLogger.Warn("Authorization header missing")
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
|
"error": ErrMissingAuthHeader.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
parts := strings.SplitN(authHeader, " ", 2)
|
|
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
|
|
if optional {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
reqLogger.Warn("Invalid authorization header format")
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
|
"error": ErrInvalidAuthHeader.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
tokenString := parts[1]
|
|
|
|
// Check cache first
|
|
if claims, found := m.tokenCache.Get(tokenString); found {
|
|
reqLogger.Info("Token retrieved from cache", map[string]interface{}{
|
|
"user_id": claims.UserID,
|
|
})
|
|
|
|
m.setUserInfo(c, claims, "cache")
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
// Try each provider until one succeeds
|
|
var validatedClaims *auth.JWTClaims
|
|
var err error
|
|
var providerName string
|
|
var providerErrors []string
|
|
|
|
for _, provider := range m.providers {
|
|
providerLog := reqLogger.WithField("provider", provider.Name())
|
|
providerLog.Info("Trying provider")
|
|
|
|
validatedClaims, err = provider.ValidateToken(tokenString)
|
|
if err == nil {
|
|
providerName = provider.Name()
|
|
providerLog.Info("Authentication successful", map[string]interface{}{
|
|
"user_id": validatedClaims.UserID,
|
|
})
|
|
break
|
|
}
|
|
|
|
providerLog.Warn("Provider validation failed", map[string]interface{}{
|
|
"error": err.Error(),
|
|
})
|
|
providerErrors = append(providerErrors, fmt.Sprintf("provider %s: %v", provider.Name(), err))
|
|
}
|
|
|
|
if err != nil {
|
|
if optional {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
reqLogger.Error("All providers failed", map[string]interface{}{
|
|
"errors": strings.Join(providerErrors, "; "),
|
|
})
|
|
|
|
// Return specific error message based on the error type
|
|
errorMessage := "Token tidak valid"
|
|
if errors.Is(err, ErrTokenExpired) {
|
|
errorMessage = "Token telah kadaluarsa"
|
|
} else if errors.Is(err, ErrInvalidSignature) {
|
|
errorMessage = "Signature token tidak valid"
|
|
} else if errors.Is(err, ErrInvalidIssuer) {
|
|
errorMessage = "Issuer token tidak valid"
|
|
} else if errors.Is(err, ErrInvalidAudience) {
|
|
errorMessage = "Audience token tidak valid"
|
|
}
|
|
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
|
"error": errorMessage,
|
|
"details": strings.Join(providerErrors, "; "),
|
|
})
|
|
return
|
|
}
|
|
|
|
// Cache the validated token
|
|
m.tokenCache.Set(tokenString, validatedClaims, 5*time.Minute)
|
|
|
|
// Set user info in context
|
|
m.setUserInfo(c, validatedClaims, providerName)
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// setUserInfo sets user information in the Gin context
|
|
func (m *AuthMiddleware) setUserInfo(c *gin.Context, claims *auth.JWTClaims, providerName string) {
|
|
c.Set("user_id", claims.UserID)
|
|
c.Set("username", claims.Username)
|
|
c.Set("email", claims.Email)
|
|
c.Set("role", claims.Role)
|
|
c.Set("auth_provider", providerName)
|
|
}
|
|
|
|
// RequireRole creates a middleware that requires a specific role
|
|
func (m *AuthMiddleware) RequireRole(requiredRole string) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
role, exists := c.Get("role")
|
|
if !exists {
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
|
"error": "user role not found",
|
|
})
|
|
return
|
|
}
|
|
|
|
userRole, ok := role.(string)
|
|
if !ok {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
|
|
"error": "invalid role format",
|
|
})
|
|
return
|
|
}
|
|
|
|
if userRole != requiredRole {
|
|
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
|
"error": fmt.Sprintf("requires %s role", requiredRole),
|
|
})
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|