1140 lines
33 KiB
Go
1140 lines
33 KiB
Go
package utils
|
|
|
|
import (
|
|
"fmt"
|
|
"net/url"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/Masterminds/squirrel"
|
|
// Still useful for array types, especially with PostgreSQL
|
|
)
|
|
|
|
// DBType represents the type of database
|
|
type DBType string
|
|
|
|
const (
|
|
DBTypePostgreSQL DBType = "postgres"
|
|
DBTypeMySQL DBType = "mysql"
|
|
DBTypeSQLite DBType = "sqlite"
|
|
DBTypeSQLServer DBType = "sqlserver"
|
|
)
|
|
|
|
// FilterOperator represents supported filter operators
|
|
type FilterOperator string
|
|
|
|
const (
|
|
OpEqual FilterOperator = "_eq"
|
|
OpNotEqual FilterOperator = "_neq"
|
|
OpLike FilterOperator = "_like"
|
|
OpILike FilterOperator = "_ilike"
|
|
OpIn FilterOperator = "_in"
|
|
OpNotIn FilterOperator = "_nin"
|
|
OpGreaterThan FilterOperator = "_gt"
|
|
OpGreaterThanEqual FilterOperator = "_gte"
|
|
OpLessThan FilterOperator = "_lt"
|
|
OpLessThanEqual FilterOperator = "_lte"
|
|
OpBetween FilterOperator = "_between"
|
|
OpNotBetween FilterOperator = "_nbetween"
|
|
OpNull FilterOperator = "_null"
|
|
OpNotNull FilterOperator = "_nnull"
|
|
OpContains FilterOperator = "_contains"
|
|
OpNotContains FilterOperator = "_ncontains"
|
|
OpStartsWith FilterOperator = "_starts_with"
|
|
OpEndsWith FilterOperator = "_ends_with"
|
|
)
|
|
|
|
// DynamicFilter represents a single filter condition
|
|
type DynamicFilter struct {
|
|
Column string `json:"column"`
|
|
Operator FilterOperator `json:"operator"`
|
|
Value interface{} `json:"value"`
|
|
LogicOp string `json:"logic_op,omitempty"` // AND, OR
|
|
}
|
|
|
|
// FilterGroup represents a group of filters
|
|
type FilterGroup struct {
|
|
Filters []DynamicFilter `json:"filters"`
|
|
LogicOp string `json:"logic_op"` // AND, OR
|
|
}
|
|
|
|
// DynamicQuery represents the complete query structure
|
|
type DynamicQuery struct {
|
|
Fields []string `json:"fields,omitempty"`
|
|
Filters []FilterGroup `json:"filters,omitempty"`
|
|
Sort []SortField `json:"sort,omitempty"`
|
|
Limit int `json:"limit"`
|
|
Offset int `json:"offset"`
|
|
GroupBy []string `json:"group_by,omitempty"`
|
|
Having []FilterGroup `json:"having,omitempty"`
|
|
}
|
|
|
|
// SortField represents sorting configuration
|
|
type SortField struct {
|
|
Column string `json:"column"`
|
|
Order string `json:"order"` // ASC, DESC
|
|
}
|
|
|
|
// UpdateData represents data for UPDATE operations
|
|
type UpdateData struct {
|
|
Columns []string `json:"columns"`
|
|
Values []interface{} `json:"values"`
|
|
}
|
|
|
|
// InsertData represents data for INSERT operations
|
|
type InsertData struct {
|
|
Columns []string `json:"columns"`
|
|
Values []interface{} `json:"values"`
|
|
}
|
|
|
|
// QueryBuilder builds SQL queries from dynamic filters using squirrel
|
|
type QueryBuilder struct {
|
|
tableName string
|
|
columnMapping map[string]string // Maps API field names to DB column names
|
|
allowedColumns map[string]bool // Security: only allow specified columns
|
|
dbType DBType
|
|
sqlBuilder squirrel.StatementBuilderType
|
|
}
|
|
|
|
// NewQueryBuilder creates a new query builder instance for a specific database type
|
|
func NewQueryBuilder(tableName string, dbType DBType) *QueryBuilder {
|
|
var placeholderFormat squirrel.PlaceholderFormat
|
|
|
|
switch dbType {
|
|
case DBTypePostgreSQL:
|
|
placeholderFormat = squirrel.Dollar
|
|
case DBTypeMySQL, DBTypeSQLite:
|
|
placeholderFormat = squirrel.Question
|
|
case DBTypeSQLServer:
|
|
placeholderFormat = squirrel.AtP
|
|
default:
|
|
// Default to a common format if an unknown type is provided
|
|
placeholderFormat = squirrel.Question
|
|
}
|
|
|
|
return &QueryBuilder{
|
|
tableName: tableName,
|
|
columnMapping: make(map[string]string),
|
|
allowedColumns: make(map[string]bool),
|
|
dbType: dbType,
|
|
sqlBuilder: squirrel.StatementBuilder.PlaceholderFormat(placeholderFormat),
|
|
}
|
|
}
|
|
|
|
// SetColumnMapping sets the mapping between API field names and database column names
|
|
func (qb *QueryBuilder) SetColumnMapping(mapping map[string]string) *QueryBuilder {
|
|
qb.columnMapping = mapping
|
|
return qb
|
|
}
|
|
|
|
// SetAllowedColumns sets the list of allowed columns for security
|
|
func (qb *QueryBuilder) SetAllowedColumns(columns []string) *QueryBuilder {
|
|
qb.allowedColumns = make(map[string]bool)
|
|
for _, col := range columns {
|
|
qb.allowedColumns[col] = true
|
|
}
|
|
return qb
|
|
}
|
|
|
|
// BuildQuery builds the complete SQL SELECT query
|
|
func (qb *QueryBuilder) BuildQuery(query DynamicQuery) (string, []interface{}, error) {
|
|
// Start with base query
|
|
baseQuery := qb.sqlBuilder.Select(qb.buildSelectFields(query.Fields)...).From(qb.tableName)
|
|
|
|
// Apply WHERE conditions
|
|
if len(query.Filters) > 0 {
|
|
whereClause, args, err := qb.buildWhereClause(query.Filters)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
baseQuery = baseQuery.Where(whereClause, args...)
|
|
}
|
|
|
|
// Apply GROUP BY
|
|
if len(query.GroupBy) > 0 {
|
|
groupByCols := qb.buildGroupByColumns(query.GroupBy)
|
|
if len(groupByCols) > 0 {
|
|
baseQuery = baseQuery.GroupBy(groupByCols...)
|
|
}
|
|
}
|
|
|
|
// Apply HAVING conditions
|
|
if len(query.Having) > 0 {
|
|
havingClause, args, err := qb.buildWhereClause(query.Having)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
baseQuery = baseQuery.Having(havingClause, args...)
|
|
}
|
|
|
|
// Apply ORDER BY
|
|
if len(query.Sort) > 0 {
|
|
for _, sort := range query.Sort {
|
|
column := qb.mapAndValidateColumn(sort.Column)
|
|
if column == "" {
|
|
continue // Skip invalid columns
|
|
}
|
|
order := "ASC"
|
|
if strings.ToUpper(sort.Order) == "DESC" {
|
|
order = "DESC"
|
|
}
|
|
baseQuery = baseQuery.OrderBy(fmt.Sprintf("%s %s", squirrel.EscapeIdentifier(column), order))
|
|
}
|
|
}
|
|
|
|
// Apply pagination with dialect-specific syntax
|
|
if query.Limit > 0 {
|
|
if qb.dbType == DBTypeSQLServer {
|
|
// SQL Server uses OFFSET-FETCH syntax
|
|
baseQuery = baseQuery.Offset(uint64(query.Offset)).Fetch(uint64(query.Limit))
|
|
} else {
|
|
// PostgreSQL, MySQL, SQLite use LIMIT/OFFSET
|
|
baseQuery = baseQuery.Limit(uint64(query.Limit))
|
|
if query.Offset > 0 {
|
|
baseQuery = baseQuery.Offset(uint64(query.Offset))
|
|
}
|
|
}
|
|
} else if query.Offset > 0 && qb.dbType != DBTypeSQLServer {
|
|
// SQL Server requires FETCH with OFFSET
|
|
baseQuery = baseQuery.Offset(uint64(query.Offset))
|
|
}
|
|
|
|
// Build final query
|
|
sql, args, err := baseQuery.ToSql()
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to build query: %w", err)
|
|
}
|
|
|
|
return sql, args, nil
|
|
}
|
|
|
|
// BuildInsertQuery builds an INSERT query
|
|
func (qb *QueryBuilder) BuildInsertQuery(data InsertData, returningColumns ...string) (string, []interface{}, error) {
|
|
if len(data.Columns) == 0 || len(data.Values) == 0 {
|
|
return "", nil, fmt.Errorf("no columns or values provided for INSERT")
|
|
}
|
|
|
|
if len(data.Columns) != len(data.Values) {
|
|
return "", nil, fmt.Errorf("columns and values count mismatch for INSERT")
|
|
}
|
|
|
|
// Prepare columns and values
|
|
columns := make([]string, 0, len(data.Columns))
|
|
values := make([]interface{}, 0, len(data.Values))
|
|
|
|
for i, col := range data.Columns {
|
|
mappedCol := qb.mapAndValidateColumn(col)
|
|
if mappedCol == "" {
|
|
continue // Skip invalid columns
|
|
}
|
|
columns = append(columns, mappedCol)
|
|
values = append(values, data.Values[i])
|
|
}
|
|
|
|
if len(columns) == 0 {
|
|
return "", nil, fmt.Errorf("no valid columns provided for INSERT")
|
|
}
|
|
|
|
// Build INSERT query
|
|
query := qb.sqlBuilder.Insert(qb.tableName).Columns(columns...).Values(values...)
|
|
|
|
// Add RETURNING/OUTPUT clause if specified and supported
|
|
if len(returningColumns) > 0 {
|
|
returningClause := qb.buildReturningClause(returningColumns)
|
|
if returningClause != "" {
|
|
query = query.Suffix(returningClause)
|
|
}
|
|
}
|
|
|
|
sql, args, err := query.ToSql()
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to build INSERT query: %w", err)
|
|
}
|
|
|
|
return sql, args, nil
|
|
}
|
|
|
|
// BuildUpdateQuery builds an UPDATE query
|
|
func (qb *QueryBuilder) BuildUpdateQuery(data UpdateData, filters []FilterGroup, returningColumns ...string) (string, []interface{}, error) {
|
|
if len(data.Columns) == 0 || len(data.Values) == 0 {
|
|
return "", nil, fmt.Errorf("no columns or values provided for UPDATE")
|
|
}
|
|
|
|
if len(data.Columns) != len(data.Values) {
|
|
return "", nil, fmt.Errorf("columns and values count mismatch for UPDATE")
|
|
}
|
|
|
|
// Prepare SET clause
|
|
setMap := make(map[string]interface{})
|
|
for i, col := range data.Columns {
|
|
mappedCol := qb.mapAndValidateColumn(col)
|
|
if mappedCol == "" {
|
|
continue // Skip invalid columns
|
|
}
|
|
setMap[mappedCol] = data.Values[i]
|
|
}
|
|
|
|
if len(setMap) == 0 {
|
|
return "", nil, fmt.Errorf("no valid columns provided for UPDATE")
|
|
}
|
|
|
|
// Build UPDATE query
|
|
query := qb.sqlBuilder.Update(qb.tableName).SetMap(setMap)
|
|
|
|
// Apply WHERE conditions
|
|
if len(filters) > 0 {
|
|
whereClause, args, err := qb.buildWhereClause(filters)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
query = query.Where(whereClause, args...)
|
|
}
|
|
|
|
// Add RETURNING/OUTPUT clause if specified and supported
|
|
if len(returningColumns) > 0 {
|
|
returningClause := qb.buildReturningClause(returningColumns)
|
|
if returningClause != "" {
|
|
query = query.Suffix(returningClause)
|
|
}
|
|
}
|
|
|
|
sql, args, err := query.ToSql()
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to build UPDATE query: %w", err)
|
|
}
|
|
|
|
return sql, args, nil
|
|
}
|
|
|
|
// BuildDeleteQuery builds a DELETE query
|
|
func (qb *QueryBuilder) BuildDeleteQuery(filters []FilterGroup, returningColumns ...string) (string, []interface{}, error) {
|
|
// Build DELETE query
|
|
query := qb.sqlBuilder.Delete(qb.tableName)
|
|
|
|
// Apply WHERE conditions
|
|
if len(filters) > 0 {
|
|
whereClause, args, err := qb.buildWhereClause(filters)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
query = query.Where(whereClause, args...)
|
|
}
|
|
|
|
// Add RETURNING/OUTPUT clause if specified and supported
|
|
if len(returningColumns) > 0 {
|
|
returningClause := qb.buildReturningClause(returningColumns)
|
|
if returningClause != "" {
|
|
query = query.Suffix(returningClause)
|
|
}
|
|
}
|
|
|
|
sql, args, err := query.ToSql()
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to build DELETE query: %w", err)
|
|
}
|
|
|
|
return sql, args, nil
|
|
}
|
|
|
|
// BuildCountQuery builds a count query
|
|
func (qb *QueryBuilder) BuildCountQuery(query DynamicQuery) (string, []interface{}, error) {
|
|
// Start with COUNT query
|
|
baseQuery := qb.sqlBuilder.Select("COUNT(*)").From(qb.tableName)
|
|
|
|
// Apply WHERE conditions
|
|
if len(query.Filters) > 0 {
|
|
whereClause, args, err := qb.buildWhereClause(query.Filters)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
baseQuery = baseQuery.Where(whereClause, args...)
|
|
}
|
|
|
|
// Apply GROUP BY
|
|
if len(query.GroupBy) > 0 {
|
|
groupByCols := qb.buildGroupByColumns(query.GroupBy)
|
|
if len(groupByCols) > 0 {
|
|
baseQuery = baseQuery.GroupBy(groupByCols...)
|
|
}
|
|
}
|
|
|
|
// Apply HAVING conditions
|
|
if len(query.Having) > 0 {
|
|
havingClause, args, err := qb.buildWhereClause(query.Having)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
baseQuery = baseQuery.Having(havingClause, args...)
|
|
}
|
|
|
|
sql, args, err := baseQuery.ToSql()
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to build COUNT query: %w", err)
|
|
}
|
|
|
|
return sql, args, nil
|
|
}
|
|
|
|
// BuildUpsertQuery builds an upsert query with dialect-specific syntax
|
|
func (qb *QueryBuilder) BuildUpsertQuery(data InsertData, conflictColumns []string, updateData UpdateData, returningColumns ...string) (string, []interface{}, error) {
|
|
switch qb.dbType {
|
|
case DBTypePostgreSQL, DBTypeSQLite:
|
|
return qb.buildPostgresUpsert(data, conflictColumns, updateData, returningColumns...)
|
|
case DBTypeMySQL:
|
|
return qb.buildMySQLUpsert(data, conflictColumns, updateData, returningColumns...)
|
|
case DBTypeSQLServer:
|
|
return qb.buildSQLServerUpsert(data, conflictColumns, updateData, returningColumns...)
|
|
default:
|
|
return "", nil, fmt.Errorf("upsert operation not supported for database type: %s", qb.dbType)
|
|
}
|
|
}
|
|
|
|
// buildPostgresUpsert builds an UPSERT query for PostgreSQL/SQLite (ON CONFLICT)
|
|
func (qb *QueryBuilder) buildPostgresUpsert(data InsertData, conflictColumns []string, updateData UpdateData, returningColumns ...string) (string, []interface{}, error) {
|
|
// ... (Validation logic from the original BuildUpsertQuery) ...
|
|
if len(data.Columns) == 0 || len(data.Values) == 0 || len(data.Columns) != len(data.Values) || len(conflictColumns) == 0 {
|
|
return "", nil, fmt.Errorf("invalid arguments for PostgreSQL/SQLite upsert")
|
|
}
|
|
|
|
// Prepare columns and values
|
|
columns := make([]string, 0, len(data.Columns))
|
|
values := make([]interface{}, 0, len(data.Values))
|
|
for i, col := range data.Columns {
|
|
mappedCol := qb.mapAndValidateColumn(col)
|
|
if mappedCol == "" {
|
|
continue
|
|
}
|
|
columns = append(columns, mappedCol)
|
|
values = append(values, data.Values[i])
|
|
}
|
|
if len(columns) == 0 {
|
|
return "", nil, fmt.Errorf("no valid columns for upsert")
|
|
}
|
|
|
|
// Prepare conflict columns
|
|
conflictCols := make([]string, 0, len(conflictColumns))
|
|
for _, col := range conflictColumns {
|
|
mappedCol := qb.mapAndValidateColumn(col)
|
|
if mappedCol == "" {
|
|
continue
|
|
}
|
|
conflictCols = append(conflictCols, mappedCol)
|
|
}
|
|
if len(conflictCols) == 0 {
|
|
return "", nil, fmt.Errorf("no valid conflict columns for upsert")
|
|
}
|
|
|
|
// Prepare update clause
|
|
updateMap := make(map[string]interface{})
|
|
for i, col := range updateData.Columns {
|
|
mappedCol := qb.mapAndValidateColumn(col)
|
|
if mappedCol == "" {
|
|
continue
|
|
}
|
|
updateMap[mappedCol] = updateData.Values[i]
|
|
}
|
|
if len(updateMap) == 0 {
|
|
return "", nil, fmt.Errorf("no valid update columns for upsert")
|
|
}
|
|
|
|
// Build query
|
|
query := qb.sqlBuilder.Insert(qb.tableName).Columns(columns...).Values(values...)
|
|
|
|
// ON CONFLICT clause
|
|
onConflictSuffix := fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET %s",
|
|
strings.Join(conflictCols, ", "),
|
|
qb.buildUpdateSetClause(updateMap))
|
|
query = query.Suffix(onConflictSuffix)
|
|
|
|
// Add RETURNING clause if specified
|
|
if len(returningColumns) > 0 {
|
|
returningClause := qb.buildReturningClause(returningColumns)
|
|
if returningClause != "" {
|
|
query = query.Suffix(" " + returningClause) // Prepend space
|
|
}
|
|
}
|
|
|
|
sql, args, err := query.ToSql()
|
|
return sql, args, err
|
|
}
|
|
|
|
// buildMySQLUpsert builds an UPSERT query for MySQL (ON DUPLICATE KEY UPDATE)
|
|
func (qb *QueryBuilder) buildMySQLUpsert(data InsertData, conflictColumns []string, updateData UpdateData, returningColumns ...string) (string, []interface{}, error) {
|
|
// ... (Validation logic) ...
|
|
if len(data.Columns) == 0 || len(data.Values) == 0 || len(data.Columns) != len(data.Values) || len(updateData.Columns) == 0 {
|
|
return "", nil, fmt.Errorf("invalid arguments for MySQL upsert")
|
|
}
|
|
|
|
// Prepare columns and values
|
|
columns := make([]string, 0, len(data.Columns))
|
|
values := make([]interface{}, 0, len(data.Values))
|
|
for i, col := range data.Columns {
|
|
mappedCol := qb.mapAndValidateColumn(col)
|
|
if mappedCol == "" {
|
|
continue
|
|
}
|
|
columns = append(columns, mappedCol)
|
|
values = append(values, data.Values[i])
|
|
}
|
|
if len(columns) == 0 {
|
|
return "", nil, fmt.Errorf("no valid columns for upsert")
|
|
}
|
|
|
|
// Prepare update clause
|
|
var updateParts []string
|
|
updateArgs := make([]interface{}, 0)
|
|
for i, col := range updateData.Columns {
|
|
mappedCol := qb.mapAndValidateColumn(col)
|
|
if mappedCol == "" {
|
|
continue
|
|
}
|
|
// In MySQL, you can reference the new values with VALUES(column_name)
|
|
updateParts = append(updateParts, fmt.Sprintf("%s = VALUES(?)", squirrel.EscapeIdentifier(mappedCol)))
|
|
updateArgs = append(updateArgs, mappedCol) // The placeholder is for the column name itself
|
|
updateArgs = append(updateArgs, updateData.Values[i])
|
|
}
|
|
if len(updateParts) == 0 {
|
|
return "", nil, fmt.Errorf("no valid update columns for upsert")
|
|
}
|
|
|
|
// Build query
|
|
query := qb.sqlBuilder.Insert(qb.tableName).Columns(columns...).Values(values...)
|
|
|
|
// ON DUPLICATE KEY UPDATE clause
|
|
onDuplicateSuffix := fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", strings.Join(updateParts, ", "))
|
|
query = query.Suffix(onDuplicateSuffix)
|
|
|
|
// MySQL doesn't support RETURNING in the same way. This is a limitation.
|
|
// Applications would need to run a separate SELECT query.
|
|
// We will ignore returningColumns for MySQL.
|
|
|
|
sql, args, err := query.ToSql()
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
// We need to merge args from the main query and the suffix
|
|
allArgs := append(args, updateArgs...)
|
|
|
|
return sql, allArgs, nil
|
|
}
|
|
|
|
// buildSQLServerUpsert builds an UPSERT query for SQL Server (MERGE)
|
|
func (qb *QueryBuilder) buildSQLServerUpsert(data InsertData, conflictColumns []string, updateData UpdateData, returningColumns ...string) (string, []interface{}, error) {
|
|
// MERGE is complex to build with squirrel's high-level API.
|
|
// We'll build it more manually but still use squirrel for escaping.
|
|
// This is a simplified version. A full MERGE can be more complex.
|
|
|
|
// ... (Validation logic) ...
|
|
if len(data.Columns) == 0 || len(data.Values) == 0 || len(data.Columns) != len(data.Values) || len(conflictColumns) == 0 || len(updateData.Columns) == 0 {
|
|
return "", nil, fmt.Errorf("invalid arguments for SQL Server upsert")
|
|
}
|
|
|
|
// Prepare columns and values
|
|
columns := make([]string, 0, len(data.Columns))
|
|
values := make([]interface{}, 0, len(data.Values))
|
|
for i, col := range data.Columns {
|
|
mappedCol := qb.mapAndValidateColumn(col)
|
|
if mappedCol == "" {
|
|
continue
|
|
}
|
|
columns = append(columns, mappedCol)
|
|
values = append(values, data.Values[i])
|
|
}
|
|
if len(columns) == 0 {
|
|
return "", nil, fmt.Errorf("no valid columns for upsert")
|
|
}
|
|
|
|
// Prepare conflict columns (target for ON clause)
|
|
var onConditions []string
|
|
for _, col := range conflictColumns {
|
|
mappedCol := qb.mapAndValidateColumn(col)
|
|
if mappedCol == "" {
|
|
continue
|
|
}
|
|
onConditions = append(onConditions, fmt.Sprintf("target.%s = source.%s", squirrel.EscapeIdentifier(mappedCol), squirrel.EscapeIdentifier(mappedCol)))
|
|
}
|
|
|
|
// Prepare update clause
|
|
var updateParts []string
|
|
for i, col := range updateData.Columns {
|
|
mappedCol := qb.mapAndValidateColumn(col)
|
|
if mappedCol == "" {
|
|
continue
|
|
}
|
|
updateParts = append(updateParts, fmt.Sprintf("target.%s = ?", squirrel.EscapeIdentifier(mappedCol)))
|
|
values = append(values, updateData.Values[i])
|
|
}
|
|
|
|
// Build MERGE statement manually
|
|
sql := fmt.Sprintf("MERGE INTO %s AS target USING (VALUES (%s)) AS source (%s) ON %s",
|
|
qb.tableName,
|
|
strings.Join(squirrel.Placeholders(len(columns)), ", "),
|
|
strings.Join(columns, ", "),
|
|
strings.Join(onConditions, " AND "),
|
|
)
|
|
|
|
sql += " WHEN MATCHED THEN UPDATE SET " + strings.Join(updateParts, ", ")
|
|
sql += " WHEN NOT MATCHED THEN INSERT (" + strings.Join(columns, ", ") + ") VALUES (" + strings.Join(squirrel.Placeholders(len(columns)), ", ") + ")"
|
|
|
|
// Add OUTPUT clause if specified
|
|
if len(returningColumns) > 0 {
|
|
var outputFields []string
|
|
for _, field := range returningColumns {
|
|
mappedCol := qb.mapAndValidateColumn(field)
|
|
if mappedCol != "" {
|
|
outputFields = append(outputFields, "INSERTED."+squirrel.EscapeIdentifier(mappedCol))
|
|
}
|
|
}
|
|
if len(outputFields) > 0 {
|
|
sql += " OUTPUT " + strings.Join(outputFields, ", ")
|
|
}
|
|
}
|
|
|
|
// Add final values for the INSERT part of MERGE
|
|
values = append(values, values...)
|
|
|
|
return sql, values, nil
|
|
}
|
|
|
|
// Helper methods
|
|
|
|
// buildSelectFields builds the SELECT fields
|
|
func (qb *QueryBuilder) buildSelectFields(fields []string) []string {
|
|
if len(fields) == 0 || (len(fields) == 1 && fields[0] == "*") {
|
|
return []string{"*"}
|
|
}
|
|
|
|
var selectedFields []string
|
|
for _, field := range fields {
|
|
if field == "*" {
|
|
selectedFields = append(selectedFields, "*")
|
|
continue
|
|
}
|
|
|
|
if strings.Contains(field, "(") || strings.Contains(field, " ") {
|
|
if qb.isValidExpression(field) {
|
|
selectedFields = append(selectedFields, field)
|
|
}
|
|
continue
|
|
}
|
|
|
|
mappedCol := qb.mapAndValidateColumn(field)
|
|
if mappedCol != "" {
|
|
selectedFields = append(selectedFields, squirrel.EscapeIdentifier(mappedCol))
|
|
}
|
|
}
|
|
|
|
if len(selectedFields) == 0 {
|
|
return []string{"*"}
|
|
}
|
|
|
|
return selectedFields
|
|
}
|
|
|
|
// buildReturningClause builds the RETURNING (Postgres/SQLite) or OUTPUT (SQL Server) clause
|
|
func (qb *QueryBuilder) buildReturningClause(columns []string) string {
|
|
if len(columns) == 0 {
|
|
return ""
|
|
}
|
|
|
|
var returningFields []string
|
|
for _, field := range columns {
|
|
if field == "*" {
|
|
returningFields = append(returningFields, "*")
|
|
continue
|
|
}
|
|
|
|
mappedCol := qb.mapAndValidateColumn(field)
|
|
if mappedCol != "" {
|
|
switch qb.dbType {
|
|
case DBTypeSQLServer:
|
|
returningFields = append(returningFields, "INSERTED."+squirrel.EscapeIdentifier(mappedCol))
|
|
default: // Postgres, SQLite
|
|
returningFields = append(returningFields, squirrel.EscapeIdentifier(mappedCol))
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(returningFields) == 0 {
|
|
return ""
|
|
}
|
|
|
|
switch qb.dbType {
|
|
case DBTypeSQLServer:
|
|
return "OUTPUT " + strings.Join(returningFields, ", ")
|
|
default: // Postgres, SQLite
|
|
return "RETURNING " + strings.Join(returningFields, ", ")
|
|
}
|
|
}
|
|
|
|
// buildGroupByColumns builds GROUP BY columns
|
|
func (qb *QueryBuilder) buildGroupByColumns(fields []string) []string {
|
|
var groupCols []string
|
|
for _, field := range fields {
|
|
mappedCol := qb.mapAndValidateColumn(field)
|
|
if mappedCol != "" {
|
|
groupCols = append(groupCols, squirrel.EscapeIdentifier(mappedCol))
|
|
}
|
|
}
|
|
return groupCols
|
|
}
|
|
|
|
// buildUpdateSetClause builds the SET clause for UPDATE (used by Postgres/SQLite upsert)
|
|
func (qb *QueryBuilder) buildUpdateSetClause(updateMap map[string]interface{}) string {
|
|
var setParts []string
|
|
for col := range updateMap {
|
|
setParts = append(setParts, fmt.Sprintf("%s = EXCLUDED.%s", squirrel.EscapeIdentifier(col), squirrel.EscapeIdentifier(col)))
|
|
}
|
|
return strings.Join(setParts, ", ")
|
|
}
|
|
|
|
// buildWhereClause builds WHERE/HAVING conditions
|
|
func (qb *QueryBuilder) buildWhereClause(filterGroups []FilterGroup) (string, []interface{}, error) {
|
|
if len(filterGroups) == 0 {
|
|
return "", nil, nil
|
|
}
|
|
|
|
var conditions []string
|
|
var allArgs []interface{}
|
|
|
|
for i, group := range filterGroups {
|
|
groupCondition, groupArgs, err := qb.buildFilterGroup(group)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
if groupCondition != "" {
|
|
if i > 0 {
|
|
logicOp := "AND"
|
|
if group.LogicOp != "" {
|
|
logicOp = strings.ToUpper(group.LogicOp)
|
|
}
|
|
conditions = append(conditions, logicOp)
|
|
}
|
|
conditions = append(conditions, fmt.Sprintf("(%s)", groupCondition))
|
|
allArgs = append(allArgs, groupArgs...)
|
|
}
|
|
}
|
|
|
|
return strings.Join(conditions, " "), allArgs, nil
|
|
}
|
|
|
|
// buildFilterGroup builds conditions for a filter group
|
|
func (qb *QueryBuilder) buildFilterGroup(group FilterGroup) (string, []interface{}, error) {
|
|
if len(group.Filters) == 0 {
|
|
return "", nil, nil
|
|
}
|
|
|
|
var conditions []string
|
|
var args []interface{}
|
|
|
|
for i, filter := range group.Filters {
|
|
condition, filterArgs, err := qb.buildFilterCondition(filter)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
if condition != "" {
|
|
if i > 0 {
|
|
logicOp := "AND"
|
|
if filter.LogicOp != "" {
|
|
logicOp = strings.ToUpper(filter.LogicOp)
|
|
} else if group.LogicOp != "" {
|
|
logicOp = strings.ToUpper(group.LogicOp)
|
|
}
|
|
conditions = append(conditions, logicOp)
|
|
}
|
|
|
|
conditions = append(conditions, condition)
|
|
args = append(args, filterArgs...)
|
|
}
|
|
}
|
|
|
|
return strings.Join(conditions, " "), args, nil
|
|
}
|
|
|
|
// buildFilterCondition builds a single filter condition with dialect-specific logic
|
|
func (qb *QueryBuilder) buildFilterCondition(filter DynamicFilter) (string, []interface{}, error) {
|
|
column := qb.mapAndValidateColumn(filter.Column)
|
|
if column == "" {
|
|
return "", nil, nil
|
|
}
|
|
column = squirrel.EscapeIdentifier(column)
|
|
|
|
switch filter.Operator {
|
|
case OpEqual:
|
|
if filter.Value == nil {
|
|
return fmt.Sprintf("%s IS NULL", column), nil, nil
|
|
}
|
|
return fmt.Sprintf("%s = ?", column), []interface{}{filter.Value}, nil
|
|
case OpNotEqual:
|
|
if filter.Value == nil {
|
|
return fmt.Sprintf("%s IS NOT NULL", column), nil, nil
|
|
}
|
|
return fmt.Sprintf("%s != ?", column), []interface{}{filter.Value}, nil
|
|
case OpLike:
|
|
if filter.Value == nil {
|
|
return "", nil, nil
|
|
}
|
|
return fmt.Sprintf("%s LIKE ?", column), []interface{}{filter.Value}, nil
|
|
case OpILike:
|
|
if filter.Value == nil {
|
|
return "", nil, nil
|
|
}
|
|
// Dialect-specific case-insensitive LIKE
|
|
switch qb.dbType {
|
|
case DBTypePostgreSQL, DBTypeSQLite:
|
|
return fmt.Sprintf("%s ILIKE ?", column), []interface{}{filter.Value}, nil
|
|
case DBTypeMySQL, DBTypeSQLServer:
|
|
// Use LOWER() function for case-insensitive comparison
|
|
return fmt.Sprintf("LOWER(%s) LIKE LOWER(?)", column), []interface{}{filter.Value}, nil
|
|
default:
|
|
// Fallback to case-sensitive LIKE
|
|
return fmt.Sprintf("%s LIKE ?", column), []interface{}{filter.Value}, nil
|
|
}
|
|
case OpIn, OpNotIn:
|
|
values := qb.parseArrayValue(filter.Value)
|
|
if len(values) == 0 {
|
|
return "", nil, nil
|
|
}
|
|
op := "IN"
|
|
if filter.Operator == OpNotIn {
|
|
op = "NOT IN"
|
|
}
|
|
return fmt.Sprintf("%s %s (%s)", column, op, squirrel.Placeholders(len(values))), values, nil
|
|
case OpGreaterThan, OpGreaterThanEqual, OpLessThan, OpLessThanEqual:
|
|
if filter.Value == nil {
|
|
return "", nil, nil
|
|
}
|
|
op := strings.TrimPrefix(string(filter.Operator), "_")
|
|
return fmt.Sprintf("%s %s ?", column, op), []interface{}{filter.Value}, nil
|
|
case OpBetween, OpNotBetween:
|
|
if filter.Value == nil {
|
|
return "", nil, nil
|
|
}
|
|
values := qb.parseArrayValue(filter.Value)
|
|
if len(values) != 2 {
|
|
return "", nil, fmt.Errorf("between operator requires exactly 2 values")
|
|
}
|
|
op := "BETWEEN"
|
|
if filter.Operator == OpNotBetween {
|
|
op = "NOT BETWEEN"
|
|
}
|
|
return fmt.Sprintf("%s %s ? AND ?", column, op), []interface{}{values[0], values[1]}, nil
|
|
case OpNull:
|
|
return fmt.Sprintf("%s IS NULL", column), nil, nil
|
|
case OpNotNull:
|
|
return fmt.Sprintf("%s IS NOT NULL", column), nil, nil
|
|
case OpContains, OpNotContains, OpStartsWith, OpEndsWith:
|
|
if filter.Value == nil {
|
|
return "", nil, nil
|
|
}
|
|
var value string
|
|
switch filter.Operator {
|
|
case OpContains:
|
|
value = fmt.Sprintf("%%%v%%", filter.Value)
|
|
case OpNotContains:
|
|
value = fmt.Sprintf("%%%v%%", filter.Value)
|
|
case OpStartsWith:
|
|
value = fmt.Sprintf("%v%%", filter.Value)
|
|
case OpEndsWith:
|
|
value = fmt.Sprintf("%%%v", filter.Value)
|
|
}
|
|
|
|
// Use the same logic as ILike
|
|
switch qb.dbType {
|
|
case DBTypePostgreSQL, DBTypeSQLite:
|
|
op := "ILIKE"
|
|
if filter.Operator == OpNotContains {
|
|
op = "NOT ILIKE"
|
|
}
|
|
return fmt.Sprintf("%s %s ?", column, op), []interface{}{value}, nil
|
|
case DBTypeMySQL, DBTypeSQLServer:
|
|
op := "LIKE"
|
|
if filter.Operator == OpNotContains {
|
|
op = "NOT LIKE"
|
|
}
|
|
return fmt.Sprintf("LOWER(%s) %s LOWER(?)", column, op), []interface{}{value}, nil
|
|
default:
|
|
op := "LIKE"
|
|
if filter.Operator == OpNotContains {
|
|
op = "NOT LIKE"
|
|
}
|
|
return fmt.Sprintf("%s %s ?", column, op), []interface{}{value}, nil
|
|
}
|
|
default:
|
|
return "", nil, fmt.Errorf("unsupported operator: %s", filter.Operator)
|
|
}
|
|
}
|
|
|
|
// parseArrayValue parses array values from various formats
|
|
func (qb *QueryBuilder) parseArrayValue(value interface{}) []interface{} {
|
|
if value == nil {
|
|
return nil
|
|
}
|
|
if reflect.TypeOf(value).Kind() == reflect.Slice {
|
|
v := reflect.ValueOf(value)
|
|
result := make([]interface{}, v.Len())
|
|
for i := 0; i < v.Len(); i++ {
|
|
result[i] = v.Index(i).Interface()
|
|
}
|
|
return result
|
|
}
|
|
if str, ok := value.(string); ok {
|
|
if strings.Contains(str, ",") {
|
|
parts := strings.Split(str, ",")
|
|
result := make([]interface{}, len(parts))
|
|
for i, part := range parts {
|
|
result[i] = strings.TrimSpace(part)
|
|
}
|
|
return result
|
|
}
|
|
return []interface{}{str}
|
|
}
|
|
return []interface{}{value}
|
|
}
|
|
|
|
// mapAndValidateColumn maps a column name and validates it
|
|
func (qb *QueryBuilder) mapAndValidateColumn(field string) string {
|
|
mappedCol := field
|
|
if mapped, exists := qb.columnMapping[field]; exists {
|
|
mappedCol = mapped
|
|
}
|
|
if len(qb.allowedColumns) > 0 && !qb.allowedColumns[mappedCol] {
|
|
return ""
|
|
}
|
|
if !qb.isValidColumnName(mappedCol) {
|
|
return ""
|
|
}
|
|
return mappedCol
|
|
}
|
|
|
|
// isValidColumnName validates column name format to prevent SQL injection
|
|
func (qb *QueryBuilder) isValidColumnName(column string) bool {
|
|
if column == "" {
|
|
return false
|
|
}
|
|
for _, r := range column {
|
|
if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' || r == '.') {
|
|
return false
|
|
}
|
|
}
|
|
suspiciousPatterns := []string{" ", ";", "--", "/*", "*/", "union", "select", "insert", "update", "delete", "drop", "alter", "create", "exec", "execute", "xp_", "sp_", "information_schema", "sysobjects", "syscolumns", "sysdatabases", "mysql", "pg_", "sqlite"}
|
|
lowerColumn := strings.ToLower(column)
|
|
for _, pattern := range suspiciousPatterns {
|
|
if strings.Contains(lowerColumn, pattern) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// isValidExpression validates SQL expressions
|
|
func (qb *QueryBuilder) isValidExpression(expr string) bool {
|
|
allowedChars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_.,() *-+/"
|
|
for _, r := range expr {
|
|
if !strings.ContainsRune(allowedChars, r) {
|
|
return false
|
|
}
|
|
}
|
|
dangerousPatterns := []string{";", "--", "/*", "*/", "union", "select", "insert", "update", "delete", "drop", "alter", "create", "exec", "execute"}
|
|
lowerExpr := strings.ToLower(expr)
|
|
for _, pattern := range dangerousPatterns {
|
|
if strings.Contains(lowerExpr, pattern) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// ValidateInput performs comprehensive input validation
|
|
func (qb *QueryBuilder) ValidateInput(input interface{}) error {
|
|
switch v := input.(type) {
|
|
case string:
|
|
if len(v) > 1000 {
|
|
return fmt.Errorf("input string too long")
|
|
}
|
|
if strings.Contains(v, "\x00") {
|
|
return fmt.Errorf("invalid characters in input")
|
|
}
|
|
case []string:
|
|
if len(v) > 100 {
|
|
return fmt.Errorf("too many items in array")
|
|
}
|
|
for _, item := range v {
|
|
if err := qb.ValidateInput(item); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SanitizeString sanitizes string inputs
|
|
func (qb *QueryBuilder) SanitizeString(input string) string {
|
|
return strings.TrimSpace(strings.ReplaceAll(input, "\x00", ""))
|
|
}
|
|
|
|
// QueryParser parses HTTP query parameters into DynamicQuery
|
|
type QueryParser struct {
|
|
defaultLimit int
|
|
maxLimit int
|
|
}
|
|
|
|
// NewQueryParser creates a new query parser
|
|
func NewQueryParser() *QueryParser {
|
|
return &QueryParser{defaultLimit: 10, maxLimit: 100}
|
|
}
|
|
|
|
// SetLimits sets default and maximum limits
|
|
func (qp *QueryParser) SetLimits(defaultLimit, maxLimit int) *QueryParser {
|
|
qp.defaultLimit = defaultLimit
|
|
qp.maxLimit = maxLimit
|
|
return qp
|
|
}
|
|
|
|
// ParseQuery parses URL query parameters into DynamicQuery
|
|
func (qp *QueryParser) ParseQuery(values url.Values) (DynamicQuery, error) {
|
|
query := DynamicQuery{Limit: qp.defaultLimit, Offset: 0}
|
|
// ... (implementation remains the same as before) ...
|
|
if fields := values.Get("fields"); fields != "" {
|
|
if fields == "*.*" || fields == "*" {
|
|
query.Fields = []string{"*"}
|
|
} else {
|
|
query.Fields = strings.Split(fields, ",")
|
|
for i, field := range query.Fields {
|
|
query.Fields[i] = strings.TrimSpace(field)
|
|
}
|
|
}
|
|
}
|
|
if limit := values.Get("limit"); limit != "" {
|
|
if l, err := strconv.Atoi(limit); err == nil {
|
|
if l > 0 && l <= qp.maxLimit {
|
|
query.Limit = l
|
|
}
|
|
}
|
|
}
|
|
if offset := values.Get("offset"); offset != "" {
|
|
if o, err := strconv.Atoi(offset); err == nil && o >= 0 {
|
|
query.Offset = o
|
|
}
|
|
}
|
|
filters, err := qp.parseFilters(values)
|
|
if err != nil {
|
|
return query, err
|
|
}
|
|
query.Filters = filters
|
|
sorts, err := qp.parseSorting(values)
|
|
if err != nil {
|
|
return query, err
|
|
}
|
|
query.Sort = sorts
|
|
if groupBy := values.Get("group"); groupBy != "" {
|
|
query.GroupBy = strings.Split(groupBy, ",")
|
|
for i, field := range query.GroupBy {
|
|
query.GroupBy[i] = strings.TrimSpace(field)
|
|
}
|
|
}
|
|
return query, nil
|
|
}
|
|
|
|
// parseFilters, parseSorting, ParseAdvancedFilters, parseDate, parseNumeric remain the same
|
|
func (qp *QueryParser) parseFilters(values url.Values) ([]FilterGroup, error) {
|
|
filterMap := make(map[string]map[string]string)
|
|
for key, vals := range values {
|
|
if strings.HasPrefix(key, "filter[") && strings.HasSuffix(key, "]") {
|
|
parts := strings.Split(key[7:len(key)-1], "][")
|
|
if len(parts) == 2 {
|
|
column, operator := parts[0], parts[1]
|
|
if filterMap[column] == nil {
|
|
filterMap[column] = make(map[string]string)
|
|
}
|
|
if len(vals) > 0 {
|
|
filterMap[column][operator] = vals[0]
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if len(filterMap) == 0 {
|
|
return nil, nil
|
|
}
|
|
var filters []DynamicFilter
|
|
for column, operators := range filterMap {
|
|
for opStr, value := range operators {
|
|
operator := FilterOperator(opStr)
|
|
var parsedValue interface{}
|
|
switch operator {
|
|
case OpIn, OpNotIn:
|
|
if value != "" {
|
|
parsedValue = strings.Split(value, ",")
|
|
}
|
|
case OpBetween, OpNotBetween:
|
|
if value != "" {
|
|
parts := strings.Split(value, ",")
|
|
if len(parts) == 2 {
|
|
parsedValue = []interface{}{strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])}
|
|
}
|
|
}
|
|
case OpNull, OpNotNull:
|
|
parsedValue = nil
|
|
default:
|
|
parsedValue = value
|
|
}
|
|
filters = append(filters, DynamicFilter{Column: column, Operator: operator, Value: parsedValue})
|
|
}
|
|
}
|
|
if len(filters) == 0 {
|
|
return nil, nil
|
|
}
|
|
return []FilterGroup{{Filters: filters, LogicOp: "AND"}}, nil
|
|
}
|
|
|
|
func (qp *QueryParser) parseSorting(values url.Values) ([]SortField, error) {
|
|
sortParam := values.Get("sort")
|
|
if sortParam == "" {
|
|
return nil, nil
|
|
}
|
|
var sorts []SortField
|
|
fields := strings.Split(sortParam, ",")
|
|
for _, field := range fields {
|
|
field = strings.TrimSpace(field)
|
|
if field == "" {
|
|
continue
|
|
}
|
|
order, column := "ASC", field
|
|
if strings.HasPrefix(field, "-") {
|
|
order = "DESC"
|
|
column = field[1:]
|
|
} else if strings.HasPrefix(field, "+") {
|
|
column = field[1:]
|
|
}
|
|
sorts = append(sorts, SortField{Column: column, Order: order})
|
|
}
|
|
return sorts, nil
|
|
}
|
|
|
|
func (qp *QueryParser) ParseAdvancedFilters(filterParam string) ([]FilterGroup, error) {
|
|
return nil, nil
|
|
}
|
|
func parseDate(value string) (interface{}, error) {
|
|
formats := []string{"2006-01-02", "2006-01-02T15:04:05Z", "2006-01-02T15:04:05.000Z", "2006-01-02 15:04:05"}
|
|
for _, format := range formats {
|
|
if t, err := time.Parse(format, value); err == nil {
|
|
return t, nil
|
|
}
|
|
}
|
|
return value, nil
|
|
}
|
|
func parseNumeric(value string) interface{} {
|
|
if i, err := strconv.Atoi(value); err == nil {
|
|
return i
|
|
}
|
|
if f, err := strconv.ParseFloat(value, 64); err == nil {
|
|
return f
|
|
}
|
|
return value
|
|
}
|