Files
antrean-anjungan/internal/middleware/security.go
2025-10-23 04:25:28 +07:00

273 lines
6.4 KiB
Go

package middleware
import (
"fmt"
"html"
"net/http"
"strings"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
)
// SecurityHeaders adds security headers to all responses
func SecurityHeaders() gin.HandlerFunc {
return gin.HandlerFunc(func(c *gin.Context) {
// Prevent clickjacking
c.Header("X-Frame-Options", "DENY")
// Prevent MIME type sniffing
c.Header("X-Content-Type-Options", "nosniff")
// Enable XSS protection
c.Header("X-XSS-Protection", "1; mode=block")
// Referrer policy
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
// Content Security Policy - adjust as needed
c.Header("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
// HSTS (HTTP Strict Transport Security) - only for HTTPS
if c.Request.TLS != nil {
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
}
c.Next()
})
}
// InputSanitization sanitizes user inputs to prevent XSS and injection attacks
func InputSanitization() gin.HandlerFunc {
return gin.HandlerFunc(func(c *gin.Context) {
// Sanitize query parameters
sanitizeQueryParams(c)
// Sanitize form data
sanitizeFormData(c)
// Sanitize JSON body if present
if c.ContentType() == "application/json" {
sanitizeJSONBody(c)
}
c.Next()
})
}
// sanitizeQueryParams sanitizes all query parameters
func sanitizeQueryParams(c *gin.Context) {
query := c.Request.URL.Query()
for key, values := range query {
for i, value := range values {
query[key][i] = sanitizeString(value)
}
}
c.Request.URL.RawQuery = query.Encode()
}
// sanitizeFormData sanitizes form data
func sanitizeFormData(c *gin.Context) {
if err := c.Request.ParseForm(); err != nil {
return
}
for key, values := range c.Request.PostForm {
for i, value := range values {
c.Request.PostForm[key][i] = sanitizeString(value)
}
}
}
// sanitizeJSONBody sanitizes JSON request body
func sanitizeJSONBody(c *gin.Context) {
// For JSON bodies, we'll let the JSON binding handle it
// and sanitize at the handler level if needed
c.Next()
}
// sanitizeString performs basic sanitization on a string
func sanitizeString(input string) string {
// Remove null bytes
input = strings.ReplaceAll(input, "\x00", "")
// HTML escape
input = html.EscapeString(input)
// Remove potentially dangerous characters
dangerousChars := []string{"<", ">", "\"", "'", "`", "\\", "\n", "\r", "\t"}
for _, char := range dangerousChars {
input = strings.ReplaceAll(input, char, "")
}
return strings.TrimSpace(input)
}
// SQLInjectionProtection provides additional SQL injection protection
func SQLInjectionProtection() gin.HandlerFunc {
return gin.HandlerFunc(func(c *gin.Context) {
// Check for suspicious patterns in query parameters
if hasSQLInjectionPatterns(c) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
"error": "Invalid input detected",
"message": "Request contains potentially malicious content",
})
return
}
c.Next()
})
}
// hasSQLInjectionPatterns checks for common SQL injection patterns
func hasSQLInjectionPatterns(c *gin.Context) bool {
suspiciousPatterns := []string{
"union select",
"union all select",
"select.*from",
"insert.*into",
"update.*set",
"delete.*from",
"drop table",
"drop database",
"alter table",
"create table",
"exec(",
"execute(",
"xp_",
"sp_",
"information_schema",
"sysobjects",
"syscolumns",
"sysdatabases",
"mysql.",
"pg_",
"sqlite_",
";--",
"/*",
"*/",
"@@",
"script>",
"<script",
"javascript:",
"vbscript:",
"onload=",
"onerror=",
"eval(",
"alert(",
}
query := strings.ToLower(c.Request.URL.RawQuery)
for _, pattern := range suspiciousPatterns {
if strings.Contains(query, pattern) {
return true
}
}
// Check form data
if err := c.Request.ParseForm(); err == nil {
for _, values := range c.Request.Form {
for _, value := range values {
lowerValue := strings.ToLower(value)
for _, pattern := range suspiciousPatterns {
if strings.Contains(lowerValue, pattern) {
return true
}
}
}
}
}
return false
}
// RateLimitByIP provides basic rate limiting by IP
func RateLimitByIP(requestsPerMinute int) gin.HandlerFunc {
// Simple in-memory rate limiter
// In production, use Redis or similar
type client struct {
count int
resetTime int64
}
clients := make(map[string]*client)
return func(c *gin.Context) {
ip := c.ClientIP()
now := time.Now().Unix()
if clients[ip] == nil {
clients[ip] = &client{count: 0, resetTime: now + 60}
}
client := clients[ip]
// Reset counter if time window passed
if now > client.resetTime {
client.count = 0
client.resetTime = now + 60
}
if client.count >= requestsPerMinute {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "Rate limit exceeded",
"message": "Too many requests. Please try again later.",
})
return
}
client.count++
c.Next()
}
}
// ValidateInputLength validates input length to prevent buffer overflow
func ValidateInputLength(maxLength int) gin.HandlerFunc {
return func(c *gin.Context) {
// Check query parameters
for key, values := range c.Request.URL.Query() {
for _, value := range values {
if len(value) > maxLength {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
"error": "Input too long",
"message": fmt.Sprintf("Parameter '%s' exceeds maximum length of %d characters", key, maxLength),
})
return
}
}
}
// Check form data
if err := c.Request.ParseForm(); err == nil {
for key, values := range c.Request.PostForm {
for _, value := range values {
if len(value) > maxLength {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
"error": "Input too long",
"message": fmt.Sprintf("Parameter '%s' exceeds maximum length of %d characters", key, maxLength),
})
return
}
}
}
}
c.Next()
}
}
// SecureCORSConfig provides secure CORS configuration
func SecureCORSConfig() gin.HandlerFunc {
return cors.New(cors.Config{
AllowOrigins: []string{"http://localhost:3000", "http://localhost:8080"}, // Configure allowed origins
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowHeaders: []string{"Origin", "Content-Type", "Authorization", "X-Requested-With"},
ExposeHeaders: []string{"Content-Length"},
AllowCredentials: true,
MaxAge: 12 * time.Hour,
})
}