Update Penggunaan Gorm

This commit is contained in:
2025-10-23 04:25:28 +07:00
parent 21f70f1d3f
commit a5523d11a3
9 changed files with 758 additions and 177 deletions

View File

@@ -821,7 +821,7 @@
);
const ipBased = document.getElementById("ipBasedCheck").checked;
let url = `ws://meninjar.dev.rssa.id:8030/api/v1/ws?user_id=${encodeURIComponent(
let url = `ws://localhost:8080/api/v1/ws?user_id=${encodeURIComponent(
userId
)}&room=${encodeURIComponent(room)}`;

5
go.mod
View File

@@ -18,6 +18,7 @@ require (
require (
github.com/daku10/go-lz-string v0.0.6
github.com/gin-contrib/cors v1.7.6
github.com/go-playground/validator/v10 v10.27.0
github.com/go-sql-driver/mysql v1.8.1
github.com/joho/godotenv v1.5.1
@@ -29,6 +30,8 @@ require (
github.com/swaggo/swag v1.16.6
github.com/tidwall/gjson v1.18.0
gopkg.in/yaml.v2 v2.4.0
gorm.io/driver/sqlite v1.6.0
gorm.io/gorm v1.30.0
)
require (
@@ -64,6 +67,7 @@ require (
github.com/mailru/easyjson v0.7.6 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/microsoft/go-mssqldb v1.8.2 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
@@ -86,5 +90,4 @@ require (
golang.org/x/tools v0.35.0 // indirect
google.golang.org/protobuf v1.36.7 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gorm.io/gorm v1.30.0 // indirect
)

6
go.sum
View File

@@ -41,6 +41,8 @@ github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY=
github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok=
github.com/gin-contrib/cors v1.7.6 h1:3gQ8GMzs1Ylpf70y8bMw4fVpycXIeX1ZemuSQIsnQQY=
github.com/gin-contrib/cors v1.7.6/go.mod h1:Ulcl+xN4jel9t1Ry8vqph23a60FwH9xVLd+3ykmTjOk=
github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4=
github.com/gin-contrib/gzip v0.0.6/go.mod h1:QOJlmV2xmayAjkNS2Y8NQsMneuRShOU/kjovCXNuzzk=
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
@@ -148,6 +150,8 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/microsoft/go-mssqldb v1.8.2 h1:236sewazvC8FvG6Dr3bszrVhMkAl4KYImryLkRMCd0I=
github.com/microsoft/go-mssqldb v1.8.2/go.mod h1:vp38dT33FGfVotRiTmDo3bFyaHq+p3LektQrjTULowo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -355,6 +359,8 @@ gorm.io/driver/mysql v1.6.0 h1:eNbLmNTpPpTOVZi8MMxCi2aaIm0ZpInbORNXDwyLGvg=
gorm.io/driver/mysql v1.6.0/go.mod h1:D/oCC2GWK3M/dqoLxnOlaNKmXz8WNTfcS9y5ovaSqKo=
gorm.io/driver/postgres v1.5.11 h1:ubBVAfbKEUld/twyKZ0IYn9rSQh448EdelLYk9Mv314=
gorm.io/driver/postgres v1.5.11/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSkweRGI=
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
gorm.io/driver/sqlserver v1.6.1 h1:XWISFsu2I2pqd1KJhhTZNJMx1jNQ+zVL/Q8ovDcUjtY=
gorm.io/driver/sqlserver v1.6.1/go.mod h1:VZeNn7hqX1aXoN5TPAFGWvxWG90xtA8erGn2gQmpc6U=
gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs=

View File

@@ -12,15 +12,17 @@ import (
"time"
"api-service/internal/config"
"api-service/internal/utils/validation"
_ "github.com/jackc/pgx/v5" // Import pgx driver
"github.com/lib/pq"
_ "gorm.io/driver/postgres" // Import GORM PostgreSQL driver
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/driver/sqlserver"
"gorm.io/gorm"
_ "github.com/go-sql-driver/mysql" // MySQL driver for database/sql
_ "gorm.io/driver/mysql" // GORM MySQL driver
_ "gorm.io/driver/sqlserver" // GORM SQL Server driver
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
@@ -40,6 +42,7 @@ const (
type Service interface {
Health() map[string]map[string]string
GetDB(name string) (*sql.DB, error)
GetGormDB(name string) (*gorm.DB, error) // New method for GORM
GetMongoClient(name string) (*mongo.Client, error)
GetReadDB(name string) (*sql.DB, error) // For read replicas
Close() error
@@ -49,18 +52,21 @@ type Service interface {
ListenForChanges(ctx context.Context, dbName string, channels []string, callback func(string, string)) error
NotifyChange(dbName, channel, payload string) error
GetPrimaryDB(name string) (*sql.DB, error) // Helper untuk get primary DB
GetSanitizer() *validation.InputSanitizer // Get sanitizer instance
}
type service struct {
sqlDatabases map[string]*sql.DB
mongoClients map[string]*mongo.Client
readReplicas map[string][]*sql.DB // Read replicas for load balancing
configs map[string]config.DatabaseConfig
readConfigs map[string][]config.DatabaseConfig
mu sync.RWMutex
readBalancer map[string]int // Round-robin counter for read replicas
listeners map[string]*pq.Listener // Tambahkan untuk tracking listeners
listenersMu sync.RWMutex
sqlDatabases map[string]*sql.DB
gormDatabases map[string]*gorm.DB // New field for GORM connections
mongoClients map[string]*mongo.Client
readReplicas map[string][]*sql.DB // Read replicas for load balancing
configs map[string]config.DatabaseConfig
readConfigs map[string][]config.DatabaseConfig
mu sync.RWMutex
readBalancer map[string]int // Round-robin counter for read replicas
listeners map[string]*pq.Listener // Tambahkan untuk tracking listeners
listenersMu sync.RWMutex
sanitizer *validation.InputSanitizer // Input sanitizer for security
}
var (
@@ -72,13 +78,15 @@ var (
func New(cfg *config.Config) Service {
once.Do(func() {
dbManager = &service{
sqlDatabases: make(map[string]*sql.DB),
mongoClients: make(map[string]*mongo.Client),
readReplicas: make(map[string][]*sql.DB),
configs: make(map[string]config.DatabaseConfig),
readConfigs: make(map[string][]config.DatabaseConfig),
readBalancer: make(map[string]int),
listeners: make(map[string]*pq.Listener),
sqlDatabases: make(map[string]*sql.DB),
gormDatabases: make(map[string]*gorm.DB),
mongoClients: make(map[string]*mongo.Client),
readReplicas: make(map[string][]*sql.DB),
configs: make(map[string]config.DatabaseConfig),
readConfigs: make(map[string][]config.DatabaseConfig),
readBalancer: make(map[string]int),
listeners: make(map[string]*pq.Listener),
sanitizer: validation.NewInputSanitizer(1000), // Initialize sanitizer with max length 1000
}
log.Println("Initializing database service...") // Log when the initialization starts
@@ -161,7 +169,13 @@ func (s *service) addDatabase(name string, config config.DatabaseConfig) error {
}
log.Printf("✅ Successfully connected to database: %s", name)
return s.configureSQLDB(name, db, config.MaxOpenConns, config.MaxIdleConns, config.ConnMaxLifetime)
err = s.configureSQLDB(name, db, config.MaxOpenConns, config.MaxIdleConns, config.ConnMaxLifetime)
if err != nil {
return err
}
// Initialize GORM for SQL databases
return s.configureGormDB(name, db, config)
}
func (s *service) addReadReplica(name string, index int, config config.DatabaseConfig) error {
@@ -317,6 +331,41 @@ func (s *service) configureSQLDB(name string, db *sql.DB, maxOpenConns, maxIdleC
return nil
}
func (s *service) configureGormDB(name string, db *sql.DB, config config.DatabaseConfig) error {
var gormDB *gorm.DB
var err error
dbType := DatabaseType(config.Type)
switch dbType {
case Postgres:
gormDB, err = gorm.Open(postgres.New(postgres.Config{
Conn: db,
}), &gorm.Config{})
case MySQL:
gormDB, err = gorm.Open(mysql.New(mysql.Config{
Conn: db,
}), &gorm.Config{})
case SQLServer:
gormDB, err = gorm.Open(sqlserver.New(sqlserver.Config{
Conn: db,
}), &gorm.Config{})
case SQLite:
gormDB, err = gorm.Open(sqlite.Open(config.Path), &gorm.Config{})
default:
return fmt.Errorf("unsupported database type for GORM: %s", config.Type)
}
if err != nil {
return fmt.Errorf("failed to initialize GORM for %s: %w", name, err)
}
s.gormDatabases[name] = gormDB
log.Printf("Successfully initialized GORM for database: %s", name)
return nil
}
// Health checks the health of all database connections by pinging each database.
func (s *service) Health() map[string]map[string]string {
s.mu.RLock()
@@ -486,6 +535,19 @@ func (s *service) GetReadDB(name string) (*sql.DB, error) {
return selected, nil
}
// GetGormDB returns a specific GORM database connection by name
func (s *service) GetGormDB(name string) (*gorm.DB, error) {
s.mu.RLock()
defer s.mu.RUnlock()
gormDB, exists := s.gormDatabases[name]
if !exists {
return nil, fmt.Errorf("GORM database %s not found", name)
}
return gormDB, nil
}
// GetMongoClient returns a specific MongoDB client by name
func (s *service) GetMongoClient(name string) (*mongo.Client, error) {
s.mu.RLock()
@@ -583,6 +645,11 @@ func (s *service) GetPrimaryDB(name string) (*sql.DB, error) {
return s.GetDB(name)
}
// GetSanitizer returns the input sanitizer instance
func (s *service) GetSanitizer() *validation.InputSanitizer {
return s.sanitizer
}
// ListenForChanges implements PostgreSQL LISTEN/NOTIFY for real-time updates
func (s *service) ListenForChanges(ctx context.Context, dbName string, channels []string, callback func(string, string)) error {
s.mu.RLock()

View File

@@ -20,6 +20,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
"github.com/google/uuid"
"gorm.io/gorm"
)
var (
@@ -90,7 +91,7 @@ func (h *RetribusiHandler) GetRetribusi(c *gin.Context) {
includeAggregation := c.Query("include_summary") == "true"
// Get database connection
dbConn, err := h.db.GetDB("postgres_satudata")
gormDB, err := h.db.GetGormDB("postgres_satudata")
if err != nil {
h.logAndRespondError(c, "Database connection failed", err, http.StatusInternalServerError)
return
@@ -103,18 +104,25 @@ func (h *RetribusiHandler) GetRetribusi(c *gin.Context) {
// Execute concurrent operations
var (
retribusis []retribusi.Retribusi
total int
total int64
aggregateData *models.AggregateData
wg sync.WaitGroup
errChan = make(chan error, 3)
mu sync.Mutex
)
// Get underlying SQL DB from GORM
sqlDB, err := gormDB.DB()
if err != nil {
h.logAndRespondError(c, "Failed to get SQL DB from GORM", err, http.StatusInternalServerError)
return
}
// Fetch total count
wg.Add(1)
go func() {
defer wg.Done()
if err := h.getTotalCount(ctx, dbConn, filter, &total); err != nil {
if err := h.getTotalCount(ctx, gormDB, filter, &total); err != nil {
mu.Lock()
errChan <- fmt.Errorf("failed to get total count: %w", err)
mu.Unlock()
@@ -125,7 +133,7 @@ func (h *RetribusiHandler) GetRetribusi(c *gin.Context) {
wg.Add(1)
go func() {
defer wg.Done()
result, err := h.fetchRetribusis(ctx, dbConn, filter, limit, offset)
result, err := h.fetchRetribusis(ctx, sqlDB, filter, limit, offset)
mu.Lock()
if err != nil {
errChan <- fmt.Errorf("failed to fetch data: %w", err)
@@ -140,7 +148,7 @@ func (h *RetribusiHandler) GetRetribusi(c *gin.Context) {
wg.Add(1)
go func() {
defer wg.Done()
result, err := h.getAggregateData(ctx, dbConn, filter)
result, err := h.getAggregateData(ctx, sqlDB, filter)
mu.Lock()
if err != nil {
errChan <- fmt.Errorf("failed to get aggregate data: %w", err)
@@ -164,7 +172,7 @@ func (h *RetribusiHandler) GetRetribusi(c *gin.Context) {
}
// Build response
meta := h.calculateMeta(limit, offset, total)
meta := h.calculateMeta(limit, offset, int(total))
response := retribusi.RetribusiGetResponse{
Message: "Data retribusi berhasil diambil",
Data: retribusis,
@@ -199,7 +207,7 @@ func (h *RetribusiHandler) GetRetribusiByID(c *gin.Context) {
return
}
dbConn, err := h.db.GetDB("postgres_satudata")
gormDB, err := h.db.GetGormDB("postgres_satudata")
if err != nil {
h.logAndRespondError(c, "Database connection failed", err, http.StatusInternalServerError)
return
@@ -208,7 +216,7 @@ func (h *RetribusiHandler) GetRetribusiByID(c *gin.Context) {
ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
defer cancel()
dataretribusi, err := h.getRetribusiByID(ctx, dbConn, id)
dataretribusi, err := h.getRetribusiByID(ctx, gormDB, id)
if err != nil {
if err == sql.ErrNoRows {
h.respondError(c, "Retribusi not found", err, http.StatusNotFound)
@@ -739,33 +747,13 @@ func (h *RetribusiHandler) GetRetribusiStats(c *gin.Context) {
})
}
// Get retribusi by ID
func (h *RetribusiHandler) getRetribusiByID(ctx context.Context, dbConn *sql.DB, id string) (*retribusi.Retribusi, error) {
query := `
SELECT
id, status, sort, user_created, date_created, user_updated, date_updated,
"Jenis", "Pelayanan", "Dinas", "Kelompok_obyek", "Kode_tarif",
"Tarif", "Satuan", "Tarif_overtime", "Satuan_overtime",
"Rekening_pokok", "Rekening_denda", "Uraian_1", "Uraian_2", "Uraian_3"
FROM data_retribusi
WHERE id = $1 AND status != 'deleted'`
row := dbConn.QueryRowContext(ctx, query, id)
// Get retribusi by ID using GORM
func (h *RetribusiHandler) getRetribusiByID(ctx context.Context, gormDB *gorm.DB, id string) (*retribusi.Retribusi, error) {
var retribusi retribusi.Retribusi
err := row.Scan(
&retribusi.ID, &retribusi.Status, &retribusi.Sort, &retribusi.UserCreated,
&retribusi.DateCreated, &retribusi.UserUpdated, &retribusi.DateUpdated,
&retribusi.Jenis, &retribusi.Pelayanan, &retribusi.Dinas, &retribusi.KelompokObyek,
&retribusi.KodeTarif, &retribusi.Tarif, &retribusi.Satuan, &retribusi.TarifOvertime,
&retribusi.SatuanOvertime, &retribusi.RekeningPokok, &retribusi.RekeningDenda,
&retribusi.Uraian1, &retribusi.Uraian2, &retribusi.Uraian3,
)
err := gormDB.WithContext(ctx).Where("id = ? AND status != ?", id, "deleted").First(&retribusi).Error
if err != nil {
return nil, err
}
return &retribusi, nil
}
@@ -1258,11 +1246,39 @@ func (h *RetribusiHandler) getAggregateData(ctx context.Context, dbConn *sql.DB,
}
// Get total count dengan filter support
func (h *RetribusiHandler) getTotalCount(ctx context.Context, dbConn *sql.DB, filter retribusi.RetribusiFilter, total *int) error {
whereClause, args := h.buildWhereClause(filter)
countQuery := fmt.Sprintf(`SELECT COUNT(*) FROM data_retribusi WHERE %s`, whereClause)
func (h *RetribusiHandler) getTotalCount(ctx context.Context, db *gorm.DB, filter retribusi.RetribusiFilter, total *int64) error {
query := db.Model(&retribusi.Retribusi{}).Where("status != ?", "deleted")
if err := dbConn.QueryRowContext(ctx, countQuery, args...).Scan(total); err != nil {
if filter.Status != nil {
query = query.Where("status = ?", *filter.Status)
}
if filter.Jenis != nil {
query = query.Where("\"Jenis\" ILIKE ?", "%"+*filter.Jenis+"%")
}
if filter.Dinas != nil {
query = query.Where("\"Dinas\" ILIKE ?", "%"+*filter.Dinas+"%")
}
if filter.KelompokObyek != nil {
query = query.Where("\"Kelompok_obyek\" ILIKE ?", "%"+*filter.KelompokObyek+"%")
}
if filter.Search != nil {
searchTerm := "%" + *filter.Search + "%"
query = query.Where("\"Jenis\" ILIKE ? OR \"Pelayanan\" ILIKE ? OR \"Dinas\" ILIKE ? OR \"Kode_tarif\" ILIKE ? OR \"Uraian_1\" ILIKE ? OR \"Uraian_2\" ILIKE ? OR \"Uraian_3\" ILIKE ?", searchTerm, searchTerm, searchTerm, searchTerm, searchTerm, searchTerm, searchTerm)
}
if filter.DateFrom != nil {
query = query.Where("date_created >= ?", *filter.DateFrom)
}
if filter.DateTo != nil {
query = query.Where("date_created <= ?", filter.DateTo.Add(24*time.Hour-time.Nanosecond))
}
if err := query.Count(total).Error; err != nil {
return fmt.Errorf("total count query failed: %w", err)
}

View File

@@ -0,0 +1,272 @@
package middleware
import (
"fmt"
"html"
"net/http"
"strings"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
)
// SecurityHeaders adds security headers to all responses
func SecurityHeaders() gin.HandlerFunc {
return gin.HandlerFunc(func(c *gin.Context) {
// Prevent clickjacking
c.Header("X-Frame-Options", "DENY")
// Prevent MIME type sniffing
c.Header("X-Content-Type-Options", "nosniff")
// Enable XSS protection
c.Header("X-XSS-Protection", "1; mode=block")
// Referrer policy
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
// Content Security Policy - adjust as needed
c.Header("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
// HSTS (HTTP Strict Transport Security) - only for HTTPS
if c.Request.TLS != nil {
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
}
c.Next()
})
}
// InputSanitization sanitizes user inputs to prevent XSS and injection attacks
func InputSanitization() gin.HandlerFunc {
return gin.HandlerFunc(func(c *gin.Context) {
// Sanitize query parameters
sanitizeQueryParams(c)
// Sanitize form data
sanitizeFormData(c)
// Sanitize JSON body if present
if c.ContentType() == "application/json" {
sanitizeJSONBody(c)
}
c.Next()
})
}
// sanitizeQueryParams sanitizes all query parameters
func sanitizeQueryParams(c *gin.Context) {
query := c.Request.URL.Query()
for key, values := range query {
for i, value := range values {
query[key][i] = sanitizeString(value)
}
}
c.Request.URL.RawQuery = query.Encode()
}
// sanitizeFormData sanitizes form data
func sanitizeFormData(c *gin.Context) {
if err := c.Request.ParseForm(); err != nil {
return
}
for key, values := range c.Request.PostForm {
for i, value := range values {
c.Request.PostForm[key][i] = sanitizeString(value)
}
}
}
// sanitizeJSONBody sanitizes JSON request body
func sanitizeJSONBody(c *gin.Context) {
// For JSON bodies, we'll let the JSON binding handle it
// and sanitize at the handler level if needed
c.Next()
}
// sanitizeString performs basic sanitization on a string
func sanitizeString(input string) string {
// Remove null bytes
input = strings.ReplaceAll(input, "\x00", "")
// HTML escape
input = html.EscapeString(input)
// Remove potentially dangerous characters
dangerousChars := []string{"<", ">", "\"", "'", "`", "\\", "\n", "\r", "\t"}
for _, char := range dangerousChars {
input = strings.ReplaceAll(input, char, "")
}
return strings.TrimSpace(input)
}
// SQLInjectionProtection provides additional SQL injection protection
func SQLInjectionProtection() gin.HandlerFunc {
return gin.HandlerFunc(func(c *gin.Context) {
// Check for suspicious patterns in query parameters
if hasSQLInjectionPatterns(c) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
"error": "Invalid input detected",
"message": "Request contains potentially malicious content",
})
return
}
c.Next()
})
}
// hasSQLInjectionPatterns checks for common SQL injection patterns
func hasSQLInjectionPatterns(c *gin.Context) bool {
suspiciousPatterns := []string{
"union select",
"union all select",
"select.*from",
"insert.*into",
"update.*set",
"delete.*from",
"drop table",
"drop database",
"alter table",
"create table",
"exec(",
"execute(",
"xp_",
"sp_",
"information_schema",
"sysobjects",
"syscolumns",
"sysdatabases",
"mysql.",
"pg_",
"sqlite_",
";--",
"/*",
"*/",
"@@",
"script>",
"<script",
"javascript:",
"vbscript:",
"onload=",
"onerror=",
"eval(",
"alert(",
}
query := strings.ToLower(c.Request.URL.RawQuery)
for _, pattern := range suspiciousPatterns {
if strings.Contains(query, pattern) {
return true
}
}
// Check form data
if err := c.Request.ParseForm(); err == nil {
for _, values := range c.Request.Form {
for _, value := range values {
lowerValue := strings.ToLower(value)
for _, pattern := range suspiciousPatterns {
if strings.Contains(lowerValue, pattern) {
return true
}
}
}
}
}
return false
}
// RateLimitByIP provides basic rate limiting by IP
func RateLimitByIP(requestsPerMinute int) gin.HandlerFunc {
// Simple in-memory rate limiter
// In production, use Redis or similar
type client struct {
count int
resetTime int64
}
clients := make(map[string]*client)
return func(c *gin.Context) {
ip := c.ClientIP()
now := time.Now().Unix()
if clients[ip] == nil {
clients[ip] = &client{count: 0, resetTime: now + 60}
}
client := clients[ip]
// Reset counter if time window passed
if now > client.resetTime {
client.count = 0
client.resetTime = now + 60
}
if client.count >= requestsPerMinute {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "Rate limit exceeded",
"message": "Too many requests. Please try again later.",
})
return
}
client.count++
c.Next()
}
}
// ValidateInputLength validates input length to prevent buffer overflow
func ValidateInputLength(maxLength int) gin.HandlerFunc {
return func(c *gin.Context) {
// Check query parameters
for key, values := range c.Request.URL.Query() {
for _, value := range values {
if len(value) > maxLength {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
"error": "Input too long",
"message": fmt.Sprintf("Parameter '%s' exceeds maximum length of %d characters", key, maxLength),
})
return
}
}
}
// Check form data
if err := c.Request.ParseForm(); err == nil {
for key, values := range c.Request.PostForm {
for _, value := range values {
if len(value) > maxLength {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
"error": "Input too long",
"message": fmt.Sprintf("Parameter '%s' exceeds maximum length of %d characters", key, maxLength),
})
return
}
}
}
}
c.Next()
}
}
// SecureCORSConfig provides secure CORS configuration
func SecureCORSConfig() gin.HandlerFunc {
return cors.New(cors.Config{
AllowOrigins: []string{"http://localhost:3000", "http://localhost:8080"}, // Configure allowed origins
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowHeaders: []string{"Origin", "Content-Type", "Authorization", "X-Requested-With"},
ExposeHeaders: []string{"Content-Length"},
AllowCredentials: true,
MaxAge: 12 * time.Hour,
})
}

View File

@@ -2,34 +2,38 @@ package retribusi
import (
"api-service/internal/models"
"database/sql"
"encoding/json"
"time"
"gorm.io/gorm"
)
// Retribusi represents the data structure for the retribusi table
// with proper null handling and optimized JSON marshaling
type Retribusi struct {
ID string `json:"id" db:"id"`
Status string `json:"status" db:"status"`
Sort models.NullableInt32 `json:"sort,omitempty" db:"sort"`
UserCreated models.NullableString `json:"user_created,omitempty" db:"user_created"`
DateCreated models.NullableTime `json:"date_created,omitempty" db:"date_created"`
UserUpdated models.NullableString `json:"user_updated,omitempty" db:"user_updated"`
DateUpdated models.NullableTime `json:"date_updated,omitempty" db:"date_updated"`
Jenis models.NullableString `json:"jenis,omitempty" db:"Jenis"`
Pelayanan models.NullableString `json:"pelayanan,omitempty" db:"Pelayanan"`
Dinas models.NullableString `json:"dinas,omitempty" db:"Dinas"`
KelompokObyek models.NullableString `json:"kelompok_obyek,omitempty" db:"Kelompok_obyek"`
KodeTarif models.NullableString `json:"kode_tarif,omitempty" db:"Kode_tarif"`
Tarif models.NullableString `json:"tarif,omitempty" db:"Tarif"`
Satuan models.NullableString `json:"satuan,omitempty" db:"Satuan"`
TarifOvertime models.NullableString `json:"tarif_overtime,omitempty" db:"Tarif_overtime"`
SatuanOvertime models.NullableString `json:"satuan_overtime,omitempty" db:"Satuan_overtime"`
RekeningPokok models.NullableString `json:"rekening_pokok,omitempty" db:"Rekening_pokok"`
RekeningDenda models.NullableString `json:"rekening_denda,omitempty" db:"Rekening_denda"`
Uraian1 models.NullableString `json:"uraian_1,omitempty" db:"Uraian_1"`
Uraian2 models.NullableString `json:"uraian_2,omitempty" db:"Uraian_2"`
Uraian3 models.NullableString `json:"uraian_3,omitempty" db:"Uraian_3"`
gorm.Model
ID string `json:"id" gorm:"column:id;primaryKey;type:varchar(255)"`
Status string `json:"status" gorm:"column:status;type:varchar(50);not null"`
Sort sql.NullInt32 `json:"sort,omitempty" gorm:"column:sort"`
UserCreated sql.NullString `json:"user_created,omitempty" gorm:"column:user_created"`
DateCreated sql.NullTime `json:"date_created,omitempty" gorm:"column:date_created"`
UserUpdated sql.NullString `json:"user_updated,omitempty" gorm:"column:user_updated"`
DateUpdated sql.NullTime `json:"date_updated,omitempty" gorm:"column:date_updated"`
Jenis sql.NullString `json:"jenis,omitempty" gorm:"column:Jenis"`
Pelayanan sql.NullString `json:"pelayanan,omitempty" gorm:"column:Pelayanan"`
Dinas sql.NullString `json:"dinas,omitempty" gorm:"column:Dinas"`
KelompokObyek sql.NullString `json:"kelompok_obyek,omitempty" gorm:"column:Kelompok_obyek"`
KodeTarif sql.NullString `json:"kode_tarif,omitempty" gorm:"column:Kode_tarif"`
Tarif sql.NullString `json:"tarif,omitempty" gorm:"column:Tarif"`
Satuan sql.NullString `json:"satuan,omitempty" gorm:"column:Satuan"`
TarifOvertime sql.NullString `json:"tarif_overtime,omitempty" gorm:"column:Tarif_overtime"`
SatuanOvertime sql.NullString `json:"satuan_overtime,omitempty" gorm:"column:Satuan_overtime"`
RekeningPokok sql.NullString `json:"rekening_pokok,omitempty" gorm:"column:Rekening_pokok"`
RekeningDenda sql.NullString `json:"rekening_denda,omitempty" gorm:"column:Rekening_denda"`
Uraian1 sql.NullString `json:"uraian_1,omitempty" gorm:"column:Uraian_1"`
Uraian2 sql.NullString `json:"uraian_2,omitempty" gorm:"column:Uraian_2"`
Uraian3 sql.NullString `json:"uraian_3,omitempty" gorm:"column:Uraian_3"`
}
// Custom JSON marshaling untuk Retribusi agar NULL values tidak muncul di response

View File

@@ -4,7 +4,6 @@ import (
"fmt"
"reflect"
"strings"
"sync"
)
// FilterOperator represents supported filter operators
@@ -67,8 +66,7 @@ type QueryBuilder struct {
tableName string
columnMapping map[string]string // Maps API field names to DB column names
allowedColumns map[string]bool // Security: only allow specified columns
paramCounter int
mu *sync.RWMutex
// PERUBAHAN 1: Hapus paramCounter dan mu untuk membuat QueryBuilder stateless dan thread-safe.
}
// NewQueryBuilder creates a new query builder instance
@@ -77,7 +75,6 @@ func NewQueryBuilder(tableName string) *QueryBuilder {
tableName: tableName,
columnMapping: make(map[string]string),
allowedColumns: make(map[string]bool),
paramCounter: 0,
}
}
@@ -88,6 +85,8 @@ func (qb *QueryBuilder) SetColumnMapping(mapping map[string]string) *QueryBuilde
}
// SetAllowedColumns sets the list of allowed columns for security
// PERUBAHAN 3: Nama kolom di sini seharusnya adalah nama kolom ASLI di database
// untuk pemeriksaan keamanan yang lebih konsisten.
func (qb *QueryBuilder) SetAllowedColumns(columns []string) *QueryBuilder {
qb.allowedColumns = make(map[string]bool)
for _, col := range columns {
@@ -98,7 +97,10 @@ func (qb *QueryBuilder) SetAllowedColumns(columns []string) *QueryBuilder {
// BuildQuery builds the complete SQL query
func (qb *QueryBuilder) BuildQuery(query DynamicQuery) (string, []interface{}, error) {
qb.paramCounter = 0
// PERUBAHAN 1: paramCounter sekarang lokal untuk fungsi ini.
// Ini membuat QueryBuilder aman untuk digunakan secara konkuren (thread-safe).
paramCounter := 0
args := []interface{}{}
// Build SELECT clause
selectClause := qb.buildSelectClause(query.Fields)
@@ -107,30 +109,30 @@ func (qb *QueryBuilder) BuildQuery(query DynamicQuery) (string, []interface{}, e
fromClause := fmt.Sprintf("FROM %s", qb.tableName)
// Build WHERE clause
whereClause, whereArgs, err := qb.buildWhereClause(query.Filters)
whereClause, whereArgs, err := qb.buildWhereClause(query.Filters, &paramCounter)
if err != nil {
return "", nil, err
}
// Build ORDER BY clause
orderClause := qb.buildOrderClause(query.Sort)
args = append(args, whereArgs...)
// Build GROUP BY clause
groupClause := qb.buildGroupByClause(query.GroupBy)
// Build HAVING clause
havingClause, havingArgs, err := qb.buildHavingClause(query.Having)
havingClause, havingArgs, err := qb.buildHavingClause(query.Having, &paramCounter)
if err != nil {
return "", nil, err
}
args = append(args, havingArgs...)
// Build ORDER BY clause
orderClause := qb.buildOrderClause(query.Sort)
// Combine all parts
sqlParts := []string{selectClause, fromClause}
args := []interface{}{}
if whereClause != "" {
sqlParts = append(sqlParts, "WHERE "+whereClause)
args = append(args, whereArgs...)
}
if groupClause != "" {
@@ -139,7 +141,6 @@ func (qb *QueryBuilder) BuildQuery(query DynamicQuery) (string, []interface{}, e
if havingClause != "" {
sqlParts = append(sqlParts, "HAVING "+havingClause)
args = append(args, havingArgs...)
}
if orderClause != "" {
@@ -148,14 +149,14 @@ func (qb *QueryBuilder) BuildQuery(query DynamicQuery) (string, []interface{}, e
// Add pagination
if query.Limit > 0 {
qb.paramCounter++
sqlParts = append(sqlParts, fmt.Sprintf("LIMIT $%d", qb.paramCounter))
paramCounter++
sqlParts = append(sqlParts, fmt.Sprintf("LIMIT $%d", paramCounter))
args = append(args, query.Limit)
}
if query.Offset > 0 {
qb.paramCounter++
sqlParts = append(sqlParts, fmt.Sprintf("OFFSET $%d", qb.paramCounter))
paramCounter++
sqlParts = append(sqlParts, fmt.Sprintf("OFFSET $%d", paramCounter))
args = append(args, query.Offset)
}
@@ -176,24 +177,23 @@ func (qb *QueryBuilder) buildSelectClause(fields []string) string {
continue
}
// Check if it's an expression (contains spaces, parentheses, etc.)
if strings.Contains(field, " ") || strings.Contains(field, "(") || strings.Contains(field, ")") {
// Expression, add as is
selectedFields = append(selectedFields, field)
continue
}
// Security check: only allow specified columns (check original field name)
if len(qb.allowedColumns) > 0 && !qb.allowedColumns[field] {
continue
// PERUBAHAN 3: Lakukan mapping terlebih dahulu, lalu pemeriksaan keamanan.
mappedCol := field
if mapped, exists := qb.columnMapping[field]; exists {
mappedCol = mapped
}
// Map field name if mapping exists
if mappedCol, exists := qb.columnMapping[field]; exists {
field = mappedCol
// Security check: hanya izinkan kolom yang sudah ditentukan (cek nama kolom DB)
if len(qb.allowedColumns) > 0 && !qb.allowedColumns[mappedCol] {
continue // Lewati kolom yang tidak diizinkan
}
selectedFields = append(selectedFields, fmt.Sprintf(`"%s"`, field))
selectedFields = append(selectedFields, fmt.Sprintf(`"%s"`, mappedCol))
}
if len(selectedFields) == 0 {
@@ -204,16 +204,17 @@ func (qb *QueryBuilder) buildSelectClause(fields []string) string {
}
// buildWhereClause builds the WHERE part of the query
func (qb *QueryBuilder) buildWhereClause(filterGroups []FilterGroup) (string, []interface{}, error) {
func (qb *QueryBuilder) buildWhereClause(filterGroups []FilterGroup, paramCounter *int) (string, []interface{}, error) {
if len(filterGroups) == 0 {
return "", nil, nil
}
var conditions []string
var args []interface{}
var groupConditions []string
var allArgs []interface{}
for i, group := range filterGroups {
groupCondition, groupArgs, err := qb.buildFilterGroup(group)
// PERUBAHAN 2: Tambahkan tanda kurung untuk setiap grup untuk memastikan urutan operasi yang benar.
groupCondition, groupArgs, err := qb.buildFilterGroup(group, paramCounter)
if err != nil {
return "", nil, err
}
@@ -224,19 +225,18 @@ func (qb *QueryBuilder) buildWhereClause(filterGroups []FilterGroup) (string, []
if group.LogicOp != "" {
logicOp = strings.ToUpper(group.LogicOp)
}
conditions = append(conditions, logicOp)
groupConditions = append(groupConditions, logicOp)
}
conditions = append(conditions, groupCondition)
args = append(args, groupArgs...)
groupConditions = append(groupConditions, fmt.Sprintf("(%s)", groupCondition))
allArgs = append(allArgs, groupArgs...)
}
}
return strings.Join(conditions, " "), args, nil
return strings.Join(groupConditions, " "), allArgs, nil
}
// buildFilterGroup builds conditions for a filter group
func (qb *QueryBuilder) buildFilterGroup(group FilterGroup) (string, []interface{}, error) {
func (qb *QueryBuilder) buildFilterGroup(group FilterGroup, paramCounter *int) (string, []interface{}, error) {
if len(group.Filters) == 0 {
return "", nil, nil
}
@@ -245,7 +245,7 @@ func (qb *QueryBuilder) buildFilterGroup(group FilterGroup) (string, []interface
var args []interface{}
for i, filter := range group.Filters {
condition, filterArgs, err := qb.buildFilterCondition(filter)
condition, filterArgs, err := qb.buildFilterCondition(filter, paramCounter)
if err != nil {
return "", nil, err
}
@@ -270,49 +270,56 @@ func (qb *QueryBuilder) buildFilterGroup(group FilterGroup) (string, []interface
}
// buildFilterCondition builds a single filter condition
func (qb *QueryBuilder) buildFilterCondition(filter DynamicFilter) (string, []interface{}, error) {
// Security check (check original field name)
if len(qb.allowedColumns) > 0 && !qb.allowedColumns[filter.Column] {
return "", nil, nil
}
// Map column name if mapping exists
func (qb *QueryBuilder) buildFilterCondition(filter DynamicFilter, paramCounter *int) (string, []interface{}, error) {
// PERUBAHAN 3: Lakukan mapping terlebih dahulu, lalu pemeriksaan keamanan.
column := filter.Column
if mappedCol, exists := qb.columnMapping[column]; exists {
column = mappedCol
}
// Security check (cek nama kolom DB hasil mapping)
if len(qb.allowedColumns) > 0 && !qb.allowedColumns[column] {
return "", nil, nil
}
// Additional security: Validate column name format
if !qb.isValidColumnName(column) {
return "", nil, fmt.Errorf("invalid column name: %s", column)
}
// Wrap column name in quotes for PostgreSQL
column = fmt.Sprintf(`"%s"`, column)
switch filter.Operator {
case OpEqual:
// PERUBAHAN 4: Tangani nilai nil secara eksplisit untuk operator kesetaraan.
if filter.Value == nil {
return "", nil, nil
return fmt.Sprintf("%s IS NULL", column), nil, nil
}
qb.paramCounter++
return fmt.Sprintf("%s = $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
*paramCounter++
return fmt.Sprintf("%s = $%d", column, *paramCounter), []interface{}{filter.Value}, nil
case OpNotEqual:
// PERUBAHAN 4: Tangani nilai nil secara eksplisit untuk operator ketidaksamaan.
if filter.Value == nil {
return "", nil, nil
return fmt.Sprintf("%s IS NOT NULL", column), nil, nil
}
qb.paramCounter++
return fmt.Sprintf("%s != $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
*paramCounter++
return fmt.Sprintf("%s != $%d", column, *paramCounter), []interface{}{filter.Value}, nil
case OpLike:
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
return fmt.Sprintf("%s LIKE $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
*paramCounter++
return fmt.Sprintf("%s LIKE $%d", column, *paramCounter), []interface{}{filter.Value}, nil
case OpILike:
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
return fmt.Sprintf("%s ILIKE $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
*paramCounter++
return fmt.Sprintf("%s ILIKE $%d", column, *paramCounter), []interface{}{filter.Value}, nil
case OpIn:
values := qb.parseArrayValue(filter.Value)
@@ -323,8 +330,8 @@ func (qb *QueryBuilder) buildFilterCondition(filter DynamicFilter) (string, []in
var placeholders []string
var args []interface{}
for _, val := range values {
qb.paramCounter++
placeholders = append(placeholders, fmt.Sprintf("$%d", qb.paramCounter))
*paramCounter++
placeholders = append(placeholders, fmt.Sprintf("$%d", *paramCounter))
args = append(args, val)
}
@@ -339,8 +346,8 @@ func (qb *QueryBuilder) buildFilterCondition(filter DynamicFilter) (string, []in
var placeholders []string
var args []interface{}
for _, val := range values {
qb.paramCounter++
placeholders = append(placeholders, fmt.Sprintf("$%d", qb.paramCounter))
*paramCounter++
placeholders = append(placeholders, fmt.Sprintf("$%d", *paramCounter))
args = append(args, val)
}
@@ -350,29 +357,29 @@ func (qb *QueryBuilder) buildFilterCondition(filter DynamicFilter) (string, []in
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
return fmt.Sprintf("%s > $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
*paramCounter++
return fmt.Sprintf("%s > $%d", column, *paramCounter), []interface{}{filter.Value}, nil
case OpGreaterThanEqual:
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
return fmt.Sprintf("%s >= $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
*paramCounter++
return fmt.Sprintf("%s >= $%d", column, *paramCounter), []interface{}{filter.Value}, nil
case OpLessThan:
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
return fmt.Sprintf("%s < $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
*paramCounter++
return fmt.Sprintf("%s < $%d", column, *paramCounter), []interface{}{filter.Value}, nil
case OpLessThanEqual:
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
return fmt.Sprintf("%s <= $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
*paramCounter++
return fmt.Sprintf("%s <= $%d", column, *paramCounter), []interface{}{filter.Value}, nil
case OpBetween:
if filter.Value == nil {
@@ -382,10 +389,10 @@ func (qb *QueryBuilder) buildFilterCondition(filter DynamicFilter) (string, []in
if len(values) != 2 {
return "", nil, fmt.Errorf("between operator requires exactly 2 values")
}
qb.paramCounter++
param1 := qb.paramCounter
qb.paramCounter++
param2 := qb.paramCounter
*paramCounter++
param1 := *paramCounter
*paramCounter++
param2 := *paramCounter
return fmt.Sprintf("%s BETWEEN $%d AND $%d", column, param1, param2), []interface{}{values[0], values[1]}, nil
case OpNotBetween:
@@ -396,10 +403,10 @@ func (qb *QueryBuilder) buildFilterCondition(filter DynamicFilter) (string, []in
if len(values) != 2 {
return "", nil, fmt.Errorf("not between operator requires exactly 2 values")
}
qb.paramCounter++
param1 := qb.paramCounter
qb.paramCounter++
param2 := qb.paramCounter
*paramCounter++
param1 := *paramCounter
*paramCounter++
param2 := *paramCounter
return fmt.Sprintf("%s NOT BETWEEN $%d AND $%d", column, param1, param2), []interface{}{values[0], values[1]}, nil
case OpNull:
@@ -412,33 +419,33 @@ func (qb *QueryBuilder) buildFilterCondition(filter DynamicFilter) (string, []in
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
*paramCounter++
value := fmt.Sprintf("%%%v%%", filter.Value)
return fmt.Sprintf("%s ILIKE $%d", column, qb.paramCounter), []interface{}{value}, nil
return fmt.Sprintf("%s ILIKE $%d", column, *paramCounter), []interface{}{value}, nil
case OpNotContains:
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
*paramCounter++
value := fmt.Sprintf("%%%v%%", filter.Value)
return fmt.Sprintf("%s NOT ILIKE $%d", column, qb.paramCounter), []interface{}{value}, nil
return fmt.Sprintf("%s NOT ILIKE $%d", column, *paramCounter), []interface{}{value}, nil
case OpStartsWith:
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
*paramCounter++
value := fmt.Sprintf("%v%%", filter.Value)
return fmt.Sprintf("%s ILIKE $%d", column, qb.paramCounter), []interface{}{value}, nil
return fmt.Sprintf("%s ILIKE $%d", column, *paramCounter), []interface{}{value}, nil
case OpEndsWith:
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
*paramCounter++
value := fmt.Sprintf("%%%v", filter.Value)
return fmt.Sprintf("%s ILIKE $%d", column, qb.paramCounter), []interface{}{value}, nil
return fmt.Sprintf("%s ILIKE $%d", column, *paramCounter), []interface{}{value}, nil
default:
return "", nil, fmt.Errorf("unsupported operator: %s", filter.Operator)
@@ -485,17 +492,17 @@ func (qb *QueryBuilder) buildOrderClause(sortFields []SortField) string {
var orderParts []string
for _, sort := range sortFields {
// PERUBAHAN 3: Lakukan mapping dan pemeriksaan keamanan.
column := sort.Column
// Security check (check original field name)
if len(qb.allowedColumns) > 0 && !qb.allowedColumns[column] {
continue
}
if mappedCol, exists := qb.columnMapping[column]; exists {
column = mappedCol
}
// Security check (cek nama kolom DB hasil mapping)
if len(qb.allowedColumns) > 0 && !qb.allowedColumns[column] {
continue
}
order := "ASC"
if sort.Order != "" {
order = strings.ToUpper(sort.Order)
@@ -519,12 +526,13 @@ func (qb *QueryBuilder) buildGroupByClause(groupFields []string) string {
var groupParts []string
for _, field := range groupFields {
// PERUBAHAN 3: Lakukan mapping dan pemeriksaan keamanan.
column := field
if mappedCol, exists := qb.columnMapping[column]; exists {
column = mappedCol
}
// Security check
// Security check (cek nama kolom DB hasil mapping)
if len(qb.allowedColumns) > 0 && !qb.allowedColumns[column] {
continue
}
@@ -540,43 +548,45 @@ func (qb *QueryBuilder) buildGroupByClause(groupFields []string) string {
}
// buildHavingClause builds the HAVING clause
func (qb *QueryBuilder) buildHavingClause(havingGroups []FilterGroup) (string, []interface{}, error) {
func (qb *QueryBuilder) buildHavingClause(havingGroups []FilterGroup, paramCounter *int) (string, []interface{}, error) {
if len(havingGroups) == 0 {
return "", nil, nil
}
return qb.buildWhereClause(havingGroups)
// Reuse buildWhereClause logic for HAVING
return qb.buildWhereClause(havingGroups, paramCounter)
}
// BuildCountQuery builds a count query
func (qb *QueryBuilder) BuildCountQuery(query DynamicQuery) (string, []interface{}, error) {
qb.paramCounter = 0
// PERUBAHAN 1: paramCounter lokal.
paramCounter := 0
args := []interface{}{}
// Build FROM clause
fromClause := fmt.Sprintf("FROM %s", qb.tableName)
// Build WHERE clause
whereClause, whereArgs, err := qb.buildWhereClause(query.Filters)
whereClause, whereArgs, err := qb.buildWhereClause(query.Filters, &paramCounter)
if err != nil {
return "", nil, err
}
args = append(args, whereArgs...)
// Build GROUP BY clause
groupClause := qb.buildGroupByClause(query.GroupBy)
// Build HAVING clause
havingClause, havingArgs, err := qb.buildHavingClause(query.Having)
havingClause, havingArgs, err := qb.buildHavingClause(query.Having, &paramCounter)
if err != nil {
return "", nil, err
}
args = append(args, havingArgs...)
// Combine parts
sqlParts := []string{"SELECT COUNT(*)", fromClause}
args := []interface{}{}
if whereClause != "" {
sqlParts = append(sqlParts, "WHERE "+whereClause)
args = append(args, whereArgs...)
}
if groupClause != "" {
@@ -585,9 +595,40 @@ func (qb *QueryBuilder) BuildCountQuery(query DynamicQuery) (string, []interface
if havingClause != "" {
sqlParts = append(sqlParts, "HAVING "+havingClause)
args = append(args, havingArgs...)
}
sql := strings.Join(sqlParts, " ")
return sql, args, nil
}
// isValidColumnName validates column name format to prevent SQL injection
func (qb *QueryBuilder) isValidColumnName(column string) bool {
if column == "" {
return false
}
// Allow only alphanumeric characters, underscores, and dots (for table.column format)
// This is more restrictive than before for better security
for _, r := range column {
if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
(r >= '0' && r <= '9') || r == '_' || r == '.') {
return false
}
}
// Prevent common SQL injection patterns
suspiciousPatterns := []string{
" ", ";", "--", "/*", "*/", "union", "select", "insert", "update", "delete",
"drop", "alter", "create", "exec", "execute", "xp_", "sp_", "information_schema",
"sysobjects", "syscolumns", "sysdatabases", "mysql", "pg_", "sqlite",
}
lowerColumn := strings.ToLower(column)
for _, pattern := range suspiciousPatterns {
if strings.Contains(lowerColumn, pattern) {
return false
}
}
return true
}

View File

@@ -0,0 +1,172 @@
package validation
import (
"html"
"regexp"
"strings"
"unicode/utf8"
)
// InputSanitizer provides comprehensive input sanitization
type InputSanitizer struct {
maxLength int
}
// NewInputSanitizer creates a new input sanitizer
func NewInputSanitizer(maxLength int) *InputSanitizer {
return &InputSanitizer{maxLength: maxLength}
}
// SanitizeString performs comprehensive sanitization on a string
func (s *InputSanitizer) SanitizeString(input string) string {
if input == "" {
return input
}
// Check length limit
if utf8.RuneCountInString(input) > s.maxLength {
input = string([]rune(input)[:s.maxLength])
}
// Remove null bytes and control characters
input = strings.Map(func(r rune) rune {
if r < 32 && r != 9 && r != 10 && r != 13 { // Allow tab, LF, CR
return -1
}
return r
}, input)
// HTML escape to prevent XSS
input = html.EscapeString(input)
// Remove potentially dangerous patterns
dangerousPatterns := []string{
`<script[^>]*>.*?</script>`,
`<iframe[^>]*>.*?</iframe>`,
`<object[^>]*>.*?</object>`,
`<embed[^>]*>.*?</embed>`,
`javascript:`,
`vbscript:`,
`data:`,
`on\w+\s*=`,
}
for _, pattern := range dangerousPatterns {
re := regexp.MustCompile(`(?i)` + pattern)
input = re.ReplaceAllString(input, "")
}
// Trim whitespace
return strings.TrimSpace(input)
}
// SanitizeSQLInput sanitizes input specifically for SQL queries
func (s *InputSanitizer) SanitizeSQLInput(input string) string {
input = s.SanitizeString(input)
// Additional SQL-specific sanitization
sqlPatterns := []string{
`;`, `--`, `/*`, `*/`, `@@`, `@`,
`xp_`, `sp_`, `exec`, `execute`,
`information_schema`, `sysobjects`,
`syscolumns`, `sysdatabases`,
}
for _, pattern := range sqlPatterns {
input = strings.ReplaceAll(input, pattern, "")
}
return input
}
// ValidateSQLSafe checks if input is safe for SQL queries
func (s *InputSanitizer) ValidateSQLSafe(input string) bool {
if input == "" {
return true
}
// Check for SQL injection patterns
suspiciousPatterns := []string{
"union select", "union all select",
"select.*from", "insert.*into", "update.*set", "delete.*from",
"drop table", "drop database", "alter table", "create table",
"information_schema", "sysobjects", "syscolumns", "sysdatabases",
"mysql.", "pg_", "sqlite_",
";--", "/*", "*/", "@@",
"script>", "<script",
"javascript:", "vbscript:",
"onload=", "onerror=", "eval(", "alert(",
}
lowerInput := strings.ToLower(input)
for _, pattern := range suspiciousPatterns {
if strings.Contains(lowerInput, pattern) {
return false
}
}
return true
}
// SanitizeJSON sanitizes JSON input
func (s *InputSanitizer) SanitizeJSON(input string) string {
input = s.SanitizeString(input)
// Remove JSON-specific dangerous patterns
jsonPatterns := []string{
`{"\w+":\s*"[^"]*javascript:[^"]*"}`,
`{"\w+":\s*"[^"]*vbscript:[^"]*"}`,
`{"\w+":\s*"[^"]*data:[^"]*"}`,
}
for _, pattern := range jsonPatterns {
re := regexp.MustCompile(`(?i)` + pattern)
input = re.ReplaceAllString(input, "")
}
return input
}
// SanitizeFilename sanitizes filename inputs
func (s *InputSanitizer) SanitizeFilename(filename string) string {
filename = s.SanitizeString(filename)
// Remove path traversal attempts
filename = strings.ReplaceAll(filename, "../", "")
filename = strings.ReplaceAll(filename, "..\\", "")
// Remove dangerous characters for filenames
dangerousChars := []string{"/", "\\", ":", "*", "?", "\"", "<", ">", "|"}
for _, char := range dangerousChars {
filename = strings.ReplaceAll(filename, char, "")
}
return filename
}
// BatchSanitize sanitizes multiple inputs at once
func (s *InputSanitizer) BatchSanitize(inputs map[string]string) map[string]string {
sanitized := make(map[string]string)
for key, value := range inputs {
sanitized[key] = s.SanitizeString(value)
}
return sanitized
}
// IsValidInputLength checks if input length is within acceptable limits
func (s *InputSanitizer) IsValidInputLength(input string, minLen, maxLen int) bool {
length := utf8.RuneCountInString(input)
return length >= minLen && length <= maxLen
}
// ContainsHTML checks if input contains HTML tags
func (s *InputSanitizer) ContainsHTML(input string) bool {
htmlRegex := regexp.MustCompile(`<[^>]+>`)
return htmlRegex.MatchString(input)
}
// StripHTML removes HTML tags from input
func (s *InputSanitizer) StripHTML(input string) string {
htmlRegex := regexp.MustCompile(`<[^>]+>`)
return htmlRegex.ReplaceAllString(input, "")
}