package utils import ( "context" "database/sql" "encoding/json" "fmt" "net/url" "reflect" "regexp" "strconv" "strings" "time" "github.com/Masterminds/squirrel" "github.com/jmoiron/sqlx" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) // DBType represents the type of database type DBType string const ( DBTypePostgreSQL DBType = "postgres" DBTypeMySQL DBType = "mysql" DBTypeSQLite DBType = "sqlite" DBTypeSQLServer DBType = "sqlserver" DBTypeMongoDB DBType = "mongodb" ) // FilterOperator represents supported filter operators type FilterOperator string const ( OpEqual FilterOperator = "_eq" OpNotEqual FilterOperator = "_neq" OpLike FilterOperator = "_like" OpILike FilterOperator = "_ilike" OpNotLike FilterOperator = "_nlike" OpNotILike FilterOperator = "_nilike" OpIn FilterOperator = "_in" OpNotIn FilterOperator = "_nin" OpGreaterThan FilterOperator = "_gt" OpGreaterThanEqual FilterOperator = "_gte" OpLessThan FilterOperator = "_lt" OpLessThanEqual FilterOperator = "_lte" OpBetween FilterOperator = "_between" OpNotBetween FilterOperator = "_nbetween" OpNull FilterOperator = "_null" OpNotNull FilterOperator = "_nnull" OpContains FilterOperator = "_contains" OpNotContains FilterOperator = "_ncontains" OpStartsWith FilterOperator = "_starts_with" OpEndsWith FilterOperator = "_ends_with" OpJsonContains FilterOperator = "_json_contains" OpJsonNotContains FilterOperator = "_json_ncontains" OpJsonExists FilterOperator = "_json_exists" OpJsonNotExists FilterOperator = "_json_nexists" OpJsonEqual FilterOperator = "_json_eq" OpJsonNotEqual FilterOperator = "_json_neq" OpArrayContains FilterOperator = "_array_contains" OpArrayNotContains FilterOperator = "_array_ncontains" OpArrayLength FilterOperator = "_array_length" ) // DynamicFilter represents a single filter condition type DynamicFilter struct { Column string `json:"column"` Operator FilterOperator `json:"operator"` Value interface{} `json:"value"` // Additional options for complex filters Options map[string]interface{} `json:"options,omitempty"` } // FilterGroup represents a group of filters with a logical operator (AND/OR) type FilterGroup struct { Filters []DynamicFilter `json:"filters"` LogicOp string `json:"logic_op"` // AND, OR } // SelectField represents a field in the SELECT clause, supporting expressions and aliases type SelectField struct { Expression string `json:"expression"` // e.g., "TMLogBarang.Nama", "COUNT(*)" Alias string `json:"alias"` // e.g., "obat_nama", "total_count" // Window function support WindowFunction *WindowFunction `json:"window_function,omitempty"` } // WindowFunction represents a window function with its configuration type WindowFunction struct { Function string `json:"function"` // e.g., "ROW_NUMBER", "RANK", "DENSE_RANK", "LEAD", "LAG" Over string `json:"over"` // PARTITION BY expression OrderBy string `json:"order_by"` // ORDER BY expression Frame string `json:"frame"` // ROWS/RANGE clause Alias string `json:"alias"` // Alias for the window function } // Join represents a JOIN clause type Join struct { Type string `json:"type"` // "INNER", "LEFT", "RIGHT", "FULL" Table string `json:"table"` // Table name to join Alias string `json:"alias"` // Table alias OnConditions FilterGroup `json:"on_conditions"` // Conditions for the ON clause // LATERAL JOIN support Lateral bool `json:"lateral,omitempty"` } // Union represents a UNION clause type Union struct { Type string `json:"type"` // "UNION", "UNION ALL" Query DynamicQuery `json:"query"` // The subquery to union with } // CTE (Common Table Expression) represents a WITH clause type CTE struct { Name string `json:"name"` // CTE alias name Query DynamicQuery `json:"query"` // The query defining the CTE // Recursive CTE support Recursive bool `json:"recursive,omitempty"` } // DynamicQuery represents the complete query structure type DynamicQuery struct { Fields []SelectField `json:"fields,omitempty"` From string `json:"from"` // Main table name Aliases string `json:"aliases"` // Main table alias Joins []Join `json:"joins,omitempty"` Filters []FilterGroup `json:"filters,omitempty"` GroupBy []string `json:"group_by,omitempty"` Having []FilterGroup `json:"having,omitempty"` Unions []Union `json:"unions,omitempty"` CTEs []CTE `json:"ctes,omitempty"` Sort []SortField `json:"sort,omitempty"` Limit int `json:"limit"` Offset int `json:"offset"` // Window function support WindowFunctions []WindowFunction `json:"window_functions,omitempty"` // JSON operations JsonOperations []JsonOperation `json:"json_operations,omitempty"` } // JsonOperation represents a JSON operation type JsonOperation struct { Type string `json:"type"` // "extract", "exists", "contains", etc. Column string `json:"column"` // JSON column Path string `json:"path"` // JSON path Value interface{} `json:"value,omitempty"` // Value for comparison Alias string `json:"alias,omitempty"` // Alias for the result } // SortField represents sorting configuration type SortField struct { Column string `json:"column"` Order string `json:"order"` // ASC, DESC } // UpdateData represents data for UPDATE operations type UpdateData struct { Columns []string `json:"columns"` Values []interface{} `json:"values"` // JSON update support JsonUpdates map[string]JsonUpdate `json:"json_updates,omitempty"` } // JsonUpdate represents a JSON update operation type JsonUpdate struct { Path string `json:"path"` // JSON path Value interface{} `json:"value"` // New value } // InsertData represents data for INSERT operations type InsertData struct { Columns []string `json:"columns"` Values []interface{} `json:"values"` // JSON insert support JsonValues map[string]interface{} `json:"json_values,omitempty"` } // QueryBuilder builds SQL queries from dynamic filters using squirrel type QueryBuilder struct { dbType DBType sqlBuilder squirrel.StatementBuilderType allowedColumns map[string]bool // Security: only allow specified columns allowedTables map[string]bool // Security: only allow specified tables // Security settings enableSecurityChecks bool maxAllowedRows int // SQL injection prevention patterns dangerousPatterns []*regexp.Regexp // Query logging enableQueryLogging bool // Connection timeout settings queryTimeout time.Duration } // NewQueryBuilder creates a new query builder instance for a specific database type func NewQueryBuilder(dbType DBType) *QueryBuilder { var placeholderFormat squirrel.PlaceholderFormat switch dbType { case DBTypePostgreSQL: placeholderFormat = squirrel.Dollar case DBTypeMySQL, DBTypeSQLite: placeholderFormat = squirrel.Question case DBTypeSQLServer: placeholderFormat = squirrel.AtP default: placeholderFormat = squirrel.Question } // Initialize dangerous patterns for SQL injection prevention dangerousPatterns := []*regexp.Regexp{ regexp.MustCompile(`(?i)(union|select|insert|update|delete|drop|alter|create|exec|execute)\s`), regexp.MustCompile(`(?i)(--|\/\*|\*\/)`), regexp.MustCompile(`(?i)(or|and)\s+1\s*=\s*1`), regexp.MustCompile(`(?i)(or|and)\s+true`), regexp.MustCompile(`(?i)(xp_|sp_)\w+`), // SQL Server extended procedures regexp.MustCompile(`(?i)(waitfor\s+delay)`), // SQL Server time-based attack regexp.MustCompile(`(?i)(benchmark|sleep)\s*\(`), // MySQL time-based attack regexp.MustCompile(`(?i)(pg_sleep)\s*\(`), // PostgreSQL time-based attack regexp.MustCompile(`(?i)(load_file|into\s+outfile)`), // File operations regexp.MustCompile(`(?i)(information_schema|sysobjects|syscolumns)`), // System tables } return &QueryBuilder{ dbType: dbType, sqlBuilder: squirrel.StatementBuilder.PlaceholderFormat(placeholderFormat), allowedColumns: make(map[string]bool), allowedTables: make(map[string]bool), enableSecurityChecks: true, maxAllowedRows: 10000, dangerousPatterns: dangerousPatterns, enableQueryLogging: true, queryTimeout: 30 * time.Second, } } // SetSecurityOptions configures security settings func (qb *QueryBuilder) SetSecurityOptions(enableChecks bool, maxRows int) *QueryBuilder { qb.enableSecurityChecks = enableChecks qb.maxAllowedRows = maxRows return qb } // SetAllowedColumns sets the list of allowed columns for security func (qb *QueryBuilder) SetAllowedColumns(columns []string) *QueryBuilder { qb.allowedColumns = make(map[string]bool) for _, col := range columns { qb.allowedColumns[col] = true } return qb } // SetAllowedTables sets the list of allowed tables for security func (qb *QueryBuilder) SetAllowedTables(tables []string) *QueryBuilder { qb.allowedTables = make(map[string]bool) for _, table := range tables { qb.allowedTables[table] = true } return qb } // SetQueryLogging enables or disables query logging func (qb *QueryBuilder) SetQueryLogging(enable bool) *QueryBuilder { qb.enableQueryLogging = enable return qb } // SetQueryTimeout sets the default query timeout func (qb *QueryBuilder) SetQueryTimeout(timeout time.Duration) *QueryBuilder { qb.queryTimeout = timeout return qb } // BuildQuery builds the complete SQL SELECT query with support for CTEs, JOINs, and UNIONs func (qb *QueryBuilder) BuildQuery(query DynamicQuery) (string, []interface{}, error) { var allArgs []interface{} var queryParts []string // Security check for limit if qb.enableSecurityChecks && query.Limit > qb.maxAllowedRows { return "", nil, fmt.Errorf("requested limit %d exceeds maximum allowed %d", query.Limit, qb.maxAllowedRows) } // Security check for table name if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[query.From] { return "", nil, fmt.Errorf("disallowed table: %s", query.From) } // 1. Build CTEs (WITH clause) if len(query.CTEs) > 0 { cteClause, cteArgs, err := qb.buildCTEClause(query.CTEs) if err != nil { return "", nil, err } queryParts = append(queryParts, cteClause) allArgs = append(allArgs, cteArgs...) } // 2. Build Main Query using Squirrel's From and Join methods fromClause := qb.buildFromClause(query.From, query.Aliases) selectFields := qb.buildSelectFields(query.Fields) // Start building the main query var mainQuery squirrel.SelectBuilder if len(query.WindowFunctions) > 0 || len(query.JsonOperations) > 0 { // We need to add window functions and JSON operations after initial select mainQuery = qb.sqlBuilder.Select(selectFields...).From(fromClause) } else { mainQuery = qb.sqlBuilder.Select(selectFields...).From(fromClause) } // Add JOINs using Squirrel's Join method if len(query.Joins) > 0 { for _, join := range query.Joins { // Security check for joined table if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[join.Table] { return "", nil, fmt.Errorf("disallowed table in join: %s", join.Table) } joinType, tableWithAlias, onClause, joinArgs, err := qb.buildSingleJoinClause(join) if err != nil { return "", nil, err } joinStr := tableWithAlias + " ON " + onClause switch strings.ToUpper(joinType) { case "LEFT": if join.Lateral { mainQuery = mainQuery.LeftJoin("LATERAL "+joinStr, joinArgs...) } else { mainQuery = mainQuery.LeftJoin(joinStr, joinArgs...) } case "RIGHT": mainQuery = mainQuery.RightJoin(joinStr, joinArgs...) case "FULL": mainQuery = mainQuery.Join("FULL JOIN "+joinStr, joinArgs...) default: if join.Lateral { mainQuery = mainQuery.Join("LATERAL "+joinStr, joinArgs...) } else { mainQuery = mainQuery.Join(joinStr, joinArgs...) } } } } // 4. Apply WHERE conditions if len(query.Filters) > 0 { whereClause, whereArgs, err := qb.BuildWhereClause(query.Filters) if err != nil { return "", nil, err } mainQuery = mainQuery.Where(whereClause, whereArgs...) } // 5. Apply GROUP BY if len(query.GroupBy) > 0 { mainQuery = mainQuery.GroupBy(qb.buildGroupByColumns(query.GroupBy)...) } // 6. Apply HAVING conditions if len(query.Having) > 0 { havingClause, havingArgs, err := qb.BuildWhereClause(query.Having) if err != nil { return "", nil, err } mainQuery = mainQuery.Having(havingClause, havingArgs...) } // 7. Apply ORDER BY if len(query.Sort) > 0 { for _, sort := range query.Sort { column := qb.validateAndEscapeColumn(sort.Column) if column == "" { continue } order := "ASC" if strings.ToUpper(sort.Order) == "DESC" { order = "DESC" } mainQuery = mainQuery.OrderBy(fmt.Sprintf("%s %s", column, order)) } } // 8. Apply window functions and JSON operations by modifying the SELECT clause if len(query.WindowFunctions) > 0 || len(query.JsonOperations) > 0 { // We need to rebuild the SELECT clause with window functions and JSON operations var finalSelectFields []string finalSelectFields = append(finalSelectFields, selectFields...) // Add window functions for _, wf := range query.WindowFunctions { windowFunc, err := qb.buildWindowFunction(wf) if err != nil { return "", nil, err } finalSelectFields = append(finalSelectFields, windowFunc) } // Add JSON operations for _, jo := range query.JsonOperations { jsonExpr, jsonArgs, err := qb.buildJsonOperation(jo) if err != nil { return "", nil, err } if jo.Alias != "" { jsonExpr += " AS " + qb.escapeIdentifier(jo.Alias) } finalSelectFields = append(finalSelectFields, jsonExpr) allArgs = append(allArgs, jsonArgs...) } // Rebuild the query with the complete SELECT clause mainQuery = qb.sqlBuilder.Select(finalSelectFields...).From(fromClause) // Re-apply all the other clauses if len(query.Joins) > 0 { for _, join := range query.Joins { // Security check for joined table if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[join.Table] { return "", nil, fmt.Errorf("disallowed table in join: %s", join.Table) } joinType, tableWithAlias, onClause, joinArgs, err := qb.buildSingleJoinClause(join) if err != nil { return "", nil, err } joinStr := tableWithAlias + " ON " + onClause switch strings.ToUpper(joinType) { case "LEFT": if join.Lateral { mainQuery = mainQuery.LeftJoin("LATERAL "+joinStr, joinArgs...) } else { mainQuery = mainQuery.LeftJoin(joinStr, joinArgs...) } case "RIGHT": mainQuery = mainQuery.RightJoin(joinStr, joinArgs...) case "FULL": mainQuery = mainQuery.Join("FULL JOIN "+joinStr, joinArgs...) default: if join.Lateral { mainQuery = mainQuery.Join("LATERAL "+joinStr, joinArgs...) } else { mainQuery = mainQuery.Join(joinStr, joinArgs...) } } } } if len(query.Filters) > 0 { whereClause, whereArgs, err := qb.BuildWhereClause(query.Filters) if err != nil { return "", nil, err } mainQuery = mainQuery.Where(whereClause, whereArgs...) } if len(query.GroupBy) > 0 { mainQuery = mainQuery.GroupBy(qb.buildGroupByColumns(query.GroupBy)...) } if len(query.Having) > 0 { havingClause, havingArgs, err := qb.BuildWhereClause(query.Having) if err != nil { return "", nil, err } mainQuery = mainQuery.Having(havingClause, havingArgs...) } if len(query.Sort) > 0 { for _, sort := range query.Sort { column := qb.validateAndEscapeColumn(sort.Column) if column == "" { continue } order := "ASC" if strings.ToUpper(sort.Order) == "DESC" { order = "DESC" } mainQuery = mainQuery.OrderBy(fmt.Sprintf("%s %s", column, order)) } } } // 9. Apply pagination with dialect-specific syntax if query.Limit > 0 { if qb.dbType == DBTypeSQLServer { // SQL Server requires ORDER BY for OFFSET FETCH if len(query.Sort) == 0 { mainQuery = mainQuery.OrderBy("(SELECT 1)") } mainQuery = mainQuery.Suffix(fmt.Sprintf("OFFSET %d ROWS FETCH NEXT %d ROWS ONLY", query.Offset, query.Limit)) } else { mainQuery = mainQuery.Limit(uint64(query.Limit)) if query.Offset > 0 { mainQuery = mainQuery.Offset(uint64(query.Offset)) } } } else if query.Offset > 0 && qb.dbType != DBTypeSQLServer { mainQuery = mainQuery.Offset(uint64(query.Offset)) } // Build final main query SQL sql, args, err := mainQuery.ToSql() if err != nil { return "", nil, fmt.Errorf("failed to build main query: %w", err) } queryParts = append(queryParts, sql) allArgs = append(allArgs, args...) // 10. Apply UNIONs if len(query.Unions) > 0 { unionClause, unionArgs, err := qb.buildUnionClause(query.Unions) if err != nil { return "", nil, err } queryParts = append(queryParts, unionClause) allArgs = append(allArgs, unionArgs...) } finalSQL := strings.Join(queryParts, " ") // Security check for dangerous patterns in user input values if qb.enableSecurityChecks { if err := qb.checkForSqlInjectionInArgs(allArgs); err != nil { return "", nil, err } } // Security check for dangerous patterns in the final SQL // if qb.enableSecurityChecks { // if err := qb.checkForSqlInjectionInSQL(finalSQL); err != nil { // return "", nil, err // } // } if qb.enableQueryLogging { fmt.Printf("[DEBUG BuilderQuery] Final SQL query: %s\n", finalSQL) fmt.Printf("[DEBUG] Query args: %v\n", allArgs) } return finalSQL, allArgs, nil } // buildWindowFunction builds a window function expression func (qb *QueryBuilder) buildWindowFunction(wf WindowFunction) (string, error) { if !qb.isValidFunctionName(wf.Function) { return "", fmt.Errorf("invalid window function name: %s", wf.Function) } windowExpr := fmt.Sprintf("%s() OVER (", wf.Function) if wf.Over != "" { windowExpr += fmt.Sprintf("PARTITION BY %s ", wf.Over) } if wf.OrderBy != "" { windowExpr += fmt.Sprintf("ORDER BY %s ", wf.OrderBy) } if wf.Frame != "" { windowExpr += wf.Frame } windowExpr += ")" if wf.Alias != "" { windowExpr += " AS " + qb.escapeIdentifier(wf.Alias) } return windowExpr, nil } // buildJsonOperation builds a JSON operation expression func (qb *QueryBuilder) buildJsonOperation(jo JsonOperation) (string, []interface{}, error) { column := qb.validateAndEscapeColumn(jo.Column) if column == "" { return "", nil, fmt.Errorf("invalid or disallowed column: %s", jo.Column) } path := jo.Path if path == "" { path = "$" } var expr string var args []interface{} switch strings.ToLower(jo.Type) { case "extract": switch qb.dbType { case DBTypePostgreSQL: expr = fmt.Sprintf("%s->>%s", column, qb.escapeJsonPath(path)) case DBTypeMySQL: expr = fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, path) case DBTypeSQLServer: expr = fmt.Sprintf("JSON_VALUE(%s, '%s')", column, qb.escapeSqlServerJsonPath(path)) case DBTypeSQLite: expr = fmt.Sprintf("json_extract(%s, '%s')", column, path) default: return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType) } case "exists": switch qb.dbType { case DBTypePostgreSQL: expr = fmt.Sprintf("jsonb_path_exists(%s, '%s')", column, path) case DBTypeMySQL: expr = fmt.Sprintf("JSON_CONTAINS_PATH(%s, 'one', '%s')", column, path) case DBTypeSQLServer: expr = fmt.Sprintf("JSON_VALUE(%s, '%s') IS NOT NULL", column, qb.escapeSqlServerJsonPath(path)) case DBTypeSQLite: expr = fmt.Sprintf("json_extract(%s, '%s') IS NOT NULL", column, path) default: return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType) } case "contains": switch qb.dbType { case DBTypePostgreSQL: expr = fmt.Sprintf("%s @> %s", column, "?") args = append(args, jo.Value) case DBTypeMySQL: expr = fmt.Sprintf("JSON_CONTAINS(%s, ?, '%s')", column, path) args = append(args, jo.Value) case DBTypeSQLServer: expr = fmt.Sprintf("JSON_VALUE(%s, '%s') = ?", column, qb.escapeSqlServerJsonPath(path)) args = append(args, jo.Value) case DBTypeSQLite: expr = fmt.Sprintf("json_extract(%s, '%s') = ?", column, path) args = append(args, jo.Value) default: return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType) } default: return "", nil, fmt.Errorf("unsupported JSON operation type: %s", jo.Type) } return expr, args, nil } // escapeJsonPath escapes a JSON path for PostgreSQL func (qb *QueryBuilder) escapeJsonPath(path string) string { // Simple implementation - in a real scenario, you'd need more sophisticated escaping return "'" + strings.ReplaceAll(path, "'", "''") + "'" } // escapeSqlServerJsonPath escapes a JSON path for SQL Server func (qb *QueryBuilder) escapeSqlServerJsonPath(path string) string { // Convert JSONPath to SQL Server format // $.path.to.property -> '$.path.to.property' if !strings.HasPrefix(path, "$") { path = "$." + path } return strings.ReplaceAll(path, ".", ".") } // buildCTEClause builds the WITH clause for Common Table Expressions func (qb *QueryBuilder) buildCTEClause(ctes []CTE) (string, []interface{}, error) { var cteParts []string var allArgs []interface{} hasRecursive := false for _, cte := range ctes { if cte.Recursive { hasRecursive = true break } } withClause := "WITH" if hasRecursive { withClause = "WITH RECURSIVE" } for _, cte := range ctes { subQuery, args, err := qb.BuildQuery(cte.Query) if err != nil { return "", nil, fmt.Errorf("failed to build CTE '%s': %w", cte.Name, err) } cteParts = append(cteParts, fmt.Sprintf("%s AS (%s)", qb.escapeIdentifier(cte.Name), subQuery)) allArgs = append(allArgs, args...) } return fmt.Sprintf("%s %s", withClause, strings.Join(cteParts, ", ")), allArgs, nil } // buildFromClause builds the FROM clause with optional alias func (qb *QueryBuilder) buildFromClause(table, alias string) string { fromClause := qb.escapeIdentifier(table) if alias != "" { fromClause += " " + qb.escapeIdentifier(alias) } return fromClause } // buildSingleJoinClause builds a single JOIN clause components func (qb *QueryBuilder) buildSingleJoinClause(join Join) (string, string, string, []interface{}, error) { joinType := strings.ToUpper(join.Type) if joinType == "" { joinType = "INNER" } table := qb.escapeIdentifier(join.Table) if join.Alias != "" { table += " " + qb.escapeIdentifier(join.Alias) } onClause, onArgs, err := qb.BuildWhereClause([]FilterGroup{join.OnConditions}) if err != nil { return "", "", "", nil, fmt.Errorf("failed to build ON clause for join on table %s: %w", join.Table, err) } return joinType, table, onClause, onArgs, nil } // buildUnionClause builds the UNION clause func (qb *QueryBuilder) buildUnionClause(unions []Union) (string, []interface{}, error) { var unionParts []string var allArgs []interface{} for _, union := range unions { subQuery, args, err := qb.BuildQuery(union.Query) if err != nil { return "", nil, fmt.Errorf("failed to build subquery for UNION: %w", err) } unionType := strings.ToUpper(union.Type) if unionType == "" { unionType = "UNION" } unionParts = append(unionParts, fmt.Sprintf("%s %s", unionType, subQuery)) allArgs = append(allArgs, args...) } return strings.Join(unionParts, " "), allArgs, nil } // buildSelectFields builds the SELECT fields from SelectField structs func (qb *QueryBuilder) buildSelectFields(fields []SelectField) []string { if len(fields) == 0 { return []string{"*"} } var selectedFields []string for _, field := range fields { expr := field.Expression if expr == "" { continue } // Basic validation for expression if !qb.isValidExpression(expr) { continue } // Handle window functions if field.WindowFunction != nil { windowFunc, err := qb.buildWindowFunction(*field.WindowFunction) if err != nil { continue } expr = windowFunc } if field.Alias != "" { selectedFields = append(selectedFields, fmt.Sprintf("%s AS %s", expr, qb.escapeIdentifier(field.Alias))) } else { selectedFields = append(selectedFields, expr) } } if len(selectedFields) == 0 { return []string{"*"} } return selectedFields } // BuildWhereClause builds WHERE/HAVING conditions from FilterGroups func (qb *QueryBuilder) BuildWhereClause(filterGroups []FilterGroup) (string, []interface{}, error) { if len(filterGroups) == 0 { return "", nil, nil } var conditions []string var allArgs []interface{} for i, group := range filterGroups { if len(group.Filters) == 0 { continue } groupCondition, groupArgs, err := qb.buildFilterGroup(group) if err != nil { return "", nil, err } if groupCondition != "" { if i > 0 { logicOp := "AND" if group.LogicOp != "" { logicOp = strings.ToUpper(group.LogicOp) } conditions = append(conditions, logicOp) } conditions = append(conditions, fmt.Sprintf("(%s)", groupCondition)) allArgs = append(allArgs, groupArgs...) } } return strings.Join(conditions, " "), allArgs, nil } // buildFilterGroup builds conditions for a single filter group func (qb *QueryBuilder) buildFilterGroup(group FilterGroup) (string, []interface{}, error) { var conditions []string var args []interface{} logicOp := "AND" if group.LogicOp != "" { logicOp = strings.ToUpper(group.LogicOp) } for i, filter := range group.Filters { condition, filterArgs, err := qb.buildFilterCondition(filter) if err != nil { return "", nil, err } if condition != "" { if i > 0 { conditions = append(conditions, logicOp) } conditions = append(conditions, condition) args = append(args, filterArgs...) } } return strings.Join(conditions, " "), args, nil } // buildFilterCondition builds a single filter condition with dialect-specific logic func (qb *QueryBuilder) buildFilterCondition(filter DynamicFilter) (string, []interface{}, error) { column := qb.validateAndEscapeColumn(filter.Column) if column == "" { return "", nil, fmt.Errorf("invalid or disallowed column: %s", filter.Column) } // Handle column-to-column comparison if valStr, ok := filter.Value.(string); ok && strings.Contains(valStr, ".") && qb.isValidExpression(valStr) && len(strings.Split(valStr, ".")) == 2 { escapedVal := qb.escapeColumnReference(valStr) switch filter.Operator { case OpEqual: return fmt.Sprintf("%s = %s", column, escapedVal), nil, nil case OpNotEqual: return fmt.Sprintf("%s <> %s", column, escapedVal), nil, nil case OpGreaterThan: return fmt.Sprintf("%s > %s", column, escapedVal), nil, nil case OpLessThan: return fmt.Sprintf("%s < %s", column, escapedVal), nil, nil } } // Handle JSON operations switch filter.Operator { case OpJsonContains, OpJsonNotContains, OpJsonExists, OpJsonNotExists, OpJsonEqual, OpJsonNotEqual: return qb.buildJsonFilterCondition(filter) case OpArrayContains, OpArrayNotContains, OpArrayLength: return qb.buildArrayFilterCondition(filter) } // Handle standard operators switch filter.Operator { case OpEqual: if filter.Value == nil { return fmt.Sprintf("%s IS NULL", column), nil, nil } return fmt.Sprintf("%s = ?", column), []interface{}{filter.Value}, nil case OpNotEqual: if filter.Value == nil { return fmt.Sprintf("%s IS NOT NULL", column), nil, nil } return fmt.Sprintf("%s <> ?", column), []interface{}{filter.Value}, nil case OpLike: if filter.Value == nil { return "", nil, nil } return fmt.Sprintf("%s LIKE ?", column), []interface{}{filter.Value}, nil case OpILike: if filter.Value == nil { return "", nil, nil } switch qb.dbType { case DBTypePostgreSQL, DBTypeSQLite: return fmt.Sprintf("%s ILIKE ?", column), []interface{}{filter.Value}, nil case DBTypeMySQL, DBTypeSQLServer: return fmt.Sprintf("LOWER(%s) LIKE LOWER(?)", column), []interface{}{filter.Value}, nil default: return fmt.Sprintf("%s LIKE ?", column), []interface{}{filter.Value}, nil } case OpIn, OpNotIn: values := qb.parseArrayValue(filter.Value) if len(values) == 0 { return "1=0", nil, nil } op := "IN" if filter.Operator == OpNotIn { op = "NOT IN" } placeholders := squirrel.Placeholders(len(values)) return fmt.Sprintf("%s %s (%s)", column, op, placeholders), values, nil case OpGreaterThan, OpGreaterThanEqual, OpLessThan, OpLessThanEqual: if filter.Value == nil { return "", nil, nil } op := strings.TrimPrefix(string(filter.Operator), "_") return fmt.Sprintf("%s %s ?", column, op), []interface{}{filter.Value}, nil case OpBetween, OpNotBetween: values := qb.parseArrayValue(filter.Value) if len(values) != 2 { return "", nil, fmt.Errorf("between operator requires exactly 2 values") } op := "BETWEEN" if filter.Operator == OpNotBetween { op = "NOT BETWEEN" } return fmt.Sprintf("%s %s ? AND ?", column, op), []interface{}{values[0], values[1]}, nil case OpNull: return fmt.Sprintf("%s IS NULL", column), nil, nil case OpNotNull: return fmt.Sprintf("%s IS NOT NULL", column), nil, nil case OpContains, OpNotContains, OpStartsWith, OpEndsWith: if filter.Value == nil { return "", nil, nil } var value string switch filter.Operator { case OpContains, OpNotContains: value = fmt.Sprintf("%%%v%%", filter.Value) case OpStartsWith: value = fmt.Sprintf("%v%%", filter.Value) case OpEndsWith: value = fmt.Sprintf("%%%v", filter.Value) } switch qb.dbType { case DBTypePostgreSQL, DBTypeSQLite: op := "ILIKE" if strings.Contains(string(filter.Operator), "Not") { op = "NOT ILIKE" } return fmt.Sprintf("%s %s ?", column, op), []interface{}{value}, nil case DBTypeMySQL, DBTypeSQLServer: op := "LIKE" if strings.Contains(string(filter.Operator), "Not") { op = "NOT LIKE" } return fmt.Sprintf("LOWER(%s) %s LOWER(?)", column, op), []interface{}{value}, nil default: op := "LIKE" if strings.Contains(string(filter.Operator), "Not") { op = "NOT LIKE" } return fmt.Sprintf("%s %s ?", column, op), []interface{}{value}, nil } default: return "", nil, fmt.Errorf("unsupported operator: %s", filter.Operator) } } // buildJsonFilterCondition builds a JSON filter condition func (qb *QueryBuilder) buildJsonFilterCondition(filter DynamicFilter) (string, []interface{}, error) { column := qb.validateAndEscapeColumn(filter.Column) if column == "" { return "", nil, fmt.Errorf("invalid or disallowed column: %s", filter.Column) } path := "$" if pathOption, ok := filter.Options["path"].(string); ok && pathOption != "" { path = pathOption } var expr string var args []interface{} switch filter.Operator { case OpJsonContains: switch qb.dbType { case DBTypePostgreSQL: expr = fmt.Sprintf("%s @> ?", column) args = append(args, filter.Value) case DBTypeMySQL: expr = fmt.Sprintf("JSON_CONTAINS(%s, ?, '%s')", column, path) args = append(args, filter.Value) case DBTypeSQLServer: expr = fmt.Sprintf("JSON_VALUE(%s, '%s') = ?", column, qb.escapeSqlServerJsonPath(path)) args = append(args, filter.Value) case DBTypeSQLite: expr = fmt.Sprintf("json_extract(%s, '%s') = ?", column, path) args = append(args, filter.Value) default: return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType) } case OpJsonNotContains: switch qb.dbType { case DBTypePostgreSQL: expr = fmt.Sprintf("NOT (%s @> ?)", column) args = append(args, filter.Value) case DBTypeMySQL: expr = fmt.Sprintf("NOT JSON_CONTAINS(%s, ?, '%s')", column, path) args = append(args, filter.Value) case DBTypeSQLServer: expr = fmt.Sprintf("JSON_VALUE(%s, '%s') <> ?", column, qb.escapeSqlServerJsonPath(path)) args = append(args, filter.Value) case DBTypeSQLite: expr = fmt.Sprintf("json_extract(%s, '%s') <> ?", column, path) args = append(args, filter.Value) default: return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType) } case OpJsonExists: switch qb.dbType { case DBTypePostgreSQL: expr = fmt.Sprintf("jsonb_path_exists(%s, '%s')", column, path) case DBTypeMySQL: expr = fmt.Sprintf("JSON_CONTAINS_PATH(%s, 'one', '%s')", column, path) case DBTypeSQLServer: expr = fmt.Sprintf("JSON_VALUE(%s, '%s') IS NOT NULL", column, qb.escapeSqlServerJsonPath(path)) case DBTypeSQLite: expr = fmt.Sprintf("json_extract(%s, '%s') IS NOT NULL", column, path) default: return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType) } case OpJsonNotExists: switch qb.dbType { case DBTypePostgreSQL: expr = fmt.Sprintf("NOT jsonb_path_exists(%s, '%s')", column, path) case DBTypeMySQL: expr = fmt.Sprintf("NOT JSON_CONTAINS_PATH(%s, 'one', '%s')", column, path) case DBTypeSQLServer: expr = fmt.Sprintf("JSON_VALUE(%s, '%s') IS NULL", column, qb.escapeSqlServerJsonPath(path)) case DBTypeSQLite: expr = fmt.Sprintf("json_extract(%s, '%s') IS NULL", column, path) default: return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType) } case OpJsonEqual: switch qb.dbType { case DBTypePostgreSQL: expr = fmt.Sprintf("%s->>%s = ?", column, qb.escapeJsonPath(path)) args = append(args, filter.Value) case DBTypeMySQL: expr = fmt.Sprintf("JSON_EXTRACT(%s, '%s') = ?", column, path) args = append(args, filter.Value) case DBTypeSQLServer: expr = fmt.Sprintf("JSON_VALUE(%s, '%s') = ?", column, qb.escapeSqlServerJsonPath(path)) args = append(args, filter.Value) case DBTypeSQLite: expr = fmt.Sprintf("json_extract(%s, '%s') = ?", column, path) args = append(args, filter.Value) default: return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType) } case OpJsonNotEqual: switch qb.dbType { case DBTypePostgreSQL: expr = fmt.Sprintf("%s->>%s <> ?", column, qb.escapeJsonPath(path)) args = append(args, filter.Value) case DBTypeMySQL: expr = fmt.Sprintf("JSON_EXTRACT(%s, '%s') <> ?", column, path) args = append(args, filter.Value) case DBTypeSQLServer: expr = fmt.Sprintf("JSON_VALUE(%s, '%s') <> ?", column, qb.escapeSqlServerJsonPath(path)) args = append(args, filter.Value) case DBTypeSQLite: expr = fmt.Sprintf("json_extract(%s, '%s') <> ?", column, path) args = append(args, filter.Value) default: return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType) } default: return "", nil, fmt.Errorf("unsupported JSON operator: %s", filter.Operator) } return expr, args, nil } // buildArrayFilterCondition builds an array filter condition func (qb *QueryBuilder) buildArrayFilterCondition(filter DynamicFilter) (string, []interface{}, error) { column := qb.validateAndEscapeColumn(filter.Column) if column == "" { return "", nil, fmt.Errorf("invalid or disallowed column: %s", filter.Column) } var expr string var args []interface{} switch filter.Operator { case OpArrayContains: switch qb.dbType { case DBTypePostgreSQL: expr = fmt.Sprintf("? = ANY(%s)", column) args = append(args, filter.Value) case DBTypeMySQL: expr = fmt.Sprintf("JSON_CONTAINS(%s, JSON_QUOTE(?))", column) args = append(args, filter.Value) case DBTypeSQLServer: expr = fmt.Sprintf("? IN (SELECT value FROM OPENJSON(%s))", column) args = append(args, filter.Value) case DBTypeSQLite: expr = fmt.Sprintf("EXISTS (SELECT 1 FROM json_each(%s) WHERE json_each.value = ?)", column) args = append(args, filter.Value) default: return "", nil, fmt.Errorf("Array operations not supported for database type: %s", qb.dbType) } case OpArrayNotContains: switch qb.dbType { case DBTypePostgreSQL: expr = fmt.Sprintf("? <> ALL(%s)", column) args = append(args, filter.Value) case DBTypeMySQL: expr = fmt.Sprintf("NOT JSON_CONTAINS(%s, JSON_QUOTE(?))", column) args = append(args, filter.Value) case DBTypeSQLServer: expr = fmt.Sprintf("? NOT IN (SELECT value FROM OPENJSON(%s))", column) args = append(args, filter.Value) case DBTypeSQLite: expr = fmt.Sprintf("NOT EXISTS (SELECT 1 FROM json_each(%s) WHERE json_each.value = ?)", column) args = append(args, filter.Value) default: return "", nil, fmt.Errorf("Array operations not supported for database type: %s", qb.dbType) } case OpArrayLength: switch qb.dbType { case DBTypePostgreSQL: if lengthOption, ok := filter.Options["length"].(int); ok { expr = fmt.Sprintf("array_length(%s, 1) = ?", column) args = append(args, lengthOption) } else { return "", nil, fmt.Errorf("array_length operator requires 'length' option") } case DBTypeMySQL: if lengthOption, ok := filter.Options["length"].(int); ok { expr = fmt.Sprintf("JSON_LENGTH(%s) = ?", column) args = append(args, lengthOption) } else { return "", nil, fmt.Errorf("array_length operator requires 'length' option") } case DBTypeSQLServer: if lengthOption, ok := filter.Options["length"].(int); ok { expr = fmt.Sprintf("(SELECT COUNT(*) FROM OPENJSON(%s)) = ?", column) args = append(args, lengthOption) } else { return "", nil, fmt.Errorf("array_length operator requires 'length' option") } case DBTypeSQLite: if lengthOption, ok := filter.Options["length"].(int); ok { expr = fmt.Sprintf("json_array_length(%s) = ?", column) args = append(args, lengthOption) } else { return "", nil, fmt.Errorf("array_length operator requires 'length' option") } default: return "", nil, fmt.Errorf("Array operations not supported for database type: %s", qb.dbType) } default: return "", nil, fmt.Errorf("unsupported array operator: %s", filter.Operator) } return expr, args, nil } // ============================================================================= // SECTION 6: EXECUTION METHODS (NEW) // Metode untuk mengeksekusi query langsung dengan logging performa. // ============================================================================= func (qb *QueryBuilder) ExecuteQuery(ctx context.Context, db *sqlx.DB, query DynamicQuery, dest interface{}) error { // sql, args, err := qb.BuildQuery(query) // if err != nil { // return err // } // start := time.Now() // err = db.SelectContext(ctx, dest, sql, args...) // fmt.Printf("[DEBUG] Query executed in %v\n", time.Since(start)) // return err sql, args, err := qb.BuildQuery(query) if err != nil { return err } // Set timeout if not already in context if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout) defer cancel() } start := time.Now() // Check if dest is a pointer to a slice of maps destValue := reflect.ValueOf(dest) if destValue.Kind() != reflect.Ptr || destValue.IsNil() { return fmt.Errorf("dest must be a non-nil pointer") } destElem := destValue.Elem() if destElem.Kind() == reflect.Slice { sliceType := destElem.Type().Elem() if sliceType.Kind() == reflect.Map && sliceType.Key().Kind() == reflect.String && sliceType.Elem().Kind() == reflect.Interface { // Handle slice of map[string]interface{} rows, err := db.QueryxContext(ctx, sql, args...) if err != nil { return err } defer rows.Close() for rows.Next() { row := make(map[string]interface{}) if err := rows.MapScan(row); err != nil { return err } destElem.Set(reflect.Append(destElem, reflect.ValueOf(row))) } fmt.Printf("[DEBUG] Query executed in %v\n", time.Since(start)) return nil } } // Default case: use SelectContext err = db.SelectContext(ctx, dest, sql, args...) if qb.enableQueryLogging { fmt.Printf("[DEBUG] Query executed in %v\n", time.Since(start)) } return err } func (qb *QueryBuilder) ExecuteQueryRow(ctx context.Context, db *sqlx.DB, query DynamicQuery, dest interface{}) error { sql, args, err := qb.BuildQuery(query) if err != nil { return err } // Set timeout if not already in context if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout) defer cancel() } start := time.Now() err = db.GetContext(ctx, dest, sql, args...) if qb.enableQueryLogging { fmt.Printf("[DEBUG] QueryRow executed in %v\n", time.Since(start)) } return err } func (qb *QueryBuilder) ExecuteCount(ctx context.Context, db *sqlx.DB, query DynamicQuery) (int64, error) { sql, args, err := qb.BuildCountQuery(query) if err != nil { return 0, err } // Set timeout if not already in context if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout) defer cancel() } var count int64 start := time.Now() err = db.GetContext(ctx, &count, sql, args...) if qb.enableQueryLogging { fmt.Printf("[DEBUG] Count query executed in %v\n", time.Since(start)) } return count, err } func (qb *QueryBuilder) ExecuteInsert(ctx context.Context, db *sqlx.DB, table string, data InsertData, returningColumns ...string) (sql.Result, error) { // Security check for table name if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[table] { return nil, fmt.Errorf("disallowed table: %s", table) } sql, args, err := qb.BuildInsertQuery(table, data, returningColumns...) if err != nil { return nil, err } // Set timeout if not already in context if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout) defer cancel() } start := time.Now() result, err := db.ExecContext(ctx, sql, args...) if qb.enableQueryLogging { fmt.Printf("[DEBUG] Insert query executed in %v\n", time.Since(start)) } return result, err } func (qb *QueryBuilder) ExecuteUpdate(ctx context.Context, db *sqlx.DB, table string, updateData UpdateData, filters []FilterGroup, returningColumns ...string) (sql.Result, error) { // Security check for table name if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[table] { return nil, fmt.Errorf("disallowed table: %s", table) } sql, args, err := qb.BuildUpdateQuery(table, updateData, filters, returningColumns...) if err != nil { return nil, err } // Set timeout if not already in context if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout) defer cancel() } start := time.Now() result, err := db.ExecContext(ctx, sql, args...) if qb.enableQueryLogging { fmt.Printf("[DEBUG] Update query executed in %v\n", time.Since(start)) } return result, err } func (qb *QueryBuilder) ExecuteDelete(ctx context.Context, db *sqlx.DB, table string, filters []FilterGroup, returningColumns ...string) (sql.Result, error) { // Security check for table name if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[table] { return nil, fmt.Errorf("disallowed table: %s", table) } sql, args, err := qb.BuildDeleteQuery(table, filters, returningColumns...) if err != nil { return nil, err } // Set timeout if not already in context if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout) defer cancel() } start := time.Now() result, err := db.ExecContext(ctx, sql, args...) if qb.enableQueryLogging { fmt.Printf("[DEBUG] Delete query executed in %v\n", time.Since(start)) } return result, err } func (qb *QueryBuilder) ExecuteUpsert(ctx context.Context, db *sqlx.DB, table string, insertData InsertData, conflictColumns []string, updateColumns []string, returningColumns ...string) (sql.Result, error) { // Security check for table name if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[table] { return nil, fmt.Errorf("disallowed table: %s", table) } sql, args, err := qb.BuildUpsertQuery(table, insertData, conflictColumns, updateColumns, returningColumns...) if err != nil { return nil, err } // Set timeout if not already in context if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout) defer cancel() } start := time.Now() result, err := db.ExecContext(ctx, sql, args...) if qb.enableQueryLogging { fmt.Printf("[DEBUG] Upsert query executed in %v\n", time.Since(start)) } return result, err } // --- Helper and Validation Methods --- func (qb *QueryBuilder) buildGroupByColumns(fields []string) []string { var groupCols []string for _, field := range fields { col := qb.validateAndEscapeColumn(field) if col != "" { groupCols = append(groupCols, col) } } return groupCols } func (qb *QueryBuilder) parseArrayValue(value interface{}) []interface{} { if value == nil { return nil } if reflect.TypeOf(value).Kind() == reflect.Slice { v := reflect.ValueOf(value) result := make([]interface{}, v.Len()) for i := 0; i < v.Len(); i++ { result[i] = v.Index(i).Interface() } return result } if str, ok := value.(string); ok { if strings.Contains(str, ",") { parts := strings.Split(str, ",") result := make([]interface{}, len(parts)) for i, part := range parts { result[i] = strings.TrimSpace(part) } return result } return []interface{}{str} } return []interface{}{value} } func (qb *QueryBuilder) validateAndEscapeColumn(field string) string { if field == "" { return "" } // Allow complex expressions like functions if strings.Contains(field, "(") { if qb.isValidExpression(field) { return field // Don't escape complex expressions, assume they are safe } return "" } // Handle dotted column names like "table.column" if strings.Contains(field, ".") { if qb.isValidExpression(field) { // Split on dot and escape each part parts := strings.Split(field, ".") var escapedParts []string for _, part := range parts { escapedParts = append(escapedParts, qb.escapeIdentifier(part)) } return strings.Join(escapedParts, ".") } return "" } // Simple column name if qb.allowedColumns != nil && !qb.allowedColumns[field] { return "" } return qb.escapeIdentifier(field) } func (qb *QueryBuilder) isValidExpression(expr string) bool { // This is a simplified check. A more robust solution might use a proper SQL parser library. // For now, we allow alphanumeric, underscore, dots, parentheses, and common operators. // For SQL Server, allow brackets [] and spaces for column names. allowedChars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_.,() *-/[]" for _, r := range expr { if !strings.ContainsRune(allowedChars, r) { return false } } // Check for dangerous keywords dangerousPatterns := []string{"--", "/*", "*/", "union", "select", "insert", "update", "delete", "drop", "alter", "create", "exec", "execute"} lowerExpr := strings.ToLower(expr) for _, pattern := range dangerousPatterns { if strings.Contains(lowerExpr, pattern) { return false } } return true } func (qb *QueryBuilder) isValidFunctionName(name string) bool { // Check if the function name is a valid SQL function validFunctions := map[string]bool{ // Aggregate functions "count": true, "sum": true, "avg": true, "min": true, "max": true, // Window functions "row_number": true, "rank": true, "dense_rank": true, "ntile": true, "lag": true, "lead": true, "first_value": true, "last_value": true, // JSON functions "json_extract": true, "json_contains": true, "json_search": true, "json_array": true, "json_object": true, "json_merge": true, // Other functions "concat": true, "substring": true, "upper": true, "lower": true, "trim": true, "coalesce": true, "nullif": true, "isnull": true, } return validFunctions[strings.ToLower(name)] } func (qb *QueryBuilder) escapeColumnReference(col string) string { parts := strings.Split(col, ".") var escaped []string for _, p := range parts { if strings.HasPrefix(p, "[") && strings.HasSuffix(p, "]") { escaped = append(escaped, p) } else { escaped = append(escaped, qb.escapeIdentifier(p)) } } return strings.Join(escaped, ".") } func (qb *QueryBuilder) escapeIdentifier(col string) string { switch qb.dbType { case DBTypePostgreSQL, DBTypeSQLite: return fmt.Sprintf("\"%s\"", strings.ReplaceAll(col, "\"", "\"\"")) case DBTypeMySQL: return fmt.Sprintf("`%s`", strings.ReplaceAll(col, "`", "``")) case DBTypeSQLServer: return fmt.Sprintf("[%s]", strings.ReplaceAll(col, "]", "]]")) default: return col } } // checkForSqlInjectionInArgs checks for potential SQL injection patterns in query arguments func (qb *QueryBuilder) checkForSqlInjectionInArgs(args []interface{}) error { if !qb.enableSecurityChecks { return nil } for _, arg := range args { if str, ok := arg.(string); ok { lowerStr := strings.ToLower(str) // Check for dangerous patterns specifically in user input values dangerousPatterns := []*regexp.Regexp{ regexp.MustCompile(`(?i)(union\s+select)`), regexp.MustCompile(`(?i)(or\s+1\s*=\s*1)`), regexp.MustCompile(`(?i)(and\s+true)`), regexp.MustCompile(`(?i)(waitfor\s+delay)`), regexp.MustCompile(`(?i)(benchmark|sleep)\s*\(`), regexp.MustCompile(`(?i)(pg_sleep)\s*\(`), regexp.MustCompile(`(?i)(load_file|into\s+outfile)`), regexp.MustCompile(`(?i)(information_schema|sysobjects|syscolumns)`), regexp.MustCompile(`(?i)(--|\/\*|\*\/)`), } for _, pattern := range dangerousPatterns { if pattern.MatchString(lowerStr) { return fmt.Errorf("potential SQL injection detected in query argument: pattern %s matched", pattern.String()) } } } } return nil } // checkForSqlInjectionInSQL checks for potential SQL injection patterns in the final SQL func (qb *QueryBuilder) checkForSqlInjectionInSQL(sql string) error { if !qb.enableSecurityChecks { return nil } // Check for dangerous patterns in the final SQL // But allow valid SQL keywords in their proper context lowerSQL := strings.ToLower(sql) // More specific patterns that actually indicate injection attempts dangerousPatterns := []*regexp.Regexp{ regexp.MustCompile(`(?i)(union\s+select)`), // UNION followed by SELECT regexp.MustCompile(`(?i)(select\s+.*\s+from\s+.*\s+where\s+.*\s+or\s+1\s*=\s*1)`), // Classic SQL injection regexp.MustCompile(`(?i)(drop\s+table)`), // DROP TABLE regexp.MustCompile(`(?i)(delete\s+from)`), // DELETE FROM regexp.MustCompile(`(?i)(insert\s+into)`), // INSERT INTO regexp.MustCompile(`(?i)(update\s+.*\s+set)`), // UPDATE SET regexp.MustCompile(`(?i)(alter\s+table)`), // ALTER TABLE regexp.MustCompile(`(?i)(create\s+table)`), // CREATE TABLE regexp.MustCompile(`(?i)(exec\s*\(|execute\s*\()`), // EXEC/EXECUTE functions regexp.MustCompile(`(?i)(--|\/\*|\*\/)`), // SQL comments } for _, pattern := range dangerousPatterns { if pattern.MatchString(lowerSQL) { return fmt.Errorf("potential SQL injection detected in SQL: pattern %s matched", pattern.String()) } } return nil } // --- Other Query Builders (Insert, Update, Delete, Upsert, Count) --- // BuildCountQuery builds a count query func (qb *QueryBuilder) BuildCountQuery(query DynamicQuery) (string, []interface{}, error) { // For a count query, we don't need fields, joins, or unions. // We only need FROM, WHERE, GROUP BY, HAVING. countQuery := DynamicQuery{ From: query.From, Aliases: query.Aliases, Filters: query.Filters, GroupBy: query.GroupBy, Having: query.Having, // Joins are important for count with filters on joined tables Joins: query.Joins, } // Build the base query for the count using Squirrel's From and Join methods fromClause := qb.buildFromClause(countQuery.From, countQuery.Aliases) baseQuery := qb.sqlBuilder.Select("COUNT(*)").From(fromClause) // Add JOINs using Squirrel's Join method if len(countQuery.Joins) > 0 { for _, join := range countQuery.Joins { // Security check for joined table if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[join.Table] { return "", nil, fmt.Errorf("disallowed table in join: %s", join.Table) } joinType, tableWithAlias, onClause, joinArgs, err := qb.buildSingleJoinClause(join) if err != nil { return "", nil, err } joinStr := tableWithAlias + " ON " + onClause switch strings.ToUpper(joinType) { case "LEFT": baseQuery = baseQuery.LeftJoin(joinStr, joinArgs...) case "RIGHT": baseQuery = baseQuery.RightJoin(joinStr, joinArgs...) case "FULL": baseQuery = baseQuery.Join("FULL JOIN "+joinStr, joinArgs...) default: baseQuery = baseQuery.Join(joinStr, joinArgs...) } } } if len(countQuery.Filters) > 0 { whereClause, whereArgs, err := qb.BuildWhereClause(countQuery.Filters) if err != nil { return "", nil, err } baseQuery = baseQuery.Where(whereClause, whereArgs...) } if len(countQuery.GroupBy) > 0 { baseQuery = baseQuery.GroupBy(qb.buildGroupByColumns(countQuery.GroupBy)...) } if len(countQuery.Having) > 0 { havingClause, havingArgs, err := qb.BuildWhereClause(countQuery.Having) if err != nil { return "", nil, err } baseQuery = baseQuery.Having(havingClause, havingArgs...) } sql, args, err := baseQuery.ToSql() if err != nil { return "", nil, fmt.Errorf("failed to build COUNT query: %w", err) } if qb.enableQueryLogging { fmt.Printf("[DEBUG] COUNT SQL query: %s\n", sql) fmt.Printf("[DEBUG] COUNT query args: %v\n", args) } return sql, args, nil } // BuildInsertQuery builds an INSERT query func (qb *QueryBuilder) BuildInsertQuery(table string, data InsertData, returningColumns ...string) (string, []interface{}, error) { // Validate columns for _, col := range data.Columns { if qb.allowedColumns != nil && !qb.allowedColumns[col] { return "", nil, fmt.Errorf("disallowed column: %s", col) } } // Start with basic insert insert := qb.sqlBuilder.Insert(table).Columns(data.Columns...).Values(data.Values...) // Handle JSON values - we need to modify the insert statement if len(data.JsonValues) > 0 { // Create a new insert builder with all columns including JSON columns allColumns := make([]string, len(data.Columns)) copy(allColumns, data.Columns) allValues := make([]interface{}, len(data.Values)) copy(allValues, data.Values) for col, val := range data.JsonValues { allColumns = append(allColumns, col) jsonVal, err := json.Marshal(val) if err != nil { return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err) } allValues = append(allValues, jsonVal) } insert = qb.sqlBuilder.Insert(table).Columns(allColumns...).Values(allValues...) } if len(returningColumns) > 0 { if qb.dbType == DBTypePostgreSQL { insert = insert.Suffix("RETURNING " + strings.Join(returningColumns, ", ")) } else { return "", nil, fmt.Errorf("RETURNING not supported for database type: %s", qb.dbType) } } sql, args, err := insert.ToSql() if err != nil { return "", nil, fmt.Errorf("failed to build INSERT query: %w", err) } return sql, args, nil } // BuildUpdateQuery builds an UPDATE query func (qb *QueryBuilder) BuildUpdateQuery(table string, updateData UpdateData, filters []FilterGroup, returningColumns ...string) (string, []interface{}, error) { // Validate columns for _, col := range updateData.Columns { if qb.allowedColumns != nil && !qb.allowedColumns[col] { return "", nil, fmt.Errorf("disallowed column: %s", col) } } // Start with basic update update := qb.sqlBuilder.Update(table).SetMap(qb.buildSetMap(updateData)) // Handle JSON updates - we need to modify the update statement if len(updateData.JsonUpdates) > 0 { // Create a new set map including JSON updates setMap := qb.buildSetMap(updateData) for col, jsonUpdate := range updateData.JsonUpdates { switch qb.dbType { case DBTypePostgreSQL: jsonVal, err := json.Marshal(jsonUpdate.Value) if err != nil { return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err) } // Use jsonb_set function for updating specific paths setMap[col] = squirrel.Expr(fmt.Sprintf("jsonb_set(%s, '%s', ?)", qb.escapeIdentifier(col), jsonUpdate.Path), jsonVal) case DBTypeMySQL: jsonVal, err := json.Marshal(jsonUpdate.Value) if err != nil { return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err) } // Use JSON_SET function for updating specific paths setMap[col] = squirrel.Expr(fmt.Sprintf("JSON_SET(%s, '%s', ?)", qb.escapeIdentifier(col), jsonUpdate.Path), jsonVal) case DBTypeSQLServer: jsonVal, err := json.Marshal(jsonUpdate.Value) if err != nil { return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err) } // Use JSON_MODIFY function for updating specific paths setMap[col] = squirrel.Expr(fmt.Sprintf("JSON_MODIFY(%s, '%s', ?)", qb.escapeIdentifier(col), jsonUpdate.Path), jsonVal) case DBTypeSQLite: jsonVal, err := json.Marshal(jsonUpdate.Value) if err != nil { return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err) } // SQLite doesn't have a built-in JSON_SET function, so we need to use json_patch setMap[col] = squirrel.Expr(fmt.Sprintf("json_patch(%s, ?)", qb.escapeIdentifier(col)), jsonVal) } } update = qb.sqlBuilder.Update(table).SetMap(setMap) } if len(filters) > 0 { whereClause, whereArgs, err := qb.BuildWhereClause(filters) if err != nil { return "", nil, err } update = update.Where(whereClause, whereArgs...) } if len(returningColumns) > 0 { if qb.dbType == DBTypePostgreSQL { update = update.Suffix("RETURNING " + strings.Join(returningColumns, ", ")) } else { return "", nil, fmt.Errorf("RETURNING not supported for database type: %s", qb.dbType) } } sql, args, err := update.ToSql() if err != nil { return "", nil, fmt.Errorf("failed to build UPDATE query: %w", err) } return sql, args, nil } // buildSetMap builds a map for SetMap from UpdateData func (qb *QueryBuilder) buildSetMap(updateData UpdateData) map[string]interface{} { setMap := make(map[string]interface{}) for i, col := range updateData.Columns { setMap[col] = updateData.Values[i] } return setMap } // BuildDeleteQuery builds a DELETE query func (qb *QueryBuilder) BuildDeleteQuery(table string, filters []FilterGroup, returningColumns ...string) (string, []interface{}, error) { delete := qb.sqlBuilder.Delete(table) if len(filters) > 0 { whereClause, whereArgs, err := qb.BuildWhereClause(filters) if err != nil { return "", nil, err } delete = delete.Where(whereClause, whereArgs...) } if len(returningColumns) > 0 { if qb.dbType == DBTypePostgreSQL { delete = delete.Suffix("RETURNING " + strings.Join(returningColumns, ", ")) } else { return "", nil, fmt.Errorf("RETURNING not supported for database type: %s", qb.dbType) } } sql, args, err := delete.ToSql() if err != nil { return "", nil, fmt.Errorf("failed to build DELETE query: %w", err) } return sql, args, nil } // BuildUpsertQuery builds an UPSERT query func (qb *QueryBuilder) BuildUpsertQuery(table string, insertData InsertData, conflictColumns []string, updateColumns []string, returningColumns ...string) (string, []interface{}, error) { // Validate columns for _, col := range insertData.Columns { if qb.allowedColumns != nil && !qb.allowedColumns[col] { return "", nil, fmt.Errorf("disallowed column: %s", col) } } for _, col := range updateColumns { if qb.allowedColumns != nil && !qb.allowedColumns[col] { return "", nil, fmt.Errorf("disallowed column: %s", col) } } switch qb.dbType { case DBTypePostgreSQL: // Handle JSON values for PostgreSQL allColumns := make([]string, len(insertData.Columns)) copy(allColumns, insertData.Columns) allValues := make([]interface{}, len(insertData.Values)) copy(allValues, insertData.Values) for col, val := range insertData.JsonValues { allColumns = append(allColumns, col) jsonVal, err := json.Marshal(val) if err != nil { return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err) } allValues = append(allValues, jsonVal) } insert := qb.sqlBuilder.Insert(table).Columns(allColumns...).Values(allValues...) if len(conflictColumns) > 0 { conflictTarget := strings.Join(conflictColumns, ", ") setClause := "" for _, col := range updateColumns { if setClause != "" { setClause += ", " } setClause += fmt.Sprintf("%s = EXCLUDED.%s", qb.escapeIdentifier(col), qb.escapeIdentifier(col)) } insert = insert.Suffix(fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET %s", conflictTarget, setClause)) } if len(returningColumns) > 0 { insert = insert.Suffix("RETURNING " + strings.Join(returningColumns, ", ")) } sql, args, err := insert.ToSql() if err != nil { return "", nil, fmt.Errorf("failed to build UPSERT query: %w", err) } return sql, args, nil case DBTypeMySQL: // Handle JSON values for MySQL allColumns := make([]string, len(insertData.Columns)) copy(allColumns, insertData.Columns) allValues := make([]interface{}, len(insertData.Values)) copy(allValues, insertData.Values) for col, val := range insertData.JsonValues { allColumns = append(allColumns, col) jsonVal, err := json.Marshal(val) if err != nil { return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err) } allValues = append(allValues, jsonVal) } insert := qb.sqlBuilder.Insert(table).Columns(allColumns...).Values(allValues...) if len(updateColumns) > 0 { setClause := "" for _, col := range updateColumns { if setClause != "" { setClause += ", " } setClause += fmt.Sprintf("%s = VALUES(%s)", qb.escapeIdentifier(col), qb.escapeIdentifier(col)) } insert = insert.Suffix(fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", setClause)) } sql, args, err := insert.ToSql() if err != nil { return "", nil, fmt.Errorf("failed to build UPSERT query: %w", err) } return sql, args, nil default: return "", nil, fmt.Errorf("UPSERT not supported for database type: %s", qb.dbType) } } // --- QueryParser (for parsing URL query strings) --- type QueryParser struct { defaultLimit int maxLimit int } func NewQueryParser() *QueryParser { return &QueryParser{defaultLimit: 10, maxLimit: 100} } func (qp *QueryParser) SetLimits(defaultLimit, maxLimit int) *QueryParser { qp.defaultLimit = defaultLimit qp.maxLimit = maxLimit return qp } // ParseQuery parses URL query parameters into a DynamicQuery struct. func (qp *QueryParser) ParseQuery(values url.Values, defaultTable string) (DynamicQuery, error) { query := DynamicQuery{ From: defaultTable, Limit: qp.defaultLimit, Offset: 0, } // Parse fields if fields := values.Get("fields"); fields != "" { if fields == "*" { query.Fields = []SelectField{{Expression: "*"}} } else { fieldList := strings.Split(fields, ",") for _, field := range fieldList { query.Fields = append(query.Fields, SelectField{Expression: strings.TrimSpace(field)}) } } } else { query.Fields = []SelectField{{Expression: "*"}} } // Parse pagination if limit := values.Get("limit"); limit != "" { if l, err := strconv.Atoi(limit); err == nil && l > 0 && l <= qp.maxLimit { query.Limit = l } } if offset := values.Get("offset"); offset != "" { if o, err := strconv.Atoi(offset); err == nil && o >= 0 { query.Offset = o } } // Parse filters filters, err := qp.parseFilters(values) if err != nil { return query, err } query.Filters = filters // Parse sorting sorts, err := qp.parseSorting(values) if err != nil { return query, err } query.Sort = sorts return query, nil } func (qp *QueryParser) parseFilters(values url.Values) ([]FilterGroup, error) { filterMap := make(map[string]map[string]string) for key, vals := range values { if strings.HasPrefix(key, "filter[") && strings.HasSuffix(key, "]") { parts := strings.Split(key[7:len(key)-1], "][") if len(parts) == 2 { column, operator := parts[0], parts[1] if filterMap[column] == nil { filterMap[column] = make(map[string]string) } if len(vals) > 0 { filterMap[column][operator] = vals[0] } } } } if len(filterMap) == 0 { return nil, nil } var filters []DynamicFilter for column, operators := range filterMap { for opStr, value := range operators { operator := FilterOperator(opStr) var parsedValue interface{} switch operator { case OpIn, OpNotIn: if value != "" { parsedValue = strings.Split(value, ",") } case OpBetween, OpNotBetween: if value != "" { parts := strings.Split(value, ",") if len(parts) == 2 { parsedValue = []interface{}{strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])} } } case OpNull, OpNotNull: parsedValue = nil default: parsedValue = value } filters = append(filters, DynamicFilter{Column: column, Operator: operator, Value: parsedValue}) } } if len(filters) == 0 { return nil, nil } return []FilterGroup{{Filters: filters, LogicOp: "AND"}}, nil } func (qp *QueryParser) parseSorting(values url.Values) ([]SortField, error) { sortParam := values.Get("sort") if sortParam == "" { return nil, nil } var sorts []SortField fields := strings.Split(sortParam, ",") for _, field := range fields { field = strings.TrimSpace(field) if field == "" { continue } order, column := "ASC", field if strings.HasPrefix(field, "-") { order = "DESC" column = field[1:] } else if strings.HasPrefix(field, "+") { column = field[1:] } sorts = append(sorts, SortField{Column: column, Order: order}) } return sorts, nil } // ParseQueryWithDefaultFields parses URL query parameters into a DynamicQuery struct with default fields. func (qp *QueryParser) ParseQueryWithDefaultFields(values url.Values, defaultTable string, defaultFields []string) (DynamicQuery, error) { query, err := qp.ParseQuery(values, defaultTable) if err != nil { return query, err } // If no fields specified, use default fields if len(query.Fields) == 0 || (len(query.Fields) == 1 && query.Fields[0].Expression == "*") { query.Fields = make([]SelectField, len(defaultFields)) for i, field := range defaultFields { query.Fields[i] = SelectField{Expression: field} } } return query, nil } // ============================================================================= // MONGODB QUERY BUILDER // ============================================================================= // MongoQueryBuilder builds MongoDB queries from dynamic filters type MongoQueryBuilder struct { allowedFields map[string]bool // Security: only allow specified fields allowedCollections map[string]bool // Security: only allow specified collections // Security settings enableSecurityChecks bool maxAllowedDocs int // Query logging enableQueryLogging bool // Connection timeout settings queryTimeout time.Duration } // NewMongoQueryBuilder creates a new MongoDB query builder instance func NewMongoQueryBuilder() *MongoQueryBuilder { return &MongoQueryBuilder{ allowedFields: make(map[string]bool), allowedCollections: make(map[string]bool), enableSecurityChecks: true, maxAllowedDocs: 10000, enableQueryLogging: true, queryTimeout: 30 * time.Second, } } // SetSecurityOptions configures security settings func (mqb *MongoQueryBuilder) SetSecurityOptions(enableChecks bool, maxDocs int) *MongoQueryBuilder { mqb.enableSecurityChecks = enableChecks mqb.maxAllowedDocs = maxDocs return mqb } // SetAllowedFields sets the list of allowed fields for security func (mqb *MongoQueryBuilder) SetAllowedFields(fields []string) *MongoQueryBuilder { mqb.allowedFields = make(map[string]bool) for _, field := range fields { mqb.allowedFields[field] = true } return mqb } // SetAllowedCollections sets the list of allowed collections for security func (mqb *MongoQueryBuilder) SetAllowedCollections(collections []string) *MongoQueryBuilder { mqb.allowedCollections = make(map[string]bool) for _, collection := range collections { mqb.allowedCollections[collection] = true } return mqb } // SetQueryLogging enables or disables query logging func (mqb *MongoQueryBuilder) SetQueryLogging(enable bool) *MongoQueryBuilder { mqb.enableQueryLogging = enable return mqb } // SetQueryTimeout sets the default query timeout func (mqb *MongoQueryBuilder) SetQueryTimeout(timeout time.Duration) *MongoQueryBuilder { mqb.queryTimeout = timeout return mqb } // BuildFindQuery builds a MongoDB find query from DynamicQuery func (mqb *MongoQueryBuilder) BuildFindQuery(query DynamicQuery) (bson.M, *options.FindOptions, error) { filter := bson.M{} findOptions := options.Find() // Security check for limit if mqb.enableSecurityChecks && query.Limit > mqb.maxAllowedDocs { return nil, nil, fmt.Errorf("requested limit %d exceeds maximum allowed %d", query.Limit, mqb.maxAllowedDocs) } // Security check for collection name if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[query.From] { return nil, nil, fmt.Errorf("disallowed collection: %s", query.From) } // Build filter from DynamicQuery filters if len(query.Filters) > 0 { mongoFilter, err := mqb.buildFilter(query.Filters) if err != nil { return nil, nil, err } filter = mongoFilter } // Set projection from fields if len(query.Fields) > 0 { projection := bson.M{} for _, field := range query.Fields { if field.Expression == "*" { // Include all fields continue } fieldName := field.Expression if field.Alias != "" { fieldName = field.Alias } if mqb.allowedFields != nil && !mqb.allowedFields[fieldName] { return nil, nil, fmt.Errorf("disallowed field: %s", fieldName) } projection[fieldName] = 1 } if len(projection) > 0 { findOptions.SetProjection(projection) } } // Set sort if len(query.Sort) > 0 { sort := bson.D{} for _, sortField := range query.Sort { fieldName := sortField.Column if mqb.allowedFields != nil && !mqb.allowedFields[fieldName] { return nil, nil, fmt.Errorf("disallowed field: %s", fieldName) } order := 1 // ASC if strings.ToUpper(sortField.Order) == "DESC" { order = -1 // DESC } sort = append(sort, bson.E{Key: fieldName, Value: order}) } findOptions.SetSort(sort) } // Set limit and offset if query.Limit > 0 { findOptions.SetLimit(int64(query.Limit)) } if query.Offset > 0 { findOptions.SetSkip(int64(query.Offset)) } return filter, findOptions, nil } // BuildAggregateQuery builds a MongoDB aggregation pipeline from DynamicQuery func (mqb *MongoQueryBuilder) BuildAggregateQuery(query DynamicQuery) ([]bson.D, error) { pipeline := []bson.D{} // Security check for collection name if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[query.From] { return nil, fmt.Errorf("disallowed collection: %s", query.From) } // Handle CTEs as stages in the pipeline if len(query.CTEs) > 0 { for _, cte := range query.CTEs { // Security check for CTE collection if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[cte.Query.From] { return nil, fmt.Errorf("disallowed collection in CTE: %s", cte.Query.From) } subPipeline, err := mqb.BuildAggregateQuery(cte.Query) if err != nil { return nil, fmt.Errorf("failed to build CTE '%s': %w", cte.Name, err) } // Add $lookup stage for joins if len(cte.Query.Joins) > 0 { for _, join := range cte.Query.Joins { // Security check for joined collection if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[join.Table] { return nil, fmt.Errorf("disallowed collection in join: %s", join.Table) } lookupStage := bson.D{ {Key: "$lookup", Value: bson.D{ {Key: "from", Value: join.Table}, {Key: "localField", Value: join.Alias}, {Key: "foreignField", Value: "_id"}, {Key: "as", Value: join.Alias}, }}, } pipeline = append(pipeline, lookupStage) } } // Add the sub-pipeline pipeline = append(pipeline, subPipeline...) } } // Match stage for filters if len(query.Filters) > 0 { filter, err := mqb.buildFilter(query.Filters) if err != nil { return nil, err } pipeline = append(pipeline, bson.D{{Key: "$match", Value: filter}}) } // Group stage for GROUP BY if len(query.GroupBy) > 0 { groupID := bson.D{} for _, field := range query.GroupBy { if mqb.allowedFields != nil && !mqb.allowedFields[field] { return nil, fmt.Errorf("disallowed field: %s", field) } groupID = append(groupID, bson.E{Key: field, Value: "$" + field}) } groupStage := bson.D{ {Key: "$group", Value: bson.D{ {Key: "_id", Value: groupID}, }}, } // Add any aggregations from fields for _, field := range query.Fields { if strings.Contains(field.Expression, "(") && strings.Contains(field.Expression, ")") { // This is an aggregation function funcName := strings.Split(field.Expression, "(")[0] funcField := strings.TrimSuffix(strings.Split(field.Expression, "(")[1], ")") if mqb.allowedFields != nil && !mqb.allowedFields[funcField] { return nil, fmt.Errorf("disallowed field: %s", funcField) } switch strings.ToLower(funcName) { case "count": groupStage = append(groupStage, bson.E{ Key: field.Alias, Value: bson.D{{Key: "$sum", Value: 1}}, }) case "sum": groupStage = append(groupStage, bson.E{ Key: field.Alias, Value: bson.D{{Key: "$sum", Value: "$" + funcField}}, }) case "avg": groupStage = append(groupStage, bson.E{ Key: field.Alias, Value: bson.D{{Key: "$avg", Value: "$" + funcField}}, }) case "min": groupStage = append(groupStage, bson.E{ Key: field.Alias, Value: bson.D{{Key: "$min", Value: "$" + funcField}}, }) case "max": groupStage = append(groupStage, bson.E{ Key: field.Alias, Value: bson.D{{Key: "$max", Value: "$" + funcField}}, }) } } } pipeline = append(pipeline, groupStage) } // Sort stage if len(query.Sort) > 0 { sort := bson.D{} for _, sortField := range query.Sort { fieldName := sortField.Column if mqb.allowedFields != nil && !mqb.allowedFields[fieldName] { return nil, fmt.Errorf("disallowed field: %s", fieldName) } order := 1 // ASC if strings.ToUpper(sortField.Order) == "DESC" { order = -1 // DESC } sort = append(sort, bson.E{Key: fieldName, Value: order}) } pipeline = append(pipeline, bson.D{{Key: "$sort", Value: sort}}) } // Skip and limit stages if query.Offset > 0 { pipeline = append(pipeline, bson.D{{Key: "$skip", Value: query.Offset}}) } if query.Limit > 0 { pipeline = append(pipeline, bson.D{{Key: "$limit", Value: query.Limit}}) } return pipeline, nil } // buildFilter builds a MongoDB filter from FilterGroups func (mqb *MongoQueryBuilder) buildFilter(filterGroups []FilterGroup) (bson.M, error) { if len(filterGroups) == 0 { return bson.M{}, nil } var result bson.M var err error for i, group := range filterGroups { if len(group.Filters) == 0 { continue } groupFilter, err := mqb.buildFilterGroup(group) if err != nil { return nil, err } if i == 0 { result = groupFilter } else { logicOp := "$and" if group.LogicOp != "" { switch strings.ToUpper(group.LogicOp) { case "OR": logicOp = "$or" } } result = bson.M{logicOp: []bson.M{result, groupFilter}} } } return result, err } // buildFilterGroup builds a filter for a single filter group func (mqb *MongoQueryBuilder) buildFilterGroup(group FilterGroup) (bson.M, error) { var filters []bson.M logicOp := "$and" if group.LogicOp != "" { switch strings.ToUpper(group.LogicOp) { case "OR": logicOp = "$or" } } for _, filter := range group.Filters { fieldFilter, err := mqb.buildFilterCondition(filter) if err != nil { return nil, err } filters = append(filters, fieldFilter) } if len(filters) == 1 { return filters[0], nil } return bson.M{logicOp: filters}, nil } // buildFilterCondition builds a single filter condition for MongoDB func (mqb *MongoQueryBuilder) buildFilterCondition(filter DynamicFilter) (bson.M, error) { field := filter.Column if mqb.allowedFields != nil && !mqb.allowedFields[field] { return nil, fmt.Errorf("disallowed field: %s", field) } switch filter.Operator { case OpEqual: return bson.M{field: filter.Value}, nil case OpNotEqual: return bson.M{field: bson.M{"$ne": filter.Value}}, nil case OpIn: values := mqb.parseArrayValue(filter.Value) return bson.M{field: bson.M{"$in": values}}, nil case OpNotIn: values := mqb.parseArrayValue(filter.Value) return bson.M{field: bson.M{"$nin": values}}, nil case OpGreaterThan: return bson.M{field: bson.M{"$gt": filter.Value}}, nil case OpGreaterThanEqual: return bson.M{field: bson.M{"$gte": filter.Value}}, nil case OpLessThan: return bson.M{field: bson.M{"$lt": filter.Value}}, nil case OpLessThanEqual: return bson.M{field: bson.M{"$lte": filter.Value}}, nil case OpLike: // Convert SQL LIKE to MongoDB regex pattern := filter.Value.(string) pattern = strings.ReplaceAll(pattern, "%", ".*") pattern = strings.ReplaceAll(pattern, "_", ".") return bson.M{field: bson.M{"$regex": pattern, "$options": "i"}}, nil case OpILike: // Case-insensitive like pattern := filter.Value.(string) pattern = strings.ReplaceAll(pattern, "%", ".*") pattern = strings.ReplaceAll(pattern, "_", ".") return bson.M{field: bson.M{"$regex": pattern, "$options": "i"}}, nil case OpContains: // Contains substring pattern := filter.Value.(string) return bson.M{field: bson.M{"$regex": pattern, "$options": "i"}}, nil case OpNotContains: // Does not contain substring pattern := filter.Value.(string) return bson.M{field: bson.M{"$not": bson.M{"$regex": pattern, "$options": "i"}}}, nil case OpStartsWith: // Starts with pattern := filter.Value.(string) return bson.M{field: bson.M{"$regex": "^" + pattern, "$options": "i"}}, nil case OpEndsWith: // Ends with pattern := filter.Value.(string) return bson.M{field: bson.M{"$regex": pattern + "$", "$options": "i"}}, nil case OpNull: return bson.M{field: bson.M{"$exists": false}}, nil case OpNotNull: return bson.M{field: bson.M{"$exists": true}}, nil case OpJsonContains: // JSON contains return bson.M{field: bson.M{"$elemMatch": filter.Value}}, nil case OpJsonNotContains: // JSON does not contain return bson.M{field: bson.M{"$not": bson.M{"$elemMatch": filter.Value}}}, nil case OpJsonExists: // JSON path exists return bson.M{field + "." + filter.Options["path"].(string): bson.M{"$exists": true}}, nil case OpJsonNotExists: // JSON path does not exist return bson.M{field + "." + filter.Options["path"].(string): bson.M{"$exists": false}}, nil case OpArrayContains: // Array contains return bson.M{field: bson.M{"$elemMatch": bson.M{"$eq": filter.Value}}}, nil case OpArrayNotContains: // Array does not contain return bson.M{field: bson.M{"$not": bson.M{"$elemMatch": bson.M{"$eq": filter.Value}}}}, nil case OpArrayLength: // Array length if lengthOption, ok := filter.Options["length"].(int); ok { return bson.M{field: bson.M{"$size": lengthOption}}, nil } return nil, fmt.Errorf("array_length operator requires 'length' option") default: return nil, fmt.Errorf("unsupported operator: %s", filter.Operator) } } // parseArrayValue parses an array value for MongoDB func (mqb *MongoQueryBuilder) parseArrayValue(value interface{}) []interface{} { if value == nil { return nil } if reflect.TypeOf(value).Kind() == reflect.Slice { v := reflect.ValueOf(value) result := make([]interface{}, v.Len()) for i := 0; i < v.Len(); i++ { result[i] = v.Index(i).Interface() } return result } if str, ok := value.(string); ok { if strings.Contains(str, ",") { parts := strings.Split(str, ",") result := make([]interface{}, len(parts)) for i, part := range parts { result[i] = strings.TrimSpace(part) } return result } return []interface{}{str} } return []interface{}{value} } // ExecuteFind executes a MongoDB find query func (mqb *MongoQueryBuilder) ExecuteFind(ctx context.Context, collection *mongo.Collection, query DynamicQuery, dest interface{}) error { // Security check for collection name if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[collection.Name()] { return fmt.Errorf("disallowed collection: %s", collection.Name()) } filter, findOptions, err := mqb.BuildFindQuery(query) if err != nil { return err } // Set timeout if not already in context if _, hasDeadline := ctx.Deadline(); !hasDeadline && mqb.queryTimeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, mqb.queryTimeout) defer cancel() } start := time.Now() cursor, err := collection.Find(ctx, filter, findOptions) if err != nil { return err } defer cursor.Close(ctx) err = cursor.All(ctx, dest) if mqb.enableQueryLogging { fmt.Printf("[DEBUG] MongoDB Find executed in %v\n", time.Since(start)) } return err } // ExecuteAggregate executes a MongoDB aggregation pipeline func (mqb *MongoQueryBuilder) ExecuteAggregate(ctx context.Context, collection *mongo.Collection, query DynamicQuery, dest interface{}) error { // Security check for collection name if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[collection.Name()] { return fmt.Errorf("disallowed collection: %s", collection.Name()) } pipeline, err := mqb.BuildAggregateQuery(query) if err != nil { return err } // Set timeout if not already in context if _, hasDeadline := ctx.Deadline(); !hasDeadline && mqb.queryTimeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, mqb.queryTimeout) defer cancel() } start := time.Now() cursor, err := collection.Aggregate(ctx, pipeline) if err != nil { return err } defer cursor.Close(ctx) err = cursor.All(ctx, dest) if mqb.enableQueryLogging { fmt.Printf("[DEBUG] MongoDB Aggregate executed in %v\n", time.Since(start)) } return err } // ExecuteCount executes a MongoDB count query func (mqb *MongoQueryBuilder) ExecuteCount(ctx context.Context, collection *mongo.Collection, query DynamicQuery) (int64, error) { // Security check for collection name if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[collection.Name()] { return 0, fmt.Errorf("disallowed collection: %s", collection.Name()) } filter, _, err := mqb.BuildFindQuery(query) if err != nil { return 0, err } // Set timeout if not already in context if _, hasDeadline := ctx.Deadline(); !hasDeadline && mqb.queryTimeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, mqb.queryTimeout) defer cancel() } start := time.Now() count, err := collection.CountDocuments(ctx, filter) if mqb.enableQueryLogging { fmt.Printf("[DEBUG] MongoDB Count executed in %v\n", time.Since(start)) } return count, err } // ExecuteInsert executes a MongoDB insert operation func (mqb *MongoQueryBuilder) ExecuteInsert(ctx context.Context, collection *mongo.Collection, data InsertData) (*mongo.InsertOneResult, error) { // Security check for collection name if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[collection.Name()] { return nil, fmt.Errorf("disallowed collection: %s", collection.Name()) } document := bson.M{} for i, col := range data.Columns { if mqb.allowedFields != nil && !mqb.allowedFields[col] { return nil, fmt.Errorf("disallowed field: %s", col) } document[col] = data.Values[i] } // Handle JSON values for col, val := range data.JsonValues { if mqb.allowedFields != nil && !mqb.allowedFields[col] { return nil, fmt.Errorf("disallowed field: %s", col) } document[col] = val } // Set timeout if not already in context if _, hasDeadline := ctx.Deadline(); !hasDeadline && mqb.queryTimeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, mqb.queryTimeout) defer cancel() } start := time.Now() result, err := collection.InsertOne(ctx, document) if mqb.enableQueryLogging { fmt.Printf("[DEBUG] MongoDB Insert executed in %v\n", time.Since(start)) } return result, err } // ExecuteUpdate executes a MongoDB update operation func (mqb *MongoQueryBuilder) ExecuteUpdate(ctx context.Context, collection *mongo.Collection, updateData UpdateData, filters []FilterGroup) (*mongo.UpdateResult, error) { // Security check for collection name if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[collection.Name()] { return nil, fmt.Errorf("disallowed collection: %s", collection.Name()) } filter, err := mqb.buildFilter(filters) if err != nil { return nil, err } update := bson.M{"$set": bson.M{}} for i, col := range updateData.Columns { if mqb.allowedFields != nil && !mqb.allowedFields[col] { return nil, fmt.Errorf("disallowed field: %s", col) } update["$set"].(bson.M)[col] = updateData.Values[i] } // Handle JSON updates for col, jsonUpdate := range updateData.JsonUpdates { if mqb.allowedFields != nil && !mqb.allowedFields[col] { return nil, fmt.Errorf("disallowed field: %s", col) } // Use dot notation for nested JSON updates update["$set"].(bson.M)[col+"."+jsonUpdate.Path] = jsonUpdate.Value } // Set timeout if not already in context if _, hasDeadline := ctx.Deadline(); !hasDeadline && mqb.queryTimeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, mqb.queryTimeout) defer cancel() } start := time.Now() result, err := collection.UpdateMany(ctx, filter, update) if mqb.enableQueryLogging { fmt.Printf("[DEBUG] MongoDB Update executed in %v\n", time.Since(start)) } return result, err } // ExecuteDelete executes a MongoDB delete operation func (mqb *MongoQueryBuilder) ExecuteDelete(ctx context.Context, collection *mongo.Collection, filters []FilterGroup) (*mongo.DeleteResult, error) { // Security check for collection name if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[collection.Name()] { return nil, fmt.Errorf("disallowed collection: %s", collection.Name()) } filter, err := mqb.buildFilter(filters) if err != nil { return nil, err } // Set timeout if not already in context if _, hasDeadline := ctx.Deadline(); !hasDeadline && mqb.queryTimeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, mqb.queryTimeout) defer cancel() } start := time.Now() result, err := collection.DeleteMany(ctx, filter) if mqb.enableQueryLogging { fmt.Printf("[DEBUG] MongoDB Delete executed in %v\n", time.Since(start)) } return result, err }