273 lines
6.4 KiB
Go
273 lines
6.4 KiB
Go
package middleware
|
|
|
|
import (
|
|
"fmt"
|
|
"html"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-contrib/cors"
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
// SecurityHeaders adds security headers to all responses
|
|
func SecurityHeaders() gin.HandlerFunc {
|
|
return gin.HandlerFunc(func(c *gin.Context) {
|
|
// Prevent clickjacking
|
|
c.Header("X-Frame-Options", "DENY")
|
|
|
|
// Prevent MIME type sniffing
|
|
c.Header("X-Content-Type-Options", "nosniff")
|
|
|
|
// Enable XSS protection
|
|
c.Header("X-XSS-Protection", "1; mode=block")
|
|
|
|
// Referrer policy
|
|
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
|
|
// Content Security Policy - adjust as needed
|
|
c.Header("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
|
|
|
|
// HSTS (HTTP Strict Transport Security) - only for HTTPS
|
|
if c.Request.TLS != nil {
|
|
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
|
}
|
|
|
|
c.Next()
|
|
})
|
|
}
|
|
|
|
// InputSanitization sanitizes user inputs to prevent XSS and injection attacks
|
|
func InputSanitization() gin.HandlerFunc {
|
|
return gin.HandlerFunc(func(c *gin.Context) {
|
|
// Sanitize query parameters
|
|
sanitizeQueryParams(c)
|
|
|
|
// Sanitize form data
|
|
sanitizeFormData(c)
|
|
|
|
// Sanitize JSON body if present
|
|
if c.ContentType() == "application/json" {
|
|
sanitizeJSONBody(c)
|
|
}
|
|
|
|
c.Next()
|
|
})
|
|
}
|
|
|
|
// sanitizeQueryParams sanitizes all query parameters
|
|
func sanitizeQueryParams(c *gin.Context) {
|
|
query := c.Request.URL.Query()
|
|
|
|
for key, values := range query {
|
|
for i, value := range values {
|
|
query[key][i] = sanitizeString(value)
|
|
}
|
|
}
|
|
|
|
c.Request.URL.RawQuery = query.Encode()
|
|
}
|
|
|
|
// sanitizeFormData sanitizes form data
|
|
func sanitizeFormData(c *gin.Context) {
|
|
if err := c.Request.ParseForm(); err != nil {
|
|
return
|
|
}
|
|
|
|
for key, values := range c.Request.PostForm {
|
|
for i, value := range values {
|
|
c.Request.PostForm[key][i] = sanitizeString(value)
|
|
}
|
|
}
|
|
}
|
|
|
|
// sanitizeJSONBody sanitizes JSON request body
|
|
func sanitizeJSONBody(c *gin.Context) {
|
|
// For JSON bodies, we'll let the JSON binding handle it
|
|
// and sanitize at the handler level if needed
|
|
c.Next()
|
|
}
|
|
|
|
// sanitizeString performs basic sanitization on a string
|
|
func sanitizeString(input string) string {
|
|
// Remove null bytes
|
|
input = strings.ReplaceAll(input, "\x00", "")
|
|
|
|
// HTML escape
|
|
input = html.EscapeString(input)
|
|
|
|
// Remove potentially dangerous characters
|
|
dangerousChars := []string{"<", ">", "\"", "'", "`", "\\", "\n", "\r", "\t"}
|
|
for _, char := range dangerousChars {
|
|
input = strings.ReplaceAll(input, char, "")
|
|
}
|
|
|
|
return strings.TrimSpace(input)
|
|
}
|
|
|
|
// SQLInjectionProtection provides additional SQL injection protection
|
|
func SQLInjectionProtection() gin.HandlerFunc {
|
|
return gin.HandlerFunc(func(c *gin.Context) {
|
|
// Check for suspicious patterns in query parameters
|
|
if hasSQLInjectionPatterns(c) {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
|
|
"error": "Invalid input detected",
|
|
"message": "Request contains potentially malicious content",
|
|
})
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
})
|
|
}
|
|
|
|
// hasSQLInjectionPatterns checks for common SQL injection patterns
|
|
func hasSQLInjectionPatterns(c *gin.Context) bool {
|
|
suspiciousPatterns := []string{
|
|
"union select",
|
|
"union all select",
|
|
"select.*from",
|
|
"insert.*into",
|
|
"update.*set",
|
|
"delete.*from",
|
|
"drop table",
|
|
"drop database",
|
|
"alter table",
|
|
"create table",
|
|
"exec(",
|
|
"execute(",
|
|
"xp_",
|
|
"sp_",
|
|
"information_schema",
|
|
"sysobjects",
|
|
"syscolumns",
|
|
"sysdatabases",
|
|
"mysql.",
|
|
"pg_",
|
|
"sqlite_",
|
|
";--",
|
|
"/*",
|
|
"*/",
|
|
"@@",
|
|
"script>",
|
|
"<script",
|
|
"javascript:",
|
|
"vbscript:",
|
|
"onload=",
|
|
"onerror=",
|
|
"eval(",
|
|
"alert(",
|
|
}
|
|
|
|
query := strings.ToLower(c.Request.URL.RawQuery)
|
|
for _, pattern := range suspiciousPatterns {
|
|
if strings.Contains(query, pattern) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
// Check form data
|
|
if err := c.Request.ParseForm(); err == nil {
|
|
for _, values := range c.Request.Form {
|
|
for _, value := range values {
|
|
lowerValue := strings.ToLower(value)
|
|
for _, pattern := range suspiciousPatterns {
|
|
if strings.Contains(lowerValue, pattern) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// RateLimitByIP provides basic rate limiting by IP
|
|
func RateLimitByIP(requestsPerMinute int) gin.HandlerFunc {
|
|
// Simple in-memory rate limiter
|
|
// In production, use Redis or similar
|
|
type client struct {
|
|
count int
|
|
resetTime int64
|
|
}
|
|
|
|
clients := make(map[string]*client)
|
|
|
|
return func(c *gin.Context) {
|
|
ip := c.ClientIP()
|
|
now := time.Now().Unix()
|
|
|
|
if clients[ip] == nil {
|
|
clients[ip] = &client{count: 0, resetTime: now + 60}
|
|
}
|
|
|
|
client := clients[ip]
|
|
|
|
// Reset counter if time window passed
|
|
if now > client.resetTime {
|
|
client.count = 0
|
|
client.resetTime = now + 60
|
|
}
|
|
|
|
if client.count >= requestsPerMinute {
|
|
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
|
"error": "Rate limit exceeded",
|
|
"message": "Too many requests. Please try again later.",
|
|
})
|
|
return
|
|
}
|
|
|
|
client.count++
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// ValidateInputLength validates input length to prevent buffer overflow
|
|
func ValidateInputLength(maxLength int) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Check query parameters
|
|
for key, values := range c.Request.URL.Query() {
|
|
for _, value := range values {
|
|
if len(value) > maxLength {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
|
|
"error": "Input too long",
|
|
"message": fmt.Sprintf("Parameter '%s' exceeds maximum length of %d characters", key, maxLength),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check form data
|
|
if err := c.Request.ParseForm(); err == nil {
|
|
for key, values := range c.Request.PostForm {
|
|
for _, value := range values {
|
|
if len(value) > maxLength {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
|
|
"error": "Input too long",
|
|
"message": fmt.Sprintf("Parameter '%s' exceeds maximum length of %d characters", key, maxLength),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// SecureCORSConfig provides secure CORS configuration
|
|
func SecureCORSConfig() gin.HandlerFunc {
|
|
return cors.New(cors.Config{
|
|
AllowOrigins: []string{"http://localhost:3000", "http://localhost:8080"}, // Configure allowed origins
|
|
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
|
AllowHeaders: []string{"Origin", "Content-Type", "Authorization", "X-Requested-With"},
|
|
ExposeHeaders: []string{"Content-Length"},
|
|
AllowCredentials: true,
|
|
MaxAge: 12 * time.Hour,
|
|
})
|
|
}
|