package middleware
import (
"fmt"
"html"
"net/http"
"strings"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
)
// SecurityHeaders adds security headers to all responses
func SecurityHeaders() gin.HandlerFunc {
return gin.HandlerFunc(func(c *gin.Context) {
// Prevent clickjacking
c.Header("X-Frame-Options", "DENY")
// Prevent MIME type sniffing
c.Header("X-Content-Type-Options", "nosniff")
// Enable XSS protection
c.Header("X-XSS-Protection", "1; mode=block")
// Referrer policy
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
// Content Security Policy - adjust as needed
c.Header("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
// HSTS (HTTP Strict Transport Security) - only for HTTPS
if c.Request.TLS != nil {
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
}
c.Next()
})
}
// InputSanitization sanitizes user inputs to prevent XSS and injection attacks
func InputSanitization() gin.HandlerFunc {
return gin.HandlerFunc(func(c *gin.Context) {
// Sanitize query parameters
sanitizeQueryParams(c)
// Sanitize form data
sanitizeFormData(c)
// Sanitize JSON body if present
if c.ContentType() == "application/json" {
sanitizeJSONBody(c)
}
c.Next()
})
}
// sanitizeQueryParams sanitizes all query parameters
func sanitizeQueryParams(c *gin.Context) {
query := c.Request.URL.Query()
for key, values := range query {
for i, value := range values {
query[key][i] = sanitizeString(value)
}
}
c.Request.URL.RawQuery = query.Encode()
}
// sanitizeFormData sanitizes form data
func sanitizeFormData(c *gin.Context) {
if err := c.Request.ParseForm(); err != nil {
return
}
for key, values := range c.Request.PostForm {
for i, value := range values {
c.Request.PostForm[key][i] = sanitizeString(value)
}
}
}
// sanitizeJSONBody sanitizes JSON request body
func sanitizeJSONBody(c *gin.Context) {
// For JSON bodies, we'll let the JSON binding handle it
// and sanitize at the handler level if needed
c.Next()
}
// sanitizeString performs basic sanitization on a string
func sanitizeString(input string) string {
// Remove null bytes
input = strings.ReplaceAll(input, "\x00", "")
// HTML escape
input = html.EscapeString(input)
// Remove potentially dangerous characters
dangerousChars := []string{"<", ">", "\"", "'", "`", "\\", "\n", "\r", "\t"}
for _, char := range dangerousChars {
input = strings.ReplaceAll(input, char, "")
}
return strings.TrimSpace(input)
}
// SQLInjectionProtection provides additional SQL injection protection
func SQLInjectionProtection() gin.HandlerFunc {
return gin.HandlerFunc(func(c *gin.Context) {
// Check for suspicious patterns in query parameters
if hasSQLInjectionPatterns(c) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
"error": "Invalid input detected",
"message": "Request contains potentially malicious content",
})
return
}
c.Next()
})
}
// hasSQLInjectionPatterns checks for common SQL injection patterns
func hasSQLInjectionPatterns(c *gin.Context) bool {
suspiciousPatterns := []string{
"union select",
"union all select",
"select.*from",
"insert.*into",
"update.*set",
"delete.*from",
"drop table",
"drop database",
"alter table",
"create table",
"exec(",
"execute(",
"xp_",
"sp_",
"information_schema",
"sysobjects",
"syscolumns",
"sysdatabases",
"mysql.",
"pg_",
"sqlite_",
";--",
"/*",
"*/",
"@@",
"script>",
"