package middleware import ( "fmt" "html" "net/http" "strings" "time" "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" ) // SecurityHeaders adds security headers to all responses func SecurityHeaders() gin.HandlerFunc { return gin.HandlerFunc(func(c *gin.Context) { // Prevent clickjacking c.Header("X-Frame-Options", "DENY") // Prevent MIME type sniffing c.Header("X-Content-Type-Options", "nosniff") // Enable XSS protection c.Header("X-XSS-Protection", "1; mode=block") // Referrer policy c.Header("Referrer-Policy", "strict-origin-when-cross-origin") // Content Security Policy - adjust as needed c.Header("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'") // HSTS (HTTP Strict Transport Security) - only for HTTPS if c.Request.TLS != nil { c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains") } c.Next() }) } // InputSanitization sanitizes user inputs to prevent XSS and injection attacks func InputSanitization() gin.HandlerFunc { return gin.HandlerFunc(func(c *gin.Context) { // Sanitize query parameters sanitizeQueryParams(c) // Sanitize form data sanitizeFormData(c) // Sanitize JSON body if present if c.ContentType() == "application/json" { sanitizeJSONBody(c) } c.Next() }) } // sanitizeQueryParams sanitizes all query parameters func sanitizeQueryParams(c *gin.Context) { query := c.Request.URL.Query() for key, values := range query { for i, value := range values { query[key][i] = sanitizeString(value) } } c.Request.URL.RawQuery = query.Encode() } // sanitizeFormData sanitizes form data func sanitizeFormData(c *gin.Context) { if err := c.Request.ParseForm(); err != nil { return } for key, values := range c.Request.PostForm { for i, value := range values { c.Request.PostForm[key][i] = sanitizeString(value) } } } // sanitizeJSONBody sanitizes JSON request body func sanitizeJSONBody(c *gin.Context) { // For JSON bodies, we'll let the JSON binding handle it // and sanitize at the handler level if needed c.Next() } // sanitizeString performs basic sanitization on a string func sanitizeString(input string) string { // Remove null bytes input = strings.ReplaceAll(input, "\x00", "") // HTML escape input = html.EscapeString(input) // Remove potentially dangerous characters dangerousChars := []string{"<", ">", "\"", "'", "`", "\\", "\n", "\r", "\t"} for _, char := range dangerousChars { input = strings.ReplaceAll(input, char, "") } return strings.TrimSpace(input) } // SQLInjectionProtection provides additional SQL injection protection func SQLInjectionProtection() gin.HandlerFunc { return gin.HandlerFunc(func(c *gin.Context) { // Check for suspicious patterns in query parameters if hasSQLInjectionPatterns(c) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{ "error": "Invalid input detected", "message": "Request contains potentially malicious content", }) return } c.Next() }) } // hasSQLInjectionPatterns checks for common SQL injection patterns func hasSQLInjectionPatterns(c *gin.Context) bool { suspiciousPatterns := []string{ "union select", "union all select", "select.*from", "insert.*into", "update.*set", "delete.*from", "drop table", "drop database", "alter table", "create table", "exec(", "execute(", "xp_", "sp_", "information_schema", "sysobjects", "syscolumns", "sysdatabases", "mysql.", "pg_", "sqlite_", ";--", "/*", "*/", "@@", "script>", " client.resetTime { client.count = 0 client.resetTime = now + 60 } if client.count >= requestsPerMinute { c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{ "error": "Rate limit exceeded", "message": "Too many requests. Please try again later.", }) return } client.count++ c.Next() } } // ValidateInputLength validates input length to prevent buffer overflow func ValidateInputLength(maxLength int) gin.HandlerFunc { return func(c *gin.Context) { // Check query parameters for key, values := range c.Request.URL.Query() { for _, value := range values { if len(value) > maxLength { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{ "error": "Input too long", "message": fmt.Sprintf("Parameter '%s' exceeds maximum length of %d characters", key, maxLength), }) return } } } // Check form data if err := c.Request.ParseForm(); err == nil { for key, values := range c.Request.PostForm { for _, value := range values { if len(value) > maxLength { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{ "error": "Input too long", "message": fmt.Sprintf("Parameter '%s' exceeds maximum length of %d characters", key, maxLength), }) return } } } } c.Next() } } // SecureCORSConfig provides secure CORS configuration func SecureCORSConfig() gin.HandlerFunc { return cors.New(cors.Config{ AllowOrigins: []string{"http://localhost:3000", "http://localhost:8080"}, // Configure allowed origins AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowHeaders: []string{"Origin", "Content-Type", "Authorization", "X-Requested-With"}, ExposeHeaders: []string{"Content-Length"}, AllowCredentials: true, MaxAge: 12 * time.Hour, }) }