Files
2025-11-26 07:24:49 +00:00

2729 lines
86 KiB
Go

package utils
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/url"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"github.com/Masterminds/squirrel"
"github.com/jmoiron/sqlx"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// DBType represents the type of database
type DBType string
const (
DBTypePostgreSQL DBType = "postgres"
DBTypeMySQL DBType = "mysql"
DBTypeSQLite DBType = "sqlite"
DBTypeSQLServer DBType = "sqlserver"
DBTypeMongoDB DBType = "mongodb"
)
// FilterOperator represents supported filter operators
type FilterOperator string
const (
OpEqual FilterOperator = "_eq"
OpNotEqual FilterOperator = "_neq"
OpLike FilterOperator = "_like"
OpILike FilterOperator = "_ilike"
OpNotLike FilterOperator = "_nlike"
OpNotILike FilterOperator = "_nilike"
OpIn FilterOperator = "_in"
OpNotIn FilterOperator = "_nin"
OpGreaterThan FilterOperator = "_gt"
OpGreaterThanEqual FilterOperator = "_gte"
OpLessThan FilterOperator = "_lt"
OpLessThanEqual FilterOperator = "_lte"
OpBetween FilterOperator = "_between"
OpNotBetween FilterOperator = "_nbetween"
OpNull FilterOperator = "_null"
OpNotNull FilterOperator = "_nnull"
OpContains FilterOperator = "_contains"
OpNotContains FilterOperator = "_ncontains"
OpStartsWith FilterOperator = "_starts_with"
OpEndsWith FilterOperator = "_ends_with"
OpJsonContains FilterOperator = "_json_contains"
OpJsonNotContains FilterOperator = "_json_ncontains"
OpJsonExists FilterOperator = "_json_exists"
OpJsonNotExists FilterOperator = "_json_nexists"
OpJsonEqual FilterOperator = "_json_eq"
OpJsonNotEqual FilterOperator = "_json_neq"
OpArrayContains FilterOperator = "_array_contains"
OpArrayNotContains FilterOperator = "_array_ncontains"
OpArrayLength FilterOperator = "_array_length"
)
// DynamicFilter represents a single filter condition
type DynamicFilter struct {
Column string `json:"column"`
Operator FilterOperator `json:"operator"`
Value interface{} `json:"value"`
// Additional options for complex filters
Options map[string]interface{} `json:"options,omitempty"`
}
// FilterGroup represents a group of filters with a logical operator (AND/OR)
type FilterGroup struct {
Filters []DynamicFilter `json:"filters"`
LogicOp string `json:"logic_op"` // AND, OR
}
// SelectField represents a field in the SELECT clause, supporting expressions and aliases
type SelectField struct {
Expression string `json:"expression"` // e.g., "TMLogBarang.Nama", "COUNT(*)"
Alias string `json:"alias"` // e.g., "obat_nama", "total_count"
// Window function support
WindowFunction *WindowFunction `json:"window_function,omitempty"`
}
// WindowFunction represents a window function with its configuration
type WindowFunction struct {
Function string `json:"function"` // e.g., "ROW_NUMBER", "RANK", "DENSE_RANK", "LEAD", "LAG"
Over string `json:"over"` // PARTITION BY expression
OrderBy string `json:"order_by"` // ORDER BY expression
Frame string `json:"frame"` // ROWS/RANGE clause
Alias string `json:"alias"` // Alias for the window function
}
// Join represents a JOIN clause
type Join struct {
Type string `json:"type"` // "INNER", "LEFT", "RIGHT", "FULL"
Table string `json:"table"` // Table name to join
Alias string `json:"alias"` // Table alias
OnConditions FilterGroup `json:"on_conditions"` // Conditions for the ON clause
// LATERAL JOIN support
Lateral bool `json:"lateral,omitempty"`
}
// Union represents a UNION clause
type Union struct {
Type string `json:"type"` // "UNION", "UNION ALL"
Query DynamicQuery `json:"query"` // The subquery to union with
}
// CTE (Common Table Expression) represents a WITH clause
type CTE struct {
Name string `json:"name"` // CTE alias name
Query DynamicQuery `json:"query"` // The query defining the CTE
// Recursive CTE support
Recursive bool `json:"recursive,omitempty"`
}
// DynamicQuery represents the complete query structure
type DynamicQuery struct {
Fields []SelectField `json:"fields,omitempty"`
From string `json:"from"` // Main table name
Aliases string `json:"aliases"` // Main table alias
Joins []Join `json:"joins,omitempty"`
Filters []FilterGroup `json:"filters,omitempty"`
GroupBy []string `json:"group_by,omitempty"`
Having []FilterGroup `json:"having,omitempty"`
Unions []Union `json:"unions,omitempty"`
CTEs []CTE `json:"ctes,omitempty"`
Sort []SortField `json:"sort,omitempty"`
Limit int `json:"limit"`
Offset int `json:"offset"`
// Window function support
WindowFunctions []WindowFunction `json:"window_functions,omitempty"`
// JSON operations
JsonOperations []JsonOperation `json:"json_operations,omitempty"`
}
// JsonOperation represents a JSON operation
type JsonOperation struct {
Type string `json:"type"` // "extract", "exists", "contains", etc.
Column string `json:"column"` // JSON column
Path string `json:"path"` // JSON path
Value interface{} `json:"value,omitempty"` // Value for comparison
Alias string `json:"alias,omitempty"` // Alias for the result
}
// SortField represents sorting configuration
type SortField struct {
Column string `json:"column"`
Order string `json:"order"` // ASC, DESC
}
// UpdateData represents data for UPDATE operations
type UpdateData struct {
Columns []string `json:"columns"`
Values []interface{} `json:"values"`
// JSON update support
JsonUpdates map[string]JsonUpdate `json:"json_updates,omitempty"`
}
// JsonUpdate represents a JSON update operation
type JsonUpdate struct {
Path string `json:"path"` // JSON path
Value interface{} `json:"value"` // New value
}
// InsertData represents data for INSERT operations
type InsertData struct {
Columns []string `json:"columns"`
Values []interface{} `json:"values"`
// JSON insert support
JsonValues map[string]interface{} `json:"json_values,omitempty"`
}
// QueryBuilder builds SQL queries from dynamic filters using squirrel
type QueryBuilder struct {
dbType DBType
sqlBuilder squirrel.StatementBuilderType
allowedColumns map[string]bool // Security: only allow specified columns
allowedTables map[string]bool // Security: only allow specified tables
// Security settings
enableSecurityChecks bool
maxAllowedRows int
// SQL injection prevention patterns
dangerousPatterns []*regexp.Regexp
// Query logging
enableQueryLogging bool
// Connection timeout settings
queryTimeout time.Duration
}
// NewQueryBuilder creates a new query builder instance for a specific database type
func NewQueryBuilder(dbType DBType) *QueryBuilder {
var placeholderFormat squirrel.PlaceholderFormat
switch dbType {
case DBTypePostgreSQL:
placeholderFormat = squirrel.Dollar
case DBTypeMySQL, DBTypeSQLite:
placeholderFormat = squirrel.Question
case DBTypeSQLServer:
placeholderFormat = squirrel.AtP
default:
placeholderFormat = squirrel.Question
}
// Initialize dangerous patterns for SQL injection prevention
dangerousPatterns := []*regexp.Regexp{
regexp.MustCompile(`(?i)(union|select|insert|update|delete|drop|alter|create|exec|execute)\s`),
regexp.MustCompile(`(?i)(--|\/\*|\*\/)`),
regexp.MustCompile(`(?i)(or|and)\s+1\s*=\s*1`),
regexp.MustCompile(`(?i)(or|and)\s+true`),
regexp.MustCompile(`(?i)(xp_|sp_)\w+`), // SQL Server extended procedures
regexp.MustCompile(`(?i)(waitfor\s+delay)`), // SQL Server time-based attack
regexp.MustCompile(`(?i)(benchmark|sleep)\s*\(`), // MySQL time-based attack
regexp.MustCompile(`(?i)(pg_sleep)\s*\(`), // PostgreSQL time-based attack
regexp.MustCompile(`(?i)(load_file|into\s+outfile)`), // File operations
regexp.MustCompile(`(?i)(information_schema|sysobjects|syscolumns)`), // System tables
}
return &QueryBuilder{
dbType: dbType,
sqlBuilder: squirrel.StatementBuilder.PlaceholderFormat(placeholderFormat),
allowedColumns: make(map[string]bool),
allowedTables: make(map[string]bool),
enableSecurityChecks: true,
maxAllowedRows: 10000,
dangerousPatterns: dangerousPatterns,
enableQueryLogging: true,
queryTimeout: 30 * time.Second,
}
}
// SetSecurityOptions configures security settings
func (qb *QueryBuilder) SetSecurityOptions(enableChecks bool, maxRows int) *QueryBuilder {
qb.enableSecurityChecks = enableChecks
qb.maxAllowedRows = maxRows
return qb
}
// SetAllowedColumns sets the list of allowed columns for security
func (qb *QueryBuilder) SetAllowedColumns(columns []string) *QueryBuilder {
qb.allowedColumns = make(map[string]bool)
for _, col := range columns {
qb.allowedColumns[col] = true
}
return qb
}
// SetAllowedTables sets the list of allowed tables for security
func (qb *QueryBuilder) SetAllowedTables(tables []string) *QueryBuilder {
qb.allowedTables = make(map[string]bool)
for _, table := range tables {
qb.allowedTables[table] = true
}
return qb
}
// SetQueryLogging enables or disables query logging
func (qb *QueryBuilder) SetQueryLogging(enable bool) *QueryBuilder {
qb.enableQueryLogging = enable
return qb
}
// SetQueryTimeout sets the default query timeout
func (qb *QueryBuilder) SetQueryTimeout(timeout time.Duration) *QueryBuilder {
qb.queryTimeout = timeout
return qb
}
// BuildQuery builds the complete SQL SELECT query with support for CTEs, JOINs, and UNIONs
func (qb *QueryBuilder) BuildQuery(query DynamicQuery) (string, []interface{}, error) {
var allArgs []interface{}
var queryParts []string
// Security check for limit
if qb.enableSecurityChecks && query.Limit > qb.maxAllowedRows {
return "", nil, fmt.Errorf("requested limit %d exceeds maximum allowed %d", query.Limit, qb.maxAllowedRows)
}
// Security check for table name
if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[query.From] {
return "", nil, fmt.Errorf("disallowed table: %s", query.From)
}
// 1. Build CTEs (WITH clause)
if len(query.CTEs) > 0 {
cteClause, cteArgs, err := qb.buildCTEClause(query.CTEs)
if err != nil {
return "", nil, err
}
queryParts = append(queryParts, cteClause)
allArgs = append(allArgs, cteArgs...)
}
// 2. Build Main Query using Squirrel's From and Join methods
fromClause := qb.buildFromClause(query.From, query.Aliases)
selectFields := qb.buildSelectFields(query.Fields)
// Start building the main query
var mainQuery squirrel.SelectBuilder
if len(query.WindowFunctions) > 0 || len(query.JsonOperations) > 0 {
// We need to add window functions and JSON operations after initial select
mainQuery = qb.sqlBuilder.Select(selectFields...).From(fromClause)
} else {
mainQuery = qb.sqlBuilder.Select(selectFields...).From(fromClause)
}
// Add JOINs using Squirrel's Join method
if len(query.Joins) > 0 {
for _, join := range query.Joins {
// Security check for joined table
if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[join.Table] {
return "", nil, fmt.Errorf("disallowed table in join: %s", join.Table)
}
joinType, tableWithAlias, onClause, joinArgs, err := qb.buildSingleJoinClause(join)
if err != nil {
return "", nil, err
}
joinStr := tableWithAlias + " ON " + onClause
switch strings.ToUpper(joinType) {
case "LEFT":
if join.Lateral {
mainQuery = mainQuery.LeftJoin("LATERAL "+joinStr, joinArgs...)
} else {
mainQuery = mainQuery.LeftJoin(joinStr, joinArgs...)
}
case "RIGHT":
mainQuery = mainQuery.RightJoin(joinStr, joinArgs...)
case "FULL":
mainQuery = mainQuery.Join("FULL JOIN "+joinStr, joinArgs...)
default:
if join.Lateral {
mainQuery = mainQuery.Join("LATERAL "+joinStr, joinArgs...)
} else {
mainQuery = mainQuery.Join(joinStr, joinArgs...)
}
}
}
}
// 4. Apply WHERE conditions
if len(query.Filters) > 0 {
whereClause, whereArgs, err := qb.BuildWhereClause(query.Filters)
if err != nil {
return "", nil, err
}
mainQuery = mainQuery.Where(whereClause, whereArgs...)
}
// 5. Apply GROUP BY
if len(query.GroupBy) > 0 {
mainQuery = mainQuery.GroupBy(qb.buildGroupByColumns(query.GroupBy)...)
}
// 6. Apply HAVING conditions
if len(query.Having) > 0 {
havingClause, havingArgs, err := qb.BuildWhereClause(query.Having)
if err != nil {
return "", nil, err
}
mainQuery = mainQuery.Having(havingClause, havingArgs...)
}
// 7. Apply ORDER BY
if len(query.Sort) > 0 {
for _, sort := range query.Sort {
column := qb.validateAndEscapeColumn(sort.Column)
if column == "" {
continue
}
order := "ASC"
if strings.ToUpper(sort.Order) == "DESC" {
order = "DESC"
}
mainQuery = mainQuery.OrderBy(fmt.Sprintf("%s %s", column, order))
}
}
// 8. Apply window functions and JSON operations by modifying the SELECT clause
if len(query.WindowFunctions) > 0 || len(query.JsonOperations) > 0 {
// We need to rebuild the SELECT clause with window functions and JSON operations
var finalSelectFields []string
finalSelectFields = append(finalSelectFields, selectFields...)
// Add window functions
for _, wf := range query.WindowFunctions {
windowFunc, err := qb.buildWindowFunction(wf)
if err != nil {
return "", nil, err
}
finalSelectFields = append(finalSelectFields, windowFunc)
}
// Add JSON operations
for _, jo := range query.JsonOperations {
jsonExpr, jsonArgs, err := qb.buildJsonOperation(jo)
if err != nil {
return "", nil, err
}
if jo.Alias != "" {
jsonExpr += " AS " + qb.escapeIdentifier(jo.Alias)
}
finalSelectFields = append(finalSelectFields, jsonExpr)
allArgs = append(allArgs, jsonArgs...)
}
// Rebuild the query with the complete SELECT clause
mainQuery = qb.sqlBuilder.Select(finalSelectFields...).From(fromClause)
// Re-apply all the other clauses
if len(query.Joins) > 0 {
for _, join := range query.Joins {
// Security check for joined table
if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[join.Table] {
return "", nil, fmt.Errorf("disallowed table in join: %s", join.Table)
}
joinType, tableWithAlias, onClause, joinArgs, err := qb.buildSingleJoinClause(join)
if err != nil {
return "", nil, err
}
joinStr := tableWithAlias + " ON " + onClause
switch strings.ToUpper(joinType) {
case "LEFT":
if join.Lateral {
mainQuery = mainQuery.LeftJoin("LATERAL "+joinStr, joinArgs...)
} else {
mainQuery = mainQuery.LeftJoin(joinStr, joinArgs...)
}
case "RIGHT":
mainQuery = mainQuery.RightJoin(joinStr, joinArgs...)
case "FULL":
mainQuery = mainQuery.Join("FULL JOIN "+joinStr, joinArgs...)
default:
if join.Lateral {
mainQuery = mainQuery.Join("LATERAL "+joinStr, joinArgs...)
} else {
mainQuery = mainQuery.Join(joinStr, joinArgs...)
}
}
}
}
if len(query.Filters) > 0 {
whereClause, whereArgs, err := qb.BuildWhereClause(query.Filters)
if err != nil {
return "", nil, err
}
mainQuery = mainQuery.Where(whereClause, whereArgs...)
}
if len(query.GroupBy) > 0 {
mainQuery = mainQuery.GroupBy(qb.buildGroupByColumns(query.GroupBy)...)
}
if len(query.Having) > 0 {
havingClause, havingArgs, err := qb.BuildWhereClause(query.Having)
if err != nil {
return "", nil, err
}
mainQuery = mainQuery.Having(havingClause, havingArgs...)
}
if len(query.Sort) > 0 {
for _, sort := range query.Sort {
column := qb.validateAndEscapeColumn(sort.Column)
if column == "" {
continue
}
order := "ASC"
if strings.ToUpper(sort.Order) == "DESC" {
order = "DESC"
}
mainQuery = mainQuery.OrderBy(fmt.Sprintf("%s %s", column, order))
}
}
}
// 9. Apply pagination with dialect-specific syntax
if query.Limit > 0 {
if qb.dbType == DBTypeSQLServer {
// SQL Server requires ORDER BY for OFFSET FETCH
if len(query.Sort) == 0 {
mainQuery = mainQuery.OrderBy("(SELECT 1)")
}
mainQuery = mainQuery.Suffix(fmt.Sprintf("OFFSET %d ROWS FETCH NEXT %d ROWS ONLY", query.Offset, query.Limit))
} else {
mainQuery = mainQuery.Limit(uint64(query.Limit))
if query.Offset > 0 {
mainQuery = mainQuery.Offset(uint64(query.Offset))
}
}
} else if query.Offset > 0 && qb.dbType != DBTypeSQLServer {
mainQuery = mainQuery.Offset(uint64(query.Offset))
}
// Build final main query SQL
sql, args, err := mainQuery.ToSql()
if err != nil {
return "", nil, fmt.Errorf("failed to build main query: %w", err)
}
queryParts = append(queryParts, sql)
allArgs = append(allArgs, args...)
// 10. Apply UNIONs
if len(query.Unions) > 0 {
unionClause, unionArgs, err := qb.buildUnionClause(query.Unions)
if err != nil {
return "", nil, err
}
queryParts = append(queryParts, unionClause)
allArgs = append(allArgs, unionArgs...)
}
finalSQL := strings.Join(queryParts, " ")
// Security check for dangerous patterns in user input values
if qb.enableSecurityChecks {
if err := qb.checkForSqlInjectionInArgs(allArgs); err != nil {
return "", nil, err
}
}
// Security check for dangerous patterns in the final SQL
// if qb.enableSecurityChecks {
// if err := qb.checkForSqlInjectionInSQL(finalSQL); err != nil {
// return "", nil, err
// }
// }
if qb.enableQueryLogging {
fmt.Printf("[DEBUG BuilderQuery] Final SQL query: %s\n", finalSQL)
fmt.Printf("[DEBUG] Query args: %v\n", allArgs)
}
return finalSQL, allArgs, nil
}
// buildWindowFunction builds a window function expression
func (qb *QueryBuilder) buildWindowFunction(wf WindowFunction) (string, error) {
if !qb.isValidFunctionName(wf.Function) {
return "", fmt.Errorf("invalid window function name: %s", wf.Function)
}
windowExpr := fmt.Sprintf("%s() OVER (", wf.Function)
if wf.Over != "" {
windowExpr += fmt.Sprintf("PARTITION BY %s ", wf.Over)
}
if wf.OrderBy != "" {
windowExpr += fmt.Sprintf("ORDER BY %s ", wf.OrderBy)
}
if wf.Frame != "" {
windowExpr += wf.Frame
}
windowExpr += ")"
if wf.Alias != "" {
windowExpr += " AS " + qb.escapeIdentifier(wf.Alias)
}
return windowExpr, nil
}
// buildJsonOperation builds a JSON operation expression
func (qb *QueryBuilder) buildJsonOperation(jo JsonOperation) (string, []interface{}, error) {
column := qb.validateAndEscapeColumn(jo.Column)
if column == "" {
return "", nil, fmt.Errorf("invalid or disallowed column: %s", jo.Column)
}
path := jo.Path
if path == "" {
path = "$"
}
var expr string
var args []interface{}
switch strings.ToLower(jo.Type) {
case "extract":
switch qb.dbType {
case DBTypePostgreSQL:
expr = fmt.Sprintf("%s->>%s", column, qb.escapeJsonPath(path))
case DBTypeMySQL:
expr = fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, path)
case DBTypeSQLServer:
expr = fmt.Sprintf("JSON_VALUE(%s, '%s')", column, qb.escapeSqlServerJsonPath(path))
case DBTypeSQLite:
expr = fmt.Sprintf("json_extract(%s, '%s')", column, path)
default:
return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType)
}
case "exists":
switch qb.dbType {
case DBTypePostgreSQL:
expr = fmt.Sprintf("jsonb_path_exists(%s, '%s')", column, path)
case DBTypeMySQL:
expr = fmt.Sprintf("JSON_CONTAINS_PATH(%s, 'one', '%s')", column, path)
case DBTypeSQLServer:
expr = fmt.Sprintf("JSON_VALUE(%s, '%s') IS NOT NULL", column, qb.escapeSqlServerJsonPath(path))
case DBTypeSQLite:
expr = fmt.Sprintf("json_extract(%s, '%s') IS NOT NULL", column, path)
default:
return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType)
}
case "contains":
switch qb.dbType {
case DBTypePostgreSQL:
expr = fmt.Sprintf("%s @> %s", column, "?")
args = append(args, jo.Value)
case DBTypeMySQL:
expr = fmt.Sprintf("JSON_CONTAINS(%s, ?, '%s')", column, path)
args = append(args, jo.Value)
case DBTypeSQLServer:
expr = fmt.Sprintf("JSON_VALUE(%s, '%s') = ?", column, qb.escapeSqlServerJsonPath(path))
args = append(args, jo.Value)
case DBTypeSQLite:
expr = fmt.Sprintf("json_extract(%s, '%s') = ?", column, path)
args = append(args, jo.Value)
default:
return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType)
}
default:
return "", nil, fmt.Errorf("unsupported JSON operation type: %s", jo.Type)
}
return expr, args, nil
}
// escapeJsonPath escapes a JSON path for PostgreSQL
func (qb *QueryBuilder) escapeJsonPath(path string) string {
// Simple implementation - in a real scenario, you'd need more sophisticated escaping
return "'" + strings.ReplaceAll(path, "'", "''") + "'"
}
// escapeSqlServerJsonPath escapes a JSON path for SQL Server
func (qb *QueryBuilder) escapeSqlServerJsonPath(path string) string {
// Convert JSONPath to SQL Server format
// $.path.to.property -> '$.path.to.property'
if !strings.HasPrefix(path, "$") {
path = "$." + path
}
return strings.ReplaceAll(path, ".", ".")
}
// buildCTEClause builds the WITH clause for Common Table Expressions
func (qb *QueryBuilder) buildCTEClause(ctes []CTE) (string, []interface{}, error) {
var cteParts []string
var allArgs []interface{}
hasRecursive := false
for _, cte := range ctes {
if cte.Recursive {
hasRecursive = true
break
}
}
withClause := "WITH"
if hasRecursive {
withClause = "WITH RECURSIVE"
}
for _, cte := range ctes {
subQuery, args, err := qb.BuildQuery(cte.Query)
if err != nil {
return "", nil, fmt.Errorf("failed to build CTE '%s': %w", cte.Name, err)
}
cteParts = append(cteParts, fmt.Sprintf("%s AS (%s)", qb.escapeIdentifier(cte.Name), subQuery))
allArgs = append(allArgs, args...)
}
return fmt.Sprintf("%s %s", withClause, strings.Join(cteParts, ", ")), allArgs, nil
}
// buildFromClause builds the FROM clause with optional alias
func (qb *QueryBuilder) buildFromClause(table, alias string) string {
fromClause := qb.escapeIdentifier(table)
if alias != "" {
fromClause += " " + qb.escapeIdentifier(alias)
}
return fromClause
}
// buildSingleJoinClause builds a single JOIN clause components
func (qb *QueryBuilder) buildSingleJoinClause(join Join) (string, string, string, []interface{}, error) {
joinType := strings.ToUpper(join.Type)
if joinType == "" {
joinType = "INNER"
}
table := qb.escapeIdentifier(join.Table)
if join.Alias != "" {
table += " " + qb.escapeIdentifier(join.Alias)
}
onClause, onArgs, err := qb.BuildWhereClause([]FilterGroup{join.OnConditions})
if err != nil {
return "", "", "", nil, fmt.Errorf("failed to build ON clause for join on table %s: %w", join.Table, err)
}
return joinType, table, onClause, onArgs, nil
}
// buildUnionClause builds the UNION clause
func (qb *QueryBuilder) buildUnionClause(unions []Union) (string, []interface{}, error) {
var unionParts []string
var allArgs []interface{}
for _, union := range unions {
subQuery, args, err := qb.BuildQuery(union.Query)
if err != nil {
return "", nil, fmt.Errorf("failed to build subquery for UNION: %w", err)
}
unionType := strings.ToUpper(union.Type)
if unionType == "" {
unionType = "UNION"
}
unionParts = append(unionParts, fmt.Sprintf("%s %s", unionType, subQuery))
allArgs = append(allArgs, args...)
}
return strings.Join(unionParts, " "), allArgs, nil
}
// buildSelectFields builds the SELECT fields from SelectField structs
func (qb *QueryBuilder) buildSelectFields(fields []SelectField) []string {
if len(fields) == 0 {
return []string{"*"}
}
var selectedFields []string
for _, field := range fields {
expr := field.Expression
if expr == "" {
continue
}
// Basic validation for expression
if !qb.isValidExpression(expr) {
continue
}
// Handle window functions
if field.WindowFunction != nil {
windowFunc, err := qb.buildWindowFunction(*field.WindowFunction)
if err != nil {
continue
}
expr = windowFunc
}
if field.Alias != "" {
selectedFields = append(selectedFields, fmt.Sprintf("%s AS %s", expr, qb.escapeIdentifier(field.Alias)))
} else {
selectedFields = append(selectedFields, expr)
}
}
if len(selectedFields) == 0 {
return []string{"*"}
}
return selectedFields
}
// BuildWhereClause builds WHERE/HAVING conditions from FilterGroups
func (qb *QueryBuilder) BuildWhereClause(filterGroups []FilterGroup) (string, []interface{}, error) {
if len(filterGroups) == 0 {
return "", nil, nil
}
var conditions []string
var allArgs []interface{}
for i, group := range filterGroups {
if len(group.Filters) == 0 {
continue
}
groupCondition, groupArgs, err := qb.buildFilterGroup(group)
if err != nil {
return "", nil, err
}
if groupCondition != "" {
if i > 0 {
logicOp := "AND"
if group.LogicOp != "" {
logicOp = strings.ToUpper(group.LogicOp)
}
conditions = append(conditions, logicOp)
}
conditions = append(conditions, fmt.Sprintf("(%s)", groupCondition))
allArgs = append(allArgs, groupArgs...)
}
}
return strings.Join(conditions, " "), allArgs, nil
}
// buildFilterGroup builds conditions for a single filter group
func (qb *QueryBuilder) buildFilterGroup(group FilterGroup) (string, []interface{}, error) {
var conditions []string
var args []interface{}
logicOp := "AND"
if group.LogicOp != "" {
logicOp = strings.ToUpper(group.LogicOp)
}
for i, filter := range group.Filters {
condition, filterArgs, err := qb.buildFilterCondition(filter)
if err != nil {
return "", nil, err
}
if condition != "" {
if i > 0 {
conditions = append(conditions, logicOp)
}
conditions = append(conditions, condition)
args = append(args, filterArgs...)
}
}
return strings.Join(conditions, " "), args, nil
}
// buildFilterCondition builds a single filter condition with dialect-specific logic
func (qb *QueryBuilder) buildFilterCondition(filter DynamicFilter) (string, []interface{}, error) {
column := qb.validateAndEscapeColumn(filter.Column)
if column == "" {
return "", nil, fmt.Errorf("invalid or disallowed column: %s", filter.Column)
}
// Handle column-to-column comparison
if valStr, ok := filter.Value.(string); ok && strings.Contains(valStr, ".") && qb.isValidExpression(valStr) && len(strings.Split(valStr, ".")) == 2 {
escapedVal := qb.escapeColumnReference(valStr)
switch filter.Operator {
case OpEqual:
return fmt.Sprintf("%s = %s", column, escapedVal), nil, nil
case OpNotEqual:
return fmt.Sprintf("%s <> %s", column, escapedVal), nil, nil
case OpGreaterThan:
return fmt.Sprintf("%s > %s", column, escapedVal), nil, nil
case OpLessThan:
return fmt.Sprintf("%s < %s", column, escapedVal), nil, nil
}
}
// Handle JSON operations
switch filter.Operator {
case OpJsonContains, OpJsonNotContains, OpJsonExists, OpJsonNotExists, OpJsonEqual, OpJsonNotEqual:
return qb.buildJsonFilterCondition(filter)
case OpArrayContains, OpArrayNotContains, OpArrayLength:
return qb.buildArrayFilterCondition(filter)
}
// Handle standard operators
switch filter.Operator {
case OpEqual:
if filter.Value == nil {
return fmt.Sprintf("%s IS NULL", column), nil, nil
}
return fmt.Sprintf("%s = ?", column), []interface{}{filter.Value}, nil
case OpNotEqual:
if filter.Value == nil {
return fmt.Sprintf("%s IS NOT NULL", column), nil, nil
}
return fmt.Sprintf("%s <> ?", column), []interface{}{filter.Value}, nil
case OpLike:
if filter.Value == nil {
return "", nil, nil
}
return fmt.Sprintf("%s LIKE ?", column), []interface{}{filter.Value}, nil
case OpILike:
if filter.Value == nil {
return "", nil, nil
}
switch qb.dbType {
case DBTypePostgreSQL, DBTypeSQLite:
return fmt.Sprintf("%s ILIKE ?", column), []interface{}{filter.Value}, nil
case DBTypeMySQL, DBTypeSQLServer:
return fmt.Sprintf("LOWER(%s) LIKE LOWER(?)", column), []interface{}{filter.Value}, nil
default:
return fmt.Sprintf("%s LIKE ?", column), []interface{}{filter.Value}, nil
}
case OpIn, OpNotIn:
values := qb.parseArrayValue(filter.Value)
if len(values) == 0 {
return "1=0", nil, nil
}
op := "IN"
if filter.Operator == OpNotIn {
op = "NOT IN"
}
placeholders := squirrel.Placeholders(len(values))
return fmt.Sprintf("%s %s (%s)", column, op, placeholders), values, nil
case OpGreaterThan, OpGreaterThanEqual, OpLessThan, OpLessThanEqual:
if filter.Value == nil {
return "", nil, nil
}
op := strings.TrimPrefix(string(filter.Operator), "_")
return fmt.Sprintf("%s %s ?", column, op), []interface{}{filter.Value}, nil
case OpBetween, OpNotBetween:
values := qb.parseArrayValue(filter.Value)
if len(values) != 2 {
return "", nil, fmt.Errorf("between operator requires exactly 2 values")
}
op := "BETWEEN"
if filter.Operator == OpNotBetween {
op = "NOT BETWEEN"
}
return fmt.Sprintf("%s %s ? AND ?", column, op), []interface{}{values[0], values[1]}, nil
case OpNull:
return fmt.Sprintf("%s IS NULL", column), nil, nil
case OpNotNull:
return fmt.Sprintf("%s IS NOT NULL", column), nil, nil
case OpContains, OpNotContains, OpStartsWith, OpEndsWith:
if filter.Value == nil {
return "", nil, nil
}
var value string
switch filter.Operator {
case OpContains, OpNotContains:
value = fmt.Sprintf("%%%v%%", filter.Value)
case OpStartsWith:
value = fmt.Sprintf("%v%%", filter.Value)
case OpEndsWith:
value = fmt.Sprintf("%%%v", filter.Value)
}
switch qb.dbType {
case DBTypePostgreSQL, DBTypeSQLite:
op := "ILIKE"
if strings.Contains(string(filter.Operator), "Not") {
op = "NOT ILIKE"
}
return fmt.Sprintf("%s %s ?", column, op), []interface{}{value}, nil
case DBTypeMySQL, DBTypeSQLServer:
op := "LIKE"
if strings.Contains(string(filter.Operator), "Not") {
op = "NOT LIKE"
}
return fmt.Sprintf("LOWER(%s) %s LOWER(?)", column, op), []interface{}{value}, nil
default:
op := "LIKE"
if strings.Contains(string(filter.Operator), "Not") {
op = "NOT LIKE"
}
return fmt.Sprintf("%s %s ?", column, op), []interface{}{value}, nil
}
default:
return "", nil, fmt.Errorf("unsupported operator: %s", filter.Operator)
}
}
// buildJsonFilterCondition builds a JSON filter condition
func (qb *QueryBuilder) buildJsonFilterCondition(filter DynamicFilter) (string, []interface{}, error) {
column := qb.validateAndEscapeColumn(filter.Column)
if column == "" {
return "", nil, fmt.Errorf("invalid or disallowed column: %s", filter.Column)
}
path := "$"
if pathOption, ok := filter.Options["path"].(string); ok && pathOption != "" {
path = pathOption
}
var expr string
var args []interface{}
switch filter.Operator {
case OpJsonContains:
switch qb.dbType {
case DBTypePostgreSQL:
expr = fmt.Sprintf("%s @> ?", column)
args = append(args, filter.Value)
case DBTypeMySQL:
expr = fmt.Sprintf("JSON_CONTAINS(%s, ?, '%s')", column, path)
args = append(args, filter.Value)
case DBTypeSQLServer:
expr = fmt.Sprintf("JSON_VALUE(%s, '%s') = ?", column, qb.escapeSqlServerJsonPath(path))
args = append(args, filter.Value)
case DBTypeSQLite:
expr = fmt.Sprintf("json_extract(%s, '%s') = ?", column, path)
args = append(args, filter.Value)
default:
return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType)
}
case OpJsonNotContains:
switch qb.dbType {
case DBTypePostgreSQL:
expr = fmt.Sprintf("NOT (%s @> ?)", column)
args = append(args, filter.Value)
case DBTypeMySQL:
expr = fmt.Sprintf("NOT JSON_CONTAINS(%s, ?, '%s')", column, path)
args = append(args, filter.Value)
case DBTypeSQLServer:
expr = fmt.Sprintf("JSON_VALUE(%s, '%s') <> ?", column, qb.escapeSqlServerJsonPath(path))
args = append(args, filter.Value)
case DBTypeSQLite:
expr = fmt.Sprintf("json_extract(%s, '%s') <> ?", column, path)
args = append(args, filter.Value)
default:
return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType)
}
case OpJsonExists:
switch qb.dbType {
case DBTypePostgreSQL:
expr = fmt.Sprintf("jsonb_path_exists(%s, '%s')", column, path)
case DBTypeMySQL:
expr = fmt.Sprintf("JSON_CONTAINS_PATH(%s, 'one', '%s')", column, path)
case DBTypeSQLServer:
expr = fmt.Sprintf("JSON_VALUE(%s, '%s') IS NOT NULL", column, qb.escapeSqlServerJsonPath(path))
case DBTypeSQLite:
expr = fmt.Sprintf("json_extract(%s, '%s') IS NOT NULL", column, path)
default:
return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType)
}
case OpJsonNotExists:
switch qb.dbType {
case DBTypePostgreSQL:
expr = fmt.Sprintf("NOT jsonb_path_exists(%s, '%s')", column, path)
case DBTypeMySQL:
expr = fmt.Sprintf("NOT JSON_CONTAINS_PATH(%s, 'one', '%s')", column, path)
case DBTypeSQLServer:
expr = fmt.Sprintf("JSON_VALUE(%s, '%s') IS NULL", column, qb.escapeSqlServerJsonPath(path))
case DBTypeSQLite:
expr = fmt.Sprintf("json_extract(%s, '%s') IS NULL", column, path)
default:
return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType)
}
case OpJsonEqual:
switch qb.dbType {
case DBTypePostgreSQL:
expr = fmt.Sprintf("%s->>%s = ?", column, qb.escapeJsonPath(path))
args = append(args, filter.Value)
case DBTypeMySQL:
expr = fmt.Sprintf("JSON_EXTRACT(%s, '%s') = ?", column, path)
args = append(args, filter.Value)
case DBTypeSQLServer:
expr = fmt.Sprintf("JSON_VALUE(%s, '%s') = ?", column, qb.escapeSqlServerJsonPath(path))
args = append(args, filter.Value)
case DBTypeSQLite:
expr = fmt.Sprintf("json_extract(%s, '%s') = ?", column, path)
args = append(args, filter.Value)
default:
return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType)
}
case OpJsonNotEqual:
switch qb.dbType {
case DBTypePostgreSQL:
expr = fmt.Sprintf("%s->>%s <> ?", column, qb.escapeJsonPath(path))
args = append(args, filter.Value)
case DBTypeMySQL:
expr = fmt.Sprintf("JSON_EXTRACT(%s, '%s') <> ?", column, path)
args = append(args, filter.Value)
case DBTypeSQLServer:
expr = fmt.Sprintf("JSON_VALUE(%s, '%s') <> ?", column, qb.escapeSqlServerJsonPath(path))
args = append(args, filter.Value)
case DBTypeSQLite:
expr = fmt.Sprintf("json_extract(%s, '%s') <> ?", column, path)
args = append(args, filter.Value)
default:
return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType)
}
default:
return "", nil, fmt.Errorf("unsupported JSON operator: %s", filter.Operator)
}
return expr, args, nil
}
// buildArrayFilterCondition builds an array filter condition
func (qb *QueryBuilder) buildArrayFilterCondition(filter DynamicFilter) (string, []interface{}, error) {
column := qb.validateAndEscapeColumn(filter.Column)
if column == "" {
return "", nil, fmt.Errorf("invalid or disallowed column: %s", filter.Column)
}
var expr string
var args []interface{}
switch filter.Operator {
case OpArrayContains:
switch qb.dbType {
case DBTypePostgreSQL:
expr = fmt.Sprintf("? = ANY(%s)", column)
args = append(args, filter.Value)
case DBTypeMySQL:
expr = fmt.Sprintf("JSON_CONTAINS(%s, JSON_QUOTE(?))", column)
args = append(args, filter.Value)
case DBTypeSQLServer:
expr = fmt.Sprintf("? IN (SELECT value FROM OPENJSON(%s))", column)
args = append(args, filter.Value)
case DBTypeSQLite:
expr = fmt.Sprintf("EXISTS (SELECT 1 FROM json_each(%s) WHERE json_each.value = ?)", column)
args = append(args, filter.Value)
default:
return "", nil, fmt.Errorf("Array operations not supported for database type: %s", qb.dbType)
}
case OpArrayNotContains:
switch qb.dbType {
case DBTypePostgreSQL:
expr = fmt.Sprintf("? <> ALL(%s)", column)
args = append(args, filter.Value)
case DBTypeMySQL:
expr = fmt.Sprintf("NOT JSON_CONTAINS(%s, JSON_QUOTE(?))", column)
args = append(args, filter.Value)
case DBTypeSQLServer:
expr = fmt.Sprintf("? NOT IN (SELECT value FROM OPENJSON(%s))", column)
args = append(args, filter.Value)
case DBTypeSQLite:
expr = fmt.Sprintf("NOT EXISTS (SELECT 1 FROM json_each(%s) WHERE json_each.value = ?)", column)
args = append(args, filter.Value)
default:
return "", nil, fmt.Errorf("Array operations not supported for database type: %s", qb.dbType)
}
case OpArrayLength:
switch qb.dbType {
case DBTypePostgreSQL:
if lengthOption, ok := filter.Options["length"].(int); ok {
expr = fmt.Sprintf("array_length(%s, 1) = ?", column)
args = append(args, lengthOption)
} else {
return "", nil, fmt.Errorf("array_length operator requires 'length' option")
}
case DBTypeMySQL:
if lengthOption, ok := filter.Options["length"].(int); ok {
expr = fmt.Sprintf("JSON_LENGTH(%s) = ?", column)
args = append(args, lengthOption)
} else {
return "", nil, fmt.Errorf("array_length operator requires 'length' option")
}
case DBTypeSQLServer:
if lengthOption, ok := filter.Options["length"].(int); ok {
expr = fmt.Sprintf("(SELECT COUNT(*) FROM OPENJSON(%s)) = ?", column)
args = append(args, lengthOption)
} else {
return "", nil, fmt.Errorf("array_length operator requires 'length' option")
}
case DBTypeSQLite:
if lengthOption, ok := filter.Options["length"].(int); ok {
expr = fmt.Sprintf("json_array_length(%s) = ?", column)
args = append(args, lengthOption)
} else {
return "", nil, fmt.Errorf("array_length operator requires 'length' option")
}
default:
return "", nil, fmt.Errorf("Array operations not supported for database type: %s", qb.dbType)
}
default:
return "", nil, fmt.Errorf("unsupported array operator: %s", filter.Operator)
}
return expr, args, nil
}
// =============================================================================
// SECTION 6: EXECUTION METHODS (NEW)
// Metode untuk mengeksekusi query langsung dengan logging performa.
// =============================================================================
func (qb *QueryBuilder) ExecuteQuery(ctx context.Context, db *sqlx.DB, query DynamicQuery, dest interface{}) error {
// sql, args, err := qb.BuildQuery(query)
// if err != nil {
// return err
// }
// start := time.Now()
// err = db.SelectContext(ctx, dest, sql, args...)
// fmt.Printf("[DEBUG] Query executed in %v\n", time.Since(start))
// return err
sql, args, err := qb.BuildQuery(query)
if err != nil {
return err
}
// Set timeout if not already in context
if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout)
defer cancel()
}
start := time.Now()
// Check if dest is a pointer to a slice of maps
destValue := reflect.ValueOf(dest)
if destValue.Kind() != reflect.Ptr || destValue.IsNil() {
return fmt.Errorf("dest must be a non-nil pointer")
}
destElem := destValue.Elem()
if destElem.Kind() == reflect.Slice {
sliceType := destElem.Type().Elem()
if sliceType.Kind() == reflect.Map &&
sliceType.Key().Kind() == reflect.String &&
sliceType.Elem().Kind() == reflect.Interface {
// Handle slice of map[string]interface{}
rows, err := db.QueryxContext(ctx, sql, args...)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
row := make(map[string]interface{})
if err := rows.MapScan(row); err != nil {
return err
}
destElem.Set(reflect.Append(destElem, reflect.ValueOf(row)))
}
fmt.Printf("[DEBUG] Query executed in %v\n", time.Since(start))
return nil
}
}
// Default case: use SelectContext
err = db.SelectContext(ctx, dest, sql, args...)
if qb.enableQueryLogging {
fmt.Printf("[DEBUG] Query executed in %v\n", time.Since(start))
}
return err
}
func (qb *QueryBuilder) ExecuteQueryRow(ctx context.Context, db *sqlx.DB, query DynamicQuery, dest interface{}) error {
sql, args, err := qb.BuildQuery(query)
if err != nil {
return err
}
// Set timeout if not already in context
if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout)
defer cancel()
}
start := time.Now()
err = db.GetContext(ctx, dest, sql, args...)
if qb.enableQueryLogging {
fmt.Printf("[DEBUG] QueryRow executed in %v\n", time.Since(start))
}
return err
}
func (qb *QueryBuilder) ExecuteCount(ctx context.Context, db *sqlx.DB, query DynamicQuery) (int64, error) {
sql, args, err := qb.BuildCountQuery(query)
if err != nil {
return 0, err
}
// Set timeout if not already in context
if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout)
defer cancel()
}
var count int64
start := time.Now()
err = db.GetContext(ctx, &count, sql, args...)
if qb.enableQueryLogging {
fmt.Printf("[DEBUG] Count query executed in %v\n", time.Since(start))
}
return count, err
}
func (qb *QueryBuilder) ExecuteInsert(ctx context.Context, db *sqlx.DB, table string, data InsertData, returningColumns ...string) (sql.Result, error) {
// Security check for table name
if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[table] {
return nil, fmt.Errorf("disallowed table: %s", table)
}
sql, args, err := qb.BuildInsertQuery(table, data, returningColumns...)
if err != nil {
return nil, err
}
// Set timeout if not already in context
if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout)
defer cancel()
}
start := time.Now()
result, err := db.ExecContext(ctx, sql, args...)
if qb.enableQueryLogging {
fmt.Printf("[DEBUG] Insert query executed in %v\n", time.Since(start))
}
return result, err
}
func (qb *QueryBuilder) ExecuteUpdate(ctx context.Context, db *sqlx.DB, table string, updateData UpdateData, filters []FilterGroup, returningColumns ...string) (sql.Result, error) {
// Security check for table name
if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[table] {
return nil, fmt.Errorf("disallowed table: %s", table)
}
sql, args, err := qb.BuildUpdateQuery(table, updateData, filters, returningColumns...)
if err != nil {
return nil, err
}
// Set timeout if not already in context
if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout)
defer cancel()
}
start := time.Now()
result, err := db.ExecContext(ctx, sql, args...)
if qb.enableQueryLogging {
fmt.Printf("[DEBUG] Update query executed in %v\n", time.Since(start))
}
return result, err
}
func (qb *QueryBuilder) ExecuteDelete(ctx context.Context, db *sqlx.DB, table string, filters []FilterGroup, returningColumns ...string) (sql.Result, error) {
// Security check for table name
if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[table] {
return nil, fmt.Errorf("disallowed table: %s", table)
}
sql, args, err := qb.BuildDeleteQuery(table, filters, returningColumns...)
if err != nil {
return nil, err
}
// Set timeout if not already in context
if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout)
defer cancel()
}
start := time.Now()
result, err := db.ExecContext(ctx, sql, args...)
if qb.enableQueryLogging {
fmt.Printf("[DEBUG] Delete query executed in %v\n", time.Since(start))
}
return result, err
}
func (qb *QueryBuilder) ExecuteUpsert(ctx context.Context, db *sqlx.DB, table string, insertData InsertData, conflictColumns []string, updateColumns []string, returningColumns ...string) (sql.Result, error) {
// Security check for table name
if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[table] {
return nil, fmt.Errorf("disallowed table: %s", table)
}
sql, args, err := qb.BuildUpsertQuery(table, insertData, conflictColumns, updateColumns, returningColumns...)
if err != nil {
return nil, err
}
// Set timeout if not already in context
if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout)
defer cancel()
}
start := time.Now()
result, err := db.ExecContext(ctx, sql, args...)
if qb.enableQueryLogging {
fmt.Printf("[DEBUG] Upsert query executed in %v\n", time.Since(start))
}
return result, err
}
// --- Helper and Validation Methods ---
func (qb *QueryBuilder) buildGroupByColumns(fields []string) []string {
var groupCols []string
for _, field := range fields {
col := qb.validateAndEscapeColumn(field)
if col != "" {
groupCols = append(groupCols, col)
}
}
return groupCols
}
func (qb *QueryBuilder) parseArrayValue(value interface{}) []interface{} {
if value == nil {
return nil
}
if reflect.TypeOf(value).Kind() == reflect.Slice {
v := reflect.ValueOf(value)
result := make([]interface{}, v.Len())
for i := 0; i < v.Len(); i++ {
result[i] = v.Index(i).Interface()
}
return result
}
if str, ok := value.(string); ok {
if strings.Contains(str, ",") {
parts := strings.Split(str, ",")
result := make([]interface{}, len(parts))
for i, part := range parts {
result[i] = strings.TrimSpace(part)
}
return result
}
return []interface{}{str}
}
return []interface{}{value}
}
func (qb *QueryBuilder) validateAndEscapeColumn(field string) string {
if field == "" {
return ""
}
// Allow complex expressions like functions
if strings.Contains(field, "(") {
if qb.isValidExpression(field) {
return field // Don't escape complex expressions, assume they are safe
}
return ""
}
// Handle dotted column names like "table.column"
if strings.Contains(field, ".") {
if qb.isValidExpression(field) {
// Split on dot and escape each part
parts := strings.Split(field, ".")
var escapedParts []string
for _, part := range parts {
escapedParts = append(escapedParts, qb.escapeIdentifier(part))
}
return strings.Join(escapedParts, ".")
}
return ""
}
// Simple column name
if qb.allowedColumns != nil && !qb.allowedColumns[field] {
return ""
}
return qb.escapeIdentifier(field)
}
func (qb *QueryBuilder) isValidExpression(expr string) bool {
// This is a simplified check. A more robust solution might use a proper SQL parser library.
// For now, we allow alphanumeric, underscore, dots, parentheses, and common operators.
// For SQL Server, allow brackets [] and spaces for column names.
allowedChars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_.,() *-/[]"
for _, r := range expr {
if !strings.ContainsRune(allowedChars, r) {
return false
}
}
// Check for dangerous keywords
dangerousPatterns := []string{"--", "/*", "*/", "union", "select", "insert", "update", "delete", "drop", "alter", "create", "exec", "execute"}
lowerExpr := strings.ToLower(expr)
for _, pattern := range dangerousPatterns {
if strings.Contains(lowerExpr, pattern) {
return false
}
}
return true
}
func (qb *QueryBuilder) isValidFunctionName(name string) bool {
// Check if the function name is a valid SQL function
validFunctions := map[string]bool{
// Aggregate functions
"count": true, "sum": true, "avg": true, "min": true, "max": true,
// Window functions
"row_number": true, "rank": true, "dense_rank": true, "ntile": true,
"lag": true, "lead": true, "first_value": true, "last_value": true,
// JSON functions
"json_extract": true, "json_contains": true, "json_search": true,
"json_array": true, "json_object": true, "json_merge": true,
// Other functions
"concat": true, "substring": true, "upper": true, "lower": true,
"trim": true, "coalesce": true, "nullif": true, "isnull": true,
}
return validFunctions[strings.ToLower(name)]
}
func (qb *QueryBuilder) escapeColumnReference(col string) string {
parts := strings.Split(col, ".")
var escaped []string
for _, p := range parts {
if strings.HasPrefix(p, "[") && strings.HasSuffix(p, "]") {
escaped = append(escaped, p)
} else {
escaped = append(escaped, qb.escapeIdentifier(p))
}
}
return strings.Join(escaped, ".")
}
func (qb *QueryBuilder) escapeIdentifier(col string) string {
switch qb.dbType {
case DBTypePostgreSQL, DBTypeSQLite:
return fmt.Sprintf("\"%s\"", strings.ReplaceAll(col, "\"", "\"\""))
case DBTypeMySQL:
return fmt.Sprintf("`%s`", strings.ReplaceAll(col, "`", "``"))
case DBTypeSQLServer:
return fmt.Sprintf("[%s]", strings.ReplaceAll(col, "]", "]]"))
default:
return col
}
}
// checkForSqlInjectionInArgs checks for potential SQL injection patterns in query arguments
func (qb *QueryBuilder) checkForSqlInjectionInArgs(args []interface{}) error {
if !qb.enableSecurityChecks {
return nil
}
for _, arg := range args {
if str, ok := arg.(string); ok {
lowerStr := strings.ToLower(str)
// Check for dangerous patterns specifically in user input values
dangerousPatterns := []*regexp.Regexp{
regexp.MustCompile(`(?i)(union\s+select)`),
regexp.MustCompile(`(?i)(or\s+1\s*=\s*1)`),
regexp.MustCompile(`(?i)(and\s+true)`),
regexp.MustCompile(`(?i)(waitfor\s+delay)`),
regexp.MustCompile(`(?i)(benchmark|sleep)\s*\(`),
regexp.MustCompile(`(?i)(pg_sleep)\s*\(`),
regexp.MustCompile(`(?i)(load_file|into\s+outfile)`),
regexp.MustCompile(`(?i)(information_schema|sysobjects|syscolumns)`),
regexp.MustCompile(`(?i)(--|\/\*|\*\/)`),
}
for _, pattern := range dangerousPatterns {
if pattern.MatchString(lowerStr) {
return fmt.Errorf("potential SQL injection detected in query argument: pattern %s matched", pattern.String())
}
}
}
}
return nil
}
// checkForSqlInjectionInSQL checks for potential SQL injection patterns in the final SQL
func (qb *QueryBuilder) checkForSqlInjectionInSQL(sql string) error {
if !qb.enableSecurityChecks {
return nil
}
// Check for dangerous patterns in the final SQL
// But allow valid SQL keywords in their proper context
lowerSQL := strings.ToLower(sql)
// More specific patterns that actually indicate injection attempts
dangerousPatterns := []*regexp.Regexp{
regexp.MustCompile(`(?i)(union\s+select)`), // UNION followed by SELECT
regexp.MustCompile(`(?i)(select\s+.*\s+from\s+.*\s+where\s+.*\s+or\s+1\s*=\s*1)`), // Classic SQL injection
regexp.MustCompile(`(?i)(drop\s+table)`), // DROP TABLE
regexp.MustCompile(`(?i)(delete\s+from)`), // DELETE FROM
regexp.MustCompile(`(?i)(insert\s+into)`), // INSERT INTO
regexp.MustCompile(`(?i)(update\s+.*\s+set)`), // UPDATE SET
regexp.MustCompile(`(?i)(alter\s+table)`), // ALTER TABLE
regexp.MustCompile(`(?i)(create\s+table)`), // CREATE TABLE
regexp.MustCompile(`(?i)(exec\s*\(|execute\s*\()`), // EXEC/EXECUTE functions
regexp.MustCompile(`(?i)(--|\/\*|\*\/)`), // SQL comments
}
for _, pattern := range dangerousPatterns {
if pattern.MatchString(lowerSQL) {
return fmt.Errorf("potential SQL injection detected in SQL: pattern %s matched", pattern.String())
}
}
return nil
}
// --- Other Query Builders (Insert, Update, Delete, Upsert, Count) ---
// BuildCountQuery builds a count query
func (qb *QueryBuilder) BuildCountQuery(query DynamicQuery) (string, []interface{}, error) {
// For a count query, we don't need fields, joins, or unions.
// We only need FROM, WHERE, GROUP BY, HAVING.
countQuery := DynamicQuery{
From: query.From,
Aliases: query.Aliases,
Filters: query.Filters,
GroupBy: query.GroupBy,
Having: query.Having,
// Joins are important for count with filters on joined tables
Joins: query.Joins,
}
// Build the base query for the count using Squirrel's From and Join methods
fromClause := qb.buildFromClause(countQuery.From, countQuery.Aliases)
baseQuery := qb.sqlBuilder.Select("COUNT(*)").From(fromClause)
// Add JOINs using Squirrel's Join method
if len(countQuery.Joins) > 0 {
for _, join := range countQuery.Joins {
// Security check for joined table
if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[join.Table] {
return "", nil, fmt.Errorf("disallowed table in join: %s", join.Table)
}
joinType, tableWithAlias, onClause, joinArgs, err := qb.buildSingleJoinClause(join)
if err != nil {
return "", nil, err
}
joinStr := tableWithAlias + " ON " + onClause
switch strings.ToUpper(joinType) {
case "LEFT":
baseQuery = baseQuery.LeftJoin(joinStr, joinArgs...)
case "RIGHT":
baseQuery = baseQuery.RightJoin(joinStr, joinArgs...)
case "FULL":
baseQuery = baseQuery.Join("FULL JOIN "+joinStr, joinArgs...)
default:
baseQuery = baseQuery.Join(joinStr, joinArgs...)
}
}
}
if len(countQuery.Filters) > 0 {
whereClause, whereArgs, err := qb.BuildWhereClause(countQuery.Filters)
if err != nil {
return "", nil, err
}
baseQuery = baseQuery.Where(whereClause, whereArgs...)
}
if len(countQuery.GroupBy) > 0 {
baseQuery = baseQuery.GroupBy(qb.buildGroupByColumns(countQuery.GroupBy)...)
}
if len(countQuery.Having) > 0 {
havingClause, havingArgs, err := qb.BuildWhereClause(countQuery.Having)
if err != nil {
return "", nil, err
}
baseQuery = baseQuery.Having(havingClause, havingArgs...)
}
sql, args, err := baseQuery.ToSql()
if err != nil {
return "", nil, fmt.Errorf("failed to build COUNT query: %w", err)
}
if qb.enableQueryLogging {
fmt.Printf("[DEBUG] COUNT SQL query: %s\n", sql)
fmt.Printf("[DEBUG] COUNT query args: %v\n", args)
}
return sql, args, nil
}
// BuildInsertQuery builds an INSERT query
func (qb *QueryBuilder) BuildInsertQuery(table string, data InsertData, returningColumns ...string) (string, []interface{}, error) {
// Validate columns
for _, col := range data.Columns {
if qb.allowedColumns != nil && !qb.allowedColumns[col] {
return "", nil, fmt.Errorf("disallowed column: %s", col)
}
}
// Start with basic insert
insert := qb.sqlBuilder.Insert(table).Columns(data.Columns...).Values(data.Values...)
// Handle JSON values - we need to modify the insert statement
if len(data.JsonValues) > 0 {
// Create a new insert builder with all columns including JSON columns
allColumns := make([]string, len(data.Columns))
copy(allColumns, data.Columns)
allValues := make([]interface{}, len(data.Values))
copy(allValues, data.Values)
for col, val := range data.JsonValues {
allColumns = append(allColumns, col)
jsonVal, err := json.Marshal(val)
if err != nil {
return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err)
}
allValues = append(allValues, jsonVal)
}
insert = qb.sqlBuilder.Insert(table).Columns(allColumns...).Values(allValues...)
}
if len(returningColumns) > 0 {
if qb.dbType == DBTypePostgreSQL {
insert = insert.Suffix("RETURNING " + strings.Join(returningColumns, ", "))
} else {
return "", nil, fmt.Errorf("RETURNING not supported for database type: %s", qb.dbType)
}
}
sql, args, err := insert.ToSql()
if err != nil {
return "", nil, fmt.Errorf("failed to build INSERT query: %w", err)
}
return sql, args, nil
}
// BuildUpdateQuery builds an UPDATE query
func (qb *QueryBuilder) BuildUpdateQuery(table string, updateData UpdateData, filters []FilterGroup, returningColumns ...string) (string, []interface{}, error) {
// Validate columns
for _, col := range updateData.Columns {
if qb.allowedColumns != nil && !qb.allowedColumns[col] {
return "", nil, fmt.Errorf("disallowed column: %s", col)
}
}
// Start with basic update
update := qb.sqlBuilder.Update(table).SetMap(qb.buildSetMap(updateData))
// Handle JSON updates - we need to modify the update statement
if len(updateData.JsonUpdates) > 0 {
// Create a new set map including JSON updates
setMap := qb.buildSetMap(updateData)
for col, jsonUpdate := range updateData.JsonUpdates {
switch qb.dbType {
case DBTypePostgreSQL:
jsonVal, err := json.Marshal(jsonUpdate.Value)
if err != nil {
return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err)
}
// Use jsonb_set function for updating specific paths
setMap[col] = squirrel.Expr(fmt.Sprintf("jsonb_set(%s, '%s', ?)", qb.escapeIdentifier(col), jsonUpdate.Path), jsonVal)
case DBTypeMySQL:
jsonVal, err := json.Marshal(jsonUpdate.Value)
if err != nil {
return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err)
}
// Use JSON_SET function for updating specific paths
setMap[col] = squirrel.Expr(fmt.Sprintf("JSON_SET(%s, '%s', ?)", qb.escapeIdentifier(col), jsonUpdate.Path), jsonVal)
case DBTypeSQLServer:
jsonVal, err := json.Marshal(jsonUpdate.Value)
if err != nil {
return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err)
}
// Use JSON_MODIFY function for updating specific paths
setMap[col] = squirrel.Expr(fmt.Sprintf("JSON_MODIFY(%s, '%s', ?)", qb.escapeIdentifier(col), jsonUpdate.Path), jsonVal)
case DBTypeSQLite:
jsonVal, err := json.Marshal(jsonUpdate.Value)
if err != nil {
return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err)
}
// SQLite doesn't have a built-in JSON_SET function, so we need to use json_patch
setMap[col] = squirrel.Expr(fmt.Sprintf("json_patch(%s, ?)", qb.escapeIdentifier(col)), jsonVal)
}
}
update = qb.sqlBuilder.Update(table).SetMap(setMap)
}
if len(filters) > 0 {
whereClause, whereArgs, err := qb.BuildWhereClause(filters)
if err != nil {
return "", nil, err
}
update = update.Where(whereClause, whereArgs...)
}
if len(returningColumns) > 0 {
if qb.dbType == DBTypePostgreSQL {
update = update.Suffix("RETURNING " + strings.Join(returningColumns, ", "))
} else {
return "", nil, fmt.Errorf("RETURNING not supported for database type: %s", qb.dbType)
}
}
sql, args, err := update.ToSql()
if err != nil {
return "", nil, fmt.Errorf("failed to build UPDATE query: %w", err)
}
return sql, args, nil
}
// buildSetMap builds a map for SetMap from UpdateData
func (qb *QueryBuilder) buildSetMap(updateData UpdateData) map[string]interface{} {
setMap := make(map[string]interface{})
for i, col := range updateData.Columns {
setMap[col] = updateData.Values[i]
}
return setMap
}
// BuildDeleteQuery builds a DELETE query
func (qb *QueryBuilder) BuildDeleteQuery(table string, filters []FilterGroup, returningColumns ...string) (string, []interface{}, error) {
delete := qb.sqlBuilder.Delete(table)
if len(filters) > 0 {
whereClause, whereArgs, err := qb.BuildWhereClause(filters)
if err != nil {
return "", nil, err
}
delete = delete.Where(whereClause, whereArgs...)
}
if len(returningColumns) > 0 {
if qb.dbType == DBTypePostgreSQL {
delete = delete.Suffix("RETURNING " + strings.Join(returningColumns, ", "))
} else {
return "", nil, fmt.Errorf("RETURNING not supported for database type: %s", qb.dbType)
}
}
sql, args, err := delete.ToSql()
if err != nil {
return "", nil, fmt.Errorf("failed to build DELETE query: %w", err)
}
return sql, args, nil
}
// BuildUpsertQuery builds an UPSERT query
func (qb *QueryBuilder) BuildUpsertQuery(table string, insertData InsertData, conflictColumns []string, updateColumns []string, returningColumns ...string) (string, []interface{}, error) {
// Validate columns
for _, col := range insertData.Columns {
if qb.allowedColumns != nil && !qb.allowedColumns[col] {
return "", nil, fmt.Errorf("disallowed column: %s", col)
}
}
for _, col := range updateColumns {
if qb.allowedColumns != nil && !qb.allowedColumns[col] {
return "", nil, fmt.Errorf("disallowed column: %s", col)
}
}
switch qb.dbType {
case DBTypePostgreSQL:
// Handle JSON values for PostgreSQL
allColumns := make([]string, len(insertData.Columns))
copy(allColumns, insertData.Columns)
allValues := make([]interface{}, len(insertData.Values))
copy(allValues, insertData.Values)
for col, val := range insertData.JsonValues {
allColumns = append(allColumns, col)
jsonVal, err := json.Marshal(val)
if err != nil {
return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err)
}
allValues = append(allValues, jsonVal)
}
insert := qb.sqlBuilder.Insert(table).Columns(allColumns...).Values(allValues...)
if len(conflictColumns) > 0 {
conflictTarget := strings.Join(conflictColumns, ", ")
setClause := ""
for _, col := range updateColumns {
if setClause != "" {
setClause += ", "
}
setClause += fmt.Sprintf("%s = EXCLUDED.%s", qb.escapeIdentifier(col), qb.escapeIdentifier(col))
}
insert = insert.Suffix(fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET %s", conflictTarget, setClause))
}
if len(returningColumns) > 0 {
insert = insert.Suffix("RETURNING " + strings.Join(returningColumns, ", "))
}
sql, args, err := insert.ToSql()
if err != nil {
return "", nil, fmt.Errorf("failed to build UPSERT query: %w", err)
}
return sql, args, nil
case DBTypeMySQL:
// Handle JSON values for MySQL
allColumns := make([]string, len(insertData.Columns))
copy(allColumns, insertData.Columns)
allValues := make([]interface{}, len(insertData.Values))
copy(allValues, insertData.Values)
for col, val := range insertData.JsonValues {
allColumns = append(allColumns, col)
jsonVal, err := json.Marshal(val)
if err != nil {
return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err)
}
allValues = append(allValues, jsonVal)
}
insert := qb.sqlBuilder.Insert(table).Columns(allColumns...).Values(allValues...)
if len(updateColumns) > 0 {
setClause := ""
for _, col := range updateColumns {
if setClause != "" {
setClause += ", "
}
setClause += fmt.Sprintf("%s = VALUES(%s)", qb.escapeIdentifier(col), qb.escapeIdentifier(col))
}
insert = insert.Suffix(fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", setClause))
}
sql, args, err := insert.ToSql()
if err != nil {
return "", nil, fmt.Errorf("failed to build UPSERT query: %w", err)
}
return sql, args, nil
default:
return "", nil, fmt.Errorf("UPSERT not supported for database type: %s", qb.dbType)
}
}
// --- QueryParser (for parsing URL query strings) ---
type QueryParser struct {
defaultLimit int
maxLimit int
}
func NewQueryParser() *QueryParser {
return &QueryParser{defaultLimit: 10, maxLimit: 100}
}
func (qp *QueryParser) SetLimits(defaultLimit, maxLimit int) *QueryParser {
qp.defaultLimit = defaultLimit
qp.maxLimit = maxLimit
return qp
}
// ParseQuery parses URL query parameters into a DynamicQuery struct.
func (qp *QueryParser) ParseQuery(values url.Values, defaultTable string) (DynamicQuery, error) {
query := DynamicQuery{
From: defaultTable,
Limit: qp.defaultLimit,
Offset: 0,
}
// Parse fields
if fields := values.Get("fields"); fields != "" {
if fields == "*" {
query.Fields = []SelectField{{Expression: "*"}}
} else {
fieldList := strings.Split(fields, ",")
for _, field := range fieldList {
query.Fields = append(query.Fields, SelectField{Expression: strings.TrimSpace(field)})
}
}
} else {
query.Fields = []SelectField{{Expression: "*"}}
}
// Parse pagination
if limit := values.Get("limit"); limit != "" {
if l, err := strconv.Atoi(limit); err == nil && l > 0 && l <= qp.maxLimit {
query.Limit = l
}
}
if offset := values.Get("offset"); offset != "" {
if o, err := strconv.Atoi(offset); err == nil && o >= 0 {
query.Offset = o
}
}
// Parse filters
filters, err := qp.parseFilters(values)
if err != nil {
return query, err
}
query.Filters = filters
// Parse sorting
sorts, err := qp.parseSorting(values)
if err != nil {
return query, err
}
query.Sort = sorts
return query, nil
}
func (qp *QueryParser) parseFilters(values url.Values) ([]FilterGroup, error) {
filterMap := make(map[string]map[string]string)
for key, vals := range values {
if strings.HasPrefix(key, "filter[") && strings.HasSuffix(key, "]") {
parts := strings.Split(key[7:len(key)-1], "][")
if len(parts) == 2 {
column, operator := parts[0], parts[1]
if filterMap[column] == nil {
filterMap[column] = make(map[string]string)
}
if len(vals) > 0 {
filterMap[column][operator] = vals[0]
}
}
}
}
if len(filterMap) == 0 {
return nil, nil
}
var filters []DynamicFilter
for column, operators := range filterMap {
for opStr, value := range operators {
operator := FilterOperator(opStr)
var parsedValue interface{}
switch operator {
case OpIn, OpNotIn:
if value != "" {
parsedValue = strings.Split(value, ",")
}
case OpBetween, OpNotBetween:
if value != "" {
parts := strings.Split(value, ",")
if len(parts) == 2 {
parsedValue = []interface{}{strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])}
}
}
case OpNull, OpNotNull:
parsedValue = nil
default:
parsedValue = value
}
filters = append(filters, DynamicFilter{Column: column, Operator: operator, Value: parsedValue})
}
}
if len(filters) == 0 {
return nil, nil
}
return []FilterGroup{{Filters: filters, LogicOp: "AND"}}, nil
}
func (qp *QueryParser) parseSorting(values url.Values) ([]SortField, error) {
sortParam := values.Get("sort")
if sortParam == "" {
return nil, nil
}
var sorts []SortField
fields := strings.Split(sortParam, ",")
for _, field := range fields {
field = strings.TrimSpace(field)
if field == "" {
continue
}
order, column := "ASC", field
if strings.HasPrefix(field, "-") {
order = "DESC"
column = field[1:]
} else if strings.HasPrefix(field, "+") {
column = field[1:]
}
sorts = append(sorts, SortField{Column: column, Order: order})
}
return sorts, nil
}
// ParseQueryWithDefaultFields parses URL query parameters into a DynamicQuery struct with default fields.
func (qp *QueryParser) ParseQueryWithDefaultFields(values url.Values, defaultTable string, defaultFields []string) (DynamicQuery, error) {
query, err := qp.ParseQuery(values, defaultTable)
if err != nil {
return query, err
}
// If no fields specified, use default fields
if len(query.Fields) == 0 || (len(query.Fields) == 1 && query.Fields[0].Expression == "*") {
query.Fields = make([]SelectField, len(defaultFields))
for i, field := range defaultFields {
query.Fields[i] = SelectField{Expression: field}
}
}
return query, nil
}
// =============================================================================
// MONGODB QUERY BUILDER
// =============================================================================
// MongoQueryBuilder builds MongoDB queries from dynamic filters
type MongoQueryBuilder struct {
allowedFields map[string]bool // Security: only allow specified fields
allowedCollections map[string]bool // Security: only allow specified collections
// Security settings
enableSecurityChecks bool
maxAllowedDocs int
// Query logging
enableQueryLogging bool
// Connection timeout settings
queryTimeout time.Duration
}
// NewMongoQueryBuilder creates a new MongoDB query builder instance
func NewMongoQueryBuilder() *MongoQueryBuilder {
return &MongoQueryBuilder{
allowedFields: make(map[string]bool),
allowedCollections: make(map[string]bool),
enableSecurityChecks: true,
maxAllowedDocs: 10000,
enableQueryLogging: true,
queryTimeout: 30 * time.Second,
}
}
// SetSecurityOptions configures security settings
func (mqb *MongoQueryBuilder) SetSecurityOptions(enableChecks bool, maxDocs int) *MongoQueryBuilder {
mqb.enableSecurityChecks = enableChecks
mqb.maxAllowedDocs = maxDocs
return mqb
}
// SetAllowedFields sets the list of allowed fields for security
func (mqb *MongoQueryBuilder) SetAllowedFields(fields []string) *MongoQueryBuilder {
mqb.allowedFields = make(map[string]bool)
for _, field := range fields {
mqb.allowedFields[field] = true
}
return mqb
}
// SetAllowedCollections sets the list of allowed collections for security
func (mqb *MongoQueryBuilder) SetAllowedCollections(collections []string) *MongoQueryBuilder {
mqb.allowedCollections = make(map[string]bool)
for _, collection := range collections {
mqb.allowedCollections[collection] = true
}
return mqb
}
// SetQueryLogging enables or disables query logging
func (mqb *MongoQueryBuilder) SetQueryLogging(enable bool) *MongoQueryBuilder {
mqb.enableQueryLogging = enable
return mqb
}
// SetQueryTimeout sets the default query timeout
func (mqb *MongoQueryBuilder) SetQueryTimeout(timeout time.Duration) *MongoQueryBuilder {
mqb.queryTimeout = timeout
return mqb
}
// BuildFindQuery builds a MongoDB find query from DynamicQuery
func (mqb *MongoQueryBuilder) BuildFindQuery(query DynamicQuery) (bson.M, *options.FindOptions, error) {
filter := bson.M{}
findOptions := options.Find()
// Security check for limit
if mqb.enableSecurityChecks && query.Limit > mqb.maxAllowedDocs {
return nil, nil, fmt.Errorf("requested limit %d exceeds maximum allowed %d", query.Limit, mqb.maxAllowedDocs)
}
// Security check for collection name
if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[query.From] {
return nil, nil, fmt.Errorf("disallowed collection: %s", query.From)
}
// Build filter from DynamicQuery filters
if len(query.Filters) > 0 {
mongoFilter, err := mqb.buildFilter(query.Filters)
if err != nil {
return nil, nil, err
}
filter = mongoFilter
}
// Set projection from fields
if len(query.Fields) > 0 {
projection := bson.M{}
for _, field := range query.Fields {
if field.Expression == "*" {
// Include all fields
continue
}
fieldName := field.Expression
if field.Alias != "" {
fieldName = field.Alias
}
if mqb.allowedFields != nil && !mqb.allowedFields[fieldName] {
return nil, nil, fmt.Errorf("disallowed field: %s", fieldName)
}
projection[fieldName] = 1
}
if len(projection) > 0 {
findOptions.SetProjection(projection)
}
}
// Set sort
if len(query.Sort) > 0 {
sort := bson.D{}
for _, sortField := range query.Sort {
fieldName := sortField.Column
if mqb.allowedFields != nil && !mqb.allowedFields[fieldName] {
return nil, nil, fmt.Errorf("disallowed field: %s", fieldName)
}
order := 1 // ASC
if strings.ToUpper(sortField.Order) == "DESC" {
order = -1 // DESC
}
sort = append(sort, bson.E{Key: fieldName, Value: order})
}
findOptions.SetSort(sort)
}
// Set limit and offset
if query.Limit > 0 {
findOptions.SetLimit(int64(query.Limit))
}
if query.Offset > 0 {
findOptions.SetSkip(int64(query.Offset))
}
return filter, findOptions, nil
}
// BuildAggregateQuery builds a MongoDB aggregation pipeline from DynamicQuery
func (mqb *MongoQueryBuilder) BuildAggregateQuery(query DynamicQuery) ([]bson.D, error) {
pipeline := []bson.D{}
// Security check for collection name
if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[query.From] {
return nil, fmt.Errorf("disallowed collection: %s", query.From)
}
// Handle CTEs as stages in the pipeline
if len(query.CTEs) > 0 {
for _, cte := range query.CTEs {
// Security check for CTE collection
if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[cte.Query.From] {
return nil, fmt.Errorf("disallowed collection in CTE: %s", cte.Query.From)
}
subPipeline, err := mqb.BuildAggregateQuery(cte.Query)
if err != nil {
return nil, fmt.Errorf("failed to build CTE '%s': %w", cte.Name, err)
}
// Add $lookup stage for joins
if len(cte.Query.Joins) > 0 {
for _, join := range cte.Query.Joins {
// Security check for joined collection
if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[join.Table] {
return nil, fmt.Errorf("disallowed collection in join: %s", join.Table)
}
lookupStage := bson.D{
{Key: "$lookup", Value: bson.D{
{Key: "from", Value: join.Table},
{Key: "localField", Value: join.Alias},
{Key: "foreignField", Value: "_id"},
{Key: "as", Value: join.Alias},
}},
}
pipeline = append(pipeline, lookupStage)
}
}
// Add the sub-pipeline
pipeline = append(pipeline, subPipeline...)
}
}
// Match stage for filters
if len(query.Filters) > 0 {
filter, err := mqb.buildFilter(query.Filters)
if err != nil {
return nil, err
}
pipeline = append(pipeline, bson.D{{Key: "$match", Value: filter}})
}
// Group stage for GROUP BY
if len(query.GroupBy) > 0 {
groupID := bson.D{}
for _, field := range query.GroupBy {
if mqb.allowedFields != nil && !mqb.allowedFields[field] {
return nil, fmt.Errorf("disallowed field: %s", field)
}
groupID = append(groupID, bson.E{Key: field, Value: "$" + field})
}
groupStage := bson.D{
{Key: "$group", Value: bson.D{
{Key: "_id", Value: groupID},
}},
}
// Add any aggregations from fields
for _, field := range query.Fields {
if strings.Contains(field.Expression, "(") && strings.Contains(field.Expression, ")") {
// This is an aggregation function
funcName := strings.Split(field.Expression, "(")[0]
funcField := strings.TrimSuffix(strings.Split(field.Expression, "(")[1], ")")
if mqb.allowedFields != nil && !mqb.allowedFields[funcField] {
return nil, fmt.Errorf("disallowed field: %s", funcField)
}
switch strings.ToLower(funcName) {
case "count":
groupStage = append(groupStage, bson.E{
Key: field.Alias, Value: bson.D{{Key: "$sum", Value: 1}},
})
case "sum":
groupStage = append(groupStage, bson.E{
Key: field.Alias, Value: bson.D{{Key: "$sum", Value: "$" + funcField}},
})
case "avg":
groupStage = append(groupStage, bson.E{
Key: field.Alias, Value: bson.D{{Key: "$avg", Value: "$" + funcField}},
})
case "min":
groupStage = append(groupStage, bson.E{
Key: field.Alias, Value: bson.D{{Key: "$min", Value: "$" + funcField}},
})
case "max":
groupStage = append(groupStage, bson.E{
Key: field.Alias, Value: bson.D{{Key: "$max", Value: "$" + funcField}},
})
}
}
}
pipeline = append(pipeline, groupStage)
}
// Sort stage
if len(query.Sort) > 0 {
sort := bson.D{}
for _, sortField := range query.Sort {
fieldName := sortField.Column
if mqb.allowedFields != nil && !mqb.allowedFields[fieldName] {
return nil, fmt.Errorf("disallowed field: %s", fieldName)
}
order := 1 // ASC
if strings.ToUpper(sortField.Order) == "DESC" {
order = -1 // DESC
}
sort = append(sort, bson.E{Key: fieldName, Value: order})
}
pipeline = append(pipeline, bson.D{{Key: "$sort", Value: sort}})
}
// Skip and limit stages
if query.Offset > 0 {
pipeline = append(pipeline, bson.D{{Key: "$skip", Value: query.Offset}})
}
if query.Limit > 0 {
pipeline = append(pipeline, bson.D{{Key: "$limit", Value: query.Limit}})
}
return pipeline, nil
}
// buildFilter builds a MongoDB filter from FilterGroups
func (mqb *MongoQueryBuilder) buildFilter(filterGroups []FilterGroup) (bson.M, error) {
if len(filterGroups) == 0 {
return bson.M{}, nil
}
var result bson.M
var err error
for i, group := range filterGroups {
if len(group.Filters) == 0 {
continue
}
groupFilter, err := mqb.buildFilterGroup(group)
if err != nil {
return nil, err
}
if i == 0 {
result = groupFilter
} else {
logicOp := "$and"
if group.LogicOp != "" {
switch strings.ToUpper(group.LogicOp) {
case "OR":
logicOp = "$or"
}
}
result = bson.M{logicOp: []bson.M{result, groupFilter}}
}
}
return result, err
}
// buildFilterGroup builds a filter for a single filter group
func (mqb *MongoQueryBuilder) buildFilterGroup(group FilterGroup) (bson.M, error) {
var filters []bson.M
logicOp := "$and"
if group.LogicOp != "" {
switch strings.ToUpper(group.LogicOp) {
case "OR":
logicOp = "$or"
}
}
for _, filter := range group.Filters {
fieldFilter, err := mqb.buildFilterCondition(filter)
if err != nil {
return nil, err
}
filters = append(filters, fieldFilter)
}
if len(filters) == 1 {
return filters[0], nil
}
return bson.M{logicOp: filters}, nil
}
// buildFilterCondition builds a single filter condition for MongoDB
func (mqb *MongoQueryBuilder) buildFilterCondition(filter DynamicFilter) (bson.M, error) {
field := filter.Column
if mqb.allowedFields != nil && !mqb.allowedFields[field] {
return nil, fmt.Errorf("disallowed field: %s", field)
}
switch filter.Operator {
case OpEqual:
return bson.M{field: filter.Value}, nil
case OpNotEqual:
return bson.M{field: bson.M{"$ne": filter.Value}}, nil
case OpIn:
values := mqb.parseArrayValue(filter.Value)
return bson.M{field: bson.M{"$in": values}}, nil
case OpNotIn:
values := mqb.parseArrayValue(filter.Value)
return bson.M{field: bson.M{"$nin": values}}, nil
case OpGreaterThan:
return bson.M{field: bson.M{"$gt": filter.Value}}, nil
case OpGreaterThanEqual:
return bson.M{field: bson.M{"$gte": filter.Value}}, nil
case OpLessThan:
return bson.M{field: bson.M{"$lt": filter.Value}}, nil
case OpLessThanEqual:
return bson.M{field: bson.M{"$lte": filter.Value}}, nil
case OpLike:
// Convert SQL LIKE to MongoDB regex
pattern := filter.Value.(string)
pattern = strings.ReplaceAll(pattern, "%", ".*")
pattern = strings.ReplaceAll(pattern, "_", ".")
return bson.M{field: bson.M{"$regex": pattern, "$options": "i"}}, nil
case OpILike:
// Case-insensitive like
pattern := filter.Value.(string)
pattern = strings.ReplaceAll(pattern, "%", ".*")
pattern = strings.ReplaceAll(pattern, "_", ".")
return bson.M{field: bson.M{"$regex": pattern, "$options": "i"}}, nil
case OpContains:
// Contains substring
pattern := filter.Value.(string)
return bson.M{field: bson.M{"$regex": pattern, "$options": "i"}}, nil
case OpNotContains:
// Does not contain substring
pattern := filter.Value.(string)
return bson.M{field: bson.M{"$not": bson.M{"$regex": pattern, "$options": "i"}}}, nil
case OpStartsWith:
// Starts with
pattern := filter.Value.(string)
return bson.M{field: bson.M{"$regex": "^" + pattern, "$options": "i"}}, nil
case OpEndsWith:
// Ends with
pattern := filter.Value.(string)
return bson.M{field: bson.M{"$regex": pattern + "$", "$options": "i"}}, nil
case OpNull:
return bson.M{field: bson.M{"$exists": false}}, nil
case OpNotNull:
return bson.M{field: bson.M{"$exists": true}}, nil
case OpJsonContains:
// JSON contains
return bson.M{field: bson.M{"$elemMatch": filter.Value}}, nil
case OpJsonNotContains:
// JSON does not contain
return bson.M{field: bson.M{"$not": bson.M{"$elemMatch": filter.Value}}}, nil
case OpJsonExists:
// JSON path exists
return bson.M{field + "." + filter.Options["path"].(string): bson.M{"$exists": true}}, nil
case OpJsonNotExists:
// JSON path does not exist
return bson.M{field + "." + filter.Options["path"].(string): bson.M{"$exists": false}}, nil
case OpArrayContains:
// Array contains
return bson.M{field: bson.M{"$elemMatch": bson.M{"$eq": filter.Value}}}, nil
case OpArrayNotContains:
// Array does not contain
return bson.M{field: bson.M{"$not": bson.M{"$elemMatch": bson.M{"$eq": filter.Value}}}}, nil
case OpArrayLength:
// Array length
if lengthOption, ok := filter.Options["length"].(int); ok {
return bson.M{field: bson.M{"$size": lengthOption}}, nil
}
return nil, fmt.Errorf("array_length operator requires 'length' option")
default:
return nil, fmt.Errorf("unsupported operator: %s", filter.Operator)
}
}
// parseArrayValue parses an array value for MongoDB
func (mqb *MongoQueryBuilder) parseArrayValue(value interface{}) []interface{} {
if value == nil {
return nil
}
if reflect.TypeOf(value).Kind() == reflect.Slice {
v := reflect.ValueOf(value)
result := make([]interface{}, v.Len())
for i := 0; i < v.Len(); i++ {
result[i] = v.Index(i).Interface()
}
return result
}
if str, ok := value.(string); ok {
if strings.Contains(str, ",") {
parts := strings.Split(str, ",")
result := make([]interface{}, len(parts))
for i, part := range parts {
result[i] = strings.TrimSpace(part)
}
return result
}
return []interface{}{str}
}
return []interface{}{value}
}
// ExecuteFind executes a MongoDB find query
func (mqb *MongoQueryBuilder) ExecuteFind(ctx context.Context, collection *mongo.Collection, query DynamicQuery, dest interface{}) error {
// Security check for collection name
if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[collection.Name()] {
return fmt.Errorf("disallowed collection: %s", collection.Name())
}
filter, findOptions, err := mqb.BuildFindQuery(query)
if err != nil {
return err
}
// Set timeout if not already in context
if _, hasDeadline := ctx.Deadline(); !hasDeadline && mqb.queryTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, mqb.queryTimeout)
defer cancel()
}
start := time.Now()
cursor, err := collection.Find(ctx, filter, findOptions)
if err != nil {
return err
}
defer cursor.Close(ctx)
err = cursor.All(ctx, dest)
if mqb.enableQueryLogging {
fmt.Printf("[DEBUG] MongoDB Find executed in %v\n", time.Since(start))
}
return err
}
// ExecuteAggregate executes a MongoDB aggregation pipeline
func (mqb *MongoQueryBuilder) ExecuteAggregate(ctx context.Context, collection *mongo.Collection, query DynamicQuery, dest interface{}) error {
// Security check for collection name
if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[collection.Name()] {
return fmt.Errorf("disallowed collection: %s", collection.Name())
}
pipeline, err := mqb.BuildAggregateQuery(query)
if err != nil {
return err
}
// Set timeout if not already in context
if _, hasDeadline := ctx.Deadline(); !hasDeadline && mqb.queryTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, mqb.queryTimeout)
defer cancel()
}
start := time.Now()
cursor, err := collection.Aggregate(ctx, pipeline)
if err != nil {
return err
}
defer cursor.Close(ctx)
err = cursor.All(ctx, dest)
if mqb.enableQueryLogging {
fmt.Printf("[DEBUG] MongoDB Aggregate executed in %v\n", time.Since(start))
}
return err
}
// ExecuteCount executes a MongoDB count query
func (mqb *MongoQueryBuilder) ExecuteCount(ctx context.Context, collection *mongo.Collection, query DynamicQuery) (int64, error) {
// Security check for collection name
if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[collection.Name()] {
return 0, fmt.Errorf("disallowed collection: %s", collection.Name())
}
filter, _, err := mqb.BuildFindQuery(query)
if err != nil {
return 0, err
}
// Set timeout if not already in context
if _, hasDeadline := ctx.Deadline(); !hasDeadline && mqb.queryTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, mqb.queryTimeout)
defer cancel()
}
start := time.Now()
count, err := collection.CountDocuments(ctx, filter)
if mqb.enableQueryLogging {
fmt.Printf("[DEBUG] MongoDB Count executed in %v\n", time.Since(start))
}
return count, err
}
// ExecuteInsert executes a MongoDB insert operation
func (mqb *MongoQueryBuilder) ExecuteInsert(ctx context.Context, collection *mongo.Collection, data InsertData) (*mongo.InsertOneResult, error) {
// Security check for collection name
if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[collection.Name()] {
return nil, fmt.Errorf("disallowed collection: %s", collection.Name())
}
document := bson.M{}
for i, col := range data.Columns {
if mqb.allowedFields != nil && !mqb.allowedFields[col] {
return nil, fmt.Errorf("disallowed field: %s", col)
}
document[col] = data.Values[i]
}
// Handle JSON values
for col, val := range data.JsonValues {
if mqb.allowedFields != nil && !mqb.allowedFields[col] {
return nil, fmt.Errorf("disallowed field: %s", col)
}
document[col] = val
}
// Set timeout if not already in context
if _, hasDeadline := ctx.Deadline(); !hasDeadline && mqb.queryTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, mqb.queryTimeout)
defer cancel()
}
start := time.Now()
result, err := collection.InsertOne(ctx, document)
if mqb.enableQueryLogging {
fmt.Printf("[DEBUG] MongoDB Insert executed in %v\n", time.Since(start))
}
return result, err
}
// ExecuteUpdate executes a MongoDB update operation
func (mqb *MongoQueryBuilder) ExecuteUpdate(ctx context.Context, collection *mongo.Collection, updateData UpdateData, filters []FilterGroup) (*mongo.UpdateResult, error) {
// Security check for collection name
if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[collection.Name()] {
return nil, fmt.Errorf("disallowed collection: %s", collection.Name())
}
filter, err := mqb.buildFilter(filters)
if err != nil {
return nil, err
}
update := bson.M{"$set": bson.M{}}
for i, col := range updateData.Columns {
if mqb.allowedFields != nil && !mqb.allowedFields[col] {
return nil, fmt.Errorf("disallowed field: %s", col)
}
update["$set"].(bson.M)[col] = updateData.Values[i]
}
// Handle JSON updates
for col, jsonUpdate := range updateData.JsonUpdates {
if mqb.allowedFields != nil && !mqb.allowedFields[col] {
return nil, fmt.Errorf("disallowed field: %s", col)
}
// Use dot notation for nested JSON updates
update["$set"].(bson.M)[col+"."+jsonUpdate.Path] = jsonUpdate.Value
}
// Set timeout if not already in context
if _, hasDeadline := ctx.Deadline(); !hasDeadline && mqb.queryTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, mqb.queryTimeout)
defer cancel()
}
start := time.Now()
result, err := collection.UpdateMany(ctx, filter, update)
if mqb.enableQueryLogging {
fmt.Printf("[DEBUG] MongoDB Update executed in %v\n", time.Since(start))
}
return result, err
}
// ExecuteDelete executes a MongoDB delete operation
func (mqb *MongoQueryBuilder) ExecuteDelete(ctx context.Context, collection *mongo.Collection, filters []FilterGroup) (*mongo.DeleteResult, error) {
// Security check for collection name
if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[collection.Name()] {
return nil, fmt.Errorf("disallowed collection: %s", collection.Name())
}
filter, err := mqb.buildFilter(filters)
if err != nil {
return nil, err
}
// Set timeout if not already in context
if _, hasDeadline := ctx.Deadline(); !hasDeadline && mqb.queryTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, mqb.queryTimeout)
defer cancel()
}
start := time.Now()
result, err := collection.DeleteMany(ctx, filter)
if mqb.enableQueryLogging {
fmt.Printf("[DEBUG] MongoDB Delete executed in %v\n", time.Since(start))
}
return result, err
}