2729 lines
86 KiB
Go
2729 lines
86 KiB
Go
package utils
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/url"
|
|
"reflect"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/Masterminds/squirrel"
|
|
"github.com/jmoiron/sqlx"
|
|
"go.mongodb.org/mongo-driver/bson"
|
|
"go.mongodb.org/mongo-driver/mongo"
|
|
"go.mongodb.org/mongo-driver/mongo/options"
|
|
)
|
|
|
|
// DBType represents the type of database
|
|
type DBType string
|
|
|
|
const (
|
|
DBTypePostgreSQL DBType = "postgres"
|
|
DBTypeMySQL DBType = "mysql"
|
|
DBTypeSQLite DBType = "sqlite"
|
|
DBTypeSQLServer DBType = "sqlserver"
|
|
DBTypeMongoDB DBType = "mongodb"
|
|
)
|
|
|
|
// FilterOperator represents supported filter operators
|
|
type FilterOperator string
|
|
|
|
const (
|
|
OpEqual FilterOperator = "_eq"
|
|
OpNotEqual FilterOperator = "_neq"
|
|
OpLike FilterOperator = "_like"
|
|
OpILike FilterOperator = "_ilike"
|
|
OpNotLike FilterOperator = "_nlike"
|
|
OpNotILike FilterOperator = "_nilike"
|
|
OpIn FilterOperator = "_in"
|
|
OpNotIn FilterOperator = "_nin"
|
|
OpGreaterThan FilterOperator = "_gt"
|
|
OpGreaterThanEqual FilterOperator = "_gte"
|
|
OpLessThan FilterOperator = "_lt"
|
|
OpLessThanEqual FilterOperator = "_lte"
|
|
OpBetween FilterOperator = "_between"
|
|
OpNotBetween FilterOperator = "_nbetween"
|
|
OpNull FilterOperator = "_null"
|
|
OpNotNull FilterOperator = "_nnull"
|
|
OpContains FilterOperator = "_contains"
|
|
OpNotContains FilterOperator = "_ncontains"
|
|
OpStartsWith FilterOperator = "_starts_with"
|
|
OpEndsWith FilterOperator = "_ends_with"
|
|
OpJsonContains FilterOperator = "_json_contains"
|
|
OpJsonNotContains FilterOperator = "_json_ncontains"
|
|
OpJsonExists FilterOperator = "_json_exists"
|
|
OpJsonNotExists FilterOperator = "_json_nexists"
|
|
OpJsonEqual FilterOperator = "_json_eq"
|
|
OpJsonNotEqual FilterOperator = "_json_neq"
|
|
OpArrayContains FilterOperator = "_array_contains"
|
|
OpArrayNotContains FilterOperator = "_array_ncontains"
|
|
OpArrayLength FilterOperator = "_array_length"
|
|
)
|
|
|
|
// DynamicFilter represents a single filter condition
|
|
type DynamicFilter struct {
|
|
Column string `json:"column"`
|
|
Operator FilterOperator `json:"operator"`
|
|
Value interface{} `json:"value"`
|
|
// Additional options for complex filters
|
|
Options map[string]interface{} `json:"options,omitempty"`
|
|
}
|
|
|
|
// FilterGroup represents a group of filters with a logical operator (AND/OR)
|
|
type FilterGroup struct {
|
|
Filters []DynamicFilter `json:"filters"`
|
|
LogicOp string `json:"logic_op"` // AND, OR
|
|
}
|
|
|
|
// SelectField represents a field in the SELECT clause, supporting expressions and aliases
|
|
type SelectField struct {
|
|
Expression string `json:"expression"` // e.g., "TMLogBarang.Nama", "COUNT(*)"
|
|
Alias string `json:"alias"` // e.g., "obat_nama", "total_count"
|
|
// Window function support
|
|
WindowFunction *WindowFunction `json:"window_function,omitempty"`
|
|
}
|
|
|
|
// WindowFunction represents a window function with its configuration
|
|
type WindowFunction struct {
|
|
Function string `json:"function"` // e.g., "ROW_NUMBER", "RANK", "DENSE_RANK", "LEAD", "LAG"
|
|
Over string `json:"over"` // PARTITION BY expression
|
|
OrderBy string `json:"order_by"` // ORDER BY expression
|
|
Frame string `json:"frame"` // ROWS/RANGE clause
|
|
Alias string `json:"alias"` // Alias for the window function
|
|
}
|
|
|
|
// Join represents a JOIN clause
|
|
type Join struct {
|
|
Type string `json:"type"` // "INNER", "LEFT", "RIGHT", "FULL"
|
|
Table string `json:"table"` // Table name to join
|
|
Alias string `json:"alias"` // Table alias
|
|
OnConditions FilterGroup `json:"on_conditions"` // Conditions for the ON clause
|
|
// LATERAL JOIN support
|
|
Lateral bool `json:"lateral,omitempty"`
|
|
}
|
|
|
|
// Union represents a UNION clause
|
|
type Union struct {
|
|
Type string `json:"type"` // "UNION", "UNION ALL"
|
|
Query DynamicQuery `json:"query"` // The subquery to union with
|
|
}
|
|
|
|
// CTE (Common Table Expression) represents a WITH clause
|
|
type CTE struct {
|
|
Name string `json:"name"` // CTE alias name
|
|
Query DynamicQuery `json:"query"` // The query defining the CTE
|
|
// Recursive CTE support
|
|
Recursive bool `json:"recursive,omitempty"`
|
|
}
|
|
|
|
// DynamicQuery represents the complete query structure
|
|
type DynamicQuery struct {
|
|
Fields []SelectField `json:"fields,omitempty"`
|
|
From string `json:"from"` // Main table name
|
|
Aliases string `json:"aliases"` // Main table alias
|
|
Joins []Join `json:"joins,omitempty"`
|
|
Filters []FilterGroup `json:"filters,omitempty"`
|
|
GroupBy []string `json:"group_by,omitempty"`
|
|
Having []FilterGroup `json:"having,omitempty"`
|
|
Unions []Union `json:"unions,omitempty"`
|
|
CTEs []CTE `json:"ctes,omitempty"`
|
|
Sort []SortField `json:"sort,omitempty"`
|
|
Limit int `json:"limit"`
|
|
Offset int `json:"offset"`
|
|
// Window function support
|
|
WindowFunctions []WindowFunction `json:"window_functions,omitempty"`
|
|
// JSON operations
|
|
JsonOperations []JsonOperation `json:"json_operations,omitempty"`
|
|
}
|
|
|
|
// JsonOperation represents a JSON operation
|
|
type JsonOperation struct {
|
|
Type string `json:"type"` // "extract", "exists", "contains", etc.
|
|
Column string `json:"column"` // JSON column
|
|
Path string `json:"path"` // JSON path
|
|
Value interface{} `json:"value,omitempty"` // Value for comparison
|
|
Alias string `json:"alias,omitempty"` // Alias for the result
|
|
}
|
|
|
|
// SortField represents sorting configuration
|
|
type SortField struct {
|
|
Column string `json:"column"`
|
|
Order string `json:"order"` // ASC, DESC
|
|
}
|
|
|
|
// UpdateData represents data for UPDATE operations
|
|
type UpdateData struct {
|
|
Columns []string `json:"columns"`
|
|
Values []interface{} `json:"values"`
|
|
// JSON update support
|
|
JsonUpdates map[string]JsonUpdate `json:"json_updates,omitempty"`
|
|
}
|
|
|
|
// JsonUpdate represents a JSON update operation
|
|
type JsonUpdate struct {
|
|
Path string `json:"path"` // JSON path
|
|
Value interface{} `json:"value"` // New value
|
|
}
|
|
|
|
// InsertData represents data for INSERT operations
|
|
type InsertData struct {
|
|
Columns []string `json:"columns"`
|
|
Values []interface{} `json:"values"`
|
|
// JSON insert support
|
|
JsonValues map[string]interface{} `json:"json_values,omitempty"`
|
|
}
|
|
|
|
// QueryBuilder builds SQL queries from dynamic filters using squirrel
|
|
type QueryBuilder struct {
|
|
dbType DBType
|
|
sqlBuilder squirrel.StatementBuilderType
|
|
allowedColumns map[string]bool // Security: only allow specified columns
|
|
allowedTables map[string]bool // Security: only allow specified tables
|
|
// Security settings
|
|
enableSecurityChecks bool
|
|
maxAllowedRows int
|
|
// SQL injection prevention patterns
|
|
dangerousPatterns []*regexp.Regexp
|
|
// Query logging
|
|
enableQueryLogging bool
|
|
// Connection timeout settings
|
|
queryTimeout time.Duration
|
|
}
|
|
|
|
// NewQueryBuilder creates a new query builder instance for a specific database type
|
|
func NewQueryBuilder(dbType DBType) *QueryBuilder {
|
|
var placeholderFormat squirrel.PlaceholderFormat
|
|
|
|
switch dbType {
|
|
case DBTypePostgreSQL:
|
|
placeholderFormat = squirrel.Dollar
|
|
case DBTypeMySQL, DBTypeSQLite:
|
|
placeholderFormat = squirrel.Question
|
|
case DBTypeSQLServer:
|
|
placeholderFormat = squirrel.AtP
|
|
default:
|
|
placeholderFormat = squirrel.Question
|
|
}
|
|
|
|
// Initialize dangerous patterns for SQL injection prevention
|
|
dangerousPatterns := []*regexp.Regexp{
|
|
regexp.MustCompile(`(?i)(union|select|insert|update|delete|drop|alter|create|exec|execute)\s`),
|
|
regexp.MustCompile(`(?i)(--|\/\*|\*\/)`),
|
|
regexp.MustCompile(`(?i)(or|and)\s+1\s*=\s*1`),
|
|
regexp.MustCompile(`(?i)(or|and)\s+true`),
|
|
regexp.MustCompile(`(?i)(xp_|sp_)\w+`), // SQL Server extended procedures
|
|
regexp.MustCompile(`(?i)(waitfor\s+delay)`), // SQL Server time-based attack
|
|
regexp.MustCompile(`(?i)(benchmark|sleep)\s*\(`), // MySQL time-based attack
|
|
regexp.MustCompile(`(?i)(pg_sleep)\s*\(`), // PostgreSQL time-based attack
|
|
regexp.MustCompile(`(?i)(load_file|into\s+outfile)`), // File operations
|
|
regexp.MustCompile(`(?i)(information_schema|sysobjects|syscolumns)`), // System tables
|
|
}
|
|
|
|
return &QueryBuilder{
|
|
dbType: dbType,
|
|
sqlBuilder: squirrel.StatementBuilder.PlaceholderFormat(placeholderFormat),
|
|
allowedColumns: make(map[string]bool),
|
|
allowedTables: make(map[string]bool),
|
|
enableSecurityChecks: true,
|
|
maxAllowedRows: 10000,
|
|
dangerousPatterns: dangerousPatterns,
|
|
enableQueryLogging: true,
|
|
queryTimeout: 30 * time.Second,
|
|
}
|
|
}
|
|
|
|
// SetSecurityOptions configures security settings
|
|
func (qb *QueryBuilder) SetSecurityOptions(enableChecks bool, maxRows int) *QueryBuilder {
|
|
qb.enableSecurityChecks = enableChecks
|
|
qb.maxAllowedRows = maxRows
|
|
return qb
|
|
}
|
|
|
|
// SetAllowedColumns sets the list of allowed columns for security
|
|
func (qb *QueryBuilder) SetAllowedColumns(columns []string) *QueryBuilder {
|
|
qb.allowedColumns = make(map[string]bool)
|
|
for _, col := range columns {
|
|
qb.allowedColumns[col] = true
|
|
}
|
|
return qb
|
|
}
|
|
|
|
// SetAllowedTables sets the list of allowed tables for security
|
|
func (qb *QueryBuilder) SetAllowedTables(tables []string) *QueryBuilder {
|
|
qb.allowedTables = make(map[string]bool)
|
|
for _, table := range tables {
|
|
qb.allowedTables[table] = true
|
|
}
|
|
return qb
|
|
}
|
|
|
|
// SetQueryLogging enables or disables query logging
|
|
func (qb *QueryBuilder) SetQueryLogging(enable bool) *QueryBuilder {
|
|
qb.enableQueryLogging = enable
|
|
return qb
|
|
}
|
|
|
|
// SetQueryTimeout sets the default query timeout
|
|
func (qb *QueryBuilder) SetQueryTimeout(timeout time.Duration) *QueryBuilder {
|
|
qb.queryTimeout = timeout
|
|
return qb
|
|
}
|
|
|
|
// BuildQuery builds the complete SQL SELECT query with support for CTEs, JOINs, and UNIONs
|
|
func (qb *QueryBuilder) BuildQuery(query DynamicQuery) (string, []interface{}, error) {
|
|
var allArgs []interface{}
|
|
var queryParts []string
|
|
|
|
// Security check for limit
|
|
if qb.enableSecurityChecks && query.Limit > qb.maxAllowedRows {
|
|
return "", nil, fmt.Errorf("requested limit %d exceeds maximum allowed %d", query.Limit, qb.maxAllowedRows)
|
|
}
|
|
|
|
// Security check for table name
|
|
if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[query.From] {
|
|
return "", nil, fmt.Errorf("disallowed table: %s", query.From)
|
|
}
|
|
|
|
// 1. Build CTEs (WITH clause)
|
|
if len(query.CTEs) > 0 {
|
|
cteClause, cteArgs, err := qb.buildCTEClause(query.CTEs)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
queryParts = append(queryParts, cteClause)
|
|
allArgs = append(allArgs, cteArgs...)
|
|
}
|
|
|
|
// 2. Build Main Query using Squirrel's From and Join methods
|
|
fromClause := qb.buildFromClause(query.From, query.Aliases)
|
|
selectFields := qb.buildSelectFields(query.Fields)
|
|
|
|
// Start building the main query
|
|
var mainQuery squirrel.SelectBuilder
|
|
if len(query.WindowFunctions) > 0 || len(query.JsonOperations) > 0 {
|
|
// We need to add window functions and JSON operations after initial select
|
|
mainQuery = qb.sqlBuilder.Select(selectFields...).From(fromClause)
|
|
} else {
|
|
mainQuery = qb.sqlBuilder.Select(selectFields...).From(fromClause)
|
|
}
|
|
|
|
// Add JOINs using Squirrel's Join method
|
|
if len(query.Joins) > 0 {
|
|
for _, join := range query.Joins {
|
|
// Security check for joined table
|
|
if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[join.Table] {
|
|
return "", nil, fmt.Errorf("disallowed table in join: %s", join.Table)
|
|
}
|
|
|
|
joinType, tableWithAlias, onClause, joinArgs, err := qb.buildSingleJoinClause(join)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
joinStr := tableWithAlias + " ON " + onClause
|
|
switch strings.ToUpper(joinType) {
|
|
case "LEFT":
|
|
if join.Lateral {
|
|
mainQuery = mainQuery.LeftJoin("LATERAL "+joinStr, joinArgs...)
|
|
} else {
|
|
mainQuery = mainQuery.LeftJoin(joinStr, joinArgs...)
|
|
}
|
|
case "RIGHT":
|
|
mainQuery = mainQuery.RightJoin(joinStr, joinArgs...)
|
|
case "FULL":
|
|
mainQuery = mainQuery.Join("FULL JOIN "+joinStr, joinArgs...)
|
|
default:
|
|
if join.Lateral {
|
|
mainQuery = mainQuery.Join("LATERAL "+joinStr, joinArgs...)
|
|
} else {
|
|
mainQuery = mainQuery.Join(joinStr, joinArgs...)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 4. Apply WHERE conditions
|
|
if len(query.Filters) > 0 {
|
|
whereClause, whereArgs, err := qb.BuildWhereClause(query.Filters)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
mainQuery = mainQuery.Where(whereClause, whereArgs...)
|
|
}
|
|
|
|
// 5. Apply GROUP BY
|
|
if len(query.GroupBy) > 0 {
|
|
mainQuery = mainQuery.GroupBy(qb.buildGroupByColumns(query.GroupBy)...)
|
|
}
|
|
|
|
// 6. Apply HAVING conditions
|
|
if len(query.Having) > 0 {
|
|
havingClause, havingArgs, err := qb.BuildWhereClause(query.Having)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
mainQuery = mainQuery.Having(havingClause, havingArgs...)
|
|
}
|
|
|
|
// 7. Apply ORDER BY
|
|
if len(query.Sort) > 0 {
|
|
for _, sort := range query.Sort {
|
|
column := qb.validateAndEscapeColumn(sort.Column)
|
|
if column == "" {
|
|
continue
|
|
}
|
|
order := "ASC"
|
|
if strings.ToUpper(sort.Order) == "DESC" {
|
|
order = "DESC"
|
|
}
|
|
mainQuery = mainQuery.OrderBy(fmt.Sprintf("%s %s", column, order))
|
|
}
|
|
}
|
|
|
|
// 8. Apply window functions and JSON operations by modifying the SELECT clause
|
|
if len(query.WindowFunctions) > 0 || len(query.JsonOperations) > 0 {
|
|
// We need to rebuild the SELECT clause with window functions and JSON operations
|
|
var finalSelectFields []string
|
|
finalSelectFields = append(finalSelectFields, selectFields...)
|
|
|
|
// Add window functions
|
|
for _, wf := range query.WindowFunctions {
|
|
windowFunc, err := qb.buildWindowFunction(wf)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
finalSelectFields = append(finalSelectFields, windowFunc)
|
|
}
|
|
|
|
// Add JSON operations
|
|
for _, jo := range query.JsonOperations {
|
|
jsonExpr, jsonArgs, err := qb.buildJsonOperation(jo)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
if jo.Alias != "" {
|
|
jsonExpr += " AS " + qb.escapeIdentifier(jo.Alias)
|
|
}
|
|
finalSelectFields = append(finalSelectFields, jsonExpr)
|
|
allArgs = append(allArgs, jsonArgs...)
|
|
}
|
|
|
|
// Rebuild the query with the complete SELECT clause
|
|
mainQuery = qb.sqlBuilder.Select(finalSelectFields...).From(fromClause)
|
|
|
|
// Re-apply all the other clauses
|
|
if len(query.Joins) > 0 {
|
|
for _, join := range query.Joins {
|
|
// Security check for joined table
|
|
if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[join.Table] {
|
|
return "", nil, fmt.Errorf("disallowed table in join: %s", join.Table)
|
|
}
|
|
|
|
joinType, tableWithAlias, onClause, joinArgs, err := qb.buildSingleJoinClause(join)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
joinStr := tableWithAlias + " ON " + onClause
|
|
switch strings.ToUpper(joinType) {
|
|
case "LEFT":
|
|
if join.Lateral {
|
|
mainQuery = mainQuery.LeftJoin("LATERAL "+joinStr, joinArgs...)
|
|
} else {
|
|
mainQuery = mainQuery.LeftJoin(joinStr, joinArgs...)
|
|
}
|
|
case "RIGHT":
|
|
mainQuery = mainQuery.RightJoin(joinStr, joinArgs...)
|
|
case "FULL":
|
|
mainQuery = mainQuery.Join("FULL JOIN "+joinStr, joinArgs...)
|
|
default:
|
|
if join.Lateral {
|
|
mainQuery = mainQuery.Join("LATERAL "+joinStr, joinArgs...)
|
|
} else {
|
|
mainQuery = mainQuery.Join(joinStr, joinArgs...)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(query.Filters) > 0 {
|
|
whereClause, whereArgs, err := qb.BuildWhereClause(query.Filters)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
mainQuery = mainQuery.Where(whereClause, whereArgs...)
|
|
}
|
|
|
|
if len(query.GroupBy) > 0 {
|
|
mainQuery = mainQuery.GroupBy(qb.buildGroupByColumns(query.GroupBy)...)
|
|
}
|
|
|
|
if len(query.Having) > 0 {
|
|
havingClause, havingArgs, err := qb.BuildWhereClause(query.Having)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
mainQuery = mainQuery.Having(havingClause, havingArgs...)
|
|
}
|
|
|
|
if len(query.Sort) > 0 {
|
|
for _, sort := range query.Sort {
|
|
column := qb.validateAndEscapeColumn(sort.Column)
|
|
if column == "" {
|
|
continue
|
|
}
|
|
order := "ASC"
|
|
if strings.ToUpper(sort.Order) == "DESC" {
|
|
order = "DESC"
|
|
}
|
|
mainQuery = mainQuery.OrderBy(fmt.Sprintf("%s %s", column, order))
|
|
}
|
|
}
|
|
}
|
|
|
|
// 9. Apply pagination with dialect-specific syntax
|
|
if query.Limit > 0 {
|
|
if qb.dbType == DBTypeSQLServer {
|
|
// SQL Server requires ORDER BY for OFFSET FETCH
|
|
if len(query.Sort) == 0 {
|
|
mainQuery = mainQuery.OrderBy("(SELECT 1)")
|
|
}
|
|
mainQuery = mainQuery.Suffix(fmt.Sprintf("OFFSET %d ROWS FETCH NEXT %d ROWS ONLY", query.Offset, query.Limit))
|
|
} else {
|
|
mainQuery = mainQuery.Limit(uint64(query.Limit))
|
|
if query.Offset > 0 {
|
|
mainQuery = mainQuery.Offset(uint64(query.Offset))
|
|
}
|
|
}
|
|
} else if query.Offset > 0 && qb.dbType != DBTypeSQLServer {
|
|
mainQuery = mainQuery.Offset(uint64(query.Offset))
|
|
}
|
|
|
|
// Build final main query SQL
|
|
sql, args, err := mainQuery.ToSql()
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to build main query: %w", err)
|
|
}
|
|
queryParts = append(queryParts, sql)
|
|
allArgs = append(allArgs, args...)
|
|
|
|
// 10. Apply UNIONs
|
|
if len(query.Unions) > 0 {
|
|
unionClause, unionArgs, err := qb.buildUnionClause(query.Unions)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
queryParts = append(queryParts, unionClause)
|
|
allArgs = append(allArgs, unionArgs...)
|
|
}
|
|
|
|
finalSQL := strings.Join(queryParts, " ")
|
|
|
|
// Security check for dangerous patterns in user input values
|
|
if qb.enableSecurityChecks {
|
|
if err := qb.checkForSqlInjectionInArgs(allArgs); err != nil {
|
|
return "", nil, err
|
|
}
|
|
}
|
|
|
|
// Security check for dangerous patterns in the final SQL
|
|
// if qb.enableSecurityChecks {
|
|
// if err := qb.checkForSqlInjectionInSQL(finalSQL); err != nil {
|
|
// return "", nil, err
|
|
// }
|
|
// }
|
|
|
|
if qb.enableQueryLogging {
|
|
fmt.Printf("[DEBUG BuilderQuery] Final SQL query: %s\n", finalSQL)
|
|
fmt.Printf("[DEBUG] Query args: %v\n", allArgs)
|
|
}
|
|
return finalSQL, allArgs, nil
|
|
}
|
|
|
|
// buildWindowFunction builds a window function expression
|
|
func (qb *QueryBuilder) buildWindowFunction(wf WindowFunction) (string, error) {
|
|
if !qb.isValidFunctionName(wf.Function) {
|
|
return "", fmt.Errorf("invalid window function name: %s", wf.Function)
|
|
}
|
|
|
|
windowExpr := fmt.Sprintf("%s() OVER (", wf.Function)
|
|
|
|
if wf.Over != "" {
|
|
windowExpr += fmt.Sprintf("PARTITION BY %s ", wf.Over)
|
|
}
|
|
|
|
if wf.OrderBy != "" {
|
|
windowExpr += fmt.Sprintf("ORDER BY %s ", wf.OrderBy)
|
|
}
|
|
|
|
if wf.Frame != "" {
|
|
windowExpr += wf.Frame
|
|
}
|
|
|
|
windowExpr += ")"
|
|
|
|
if wf.Alias != "" {
|
|
windowExpr += " AS " + qb.escapeIdentifier(wf.Alias)
|
|
}
|
|
|
|
return windowExpr, nil
|
|
}
|
|
|
|
// buildJsonOperation builds a JSON operation expression
|
|
func (qb *QueryBuilder) buildJsonOperation(jo JsonOperation) (string, []interface{}, error) {
|
|
column := qb.validateAndEscapeColumn(jo.Column)
|
|
if column == "" {
|
|
return "", nil, fmt.Errorf("invalid or disallowed column: %s", jo.Column)
|
|
}
|
|
|
|
path := jo.Path
|
|
if path == "" {
|
|
path = "$"
|
|
}
|
|
|
|
var expr string
|
|
var args []interface{}
|
|
|
|
switch strings.ToLower(jo.Type) {
|
|
case "extract":
|
|
switch qb.dbType {
|
|
case DBTypePostgreSQL:
|
|
expr = fmt.Sprintf("%s->>%s", column, qb.escapeJsonPath(path))
|
|
case DBTypeMySQL:
|
|
expr = fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, path)
|
|
case DBTypeSQLServer:
|
|
expr = fmt.Sprintf("JSON_VALUE(%s, '%s')", column, qb.escapeSqlServerJsonPath(path))
|
|
case DBTypeSQLite:
|
|
expr = fmt.Sprintf("json_extract(%s, '%s')", column, path)
|
|
default:
|
|
return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType)
|
|
}
|
|
case "exists":
|
|
switch qb.dbType {
|
|
case DBTypePostgreSQL:
|
|
expr = fmt.Sprintf("jsonb_path_exists(%s, '%s')", column, path)
|
|
case DBTypeMySQL:
|
|
expr = fmt.Sprintf("JSON_CONTAINS_PATH(%s, 'one', '%s')", column, path)
|
|
case DBTypeSQLServer:
|
|
expr = fmt.Sprintf("JSON_VALUE(%s, '%s') IS NOT NULL", column, qb.escapeSqlServerJsonPath(path))
|
|
case DBTypeSQLite:
|
|
expr = fmt.Sprintf("json_extract(%s, '%s') IS NOT NULL", column, path)
|
|
default:
|
|
return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType)
|
|
}
|
|
case "contains":
|
|
switch qb.dbType {
|
|
case DBTypePostgreSQL:
|
|
expr = fmt.Sprintf("%s @> %s", column, "?")
|
|
args = append(args, jo.Value)
|
|
case DBTypeMySQL:
|
|
expr = fmt.Sprintf("JSON_CONTAINS(%s, ?, '%s')", column, path)
|
|
args = append(args, jo.Value)
|
|
case DBTypeSQLServer:
|
|
expr = fmt.Sprintf("JSON_VALUE(%s, '%s') = ?", column, qb.escapeSqlServerJsonPath(path))
|
|
args = append(args, jo.Value)
|
|
case DBTypeSQLite:
|
|
expr = fmt.Sprintf("json_extract(%s, '%s') = ?", column, path)
|
|
args = append(args, jo.Value)
|
|
default:
|
|
return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType)
|
|
}
|
|
default:
|
|
return "", nil, fmt.Errorf("unsupported JSON operation type: %s", jo.Type)
|
|
}
|
|
|
|
return expr, args, nil
|
|
}
|
|
|
|
// escapeJsonPath escapes a JSON path for PostgreSQL
|
|
func (qb *QueryBuilder) escapeJsonPath(path string) string {
|
|
// Simple implementation - in a real scenario, you'd need more sophisticated escaping
|
|
return "'" + strings.ReplaceAll(path, "'", "''") + "'"
|
|
}
|
|
|
|
// escapeSqlServerJsonPath escapes a JSON path for SQL Server
|
|
func (qb *QueryBuilder) escapeSqlServerJsonPath(path string) string {
|
|
// Convert JSONPath to SQL Server format
|
|
// $.path.to.property -> '$.path.to.property'
|
|
if !strings.HasPrefix(path, "$") {
|
|
path = "$." + path
|
|
}
|
|
return strings.ReplaceAll(path, ".", ".")
|
|
}
|
|
|
|
// buildCTEClause builds the WITH clause for Common Table Expressions
|
|
func (qb *QueryBuilder) buildCTEClause(ctes []CTE) (string, []interface{}, error) {
|
|
var cteParts []string
|
|
var allArgs []interface{}
|
|
|
|
hasRecursive := false
|
|
for _, cte := range ctes {
|
|
if cte.Recursive {
|
|
hasRecursive = true
|
|
break
|
|
}
|
|
}
|
|
|
|
withClause := "WITH"
|
|
if hasRecursive {
|
|
withClause = "WITH RECURSIVE"
|
|
}
|
|
|
|
for _, cte := range ctes {
|
|
subQuery, args, err := qb.BuildQuery(cte.Query)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to build CTE '%s': %w", cte.Name, err)
|
|
}
|
|
cteParts = append(cteParts, fmt.Sprintf("%s AS (%s)", qb.escapeIdentifier(cte.Name), subQuery))
|
|
allArgs = append(allArgs, args...)
|
|
}
|
|
|
|
return fmt.Sprintf("%s %s", withClause, strings.Join(cteParts, ", ")), allArgs, nil
|
|
}
|
|
|
|
// buildFromClause builds the FROM clause with optional alias
|
|
func (qb *QueryBuilder) buildFromClause(table, alias string) string {
|
|
fromClause := qb.escapeIdentifier(table)
|
|
if alias != "" {
|
|
fromClause += " " + qb.escapeIdentifier(alias)
|
|
}
|
|
return fromClause
|
|
}
|
|
|
|
// buildSingleJoinClause builds a single JOIN clause components
|
|
func (qb *QueryBuilder) buildSingleJoinClause(join Join) (string, string, string, []interface{}, error) {
|
|
joinType := strings.ToUpper(join.Type)
|
|
if joinType == "" {
|
|
joinType = "INNER"
|
|
}
|
|
|
|
table := qb.escapeIdentifier(join.Table)
|
|
if join.Alias != "" {
|
|
table += " " + qb.escapeIdentifier(join.Alias)
|
|
}
|
|
|
|
onClause, onArgs, err := qb.BuildWhereClause([]FilterGroup{join.OnConditions})
|
|
if err != nil {
|
|
return "", "", "", nil, fmt.Errorf("failed to build ON clause for join on table %s: %w", join.Table, err)
|
|
}
|
|
|
|
return joinType, table, onClause, onArgs, nil
|
|
}
|
|
|
|
// buildUnionClause builds the UNION clause
|
|
func (qb *QueryBuilder) buildUnionClause(unions []Union) (string, []interface{}, error) {
|
|
var unionParts []string
|
|
var allArgs []interface{}
|
|
|
|
for _, union := range unions {
|
|
subQuery, args, err := qb.BuildQuery(union.Query)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to build subquery for UNION: %w", err)
|
|
}
|
|
unionType := strings.ToUpper(union.Type)
|
|
if unionType == "" {
|
|
unionType = "UNION"
|
|
}
|
|
unionParts = append(unionParts, fmt.Sprintf("%s %s", unionType, subQuery))
|
|
allArgs = append(allArgs, args...)
|
|
}
|
|
|
|
return strings.Join(unionParts, " "), allArgs, nil
|
|
}
|
|
|
|
// buildSelectFields builds the SELECT fields from SelectField structs
|
|
func (qb *QueryBuilder) buildSelectFields(fields []SelectField) []string {
|
|
if len(fields) == 0 {
|
|
return []string{"*"}
|
|
}
|
|
|
|
var selectedFields []string
|
|
for _, field := range fields {
|
|
expr := field.Expression
|
|
if expr == "" {
|
|
continue
|
|
}
|
|
// Basic validation for expression
|
|
if !qb.isValidExpression(expr) {
|
|
continue
|
|
}
|
|
|
|
// Handle window functions
|
|
if field.WindowFunction != nil {
|
|
windowFunc, err := qb.buildWindowFunction(*field.WindowFunction)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
expr = windowFunc
|
|
}
|
|
|
|
if field.Alias != "" {
|
|
selectedFields = append(selectedFields, fmt.Sprintf("%s AS %s", expr, qb.escapeIdentifier(field.Alias)))
|
|
} else {
|
|
selectedFields = append(selectedFields, expr)
|
|
}
|
|
}
|
|
|
|
if len(selectedFields) == 0 {
|
|
return []string{"*"}
|
|
}
|
|
|
|
return selectedFields
|
|
}
|
|
|
|
// BuildWhereClause builds WHERE/HAVING conditions from FilterGroups
|
|
func (qb *QueryBuilder) BuildWhereClause(filterGroups []FilterGroup) (string, []interface{}, error) {
|
|
if len(filterGroups) == 0 {
|
|
return "", nil, nil
|
|
}
|
|
|
|
var conditions []string
|
|
var allArgs []interface{}
|
|
|
|
for i, group := range filterGroups {
|
|
if len(group.Filters) == 0 {
|
|
continue
|
|
}
|
|
|
|
groupCondition, groupArgs, err := qb.buildFilterGroup(group)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
if groupCondition != "" {
|
|
if i > 0 {
|
|
logicOp := "AND"
|
|
if group.LogicOp != "" {
|
|
logicOp = strings.ToUpper(group.LogicOp)
|
|
}
|
|
conditions = append(conditions, logicOp)
|
|
}
|
|
conditions = append(conditions, fmt.Sprintf("(%s)", groupCondition))
|
|
allArgs = append(allArgs, groupArgs...)
|
|
}
|
|
}
|
|
|
|
return strings.Join(conditions, " "), allArgs, nil
|
|
}
|
|
|
|
// buildFilterGroup builds conditions for a single filter group
|
|
func (qb *QueryBuilder) buildFilterGroup(group FilterGroup) (string, []interface{}, error) {
|
|
var conditions []string
|
|
var args []interface{}
|
|
logicOp := "AND"
|
|
if group.LogicOp != "" {
|
|
logicOp = strings.ToUpper(group.LogicOp)
|
|
}
|
|
|
|
for i, filter := range group.Filters {
|
|
condition, filterArgs, err := qb.buildFilterCondition(filter)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
if condition != "" {
|
|
if i > 0 {
|
|
conditions = append(conditions, logicOp)
|
|
}
|
|
conditions = append(conditions, condition)
|
|
args = append(args, filterArgs...)
|
|
}
|
|
}
|
|
|
|
return strings.Join(conditions, " "), args, nil
|
|
}
|
|
|
|
// buildFilterCondition builds a single filter condition with dialect-specific logic
|
|
func (qb *QueryBuilder) buildFilterCondition(filter DynamicFilter) (string, []interface{}, error) {
|
|
column := qb.validateAndEscapeColumn(filter.Column)
|
|
if column == "" {
|
|
return "", nil, fmt.Errorf("invalid or disallowed column: %s", filter.Column)
|
|
}
|
|
|
|
// Handle column-to-column comparison
|
|
if valStr, ok := filter.Value.(string); ok && strings.Contains(valStr, ".") && qb.isValidExpression(valStr) && len(strings.Split(valStr, ".")) == 2 {
|
|
escapedVal := qb.escapeColumnReference(valStr)
|
|
switch filter.Operator {
|
|
case OpEqual:
|
|
return fmt.Sprintf("%s = %s", column, escapedVal), nil, nil
|
|
case OpNotEqual:
|
|
return fmt.Sprintf("%s <> %s", column, escapedVal), nil, nil
|
|
case OpGreaterThan:
|
|
return fmt.Sprintf("%s > %s", column, escapedVal), nil, nil
|
|
case OpLessThan:
|
|
return fmt.Sprintf("%s < %s", column, escapedVal), nil, nil
|
|
}
|
|
}
|
|
|
|
// Handle JSON operations
|
|
switch filter.Operator {
|
|
case OpJsonContains, OpJsonNotContains, OpJsonExists, OpJsonNotExists, OpJsonEqual, OpJsonNotEqual:
|
|
return qb.buildJsonFilterCondition(filter)
|
|
case OpArrayContains, OpArrayNotContains, OpArrayLength:
|
|
return qb.buildArrayFilterCondition(filter)
|
|
}
|
|
|
|
// Handle standard operators
|
|
switch filter.Operator {
|
|
case OpEqual:
|
|
if filter.Value == nil {
|
|
return fmt.Sprintf("%s IS NULL", column), nil, nil
|
|
}
|
|
return fmt.Sprintf("%s = ?", column), []interface{}{filter.Value}, nil
|
|
case OpNotEqual:
|
|
if filter.Value == nil {
|
|
return fmt.Sprintf("%s IS NOT NULL", column), nil, nil
|
|
}
|
|
return fmt.Sprintf("%s <> ?", column), []interface{}{filter.Value}, nil
|
|
case OpLike:
|
|
if filter.Value == nil {
|
|
return "", nil, nil
|
|
}
|
|
return fmt.Sprintf("%s LIKE ?", column), []interface{}{filter.Value}, nil
|
|
case OpILike:
|
|
if filter.Value == nil {
|
|
return "", nil, nil
|
|
}
|
|
switch qb.dbType {
|
|
case DBTypePostgreSQL, DBTypeSQLite:
|
|
return fmt.Sprintf("%s ILIKE ?", column), []interface{}{filter.Value}, nil
|
|
case DBTypeMySQL, DBTypeSQLServer:
|
|
return fmt.Sprintf("LOWER(%s) LIKE LOWER(?)", column), []interface{}{filter.Value}, nil
|
|
default:
|
|
return fmt.Sprintf("%s LIKE ?", column), []interface{}{filter.Value}, nil
|
|
}
|
|
case OpIn, OpNotIn:
|
|
values := qb.parseArrayValue(filter.Value)
|
|
if len(values) == 0 {
|
|
return "1=0", nil, nil
|
|
}
|
|
op := "IN"
|
|
if filter.Operator == OpNotIn {
|
|
op = "NOT IN"
|
|
}
|
|
placeholders := squirrel.Placeholders(len(values))
|
|
return fmt.Sprintf("%s %s (%s)", column, op, placeholders), values, nil
|
|
case OpGreaterThan, OpGreaterThanEqual, OpLessThan, OpLessThanEqual:
|
|
if filter.Value == nil {
|
|
return "", nil, nil
|
|
}
|
|
op := strings.TrimPrefix(string(filter.Operator), "_")
|
|
return fmt.Sprintf("%s %s ?", column, op), []interface{}{filter.Value}, nil
|
|
case OpBetween, OpNotBetween:
|
|
values := qb.parseArrayValue(filter.Value)
|
|
if len(values) != 2 {
|
|
return "", nil, fmt.Errorf("between operator requires exactly 2 values")
|
|
}
|
|
op := "BETWEEN"
|
|
if filter.Operator == OpNotBetween {
|
|
op = "NOT BETWEEN"
|
|
}
|
|
return fmt.Sprintf("%s %s ? AND ?", column, op), []interface{}{values[0], values[1]}, nil
|
|
case OpNull:
|
|
return fmt.Sprintf("%s IS NULL", column), nil, nil
|
|
case OpNotNull:
|
|
return fmt.Sprintf("%s IS NOT NULL", column), nil, nil
|
|
case OpContains, OpNotContains, OpStartsWith, OpEndsWith:
|
|
if filter.Value == nil {
|
|
return "", nil, nil
|
|
}
|
|
var value string
|
|
switch filter.Operator {
|
|
case OpContains, OpNotContains:
|
|
value = fmt.Sprintf("%%%v%%", filter.Value)
|
|
case OpStartsWith:
|
|
value = fmt.Sprintf("%v%%", filter.Value)
|
|
case OpEndsWith:
|
|
value = fmt.Sprintf("%%%v", filter.Value)
|
|
}
|
|
|
|
switch qb.dbType {
|
|
case DBTypePostgreSQL, DBTypeSQLite:
|
|
op := "ILIKE"
|
|
if strings.Contains(string(filter.Operator), "Not") {
|
|
op = "NOT ILIKE"
|
|
}
|
|
return fmt.Sprintf("%s %s ?", column, op), []interface{}{value}, nil
|
|
case DBTypeMySQL, DBTypeSQLServer:
|
|
op := "LIKE"
|
|
if strings.Contains(string(filter.Operator), "Not") {
|
|
op = "NOT LIKE"
|
|
}
|
|
return fmt.Sprintf("LOWER(%s) %s LOWER(?)", column, op), []interface{}{value}, nil
|
|
default:
|
|
op := "LIKE"
|
|
if strings.Contains(string(filter.Operator), "Not") {
|
|
op = "NOT LIKE"
|
|
}
|
|
return fmt.Sprintf("%s %s ?", column, op), []interface{}{value}, nil
|
|
}
|
|
default:
|
|
return "", nil, fmt.Errorf("unsupported operator: %s", filter.Operator)
|
|
}
|
|
}
|
|
|
|
// buildJsonFilterCondition builds a JSON filter condition
|
|
func (qb *QueryBuilder) buildJsonFilterCondition(filter DynamicFilter) (string, []interface{}, error) {
|
|
column := qb.validateAndEscapeColumn(filter.Column)
|
|
if column == "" {
|
|
return "", nil, fmt.Errorf("invalid or disallowed column: %s", filter.Column)
|
|
}
|
|
|
|
path := "$"
|
|
if pathOption, ok := filter.Options["path"].(string); ok && pathOption != "" {
|
|
path = pathOption
|
|
}
|
|
|
|
var expr string
|
|
var args []interface{}
|
|
|
|
switch filter.Operator {
|
|
case OpJsonContains:
|
|
switch qb.dbType {
|
|
case DBTypePostgreSQL:
|
|
expr = fmt.Sprintf("%s @> ?", column)
|
|
args = append(args, filter.Value)
|
|
case DBTypeMySQL:
|
|
expr = fmt.Sprintf("JSON_CONTAINS(%s, ?, '%s')", column, path)
|
|
args = append(args, filter.Value)
|
|
case DBTypeSQLServer:
|
|
expr = fmt.Sprintf("JSON_VALUE(%s, '%s') = ?", column, qb.escapeSqlServerJsonPath(path))
|
|
args = append(args, filter.Value)
|
|
case DBTypeSQLite:
|
|
expr = fmt.Sprintf("json_extract(%s, '%s') = ?", column, path)
|
|
args = append(args, filter.Value)
|
|
default:
|
|
return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType)
|
|
}
|
|
case OpJsonNotContains:
|
|
switch qb.dbType {
|
|
case DBTypePostgreSQL:
|
|
expr = fmt.Sprintf("NOT (%s @> ?)", column)
|
|
args = append(args, filter.Value)
|
|
case DBTypeMySQL:
|
|
expr = fmt.Sprintf("NOT JSON_CONTAINS(%s, ?, '%s')", column, path)
|
|
args = append(args, filter.Value)
|
|
case DBTypeSQLServer:
|
|
expr = fmt.Sprintf("JSON_VALUE(%s, '%s') <> ?", column, qb.escapeSqlServerJsonPath(path))
|
|
args = append(args, filter.Value)
|
|
case DBTypeSQLite:
|
|
expr = fmt.Sprintf("json_extract(%s, '%s') <> ?", column, path)
|
|
args = append(args, filter.Value)
|
|
default:
|
|
return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType)
|
|
}
|
|
case OpJsonExists:
|
|
switch qb.dbType {
|
|
case DBTypePostgreSQL:
|
|
expr = fmt.Sprintf("jsonb_path_exists(%s, '%s')", column, path)
|
|
case DBTypeMySQL:
|
|
expr = fmt.Sprintf("JSON_CONTAINS_PATH(%s, 'one', '%s')", column, path)
|
|
case DBTypeSQLServer:
|
|
expr = fmt.Sprintf("JSON_VALUE(%s, '%s') IS NOT NULL", column, qb.escapeSqlServerJsonPath(path))
|
|
case DBTypeSQLite:
|
|
expr = fmt.Sprintf("json_extract(%s, '%s') IS NOT NULL", column, path)
|
|
default:
|
|
return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType)
|
|
}
|
|
case OpJsonNotExists:
|
|
switch qb.dbType {
|
|
case DBTypePostgreSQL:
|
|
expr = fmt.Sprintf("NOT jsonb_path_exists(%s, '%s')", column, path)
|
|
case DBTypeMySQL:
|
|
expr = fmt.Sprintf("NOT JSON_CONTAINS_PATH(%s, 'one', '%s')", column, path)
|
|
case DBTypeSQLServer:
|
|
expr = fmt.Sprintf("JSON_VALUE(%s, '%s') IS NULL", column, qb.escapeSqlServerJsonPath(path))
|
|
case DBTypeSQLite:
|
|
expr = fmt.Sprintf("json_extract(%s, '%s') IS NULL", column, path)
|
|
default:
|
|
return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType)
|
|
}
|
|
case OpJsonEqual:
|
|
switch qb.dbType {
|
|
case DBTypePostgreSQL:
|
|
expr = fmt.Sprintf("%s->>%s = ?", column, qb.escapeJsonPath(path))
|
|
args = append(args, filter.Value)
|
|
case DBTypeMySQL:
|
|
expr = fmt.Sprintf("JSON_EXTRACT(%s, '%s') = ?", column, path)
|
|
args = append(args, filter.Value)
|
|
case DBTypeSQLServer:
|
|
expr = fmt.Sprintf("JSON_VALUE(%s, '%s') = ?", column, qb.escapeSqlServerJsonPath(path))
|
|
args = append(args, filter.Value)
|
|
case DBTypeSQLite:
|
|
expr = fmt.Sprintf("json_extract(%s, '%s') = ?", column, path)
|
|
args = append(args, filter.Value)
|
|
default:
|
|
return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType)
|
|
}
|
|
case OpJsonNotEqual:
|
|
switch qb.dbType {
|
|
case DBTypePostgreSQL:
|
|
expr = fmt.Sprintf("%s->>%s <> ?", column, qb.escapeJsonPath(path))
|
|
args = append(args, filter.Value)
|
|
case DBTypeMySQL:
|
|
expr = fmt.Sprintf("JSON_EXTRACT(%s, '%s') <> ?", column, path)
|
|
args = append(args, filter.Value)
|
|
case DBTypeSQLServer:
|
|
expr = fmt.Sprintf("JSON_VALUE(%s, '%s') <> ?", column, qb.escapeSqlServerJsonPath(path))
|
|
args = append(args, filter.Value)
|
|
case DBTypeSQLite:
|
|
expr = fmt.Sprintf("json_extract(%s, '%s') <> ?", column, path)
|
|
args = append(args, filter.Value)
|
|
default:
|
|
return "", nil, fmt.Errorf("JSON operations not supported for database type: %s", qb.dbType)
|
|
}
|
|
default:
|
|
return "", nil, fmt.Errorf("unsupported JSON operator: %s", filter.Operator)
|
|
}
|
|
|
|
return expr, args, nil
|
|
}
|
|
|
|
// buildArrayFilterCondition builds an array filter condition
|
|
func (qb *QueryBuilder) buildArrayFilterCondition(filter DynamicFilter) (string, []interface{}, error) {
|
|
column := qb.validateAndEscapeColumn(filter.Column)
|
|
if column == "" {
|
|
return "", nil, fmt.Errorf("invalid or disallowed column: %s", filter.Column)
|
|
}
|
|
|
|
var expr string
|
|
var args []interface{}
|
|
|
|
switch filter.Operator {
|
|
case OpArrayContains:
|
|
switch qb.dbType {
|
|
case DBTypePostgreSQL:
|
|
expr = fmt.Sprintf("? = ANY(%s)", column)
|
|
args = append(args, filter.Value)
|
|
case DBTypeMySQL:
|
|
expr = fmt.Sprintf("JSON_CONTAINS(%s, JSON_QUOTE(?))", column)
|
|
args = append(args, filter.Value)
|
|
case DBTypeSQLServer:
|
|
expr = fmt.Sprintf("? IN (SELECT value FROM OPENJSON(%s))", column)
|
|
args = append(args, filter.Value)
|
|
case DBTypeSQLite:
|
|
expr = fmt.Sprintf("EXISTS (SELECT 1 FROM json_each(%s) WHERE json_each.value = ?)", column)
|
|
args = append(args, filter.Value)
|
|
default:
|
|
return "", nil, fmt.Errorf("Array operations not supported for database type: %s", qb.dbType)
|
|
}
|
|
case OpArrayNotContains:
|
|
switch qb.dbType {
|
|
case DBTypePostgreSQL:
|
|
expr = fmt.Sprintf("? <> ALL(%s)", column)
|
|
args = append(args, filter.Value)
|
|
case DBTypeMySQL:
|
|
expr = fmt.Sprintf("NOT JSON_CONTAINS(%s, JSON_QUOTE(?))", column)
|
|
args = append(args, filter.Value)
|
|
case DBTypeSQLServer:
|
|
expr = fmt.Sprintf("? NOT IN (SELECT value FROM OPENJSON(%s))", column)
|
|
args = append(args, filter.Value)
|
|
case DBTypeSQLite:
|
|
expr = fmt.Sprintf("NOT EXISTS (SELECT 1 FROM json_each(%s) WHERE json_each.value = ?)", column)
|
|
args = append(args, filter.Value)
|
|
default:
|
|
return "", nil, fmt.Errorf("Array operations not supported for database type: %s", qb.dbType)
|
|
}
|
|
case OpArrayLength:
|
|
switch qb.dbType {
|
|
case DBTypePostgreSQL:
|
|
if lengthOption, ok := filter.Options["length"].(int); ok {
|
|
expr = fmt.Sprintf("array_length(%s, 1) = ?", column)
|
|
args = append(args, lengthOption)
|
|
} else {
|
|
return "", nil, fmt.Errorf("array_length operator requires 'length' option")
|
|
}
|
|
case DBTypeMySQL:
|
|
if lengthOption, ok := filter.Options["length"].(int); ok {
|
|
expr = fmt.Sprintf("JSON_LENGTH(%s) = ?", column)
|
|
args = append(args, lengthOption)
|
|
} else {
|
|
return "", nil, fmt.Errorf("array_length operator requires 'length' option")
|
|
}
|
|
case DBTypeSQLServer:
|
|
if lengthOption, ok := filter.Options["length"].(int); ok {
|
|
expr = fmt.Sprintf("(SELECT COUNT(*) FROM OPENJSON(%s)) = ?", column)
|
|
args = append(args, lengthOption)
|
|
} else {
|
|
return "", nil, fmt.Errorf("array_length operator requires 'length' option")
|
|
}
|
|
case DBTypeSQLite:
|
|
if lengthOption, ok := filter.Options["length"].(int); ok {
|
|
expr = fmt.Sprintf("json_array_length(%s) = ?", column)
|
|
args = append(args, lengthOption)
|
|
} else {
|
|
return "", nil, fmt.Errorf("array_length operator requires 'length' option")
|
|
}
|
|
default:
|
|
return "", nil, fmt.Errorf("Array operations not supported for database type: %s", qb.dbType)
|
|
}
|
|
default:
|
|
return "", nil, fmt.Errorf("unsupported array operator: %s", filter.Operator)
|
|
}
|
|
|
|
return expr, args, nil
|
|
}
|
|
|
|
// =============================================================================
|
|
// SECTION 6: EXECUTION METHODS (NEW)
|
|
// Metode untuk mengeksekusi query langsung dengan logging performa.
|
|
// =============================================================================
|
|
|
|
func (qb *QueryBuilder) ExecuteQuery(ctx context.Context, db *sqlx.DB, query DynamicQuery, dest interface{}) error {
|
|
// sql, args, err := qb.BuildQuery(query)
|
|
// if err != nil {
|
|
// return err
|
|
// }
|
|
// start := time.Now()
|
|
// err = db.SelectContext(ctx, dest, sql, args...)
|
|
// fmt.Printf("[DEBUG] Query executed in %v\n", time.Since(start))
|
|
// return err
|
|
sql, args, err := qb.BuildQuery(query)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Set timeout if not already in context
|
|
if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout)
|
|
defer cancel()
|
|
}
|
|
|
|
start := time.Now()
|
|
|
|
// Check if dest is a pointer to a slice of maps
|
|
destValue := reflect.ValueOf(dest)
|
|
if destValue.Kind() != reflect.Ptr || destValue.IsNil() {
|
|
return fmt.Errorf("dest must be a non-nil pointer")
|
|
}
|
|
|
|
destElem := destValue.Elem()
|
|
if destElem.Kind() == reflect.Slice {
|
|
sliceType := destElem.Type().Elem()
|
|
if sliceType.Kind() == reflect.Map &&
|
|
sliceType.Key().Kind() == reflect.String &&
|
|
sliceType.Elem().Kind() == reflect.Interface {
|
|
|
|
// Handle slice of map[string]interface{}
|
|
rows, err := db.QueryxContext(ctx, sql, args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
row := make(map[string]interface{})
|
|
if err := rows.MapScan(row); err != nil {
|
|
return err
|
|
}
|
|
destElem.Set(reflect.Append(destElem, reflect.ValueOf(row)))
|
|
}
|
|
|
|
fmt.Printf("[DEBUG] Query executed in %v\n", time.Since(start))
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Default case: use SelectContext
|
|
err = db.SelectContext(ctx, dest, sql, args...)
|
|
if qb.enableQueryLogging {
|
|
fmt.Printf("[DEBUG] Query executed in %v\n", time.Since(start))
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (qb *QueryBuilder) ExecuteQueryRow(ctx context.Context, db *sqlx.DB, query DynamicQuery, dest interface{}) error {
|
|
sql, args, err := qb.BuildQuery(query)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Set timeout if not already in context
|
|
if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout)
|
|
defer cancel()
|
|
}
|
|
|
|
start := time.Now()
|
|
err = db.GetContext(ctx, dest, sql, args...)
|
|
if qb.enableQueryLogging {
|
|
fmt.Printf("[DEBUG] QueryRow executed in %v\n", time.Since(start))
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (qb *QueryBuilder) ExecuteCount(ctx context.Context, db *sqlx.DB, query DynamicQuery) (int64, error) {
|
|
sql, args, err := qb.BuildCountQuery(query)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// Set timeout if not already in context
|
|
if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout)
|
|
defer cancel()
|
|
}
|
|
|
|
var count int64
|
|
start := time.Now()
|
|
err = db.GetContext(ctx, &count, sql, args...)
|
|
if qb.enableQueryLogging {
|
|
fmt.Printf("[DEBUG] Count query executed in %v\n", time.Since(start))
|
|
}
|
|
return count, err
|
|
}
|
|
|
|
func (qb *QueryBuilder) ExecuteInsert(ctx context.Context, db *sqlx.DB, table string, data InsertData, returningColumns ...string) (sql.Result, error) {
|
|
// Security check for table name
|
|
if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[table] {
|
|
return nil, fmt.Errorf("disallowed table: %s", table)
|
|
}
|
|
|
|
sql, args, err := qb.BuildInsertQuery(table, data, returningColumns...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Set timeout if not already in context
|
|
if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout)
|
|
defer cancel()
|
|
}
|
|
|
|
start := time.Now()
|
|
result, err := db.ExecContext(ctx, sql, args...)
|
|
if qb.enableQueryLogging {
|
|
fmt.Printf("[DEBUG] Insert query executed in %v\n", time.Since(start))
|
|
}
|
|
return result, err
|
|
}
|
|
|
|
func (qb *QueryBuilder) ExecuteUpdate(ctx context.Context, db *sqlx.DB, table string, updateData UpdateData, filters []FilterGroup, returningColumns ...string) (sql.Result, error) {
|
|
// Security check for table name
|
|
if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[table] {
|
|
return nil, fmt.Errorf("disallowed table: %s", table)
|
|
}
|
|
|
|
sql, args, err := qb.BuildUpdateQuery(table, updateData, filters, returningColumns...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Set timeout if not already in context
|
|
if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout)
|
|
defer cancel()
|
|
}
|
|
|
|
start := time.Now()
|
|
result, err := db.ExecContext(ctx, sql, args...)
|
|
if qb.enableQueryLogging {
|
|
fmt.Printf("[DEBUG] Update query executed in %v\n", time.Since(start))
|
|
}
|
|
return result, err
|
|
}
|
|
|
|
func (qb *QueryBuilder) ExecuteDelete(ctx context.Context, db *sqlx.DB, table string, filters []FilterGroup, returningColumns ...string) (sql.Result, error) {
|
|
// Security check for table name
|
|
if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[table] {
|
|
return nil, fmt.Errorf("disallowed table: %s", table)
|
|
}
|
|
|
|
sql, args, err := qb.BuildDeleteQuery(table, filters, returningColumns...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Set timeout if not already in context
|
|
if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout)
|
|
defer cancel()
|
|
}
|
|
|
|
start := time.Now()
|
|
result, err := db.ExecContext(ctx, sql, args...)
|
|
if qb.enableQueryLogging {
|
|
fmt.Printf("[DEBUG] Delete query executed in %v\n", time.Since(start))
|
|
}
|
|
return result, err
|
|
}
|
|
|
|
func (qb *QueryBuilder) ExecuteUpsert(ctx context.Context, db *sqlx.DB, table string, insertData InsertData, conflictColumns []string, updateColumns []string, returningColumns ...string) (sql.Result, error) {
|
|
// Security check for table name
|
|
if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[table] {
|
|
return nil, fmt.Errorf("disallowed table: %s", table)
|
|
}
|
|
|
|
sql, args, err := qb.BuildUpsertQuery(table, insertData, conflictColumns, updateColumns, returningColumns...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Set timeout if not already in context
|
|
if _, hasDeadline := ctx.Deadline(); !hasDeadline && qb.queryTimeout > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, qb.queryTimeout)
|
|
defer cancel()
|
|
}
|
|
|
|
start := time.Now()
|
|
result, err := db.ExecContext(ctx, sql, args...)
|
|
if qb.enableQueryLogging {
|
|
fmt.Printf("[DEBUG] Upsert query executed in %v\n", time.Since(start))
|
|
}
|
|
return result, err
|
|
}
|
|
|
|
// --- Helper and Validation Methods ---
|
|
|
|
func (qb *QueryBuilder) buildGroupByColumns(fields []string) []string {
|
|
var groupCols []string
|
|
for _, field := range fields {
|
|
col := qb.validateAndEscapeColumn(field)
|
|
if col != "" {
|
|
groupCols = append(groupCols, col)
|
|
}
|
|
}
|
|
return groupCols
|
|
}
|
|
|
|
func (qb *QueryBuilder) parseArrayValue(value interface{}) []interface{} {
|
|
if value == nil {
|
|
return nil
|
|
}
|
|
if reflect.TypeOf(value).Kind() == reflect.Slice {
|
|
v := reflect.ValueOf(value)
|
|
result := make([]interface{}, v.Len())
|
|
for i := 0; i < v.Len(); i++ {
|
|
result[i] = v.Index(i).Interface()
|
|
}
|
|
return result
|
|
}
|
|
if str, ok := value.(string); ok {
|
|
if strings.Contains(str, ",") {
|
|
parts := strings.Split(str, ",")
|
|
result := make([]interface{}, len(parts))
|
|
for i, part := range parts {
|
|
result[i] = strings.TrimSpace(part)
|
|
}
|
|
return result
|
|
}
|
|
return []interface{}{str}
|
|
}
|
|
return []interface{}{value}
|
|
}
|
|
|
|
func (qb *QueryBuilder) validateAndEscapeColumn(field string) string {
|
|
if field == "" {
|
|
return ""
|
|
}
|
|
// Allow complex expressions like functions
|
|
if strings.Contains(field, "(") {
|
|
if qb.isValidExpression(field) {
|
|
return field // Don't escape complex expressions, assume they are safe
|
|
}
|
|
return ""
|
|
}
|
|
// Handle dotted column names like "table.column"
|
|
if strings.Contains(field, ".") {
|
|
if qb.isValidExpression(field) {
|
|
// Split on dot and escape each part
|
|
parts := strings.Split(field, ".")
|
|
var escapedParts []string
|
|
for _, part := range parts {
|
|
escapedParts = append(escapedParts, qb.escapeIdentifier(part))
|
|
}
|
|
return strings.Join(escapedParts, ".")
|
|
}
|
|
return ""
|
|
}
|
|
// Simple column name
|
|
if qb.allowedColumns != nil && !qb.allowedColumns[field] {
|
|
return ""
|
|
}
|
|
return qb.escapeIdentifier(field)
|
|
}
|
|
|
|
func (qb *QueryBuilder) isValidExpression(expr string) bool {
|
|
// This is a simplified check. A more robust solution might use a proper SQL parser library.
|
|
// For now, we allow alphanumeric, underscore, dots, parentheses, and common operators.
|
|
// For SQL Server, allow brackets [] and spaces for column names.
|
|
allowedChars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_.,() *-/[]"
|
|
for _, r := range expr {
|
|
if !strings.ContainsRune(allowedChars, r) {
|
|
return false
|
|
}
|
|
}
|
|
// Check for dangerous keywords
|
|
dangerousPatterns := []string{"--", "/*", "*/", "union", "select", "insert", "update", "delete", "drop", "alter", "create", "exec", "execute"}
|
|
lowerExpr := strings.ToLower(expr)
|
|
for _, pattern := range dangerousPatterns {
|
|
if strings.Contains(lowerExpr, pattern) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (qb *QueryBuilder) isValidFunctionName(name string) bool {
|
|
// Check if the function name is a valid SQL function
|
|
validFunctions := map[string]bool{
|
|
// Aggregate functions
|
|
"count": true, "sum": true, "avg": true, "min": true, "max": true,
|
|
// Window functions
|
|
"row_number": true, "rank": true, "dense_rank": true, "ntile": true,
|
|
"lag": true, "lead": true, "first_value": true, "last_value": true,
|
|
// JSON functions
|
|
"json_extract": true, "json_contains": true, "json_search": true,
|
|
"json_array": true, "json_object": true, "json_merge": true,
|
|
// Other functions
|
|
"concat": true, "substring": true, "upper": true, "lower": true,
|
|
"trim": true, "coalesce": true, "nullif": true, "isnull": true,
|
|
}
|
|
|
|
return validFunctions[strings.ToLower(name)]
|
|
}
|
|
|
|
func (qb *QueryBuilder) escapeColumnReference(col string) string {
|
|
parts := strings.Split(col, ".")
|
|
var escaped []string
|
|
for _, p := range parts {
|
|
if strings.HasPrefix(p, "[") && strings.HasSuffix(p, "]") {
|
|
escaped = append(escaped, p)
|
|
} else {
|
|
escaped = append(escaped, qb.escapeIdentifier(p))
|
|
}
|
|
}
|
|
return strings.Join(escaped, ".")
|
|
}
|
|
|
|
func (qb *QueryBuilder) escapeIdentifier(col string) string {
|
|
switch qb.dbType {
|
|
case DBTypePostgreSQL, DBTypeSQLite:
|
|
return fmt.Sprintf("\"%s\"", strings.ReplaceAll(col, "\"", "\"\""))
|
|
case DBTypeMySQL:
|
|
return fmt.Sprintf("`%s`", strings.ReplaceAll(col, "`", "``"))
|
|
case DBTypeSQLServer:
|
|
return fmt.Sprintf("[%s]", strings.ReplaceAll(col, "]", "]]"))
|
|
default:
|
|
return col
|
|
}
|
|
}
|
|
|
|
// checkForSqlInjectionInArgs checks for potential SQL injection patterns in query arguments
|
|
func (qb *QueryBuilder) checkForSqlInjectionInArgs(args []interface{}) error {
|
|
if !qb.enableSecurityChecks {
|
|
return nil
|
|
}
|
|
|
|
for _, arg := range args {
|
|
if str, ok := arg.(string); ok {
|
|
lowerStr := strings.ToLower(str)
|
|
// Check for dangerous patterns specifically in user input values
|
|
dangerousPatterns := []*regexp.Regexp{
|
|
regexp.MustCompile(`(?i)(union\s+select)`),
|
|
regexp.MustCompile(`(?i)(or\s+1\s*=\s*1)`),
|
|
regexp.MustCompile(`(?i)(and\s+true)`),
|
|
regexp.MustCompile(`(?i)(waitfor\s+delay)`),
|
|
regexp.MustCompile(`(?i)(benchmark|sleep)\s*\(`),
|
|
regexp.MustCompile(`(?i)(pg_sleep)\s*\(`),
|
|
regexp.MustCompile(`(?i)(load_file|into\s+outfile)`),
|
|
regexp.MustCompile(`(?i)(information_schema|sysobjects|syscolumns)`),
|
|
regexp.MustCompile(`(?i)(--|\/\*|\*\/)`),
|
|
}
|
|
|
|
for _, pattern := range dangerousPatterns {
|
|
if pattern.MatchString(lowerStr) {
|
|
return fmt.Errorf("potential SQL injection detected in query argument: pattern %s matched", pattern.String())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// checkForSqlInjectionInSQL checks for potential SQL injection patterns in the final SQL
|
|
func (qb *QueryBuilder) checkForSqlInjectionInSQL(sql string) error {
|
|
if !qb.enableSecurityChecks {
|
|
return nil
|
|
}
|
|
|
|
// Check for dangerous patterns in the final SQL
|
|
// But allow valid SQL keywords in their proper context
|
|
lowerSQL := strings.ToLower(sql)
|
|
|
|
// More specific patterns that actually indicate injection attempts
|
|
dangerousPatterns := []*regexp.Regexp{
|
|
regexp.MustCompile(`(?i)(union\s+select)`), // UNION followed by SELECT
|
|
regexp.MustCompile(`(?i)(select\s+.*\s+from\s+.*\s+where\s+.*\s+or\s+1\s*=\s*1)`), // Classic SQL injection
|
|
regexp.MustCompile(`(?i)(drop\s+table)`), // DROP TABLE
|
|
regexp.MustCompile(`(?i)(delete\s+from)`), // DELETE FROM
|
|
regexp.MustCompile(`(?i)(insert\s+into)`), // INSERT INTO
|
|
regexp.MustCompile(`(?i)(update\s+.*\s+set)`), // UPDATE SET
|
|
regexp.MustCompile(`(?i)(alter\s+table)`), // ALTER TABLE
|
|
regexp.MustCompile(`(?i)(create\s+table)`), // CREATE TABLE
|
|
regexp.MustCompile(`(?i)(exec\s*\(|execute\s*\()`), // EXEC/EXECUTE functions
|
|
regexp.MustCompile(`(?i)(--|\/\*|\*\/)`), // SQL comments
|
|
}
|
|
|
|
for _, pattern := range dangerousPatterns {
|
|
if pattern.MatchString(lowerSQL) {
|
|
return fmt.Errorf("potential SQL injection detected in SQL: pattern %s matched", pattern.String())
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// --- Other Query Builders (Insert, Update, Delete, Upsert, Count) ---
|
|
|
|
// BuildCountQuery builds a count query
|
|
func (qb *QueryBuilder) BuildCountQuery(query DynamicQuery) (string, []interface{}, error) {
|
|
// For a count query, we don't need fields, joins, or unions.
|
|
// We only need FROM, WHERE, GROUP BY, HAVING.
|
|
countQuery := DynamicQuery{
|
|
From: query.From,
|
|
Aliases: query.Aliases,
|
|
Filters: query.Filters,
|
|
GroupBy: query.GroupBy,
|
|
Having: query.Having,
|
|
// Joins are important for count with filters on joined tables
|
|
Joins: query.Joins,
|
|
}
|
|
|
|
// Build the base query for the count using Squirrel's From and Join methods
|
|
fromClause := qb.buildFromClause(countQuery.From, countQuery.Aliases)
|
|
baseQuery := qb.sqlBuilder.Select("COUNT(*)").From(fromClause)
|
|
|
|
// Add JOINs using Squirrel's Join method
|
|
if len(countQuery.Joins) > 0 {
|
|
for _, join := range countQuery.Joins {
|
|
// Security check for joined table
|
|
if qb.enableSecurityChecks && len(qb.allowedTables) > 0 && !qb.allowedTables[join.Table] {
|
|
return "", nil, fmt.Errorf("disallowed table in join: %s", join.Table)
|
|
}
|
|
|
|
joinType, tableWithAlias, onClause, joinArgs, err := qb.buildSingleJoinClause(join)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
joinStr := tableWithAlias + " ON " + onClause
|
|
switch strings.ToUpper(joinType) {
|
|
case "LEFT":
|
|
baseQuery = baseQuery.LeftJoin(joinStr, joinArgs...)
|
|
case "RIGHT":
|
|
baseQuery = baseQuery.RightJoin(joinStr, joinArgs...)
|
|
case "FULL":
|
|
baseQuery = baseQuery.Join("FULL JOIN "+joinStr, joinArgs...)
|
|
default:
|
|
baseQuery = baseQuery.Join(joinStr, joinArgs...)
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(countQuery.Filters) > 0 {
|
|
whereClause, whereArgs, err := qb.BuildWhereClause(countQuery.Filters)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
baseQuery = baseQuery.Where(whereClause, whereArgs...)
|
|
}
|
|
|
|
if len(countQuery.GroupBy) > 0 {
|
|
baseQuery = baseQuery.GroupBy(qb.buildGroupByColumns(countQuery.GroupBy)...)
|
|
}
|
|
|
|
if len(countQuery.Having) > 0 {
|
|
havingClause, havingArgs, err := qb.BuildWhereClause(countQuery.Having)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
baseQuery = baseQuery.Having(havingClause, havingArgs...)
|
|
}
|
|
|
|
sql, args, err := baseQuery.ToSql()
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to build COUNT query: %w", err)
|
|
}
|
|
|
|
if qb.enableQueryLogging {
|
|
fmt.Printf("[DEBUG] COUNT SQL query: %s\n", sql)
|
|
fmt.Printf("[DEBUG] COUNT query args: %v\n", args)
|
|
}
|
|
return sql, args, nil
|
|
}
|
|
|
|
// BuildInsertQuery builds an INSERT query
|
|
func (qb *QueryBuilder) BuildInsertQuery(table string, data InsertData, returningColumns ...string) (string, []interface{}, error) {
|
|
// Validate columns
|
|
for _, col := range data.Columns {
|
|
if qb.allowedColumns != nil && !qb.allowedColumns[col] {
|
|
return "", nil, fmt.Errorf("disallowed column: %s", col)
|
|
}
|
|
}
|
|
|
|
// Start with basic insert
|
|
insert := qb.sqlBuilder.Insert(table).Columns(data.Columns...).Values(data.Values...)
|
|
|
|
// Handle JSON values - we need to modify the insert statement
|
|
if len(data.JsonValues) > 0 {
|
|
// Create a new insert builder with all columns including JSON columns
|
|
allColumns := make([]string, len(data.Columns))
|
|
copy(allColumns, data.Columns)
|
|
|
|
allValues := make([]interface{}, len(data.Values))
|
|
copy(allValues, data.Values)
|
|
|
|
for col, val := range data.JsonValues {
|
|
allColumns = append(allColumns, col)
|
|
jsonVal, err := json.Marshal(val)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err)
|
|
}
|
|
allValues = append(allValues, jsonVal)
|
|
}
|
|
|
|
insert = qb.sqlBuilder.Insert(table).Columns(allColumns...).Values(allValues...)
|
|
}
|
|
|
|
if len(returningColumns) > 0 {
|
|
if qb.dbType == DBTypePostgreSQL {
|
|
insert = insert.Suffix("RETURNING " + strings.Join(returningColumns, ", "))
|
|
} else {
|
|
return "", nil, fmt.Errorf("RETURNING not supported for database type: %s", qb.dbType)
|
|
}
|
|
}
|
|
|
|
sql, args, err := insert.ToSql()
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to build INSERT query: %w", err)
|
|
}
|
|
|
|
return sql, args, nil
|
|
}
|
|
|
|
// BuildUpdateQuery builds an UPDATE query
|
|
func (qb *QueryBuilder) BuildUpdateQuery(table string, updateData UpdateData, filters []FilterGroup, returningColumns ...string) (string, []interface{}, error) {
|
|
// Validate columns
|
|
for _, col := range updateData.Columns {
|
|
if qb.allowedColumns != nil && !qb.allowedColumns[col] {
|
|
return "", nil, fmt.Errorf("disallowed column: %s", col)
|
|
}
|
|
}
|
|
|
|
// Start with basic update
|
|
update := qb.sqlBuilder.Update(table).SetMap(qb.buildSetMap(updateData))
|
|
|
|
// Handle JSON updates - we need to modify the update statement
|
|
if len(updateData.JsonUpdates) > 0 {
|
|
// Create a new set map including JSON updates
|
|
setMap := qb.buildSetMap(updateData)
|
|
|
|
for col, jsonUpdate := range updateData.JsonUpdates {
|
|
switch qb.dbType {
|
|
case DBTypePostgreSQL:
|
|
jsonVal, err := json.Marshal(jsonUpdate.Value)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err)
|
|
}
|
|
// Use jsonb_set function for updating specific paths
|
|
setMap[col] = squirrel.Expr(fmt.Sprintf("jsonb_set(%s, '%s', ?)", qb.escapeIdentifier(col), jsonUpdate.Path), jsonVal)
|
|
case DBTypeMySQL:
|
|
jsonVal, err := json.Marshal(jsonUpdate.Value)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err)
|
|
}
|
|
// Use JSON_SET function for updating specific paths
|
|
setMap[col] = squirrel.Expr(fmt.Sprintf("JSON_SET(%s, '%s', ?)", qb.escapeIdentifier(col), jsonUpdate.Path), jsonVal)
|
|
case DBTypeSQLServer:
|
|
jsonVal, err := json.Marshal(jsonUpdate.Value)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err)
|
|
}
|
|
// Use JSON_MODIFY function for updating specific paths
|
|
setMap[col] = squirrel.Expr(fmt.Sprintf("JSON_MODIFY(%s, '%s', ?)", qb.escapeIdentifier(col), jsonUpdate.Path), jsonVal)
|
|
case DBTypeSQLite:
|
|
jsonVal, err := json.Marshal(jsonUpdate.Value)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err)
|
|
}
|
|
// SQLite doesn't have a built-in JSON_SET function, so we need to use json_patch
|
|
setMap[col] = squirrel.Expr(fmt.Sprintf("json_patch(%s, ?)", qb.escapeIdentifier(col)), jsonVal)
|
|
}
|
|
}
|
|
|
|
update = qb.sqlBuilder.Update(table).SetMap(setMap)
|
|
}
|
|
|
|
if len(filters) > 0 {
|
|
whereClause, whereArgs, err := qb.BuildWhereClause(filters)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
update = update.Where(whereClause, whereArgs...)
|
|
}
|
|
|
|
if len(returningColumns) > 0 {
|
|
if qb.dbType == DBTypePostgreSQL {
|
|
update = update.Suffix("RETURNING " + strings.Join(returningColumns, ", "))
|
|
} else {
|
|
return "", nil, fmt.Errorf("RETURNING not supported for database type: %s", qb.dbType)
|
|
}
|
|
}
|
|
|
|
sql, args, err := update.ToSql()
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to build UPDATE query: %w", err)
|
|
}
|
|
|
|
return sql, args, nil
|
|
}
|
|
|
|
// buildSetMap builds a map for SetMap from UpdateData
|
|
func (qb *QueryBuilder) buildSetMap(updateData UpdateData) map[string]interface{} {
|
|
setMap := make(map[string]interface{})
|
|
for i, col := range updateData.Columns {
|
|
setMap[col] = updateData.Values[i]
|
|
}
|
|
return setMap
|
|
}
|
|
|
|
// BuildDeleteQuery builds a DELETE query
|
|
func (qb *QueryBuilder) BuildDeleteQuery(table string, filters []FilterGroup, returningColumns ...string) (string, []interface{}, error) {
|
|
delete := qb.sqlBuilder.Delete(table)
|
|
|
|
if len(filters) > 0 {
|
|
whereClause, whereArgs, err := qb.BuildWhereClause(filters)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
delete = delete.Where(whereClause, whereArgs...)
|
|
}
|
|
|
|
if len(returningColumns) > 0 {
|
|
if qb.dbType == DBTypePostgreSQL {
|
|
delete = delete.Suffix("RETURNING " + strings.Join(returningColumns, ", "))
|
|
} else {
|
|
return "", nil, fmt.Errorf("RETURNING not supported for database type: %s", qb.dbType)
|
|
}
|
|
}
|
|
|
|
sql, args, err := delete.ToSql()
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to build DELETE query: %w", err)
|
|
}
|
|
|
|
return sql, args, nil
|
|
}
|
|
|
|
// BuildUpsertQuery builds an UPSERT query
|
|
func (qb *QueryBuilder) BuildUpsertQuery(table string, insertData InsertData, conflictColumns []string, updateColumns []string, returningColumns ...string) (string, []interface{}, error) {
|
|
// Validate columns
|
|
for _, col := range insertData.Columns {
|
|
if qb.allowedColumns != nil && !qb.allowedColumns[col] {
|
|
return "", nil, fmt.Errorf("disallowed column: %s", col)
|
|
}
|
|
}
|
|
for _, col := range updateColumns {
|
|
if qb.allowedColumns != nil && !qb.allowedColumns[col] {
|
|
return "", nil, fmt.Errorf("disallowed column: %s", col)
|
|
}
|
|
}
|
|
|
|
switch qb.dbType {
|
|
case DBTypePostgreSQL:
|
|
// Handle JSON values for PostgreSQL
|
|
allColumns := make([]string, len(insertData.Columns))
|
|
copy(allColumns, insertData.Columns)
|
|
|
|
allValues := make([]interface{}, len(insertData.Values))
|
|
copy(allValues, insertData.Values)
|
|
|
|
for col, val := range insertData.JsonValues {
|
|
allColumns = append(allColumns, col)
|
|
jsonVal, err := json.Marshal(val)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err)
|
|
}
|
|
allValues = append(allValues, jsonVal)
|
|
}
|
|
|
|
insert := qb.sqlBuilder.Insert(table).Columns(allColumns...).Values(allValues...)
|
|
if len(conflictColumns) > 0 {
|
|
conflictTarget := strings.Join(conflictColumns, ", ")
|
|
setClause := ""
|
|
for _, col := range updateColumns {
|
|
if setClause != "" {
|
|
setClause += ", "
|
|
}
|
|
setClause += fmt.Sprintf("%s = EXCLUDED.%s", qb.escapeIdentifier(col), qb.escapeIdentifier(col))
|
|
}
|
|
insert = insert.Suffix(fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET %s", conflictTarget, setClause))
|
|
}
|
|
if len(returningColumns) > 0 {
|
|
insert = insert.Suffix("RETURNING " + strings.Join(returningColumns, ", "))
|
|
}
|
|
sql, args, err := insert.ToSql()
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to build UPSERT query: %w", err)
|
|
}
|
|
return sql, args, nil
|
|
case DBTypeMySQL:
|
|
// Handle JSON values for MySQL
|
|
allColumns := make([]string, len(insertData.Columns))
|
|
copy(allColumns, insertData.Columns)
|
|
|
|
allValues := make([]interface{}, len(insertData.Values))
|
|
copy(allValues, insertData.Values)
|
|
|
|
for col, val := range insertData.JsonValues {
|
|
allColumns = append(allColumns, col)
|
|
jsonVal, err := json.Marshal(val)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to marshal JSON value for column %s: %w", col, err)
|
|
}
|
|
allValues = append(allValues, jsonVal)
|
|
}
|
|
|
|
insert := qb.sqlBuilder.Insert(table).Columns(allColumns...).Values(allValues...)
|
|
if len(updateColumns) > 0 {
|
|
setClause := ""
|
|
for _, col := range updateColumns {
|
|
if setClause != "" {
|
|
setClause += ", "
|
|
}
|
|
setClause += fmt.Sprintf("%s = VALUES(%s)", qb.escapeIdentifier(col), qb.escapeIdentifier(col))
|
|
}
|
|
insert = insert.Suffix(fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", setClause))
|
|
}
|
|
sql, args, err := insert.ToSql()
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to build UPSERT query: %w", err)
|
|
}
|
|
return sql, args, nil
|
|
default:
|
|
return "", nil, fmt.Errorf("UPSERT not supported for database type: %s", qb.dbType)
|
|
}
|
|
}
|
|
|
|
// --- QueryParser (for parsing URL query strings) ---
|
|
|
|
type QueryParser struct {
|
|
defaultLimit int
|
|
maxLimit int
|
|
}
|
|
|
|
func NewQueryParser() *QueryParser {
|
|
return &QueryParser{defaultLimit: 10, maxLimit: 100}
|
|
}
|
|
|
|
func (qp *QueryParser) SetLimits(defaultLimit, maxLimit int) *QueryParser {
|
|
qp.defaultLimit = defaultLimit
|
|
qp.maxLimit = maxLimit
|
|
return qp
|
|
}
|
|
|
|
// ParseQuery parses URL query parameters into a DynamicQuery struct.
|
|
func (qp *QueryParser) ParseQuery(values url.Values, defaultTable string) (DynamicQuery, error) {
|
|
query := DynamicQuery{
|
|
From: defaultTable,
|
|
Limit: qp.defaultLimit,
|
|
Offset: 0,
|
|
}
|
|
|
|
// Parse fields
|
|
if fields := values.Get("fields"); fields != "" {
|
|
if fields == "*" {
|
|
query.Fields = []SelectField{{Expression: "*"}}
|
|
} else {
|
|
fieldList := strings.Split(fields, ",")
|
|
for _, field := range fieldList {
|
|
query.Fields = append(query.Fields, SelectField{Expression: strings.TrimSpace(field)})
|
|
}
|
|
}
|
|
} else {
|
|
query.Fields = []SelectField{{Expression: "*"}}
|
|
}
|
|
|
|
// Parse pagination
|
|
if limit := values.Get("limit"); limit != "" {
|
|
if l, err := strconv.Atoi(limit); err == nil && l > 0 && l <= qp.maxLimit {
|
|
query.Limit = l
|
|
}
|
|
}
|
|
if offset := values.Get("offset"); offset != "" {
|
|
if o, err := strconv.Atoi(offset); err == nil && o >= 0 {
|
|
query.Offset = o
|
|
}
|
|
}
|
|
|
|
// Parse filters
|
|
filters, err := qp.parseFilters(values)
|
|
if err != nil {
|
|
return query, err
|
|
}
|
|
query.Filters = filters
|
|
|
|
// Parse sorting
|
|
sorts, err := qp.parseSorting(values)
|
|
if err != nil {
|
|
return query, err
|
|
}
|
|
query.Sort = sorts
|
|
|
|
return query, nil
|
|
}
|
|
|
|
func (qp *QueryParser) parseFilters(values url.Values) ([]FilterGroup, error) {
|
|
filterMap := make(map[string]map[string]string)
|
|
for key, vals := range values {
|
|
if strings.HasPrefix(key, "filter[") && strings.HasSuffix(key, "]") {
|
|
parts := strings.Split(key[7:len(key)-1], "][")
|
|
if len(parts) == 2 {
|
|
column, operator := parts[0], parts[1]
|
|
if filterMap[column] == nil {
|
|
filterMap[column] = make(map[string]string)
|
|
}
|
|
if len(vals) > 0 {
|
|
filterMap[column][operator] = vals[0]
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if len(filterMap) == 0 {
|
|
return nil, nil
|
|
}
|
|
var filters []DynamicFilter
|
|
for column, operators := range filterMap {
|
|
for opStr, value := range operators {
|
|
operator := FilterOperator(opStr)
|
|
var parsedValue interface{}
|
|
switch operator {
|
|
case OpIn, OpNotIn:
|
|
if value != "" {
|
|
parsedValue = strings.Split(value, ",")
|
|
}
|
|
case OpBetween, OpNotBetween:
|
|
if value != "" {
|
|
parts := strings.Split(value, ",")
|
|
if len(parts) == 2 {
|
|
parsedValue = []interface{}{strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])}
|
|
}
|
|
}
|
|
case OpNull, OpNotNull:
|
|
parsedValue = nil
|
|
default:
|
|
parsedValue = value
|
|
}
|
|
filters = append(filters, DynamicFilter{Column: column, Operator: operator, Value: parsedValue})
|
|
}
|
|
}
|
|
if len(filters) == 0 {
|
|
return nil, nil
|
|
}
|
|
return []FilterGroup{{Filters: filters, LogicOp: "AND"}}, nil
|
|
}
|
|
|
|
func (qp *QueryParser) parseSorting(values url.Values) ([]SortField, error) {
|
|
sortParam := values.Get("sort")
|
|
if sortParam == "" {
|
|
return nil, nil
|
|
}
|
|
var sorts []SortField
|
|
fields := strings.Split(sortParam, ",")
|
|
for _, field := range fields {
|
|
field = strings.TrimSpace(field)
|
|
if field == "" {
|
|
continue
|
|
}
|
|
order, column := "ASC", field
|
|
if strings.HasPrefix(field, "-") {
|
|
order = "DESC"
|
|
column = field[1:]
|
|
} else if strings.HasPrefix(field, "+") {
|
|
column = field[1:]
|
|
}
|
|
sorts = append(sorts, SortField{Column: column, Order: order})
|
|
}
|
|
return sorts, nil
|
|
}
|
|
|
|
// ParseQueryWithDefaultFields parses URL query parameters into a DynamicQuery struct with default fields.
|
|
func (qp *QueryParser) ParseQueryWithDefaultFields(values url.Values, defaultTable string, defaultFields []string) (DynamicQuery, error) {
|
|
query, err := qp.ParseQuery(values, defaultTable)
|
|
if err != nil {
|
|
return query, err
|
|
}
|
|
|
|
// If no fields specified, use default fields
|
|
if len(query.Fields) == 0 || (len(query.Fields) == 1 && query.Fields[0].Expression == "*") {
|
|
query.Fields = make([]SelectField, len(defaultFields))
|
|
for i, field := range defaultFields {
|
|
query.Fields[i] = SelectField{Expression: field}
|
|
}
|
|
}
|
|
|
|
return query, nil
|
|
}
|
|
|
|
// =============================================================================
|
|
// MONGODB QUERY BUILDER
|
|
// =============================================================================
|
|
|
|
// MongoQueryBuilder builds MongoDB queries from dynamic filters
|
|
type MongoQueryBuilder struct {
|
|
allowedFields map[string]bool // Security: only allow specified fields
|
|
allowedCollections map[string]bool // Security: only allow specified collections
|
|
// Security settings
|
|
enableSecurityChecks bool
|
|
maxAllowedDocs int
|
|
// Query logging
|
|
enableQueryLogging bool
|
|
// Connection timeout settings
|
|
queryTimeout time.Duration
|
|
}
|
|
|
|
// NewMongoQueryBuilder creates a new MongoDB query builder instance
|
|
func NewMongoQueryBuilder() *MongoQueryBuilder {
|
|
return &MongoQueryBuilder{
|
|
allowedFields: make(map[string]bool),
|
|
allowedCollections: make(map[string]bool),
|
|
enableSecurityChecks: true,
|
|
maxAllowedDocs: 10000,
|
|
enableQueryLogging: true,
|
|
queryTimeout: 30 * time.Second,
|
|
}
|
|
}
|
|
|
|
// SetSecurityOptions configures security settings
|
|
func (mqb *MongoQueryBuilder) SetSecurityOptions(enableChecks bool, maxDocs int) *MongoQueryBuilder {
|
|
mqb.enableSecurityChecks = enableChecks
|
|
mqb.maxAllowedDocs = maxDocs
|
|
return mqb
|
|
}
|
|
|
|
// SetAllowedFields sets the list of allowed fields for security
|
|
func (mqb *MongoQueryBuilder) SetAllowedFields(fields []string) *MongoQueryBuilder {
|
|
mqb.allowedFields = make(map[string]bool)
|
|
for _, field := range fields {
|
|
mqb.allowedFields[field] = true
|
|
}
|
|
return mqb
|
|
}
|
|
|
|
// SetAllowedCollections sets the list of allowed collections for security
|
|
func (mqb *MongoQueryBuilder) SetAllowedCollections(collections []string) *MongoQueryBuilder {
|
|
mqb.allowedCollections = make(map[string]bool)
|
|
for _, collection := range collections {
|
|
mqb.allowedCollections[collection] = true
|
|
}
|
|
return mqb
|
|
}
|
|
|
|
// SetQueryLogging enables or disables query logging
|
|
func (mqb *MongoQueryBuilder) SetQueryLogging(enable bool) *MongoQueryBuilder {
|
|
mqb.enableQueryLogging = enable
|
|
return mqb
|
|
}
|
|
|
|
// SetQueryTimeout sets the default query timeout
|
|
func (mqb *MongoQueryBuilder) SetQueryTimeout(timeout time.Duration) *MongoQueryBuilder {
|
|
mqb.queryTimeout = timeout
|
|
return mqb
|
|
}
|
|
|
|
// BuildFindQuery builds a MongoDB find query from DynamicQuery
|
|
func (mqb *MongoQueryBuilder) BuildFindQuery(query DynamicQuery) (bson.M, *options.FindOptions, error) {
|
|
filter := bson.M{}
|
|
findOptions := options.Find()
|
|
|
|
// Security check for limit
|
|
if mqb.enableSecurityChecks && query.Limit > mqb.maxAllowedDocs {
|
|
return nil, nil, fmt.Errorf("requested limit %d exceeds maximum allowed %d", query.Limit, mqb.maxAllowedDocs)
|
|
}
|
|
|
|
// Security check for collection name
|
|
if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[query.From] {
|
|
return nil, nil, fmt.Errorf("disallowed collection: %s", query.From)
|
|
}
|
|
|
|
// Build filter from DynamicQuery filters
|
|
if len(query.Filters) > 0 {
|
|
mongoFilter, err := mqb.buildFilter(query.Filters)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
filter = mongoFilter
|
|
}
|
|
|
|
// Set projection from fields
|
|
if len(query.Fields) > 0 {
|
|
projection := bson.M{}
|
|
for _, field := range query.Fields {
|
|
if field.Expression == "*" {
|
|
// Include all fields
|
|
continue
|
|
}
|
|
fieldName := field.Expression
|
|
if field.Alias != "" {
|
|
fieldName = field.Alias
|
|
}
|
|
if mqb.allowedFields != nil && !mqb.allowedFields[fieldName] {
|
|
return nil, nil, fmt.Errorf("disallowed field: %s", fieldName)
|
|
}
|
|
projection[fieldName] = 1
|
|
}
|
|
if len(projection) > 0 {
|
|
findOptions.SetProjection(projection)
|
|
}
|
|
}
|
|
|
|
// Set sort
|
|
if len(query.Sort) > 0 {
|
|
sort := bson.D{}
|
|
for _, sortField := range query.Sort {
|
|
fieldName := sortField.Column
|
|
if mqb.allowedFields != nil && !mqb.allowedFields[fieldName] {
|
|
return nil, nil, fmt.Errorf("disallowed field: %s", fieldName)
|
|
}
|
|
order := 1 // ASC
|
|
if strings.ToUpper(sortField.Order) == "DESC" {
|
|
order = -1 // DESC
|
|
}
|
|
sort = append(sort, bson.E{Key: fieldName, Value: order})
|
|
}
|
|
findOptions.SetSort(sort)
|
|
}
|
|
|
|
// Set limit and offset
|
|
if query.Limit > 0 {
|
|
findOptions.SetLimit(int64(query.Limit))
|
|
}
|
|
if query.Offset > 0 {
|
|
findOptions.SetSkip(int64(query.Offset))
|
|
}
|
|
|
|
return filter, findOptions, nil
|
|
}
|
|
|
|
// BuildAggregateQuery builds a MongoDB aggregation pipeline from DynamicQuery
|
|
func (mqb *MongoQueryBuilder) BuildAggregateQuery(query DynamicQuery) ([]bson.D, error) {
|
|
pipeline := []bson.D{}
|
|
|
|
// Security check for collection name
|
|
if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[query.From] {
|
|
return nil, fmt.Errorf("disallowed collection: %s", query.From)
|
|
}
|
|
|
|
// Handle CTEs as stages in the pipeline
|
|
if len(query.CTEs) > 0 {
|
|
for _, cte := range query.CTEs {
|
|
// Security check for CTE collection
|
|
if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[cte.Query.From] {
|
|
return nil, fmt.Errorf("disallowed collection in CTE: %s", cte.Query.From)
|
|
}
|
|
|
|
subPipeline, err := mqb.BuildAggregateQuery(cte.Query)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to build CTE '%s': %w", cte.Name, err)
|
|
}
|
|
// Add $lookup stage for joins
|
|
if len(cte.Query.Joins) > 0 {
|
|
for _, join := range cte.Query.Joins {
|
|
// Security check for joined collection
|
|
if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[join.Table] {
|
|
return nil, fmt.Errorf("disallowed collection in join: %s", join.Table)
|
|
}
|
|
|
|
lookupStage := bson.D{
|
|
{Key: "$lookup", Value: bson.D{
|
|
{Key: "from", Value: join.Table},
|
|
{Key: "localField", Value: join.Alias},
|
|
{Key: "foreignField", Value: "_id"},
|
|
{Key: "as", Value: join.Alias},
|
|
}},
|
|
}
|
|
pipeline = append(pipeline, lookupStage)
|
|
}
|
|
}
|
|
// Add the sub-pipeline
|
|
pipeline = append(pipeline, subPipeline...)
|
|
}
|
|
}
|
|
|
|
// Match stage for filters
|
|
if len(query.Filters) > 0 {
|
|
filter, err := mqb.buildFilter(query.Filters)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pipeline = append(pipeline, bson.D{{Key: "$match", Value: filter}})
|
|
}
|
|
|
|
// Group stage for GROUP BY
|
|
if len(query.GroupBy) > 0 {
|
|
groupID := bson.D{}
|
|
for _, field := range query.GroupBy {
|
|
if mqb.allowedFields != nil && !mqb.allowedFields[field] {
|
|
return nil, fmt.Errorf("disallowed field: %s", field)
|
|
}
|
|
groupID = append(groupID, bson.E{Key: field, Value: "$" + field})
|
|
}
|
|
|
|
groupStage := bson.D{
|
|
{Key: "$group", Value: bson.D{
|
|
{Key: "_id", Value: groupID},
|
|
}},
|
|
}
|
|
|
|
// Add any aggregations from fields
|
|
for _, field := range query.Fields {
|
|
if strings.Contains(field.Expression, "(") && strings.Contains(field.Expression, ")") {
|
|
// This is an aggregation function
|
|
funcName := strings.Split(field.Expression, "(")[0]
|
|
funcField := strings.TrimSuffix(strings.Split(field.Expression, "(")[1], ")")
|
|
|
|
if mqb.allowedFields != nil && !mqb.allowedFields[funcField] {
|
|
return nil, fmt.Errorf("disallowed field: %s", funcField)
|
|
}
|
|
|
|
switch strings.ToLower(funcName) {
|
|
case "count":
|
|
groupStage = append(groupStage, bson.E{
|
|
Key: field.Alias, Value: bson.D{{Key: "$sum", Value: 1}},
|
|
})
|
|
case "sum":
|
|
groupStage = append(groupStage, bson.E{
|
|
Key: field.Alias, Value: bson.D{{Key: "$sum", Value: "$" + funcField}},
|
|
})
|
|
case "avg":
|
|
groupStage = append(groupStage, bson.E{
|
|
Key: field.Alias, Value: bson.D{{Key: "$avg", Value: "$" + funcField}},
|
|
})
|
|
case "min":
|
|
groupStage = append(groupStage, bson.E{
|
|
Key: field.Alias, Value: bson.D{{Key: "$min", Value: "$" + funcField}},
|
|
})
|
|
case "max":
|
|
groupStage = append(groupStage, bson.E{
|
|
Key: field.Alias, Value: bson.D{{Key: "$max", Value: "$" + funcField}},
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
pipeline = append(pipeline, groupStage)
|
|
}
|
|
|
|
// Sort stage
|
|
if len(query.Sort) > 0 {
|
|
sort := bson.D{}
|
|
for _, sortField := range query.Sort {
|
|
fieldName := sortField.Column
|
|
if mqb.allowedFields != nil && !mqb.allowedFields[fieldName] {
|
|
return nil, fmt.Errorf("disallowed field: %s", fieldName)
|
|
}
|
|
order := 1 // ASC
|
|
if strings.ToUpper(sortField.Order) == "DESC" {
|
|
order = -1 // DESC
|
|
}
|
|
sort = append(sort, bson.E{Key: fieldName, Value: order})
|
|
}
|
|
pipeline = append(pipeline, bson.D{{Key: "$sort", Value: sort}})
|
|
}
|
|
|
|
// Skip and limit stages
|
|
if query.Offset > 0 {
|
|
pipeline = append(pipeline, bson.D{{Key: "$skip", Value: query.Offset}})
|
|
}
|
|
if query.Limit > 0 {
|
|
pipeline = append(pipeline, bson.D{{Key: "$limit", Value: query.Limit}})
|
|
}
|
|
|
|
return pipeline, nil
|
|
}
|
|
|
|
// buildFilter builds a MongoDB filter from FilterGroups
|
|
func (mqb *MongoQueryBuilder) buildFilter(filterGroups []FilterGroup) (bson.M, error) {
|
|
if len(filterGroups) == 0 {
|
|
return bson.M{}, nil
|
|
}
|
|
|
|
var result bson.M
|
|
var err error
|
|
|
|
for i, group := range filterGroups {
|
|
if len(group.Filters) == 0 {
|
|
continue
|
|
}
|
|
|
|
groupFilter, err := mqb.buildFilterGroup(group)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if i == 0 {
|
|
result = groupFilter
|
|
} else {
|
|
logicOp := "$and"
|
|
if group.LogicOp != "" {
|
|
switch strings.ToUpper(group.LogicOp) {
|
|
case "OR":
|
|
logicOp = "$or"
|
|
}
|
|
}
|
|
result = bson.M{logicOp: []bson.M{result, groupFilter}}
|
|
}
|
|
}
|
|
|
|
return result, err
|
|
}
|
|
|
|
// buildFilterGroup builds a filter for a single filter group
|
|
func (mqb *MongoQueryBuilder) buildFilterGroup(group FilterGroup) (bson.M, error) {
|
|
var filters []bson.M
|
|
logicOp := "$and"
|
|
if group.LogicOp != "" {
|
|
switch strings.ToUpper(group.LogicOp) {
|
|
case "OR":
|
|
logicOp = "$or"
|
|
}
|
|
}
|
|
|
|
for _, filter := range group.Filters {
|
|
fieldFilter, err := mqb.buildFilterCondition(filter)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
filters = append(filters, fieldFilter)
|
|
}
|
|
|
|
if len(filters) == 1 {
|
|
return filters[0], nil
|
|
}
|
|
return bson.M{logicOp: filters}, nil
|
|
}
|
|
|
|
// buildFilterCondition builds a single filter condition for MongoDB
|
|
func (mqb *MongoQueryBuilder) buildFilterCondition(filter DynamicFilter) (bson.M, error) {
|
|
field := filter.Column
|
|
if mqb.allowedFields != nil && !mqb.allowedFields[field] {
|
|
return nil, fmt.Errorf("disallowed field: %s", field)
|
|
}
|
|
|
|
switch filter.Operator {
|
|
case OpEqual:
|
|
return bson.M{field: filter.Value}, nil
|
|
case OpNotEqual:
|
|
return bson.M{field: bson.M{"$ne": filter.Value}}, nil
|
|
case OpIn:
|
|
values := mqb.parseArrayValue(filter.Value)
|
|
return bson.M{field: bson.M{"$in": values}}, nil
|
|
case OpNotIn:
|
|
values := mqb.parseArrayValue(filter.Value)
|
|
return bson.M{field: bson.M{"$nin": values}}, nil
|
|
case OpGreaterThan:
|
|
return bson.M{field: bson.M{"$gt": filter.Value}}, nil
|
|
case OpGreaterThanEqual:
|
|
return bson.M{field: bson.M{"$gte": filter.Value}}, nil
|
|
case OpLessThan:
|
|
return bson.M{field: bson.M{"$lt": filter.Value}}, nil
|
|
case OpLessThanEqual:
|
|
return bson.M{field: bson.M{"$lte": filter.Value}}, nil
|
|
case OpLike:
|
|
// Convert SQL LIKE to MongoDB regex
|
|
pattern := filter.Value.(string)
|
|
pattern = strings.ReplaceAll(pattern, "%", ".*")
|
|
pattern = strings.ReplaceAll(pattern, "_", ".")
|
|
return bson.M{field: bson.M{"$regex": pattern, "$options": "i"}}, nil
|
|
case OpILike:
|
|
// Case-insensitive like
|
|
pattern := filter.Value.(string)
|
|
pattern = strings.ReplaceAll(pattern, "%", ".*")
|
|
pattern = strings.ReplaceAll(pattern, "_", ".")
|
|
return bson.M{field: bson.M{"$regex": pattern, "$options": "i"}}, nil
|
|
case OpContains:
|
|
// Contains substring
|
|
pattern := filter.Value.(string)
|
|
return bson.M{field: bson.M{"$regex": pattern, "$options": "i"}}, nil
|
|
case OpNotContains:
|
|
// Does not contain substring
|
|
pattern := filter.Value.(string)
|
|
return bson.M{field: bson.M{"$not": bson.M{"$regex": pattern, "$options": "i"}}}, nil
|
|
case OpStartsWith:
|
|
// Starts with
|
|
pattern := filter.Value.(string)
|
|
return bson.M{field: bson.M{"$regex": "^" + pattern, "$options": "i"}}, nil
|
|
case OpEndsWith:
|
|
// Ends with
|
|
pattern := filter.Value.(string)
|
|
return bson.M{field: bson.M{"$regex": pattern + "$", "$options": "i"}}, nil
|
|
case OpNull:
|
|
return bson.M{field: bson.M{"$exists": false}}, nil
|
|
case OpNotNull:
|
|
return bson.M{field: bson.M{"$exists": true}}, nil
|
|
case OpJsonContains:
|
|
// JSON contains
|
|
return bson.M{field: bson.M{"$elemMatch": filter.Value}}, nil
|
|
case OpJsonNotContains:
|
|
// JSON does not contain
|
|
return bson.M{field: bson.M{"$not": bson.M{"$elemMatch": filter.Value}}}, nil
|
|
case OpJsonExists:
|
|
// JSON path exists
|
|
return bson.M{field + "." + filter.Options["path"].(string): bson.M{"$exists": true}}, nil
|
|
case OpJsonNotExists:
|
|
// JSON path does not exist
|
|
return bson.M{field + "." + filter.Options["path"].(string): bson.M{"$exists": false}}, nil
|
|
case OpArrayContains:
|
|
// Array contains
|
|
return bson.M{field: bson.M{"$elemMatch": bson.M{"$eq": filter.Value}}}, nil
|
|
case OpArrayNotContains:
|
|
// Array does not contain
|
|
return bson.M{field: bson.M{"$not": bson.M{"$elemMatch": bson.M{"$eq": filter.Value}}}}, nil
|
|
case OpArrayLength:
|
|
// Array length
|
|
if lengthOption, ok := filter.Options["length"].(int); ok {
|
|
return bson.M{field: bson.M{"$size": lengthOption}}, nil
|
|
}
|
|
return nil, fmt.Errorf("array_length operator requires 'length' option")
|
|
default:
|
|
return nil, fmt.Errorf("unsupported operator: %s", filter.Operator)
|
|
}
|
|
}
|
|
|
|
// parseArrayValue parses an array value for MongoDB
|
|
func (mqb *MongoQueryBuilder) parseArrayValue(value interface{}) []interface{} {
|
|
if value == nil {
|
|
return nil
|
|
}
|
|
if reflect.TypeOf(value).Kind() == reflect.Slice {
|
|
v := reflect.ValueOf(value)
|
|
result := make([]interface{}, v.Len())
|
|
for i := 0; i < v.Len(); i++ {
|
|
result[i] = v.Index(i).Interface()
|
|
}
|
|
return result
|
|
}
|
|
if str, ok := value.(string); ok {
|
|
if strings.Contains(str, ",") {
|
|
parts := strings.Split(str, ",")
|
|
result := make([]interface{}, len(parts))
|
|
for i, part := range parts {
|
|
result[i] = strings.TrimSpace(part)
|
|
}
|
|
return result
|
|
}
|
|
return []interface{}{str}
|
|
}
|
|
return []interface{}{value}
|
|
}
|
|
|
|
// ExecuteFind executes a MongoDB find query
|
|
func (mqb *MongoQueryBuilder) ExecuteFind(ctx context.Context, collection *mongo.Collection, query DynamicQuery, dest interface{}) error {
|
|
// Security check for collection name
|
|
if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[collection.Name()] {
|
|
return fmt.Errorf("disallowed collection: %s", collection.Name())
|
|
}
|
|
|
|
filter, findOptions, err := mqb.BuildFindQuery(query)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Set timeout if not already in context
|
|
if _, hasDeadline := ctx.Deadline(); !hasDeadline && mqb.queryTimeout > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, mqb.queryTimeout)
|
|
defer cancel()
|
|
}
|
|
|
|
start := time.Now()
|
|
cursor, err := collection.Find(ctx, filter, findOptions)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer cursor.Close(ctx)
|
|
err = cursor.All(ctx, dest)
|
|
if mqb.enableQueryLogging {
|
|
fmt.Printf("[DEBUG] MongoDB Find executed in %v\n", time.Since(start))
|
|
}
|
|
return err
|
|
}
|
|
|
|
// ExecuteAggregate executes a MongoDB aggregation pipeline
|
|
func (mqb *MongoQueryBuilder) ExecuteAggregate(ctx context.Context, collection *mongo.Collection, query DynamicQuery, dest interface{}) error {
|
|
// Security check for collection name
|
|
if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[collection.Name()] {
|
|
return fmt.Errorf("disallowed collection: %s", collection.Name())
|
|
}
|
|
|
|
pipeline, err := mqb.BuildAggregateQuery(query)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Set timeout if not already in context
|
|
if _, hasDeadline := ctx.Deadline(); !hasDeadline && mqb.queryTimeout > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, mqb.queryTimeout)
|
|
defer cancel()
|
|
}
|
|
|
|
start := time.Now()
|
|
cursor, err := collection.Aggregate(ctx, pipeline)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer cursor.Close(ctx)
|
|
err = cursor.All(ctx, dest)
|
|
if mqb.enableQueryLogging {
|
|
fmt.Printf("[DEBUG] MongoDB Aggregate executed in %v\n", time.Since(start))
|
|
}
|
|
return err
|
|
}
|
|
|
|
// ExecuteCount executes a MongoDB count query
|
|
func (mqb *MongoQueryBuilder) ExecuteCount(ctx context.Context, collection *mongo.Collection, query DynamicQuery) (int64, error) {
|
|
// Security check for collection name
|
|
if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[collection.Name()] {
|
|
return 0, fmt.Errorf("disallowed collection: %s", collection.Name())
|
|
}
|
|
|
|
filter, _, err := mqb.BuildFindQuery(query)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// Set timeout if not already in context
|
|
if _, hasDeadline := ctx.Deadline(); !hasDeadline && mqb.queryTimeout > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, mqb.queryTimeout)
|
|
defer cancel()
|
|
}
|
|
|
|
start := time.Now()
|
|
count, err := collection.CountDocuments(ctx, filter)
|
|
if mqb.enableQueryLogging {
|
|
fmt.Printf("[DEBUG] MongoDB Count executed in %v\n", time.Since(start))
|
|
}
|
|
return count, err
|
|
}
|
|
|
|
// ExecuteInsert executes a MongoDB insert operation
|
|
func (mqb *MongoQueryBuilder) ExecuteInsert(ctx context.Context, collection *mongo.Collection, data InsertData) (*mongo.InsertOneResult, error) {
|
|
// Security check for collection name
|
|
if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[collection.Name()] {
|
|
return nil, fmt.Errorf("disallowed collection: %s", collection.Name())
|
|
}
|
|
|
|
document := bson.M{}
|
|
for i, col := range data.Columns {
|
|
if mqb.allowedFields != nil && !mqb.allowedFields[col] {
|
|
return nil, fmt.Errorf("disallowed field: %s", col)
|
|
}
|
|
document[col] = data.Values[i]
|
|
}
|
|
|
|
// Handle JSON values
|
|
for col, val := range data.JsonValues {
|
|
if mqb.allowedFields != nil && !mqb.allowedFields[col] {
|
|
return nil, fmt.Errorf("disallowed field: %s", col)
|
|
}
|
|
document[col] = val
|
|
}
|
|
|
|
// Set timeout if not already in context
|
|
if _, hasDeadline := ctx.Deadline(); !hasDeadline && mqb.queryTimeout > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, mqb.queryTimeout)
|
|
defer cancel()
|
|
}
|
|
|
|
start := time.Now()
|
|
result, err := collection.InsertOne(ctx, document)
|
|
if mqb.enableQueryLogging {
|
|
fmt.Printf("[DEBUG] MongoDB Insert executed in %v\n", time.Since(start))
|
|
}
|
|
return result, err
|
|
}
|
|
|
|
// ExecuteUpdate executes a MongoDB update operation
|
|
func (mqb *MongoQueryBuilder) ExecuteUpdate(ctx context.Context, collection *mongo.Collection, updateData UpdateData, filters []FilterGroup) (*mongo.UpdateResult, error) {
|
|
// Security check for collection name
|
|
if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[collection.Name()] {
|
|
return nil, fmt.Errorf("disallowed collection: %s", collection.Name())
|
|
}
|
|
|
|
filter, err := mqb.buildFilter(filters)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
update := bson.M{"$set": bson.M{}}
|
|
for i, col := range updateData.Columns {
|
|
if mqb.allowedFields != nil && !mqb.allowedFields[col] {
|
|
return nil, fmt.Errorf("disallowed field: %s", col)
|
|
}
|
|
update["$set"].(bson.M)[col] = updateData.Values[i]
|
|
}
|
|
|
|
// Handle JSON updates
|
|
for col, jsonUpdate := range updateData.JsonUpdates {
|
|
if mqb.allowedFields != nil && !mqb.allowedFields[col] {
|
|
return nil, fmt.Errorf("disallowed field: %s", col)
|
|
}
|
|
// Use dot notation for nested JSON updates
|
|
update["$set"].(bson.M)[col+"."+jsonUpdate.Path] = jsonUpdate.Value
|
|
}
|
|
|
|
// Set timeout if not already in context
|
|
if _, hasDeadline := ctx.Deadline(); !hasDeadline && mqb.queryTimeout > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, mqb.queryTimeout)
|
|
defer cancel()
|
|
}
|
|
|
|
start := time.Now()
|
|
result, err := collection.UpdateMany(ctx, filter, update)
|
|
if mqb.enableQueryLogging {
|
|
fmt.Printf("[DEBUG] MongoDB Update executed in %v\n", time.Since(start))
|
|
}
|
|
return result, err
|
|
}
|
|
|
|
// ExecuteDelete executes a MongoDB delete operation
|
|
func (mqb *MongoQueryBuilder) ExecuteDelete(ctx context.Context, collection *mongo.Collection, filters []FilterGroup) (*mongo.DeleteResult, error) {
|
|
// Security check for collection name
|
|
if mqb.enableSecurityChecks && len(mqb.allowedCollections) > 0 && !mqb.allowedCollections[collection.Name()] {
|
|
return nil, fmt.Errorf("disallowed collection: %s", collection.Name())
|
|
}
|
|
|
|
filter, err := mqb.buildFilter(filters)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Set timeout if not already in context
|
|
if _, hasDeadline := ctx.Deadline(); !hasDeadline && mqb.queryTimeout > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, mqb.queryTimeout)
|
|
defer cancel()
|
|
}
|
|
|
|
start := time.Now()
|
|
result, err := collection.DeleteMany(ctx, filter)
|
|
if mqb.enableQueryLogging {
|
|
fmt.Printf("[DEBUG] MongoDB Delete executed in %v\n", time.Since(start))
|
|
}
|
|
return result, err
|
|
}
|