Files
antrean-anjungan/internal/utils/query/builder.go
2025-10-23 05:37:33 +07:00

1140 lines
33 KiB
Go

package utils
import (
"fmt"
"net/url"
"reflect"
"strconv"
"strings"
"time"
"github.com/Masterminds/squirrel"
// Still useful for array types, especially with PostgreSQL
)
// DBType represents the type of database
type DBType string
const (
DBTypePostgreSQL DBType = "postgres"
DBTypeMySQL DBType = "mysql"
DBTypeSQLite DBType = "sqlite"
DBTypeSQLServer DBType = "sqlserver"
)
// FilterOperator represents supported filter operators
type FilterOperator string
const (
OpEqual FilterOperator = "_eq"
OpNotEqual FilterOperator = "_neq"
OpLike FilterOperator = "_like"
OpILike FilterOperator = "_ilike"
OpIn FilterOperator = "_in"
OpNotIn FilterOperator = "_nin"
OpGreaterThan FilterOperator = "_gt"
OpGreaterThanEqual FilterOperator = "_gte"
OpLessThan FilterOperator = "_lt"
OpLessThanEqual FilterOperator = "_lte"
OpBetween FilterOperator = "_between"
OpNotBetween FilterOperator = "_nbetween"
OpNull FilterOperator = "_null"
OpNotNull FilterOperator = "_nnull"
OpContains FilterOperator = "_contains"
OpNotContains FilterOperator = "_ncontains"
OpStartsWith FilterOperator = "_starts_with"
OpEndsWith FilterOperator = "_ends_with"
)
// DynamicFilter represents a single filter condition
type DynamicFilter struct {
Column string `json:"column"`
Operator FilterOperator `json:"operator"`
Value interface{} `json:"value"`
LogicOp string `json:"logic_op,omitempty"` // AND, OR
}
// FilterGroup represents a group of filters
type FilterGroup struct {
Filters []DynamicFilter `json:"filters"`
LogicOp string `json:"logic_op"` // AND, OR
}
// DynamicQuery represents the complete query structure
type DynamicQuery struct {
Fields []string `json:"fields,omitempty"`
Filters []FilterGroup `json:"filters,omitempty"`
Sort []SortField `json:"sort,omitempty"`
Limit int `json:"limit"`
Offset int `json:"offset"`
GroupBy []string `json:"group_by,omitempty"`
Having []FilterGroup `json:"having,omitempty"`
}
// SortField represents sorting configuration
type SortField struct {
Column string `json:"column"`
Order string `json:"order"` // ASC, DESC
}
// UpdateData represents data for UPDATE operations
type UpdateData struct {
Columns []string `json:"columns"`
Values []interface{} `json:"values"`
}
// InsertData represents data for INSERT operations
type InsertData struct {
Columns []string `json:"columns"`
Values []interface{} `json:"values"`
}
// QueryBuilder builds SQL queries from dynamic filters using squirrel
type QueryBuilder struct {
tableName string
columnMapping map[string]string // Maps API field names to DB column names
allowedColumns map[string]bool // Security: only allow specified columns
dbType DBType
sqlBuilder squirrel.StatementBuilderType
}
// NewQueryBuilder creates a new query builder instance for a specific database type
func NewQueryBuilder(tableName string, dbType DBType) *QueryBuilder {
var placeholderFormat squirrel.PlaceholderFormat
switch dbType {
case DBTypePostgreSQL:
placeholderFormat = squirrel.Dollar
case DBTypeMySQL, DBTypeSQLite:
placeholderFormat = squirrel.Question
case DBTypeSQLServer:
placeholderFormat = squirrel.AtP
default:
// Default to a common format if an unknown type is provided
placeholderFormat = squirrel.Question
}
return &QueryBuilder{
tableName: tableName,
columnMapping: make(map[string]string),
allowedColumns: make(map[string]bool),
dbType: dbType,
sqlBuilder: squirrel.StatementBuilder.PlaceholderFormat(placeholderFormat),
}
}
// SetColumnMapping sets the mapping between API field names and database column names
func (qb *QueryBuilder) SetColumnMapping(mapping map[string]string) *QueryBuilder {
qb.columnMapping = mapping
return qb
}
// SetAllowedColumns sets the list of allowed columns for security
func (qb *QueryBuilder) SetAllowedColumns(columns []string) *QueryBuilder {
qb.allowedColumns = make(map[string]bool)
for _, col := range columns {
qb.allowedColumns[col] = true
}
return qb
}
// BuildQuery builds the complete SQL SELECT query
func (qb *QueryBuilder) BuildQuery(query DynamicQuery) (string, []interface{}, error) {
// Start with base query
baseQuery := qb.sqlBuilder.Select(qb.buildSelectFields(query.Fields)...).From(qb.tableName)
// Apply WHERE conditions
if len(query.Filters) > 0 {
whereClause, args, err := qb.buildWhereClause(query.Filters)
if err != nil {
return "", nil, err
}
baseQuery = baseQuery.Where(whereClause, args...)
}
// Apply GROUP BY
if len(query.GroupBy) > 0 {
groupByCols := qb.buildGroupByColumns(query.GroupBy)
if len(groupByCols) > 0 {
baseQuery = baseQuery.GroupBy(groupByCols...)
}
}
// Apply HAVING conditions
if len(query.Having) > 0 {
havingClause, args, err := qb.buildWhereClause(query.Having)
if err != nil {
return "", nil, err
}
baseQuery = baseQuery.Having(havingClause, args...)
}
// Apply ORDER BY
if len(query.Sort) > 0 {
for _, sort := range query.Sort {
column := qb.mapAndValidateColumn(sort.Column)
if column == "" {
continue // Skip invalid columns
}
order := "ASC"
if strings.ToUpper(sort.Order) == "DESC" {
order = "DESC"
}
baseQuery = baseQuery.OrderBy(fmt.Sprintf("%s %s", squirrel.EscapeIdentifier(column), order))
}
}
// Apply pagination with dialect-specific syntax
if query.Limit > 0 {
if qb.dbType == DBTypeSQLServer {
// SQL Server uses OFFSET-FETCH syntax
baseQuery = baseQuery.Offset(uint64(query.Offset)).Fetch(uint64(query.Limit))
} else {
// PostgreSQL, MySQL, SQLite use LIMIT/OFFSET
baseQuery = baseQuery.Limit(uint64(query.Limit))
if query.Offset > 0 {
baseQuery = baseQuery.Offset(uint64(query.Offset))
}
}
} else if query.Offset > 0 && qb.dbType != DBTypeSQLServer {
// SQL Server requires FETCH with OFFSET
baseQuery = baseQuery.Offset(uint64(query.Offset))
}
// Build final query
sql, args, err := baseQuery.ToSql()
if err != nil {
return "", nil, fmt.Errorf("failed to build query: %w", err)
}
return sql, args, nil
}
// BuildInsertQuery builds an INSERT query
func (qb *QueryBuilder) BuildInsertQuery(data InsertData, returningColumns ...string) (string, []interface{}, error) {
if len(data.Columns) == 0 || len(data.Values) == 0 {
return "", nil, fmt.Errorf("no columns or values provided for INSERT")
}
if len(data.Columns) != len(data.Values) {
return "", nil, fmt.Errorf("columns and values count mismatch for INSERT")
}
// Prepare columns and values
columns := make([]string, 0, len(data.Columns))
values := make([]interface{}, 0, len(data.Values))
for i, col := range data.Columns {
mappedCol := qb.mapAndValidateColumn(col)
if mappedCol == "" {
continue // Skip invalid columns
}
columns = append(columns, mappedCol)
values = append(values, data.Values[i])
}
if len(columns) == 0 {
return "", nil, fmt.Errorf("no valid columns provided for INSERT")
}
// Build INSERT query
query := qb.sqlBuilder.Insert(qb.tableName).Columns(columns...).Values(values...)
// Add RETURNING/OUTPUT clause if specified and supported
if len(returningColumns) > 0 {
returningClause := qb.buildReturningClause(returningColumns)
if returningClause != "" {
query = query.Suffix(returningClause)
}
}
sql, args, err := query.ToSql()
if err != nil {
return "", nil, fmt.Errorf("failed to build INSERT query: %w", err)
}
return sql, args, nil
}
// BuildUpdateQuery builds an UPDATE query
func (qb *QueryBuilder) BuildUpdateQuery(data UpdateData, filters []FilterGroup, returningColumns ...string) (string, []interface{}, error) {
if len(data.Columns) == 0 || len(data.Values) == 0 {
return "", nil, fmt.Errorf("no columns or values provided for UPDATE")
}
if len(data.Columns) != len(data.Values) {
return "", nil, fmt.Errorf("columns and values count mismatch for UPDATE")
}
// Prepare SET clause
setMap := make(map[string]interface{})
for i, col := range data.Columns {
mappedCol := qb.mapAndValidateColumn(col)
if mappedCol == "" {
continue // Skip invalid columns
}
setMap[mappedCol] = data.Values[i]
}
if len(setMap) == 0 {
return "", nil, fmt.Errorf("no valid columns provided for UPDATE")
}
// Build UPDATE query
query := qb.sqlBuilder.Update(qb.tableName).SetMap(setMap)
// Apply WHERE conditions
if len(filters) > 0 {
whereClause, args, err := qb.buildWhereClause(filters)
if err != nil {
return "", nil, err
}
query = query.Where(whereClause, args...)
}
// Add RETURNING/OUTPUT clause if specified and supported
if len(returningColumns) > 0 {
returningClause := qb.buildReturningClause(returningColumns)
if returningClause != "" {
query = query.Suffix(returningClause)
}
}
sql, args, err := query.ToSql()
if err != nil {
return "", nil, fmt.Errorf("failed to build UPDATE query: %w", err)
}
return sql, args, nil
}
// BuildDeleteQuery builds a DELETE query
func (qb *QueryBuilder) BuildDeleteQuery(filters []FilterGroup, returningColumns ...string) (string, []interface{}, error) {
// Build DELETE query
query := qb.sqlBuilder.Delete(qb.tableName)
// Apply WHERE conditions
if len(filters) > 0 {
whereClause, args, err := qb.buildWhereClause(filters)
if err != nil {
return "", nil, err
}
query = query.Where(whereClause, args...)
}
// Add RETURNING/OUTPUT clause if specified and supported
if len(returningColumns) > 0 {
returningClause := qb.buildReturningClause(returningColumns)
if returningClause != "" {
query = query.Suffix(returningClause)
}
}
sql, args, err := query.ToSql()
if err != nil {
return "", nil, fmt.Errorf("failed to build DELETE query: %w", err)
}
return sql, args, nil
}
// BuildCountQuery builds a count query
func (qb *QueryBuilder) BuildCountQuery(query DynamicQuery) (string, []interface{}, error) {
// Start with COUNT query
baseQuery := qb.sqlBuilder.Select("COUNT(*)").From(qb.tableName)
// Apply WHERE conditions
if len(query.Filters) > 0 {
whereClause, args, err := qb.buildWhereClause(query.Filters)
if err != nil {
return "", nil, err
}
baseQuery = baseQuery.Where(whereClause, args...)
}
// Apply GROUP BY
if len(query.GroupBy) > 0 {
groupByCols := qb.buildGroupByColumns(query.GroupBy)
if len(groupByCols) > 0 {
baseQuery = baseQuery.GroupBy(groupByCols...)
}
}
// Apply HAVING conditions
if len(query.Having) > 0 {
havingClause, args, err := qb.buildWhereClause(query.Having)
if err != nil {
return "", nil, err
}
baseQuery = baseQuery.Having(havingClause, args...)
}
sql, args, err := baseQuery.ToSql()
if err != nil {
return "", nil, fmt.Errorf("failed to build COUNT query: %w", err)
}
return sql, args, nil
}
// BuildUpsertQuery builds an upsert query with dialect-specific syntax
func (qb *QueryBuilder) BuildUpsertQuery(data InsertData, conflictColumns []string, updateData UpdateData, returningColumns ...string) (string, []interface{}, error) {
switch qb.dbType {
case DBTypePostgreSQL, DBTypeSQLite:
return qb.buildPostgresUpsert(data, conflictColumns, updateData, returningColumns...)
case DBTypeMySQL:
return qb.buildMySQLUpsert(data, conflictColumns, updateData, returningColumns...)
case DBTypeSQLServer:
return qb.buildSQLServerUpsert(data, conflictColumns, updateData, returningColumns...)
default:
return "", nil, fmt.Errorf("upsert operation not supported for database type: %s", qb.dbType)
}
}
// buildPostgresUpsert builds an UPSERT query for PostgreSQL/SQLite (ON CONFLICT)
func (qb *QueryBuilder) buildPostgresUpsert(data InsertData, conflictColumns []string, updateData UpdateData, returningColumns ...string) (string, []interface{}, error) {
// ... (Validation logic from the original BuildUpsertQuery) ...
if len(data.Columns) == 0 || len(data.Values) == 0 || len(data.Columns) != len(data.Values) || len(conflictColumns) == 0 {
return "", nil, fmt.Errorf("invalid arguments for PostgreSQL/SQLite upsert")
}
// Prepare columns and values
columns := make([]string, 0, len(data.Columns))
values := make([]interface{}, 0, len(data.Values))
for i, col := range data.Columns {
mappedCol := qb.mapAndValidateColumn(col)
if mappedCol == "" {
continue
}
columns = append(columns, mappedCol)
values = append(values, data.Values[i])
}
if len(columns) == 0 {
return "", nil, fmt.Errorf("no valid columns for upsert")
}
// Prepare conflict columns
conflictCols := make([]string, 0, len(conflictColumns))
for _, col := range conflictColumns {
mappedCol := qb.mapAndValidateColumn(col)
if mappedCol == "" {
continue
}
conflictCols = append(conflictCols, mappedCol)
}
if len(conflictCols) == 0 {
return "", nil, fmt.Errorf("no valid conflict columns for upsert")
}
// Prepare update clause
updateMap := make(map[string]interface{})
for i, col := range updateData.Columns {
mappedCol := qb.mapAndValidateColumn(col)
if mappedCol == "" {
continue
}
updateMap[mappedCol] = updateData.Values[i]
}
if len(updateMap) == 0 {
return "", nil, fmt.Errorf("no valid update columns for upsert")
}
// Build query
query := qb.sqlBuilder.Insert(qb.tableName).Columns(columns...).Values(values...)
// ON CONFLICT clause
onConflictSuffix := fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET %s",
strings.Join(conflictCols, ", "),
qb.buildUpdateSetClause(updateMap))
query = query.Suffix(onConflictSuffix)
// Add RETURNING clause if specified
if len(returningColumns) > 0 {
returningClause := qb.buildReturningClause(returningColumns)
if returningClause != "" {
query = query.Suffix(" " + returningClause) // Prepend space
}
}
sql, args, err := query.ToSql()
return sql, args, err
}
// buildMySQLUpsert builds an UPSERT query for MySQL (ON DUPLICATE KEY UPDATE)
func (qb *QueryBuilder) buildMySQLUpsert(data InsertData, conflictColumns []string, updateData UpdateData, returningColumns ...string) (string, []interface{}, error) {
// ... (Validation logic) ...
if len(data.Columns) == 0 || len(data.Values) == 0 || len(data.Columns) != len(data.Values) || len(updateData.Columns) == 0 {
return "", nil, fmt.Errorf("invalid arguments for MySQL upsert")
}
// Prepare columns and values
columns := make([]string, 0, len(data.Columns))
values := make([]interface{}, 0, len(data.Values))
for i, col := range data.Columns {
mappedCol := qb.mapAndValidateColumn(col)
if mappedCol == "" {
continue
}
columns = append(columns, mappedCol)
values = append(values, data.Values[i])
}
if len(columns) == 0 {
return "", nil, fmt.Errorf("no valid columns for upsert")
}
// Prepare update clause
var updateParts []string
updateArgs := make([]interface{}, 0)
for i, col := range updateData.Columns {
mappedCol := qb.mapAndValidateColumn(col)
if mappedCol == "" {
continue
}
// In MySQL, you can reference the new values with VALUES(column_name)
updateParts = append(updateParts, fmt.Sprintf("%s = VALUES(?)", squirrel.EscapeIdentifier(mappedCol)))
updateArgs = append(updateArgs, mappedCol) // The placeholder is for the column name itself
updateArgs = append(updateArgs, updateData.Values[i])
}
if len(updateParts) == 0 {
return "", nil, fmt.Errorf("no valid update columns for upsert")
}
// Build query
query := qb.sqlBuilder.Insert(qb.tableName).Columns(columns...).Values(values...)
// ON DUPLICATE KEY UPDATE clause
onDuplicateSuffix := fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", strings.Join(updateParts, ", "))
query = query.Suffix(onDuplicateSuffix)
// MySQL doesn't support RETURNING in the same way. This is a limitation.
// Applications would need to run a separate SELECT query.
// We will ignore returningColumns for MySQL.
sql, args, err := query.ToSql()
if err != nil {
return "", nil, err
}
// We need to merge args from the main query and the suffix
allArgs := append(args, updateArgs...)
return sql, allArgs, nil
}
// buildSQLServerUpsert builds an UPSERT query for SQL Server (MERGE)
func (qb *QueryBuilder) buildSQLServerUpsert(data InsertData, conflictColumns []string, updateData UpdateData, returningColumns ...string) (string, []interface{}, error) {
// MERGE is complex to build with squirrel's high-level API.
// We'll build it more manually but still use squirrel for escaping.
// This is a simplified version. A full MERGE can be more complex.
// ... (Validation logic) ...
if len(data.Columns) == 0 || len(data.Values) == 0 || len(data.Columns) != len(data.Values) || len(conflictColumns) == 0 || len(updateData.Columns) == 0 {
return "", nil, fmt.Errorf("invalid arguments for SQL Server upsert")
}
// Prepare columns and values
columns := make([]string, 0, len(data.Columns))
values := make([]interface{}, 0, len(data.Values))
for i, col := range data.Columns {
mappedCol := qb.mapAndValidateColumn(col)
if mappedCol == "" {
continue
}
columns = append(columns, mappedCol)
values = append(values, data.Values[i])
}
if len(columns) == 0 {
return "", nil, fmt.Errorf("no valid columns for upsert")
}
// Prepare conflict columns (target for ON clause)
var onConditions []string
for _, col := range conflictColumns {
mappedCol := qb.mapAndValidateColumn(col)
if mappedCol == "" {
continue
}
onConditions = append(onConditions, fmt.Sprintf("target.%s = source.%s", squirrel.EscapeIdentifier(mappedCol), squirrel.EscapeIdentifier(mappedCol)))
}
// Prepare update clause
var updateParts []string
for i, col := range updateData.Columns {
mappedCol := qb.mapAndValidateColumn(col)
if mappedCol == "" {
continue
}
updateParts = append(updateParts, fmt.Sprintf("target.%s = ?", squirrel.EscapeIdentifier(mappedCol)))
values = append(values, updateData.Values[i])
}
// Build MERGE statement manually
sql := fmt.Sprintf("MERGE INTO %s AS target USING (VALUES (%s)) AS source (%s) ON %s",
qb.tableName,
strings.Join(squirrel.Placeholders(len(columns)), ", "),
strings.Join(columns, ", "),
strings.Join(onConditions, " AND "),
)
sql += " WHEN MATCHED THEN UPDATE SET " + strings.Join(updateParts, ", ")
sql += " WHEN NOT MATCHED THEN INSERT (" + strings.Join(columns, ", ") + ") VALUES (" + strings.Join(squirrel.Placeholders(len(columns)), ", ") + ")"
// Add OUTPUT clause if specified
if len(returningColumns) > 0 {
var outputFields []string
for _, field := range returningColumns {
mappedCol := qb.mapAndValidateColumn(field)
if mappedCol != "" {
outputFields = append(outputFields, "INSERTED."+squirrel.EscapeIdentifier(mappedCol))
}
}
if len(outputFields) > 0 {
sql += " OUTPUT " + strings.Join(outputFields, ", ")
}
}
// Add final values for the INSERT part of MERGE
values = append(values, values...)
return sql, values, nil
}
// Helper methods
// buildSelectFields builds the SELECT fields
func (qb *QueryBuilder) buildSelectFields(fields []string) []string {
if len(fields) == 0 || (len(fields) == 1 && fields[0] == "*") {
return []string{"*"}
}
var selectedFields []string
for _, field := range fields {
if field == "*" {
selectedFields = append(selectedFields, "*")
continue
}
if strings.Contains(field, "(") || strings.Contains(field, " ") {
if qb.isValidExpression(field) {
selectedFields = append(selectedFields, field)
}
continue
}
mappedCol := qb.mapAndValidateColumn(field)
if mappedCol != "" {
selectedFields = append(selectedFields, squirrel.EscapeIdentifier(mappedCol))
}
}
if len(selectedFields) == 0 {
return []string{"*"}
}
return selectedFields
}
// buildReturningClause builds the RETURNING (Postgres/SQLite) or OUTPUT (SQL Server) clause
func (qb *QueryBuilder) buildReturningClause(columns []string) string {
if len(columns) == 0 {
return ""
}
var returningFields []string
for _, field := range columns {
if field == "*" {
returningFields = append(returningFields, "*")
continue
}
mappedCol := qb.mapAndValidateColumn(field)
if mappedCol != "" {
switch qb.dbType {
case DBTypeSQLServer:
returningFields = append(returningFields, "INSERTED."+squirrel.EscapeIdentifier(mappedCol))
default: // Postgres, SQLite
returningFields = append(returningFields, squirrel.EscapeIdentifier(mappedCol))
}
}
}
if len(returningFields) == 0 {
return ""
}
switch qb.dbType {
case DBTypeSQLServer:
return "OUTPUT " + strings.Join(returningFields, ", ")
default: // Postgres, SQLite
return "RETURNING " + strings.Join(returningFields, ", ")
}
}
// buildGroupByColumns builds GROUP BY columns
func (qb *QueryBuilder) buildGroupByColumns(fields []string) []string {
var groupCols []string
for _, field := range fields {
mappedCol := qb.mapAndValidateColumn(field)
if mappedCol != "" {
groupCols = append(groupCols, squirrel.EscapeIdentifier(mappedCol))
}
}
return groupCols
}
// buildUpdateSetClause builds the SET clause for UPDATE (used by Postgres/SQLite upsert)
func (qb *QueryBuilder) buildUpdateSetClause(updateMap map[string]interface{}) string {
var setParts []string
for col := range updateMap {
setParts = append(setParts, fmt.Sprintf("%s = EXCLUDED.%s", squirrel.EscapeIdentifier(col), squirrel.EscapeIdentifier(col)))
}
return strings.Join(setParts, ", ")
}
// buildWhereClause builds WHERE/HAVING conditions
func (qb *QueryBuilder) buildWhereClause(filterGroups []FilterGroup) (string, []interface{}, error) {
if len(filterGroups) == 0 {
return "", nil, nil
}
var conditions []string
var allArgs []interface{}
for i, group := range filterGroups {
groupCondition, groupArgs, err := qb.buildFilterGroup(group)
if err != nil {
return "", nil, err
}
if groupCondition != "" {
if i > 0 {
logicOp := "AND"
if group.LogicOp != "" {
logicOp = strings.ToUpper(group.LogicOp)
}
conditions = append(conditions, logicOp)
}
conditions = append(conditions, fmt.Sprintf("(%s)", groupCondition))
allArgs = append(allArgs, groupArgs...)
}
}
return strings.Join(conditions, " "), allArgs, nil
}
// buildFilterGroup builds conditions for a filter group
func (qb *QueryBuilder) buildFilterGroup(group FilterGroup) (string, []interface{}, error) {
if len(group.Filters) == 0 {
return "", nil, nil
}
var conditions []string
var args []interface{}
for i, filter := range group.Filters {
condition, filterArgs, err := qb.buildFilterCondition(filter)
if err != nil {
return "", nil, err
}
if condition != "" {
if i > 0 {
logicOp := "AND"
if filter.LogicOp != "" {
logicOp = strings.ToUpper(filter.LogicOp)
} else if group.LogicOp != "" {
logicOp = strings.ToUpper(group.LogicOp)
}
conditions = append(conditions, logicOp)
}
conditions = append(conditions, condition)
args = append(args, filterArgs...)
}
}
return strings.Join(conditions, " "), args, nil
}
// buildFilterCondition builds a single filter condition with dialect-specific logic
func (qb *QueryBuilder) buildFilterCondition(filter DynamicFilter) (string, []interface{}, error) {
column := qb.mapAndValidateColumn(filter.Column)
if column == "" {
return "", nil, nil
}
column = squirrel.EscapeIdentifier(column)
switch filter.Operator {
case OpEqual:
if filter.Value == nil {
return fmt.Sprintf("%s IS NULL", column), nil, nil
}
return fmt.Sprintf("%s = ?", column), []interface{}{filter.Value}, nil
case OpNotEqual:
if filter.Value == nil {
return fmt.Sprintf("%s IS NOT NULL", column), nil, nil
}
return fmt.Sprintf("%s != ?", column), []interface{}{filter.Value}, nil
case OpLike:
if filter.Value == nil {
return "", nil, nil
}
return fmt.Sprintf("%s LIKE ?", column), []interface{}{filter.Value}, nil
case OpILike:
if filter.Value == nil {
return "", nil, nil
}
// Dialect-specific case-insensitive LIKE
switch qb.dbType {
case DBTypePostgreSQL, DBTypeSQLite:
return fmt.Sprintf("%s ILIKE ?", column), []interface{}{filter.Value}, nil
case DBTypeMySQL, DBTypeSQLServer:
// Use LOWER() function for case-insensitive comparison
return fmt.Sprintf("LOWER(%s) LIKE LOWER(?)", column), []interface{}{filter.Value}, nil
default:
// Fallback to case-sensitive LIKE
return fmt.Sprintf("%s LIKE ?", column), []interface{}{filter.Value}, nil
}
case OpIn, OpNotIn:
values := qb.parseArrayValue(filter.Value)
if len(values) == 0 {
return "", nil, nil
}
op := "IN"
if filter.Operator == OpNotIn {
op = "NOT IN"
}
return fmt.Sprintf("%s %s (%s)", column, op, squirrel.Placeholders(len(values))), values, nil
case OpGreaterThan, OpGreaterThanEqual, OpLessThan, OpLessThanEqual:
if filter.Value == nil {
return "", nil, nil
}
op := strings.TrimPrefix(string(filter.Operator), "_")
return fmt.Sprintf("%s %s ?", column, op), []interface{}{filter.Value}, nil
case OpBetween, OpNotBetween:
if filter.Value == nil {
return "", nil, nil
}
values := qb.parseArrayValue(filter.Value)
if len(values) != 2 {
return "", nil, fmt.Errorf("between operator requires exactly 2 values")
}
op := "BETWEEN"
if filter.Operator == OpNotBetween {
op = "NOT BETWEEN"
}
return fmt.Sprintf("%s %s ? AND ?", column, op), []interface{}{values[0], values[1]}, nil
case OpNull:
return fmt.Sprintf("%s IS NULL", column), nil, nil
case OpNotNull:
return fmt.Sprintf("%s IS NOT NULL", column), nil, nil
case OpContains, OpNotContains, OpStartsWith, OpEndsWith:
if filter.Value == nil {
return "", nil, nil
}
var value string
switch filter.Operator {
case OpContains:
value = fmt.Sprintf("%%%v%%", filter.Value)
case OpNotContains:
value = fmt.Sprintf("%%%v%%", filter.Value)
case OpStartsWith:
value = fmt.Sprintf("%v%%", filter.Value)
case OpEndsWith:
value = fmt.Sprintf("%%%v", filter.Value)
}
// Use the same logic as ILike
switch qb.dbType {
case DBTypePostgreSQL, DBTypeSQLite:
op := "ILIKE"
if filter.Operator == OpNotContains {
op = "NOT ILIKE"
}
return fmt.Sprintf("%s %s ?", column, op), []interface{}{value}, nil
case DBTypeMySQL, DBTypeSQLServer:
op := "LIKE"
if filter.Operator == OpNotContains {
op = "NOT LIKE"
}
return fmt.Sprintf("LOWER(%s) %s LOWER(?)", column, op), []interface{}{value}, nil
default:
op := "LIKE"
if filter.Operator == OpNotContains {
op = "NOT LIKE"
}
return fmt.Sprintf("%s %s ?", column, op), []interface{}{value}, nil
}
default:
return "", nil, fmt.Errorf("unsupported operator: %s", filter.Operator)
}
}
// parseArrayValue parses array values from various formats
func (qb *QueryBuilder) parseArrayValue(value interface{}) []interface{} {
if value == nil {
return nil
}
if reflect.TypeOf(value).Kind() == reflect.Slice {
v := reflect.ValueOf(value)
result := make([]interface{}, v.Len())
for i := 0; i < v.Len(); i++ {
result[i] = v.Index(i).Interface()
}
return result
}
if str, ok := value.(string); ok {
if strings.Contains(str, ",") {
parts := strings.Split(str, ",")
result := make([]interface{}, len(parts))
for i, part := range parts {
result[i] = strings.TrimSpace(part)
}
return result
}
return []interface{}{str}
}
return []interface{}{value}
}
// mapAndValidateColumn maps a column name and validates it
func (qb *QueryBuilder) mapAndValidateColumn(field string) string {
mappedCol := field
if mapped, exists := qb.columnMapping[field]; exists {
mappedCol = mapped
}
if len(qb.allowedColumns) > 0 && !qb.allowedColumns[mappedCol] {
return ""
}
if !qb.isValidColumnName(mappedCol) {
return ""
}
return mappedCol
}
// isValidColumnName validates column name format to prevent SQL injection
func (qb *QueryBuilder) isValidColumnName(column string) bool {
if column == "" {
return false
}
for _, r := range column {
if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' || r == '.') {
return false
}
}
suspiciousPatterns := []string{" ", ";", "--", "/*", "*/", "union", "select", "insert", "update", "delete", "drop", "alter", "create", "exec", "execute", "xp_", "sp_", "information_schema", "sysobjects", "syscolumns", "sysdatabases", "mysql", "pg_", "sqlite"}
lowerColumn := strings.ToLower(column)
for _, pattern := range suspiciousPatterns {
if strings.Contains(lowerColumn, pattern) {
return false
}
}
return true
}
// isValidExpression validates SQL expressions
func (qb *QueryBuilder) isValidExpression(expr string) bool {
allowedChars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_.,() *-+/"
for _, r := range expr {
if !strings.ContainsRune(allowedChars, r) {
return false
}
}
dangerousPatterns := []string{";", "--", "/*", "*/", "union", "select", "insert", "update", "delete", "drop", "alter", "create", "exec", "execute"}
lowerExpr := strings.ToLower(expr)
for _, pattern := range dangerousPatterns {
if strings.Contains(lowerExpr, pattern) {
return false
}
}
return true
}
// ValidateInput performs comprehensive input validation
func (qb *QueryBuilder) ValidateInput(input interface{}) error {
switch v := input.(type) {
case string:
if len(v) > 1000 {
return fmt.Errorf("input string too long")
}
if strings.Contains(v, "\x00") {
return fmt.Errorf("invalid characters in input")
}
case []string:
if len(v) > 100 {
return fmt.Errorf("too many items in array")
}
for _, item := range v {
if err := qb.ValidateInput(item); err != nil {
return err
}
}
}
return nil
}
// SanitizeString sanitizes string inputs
func (qb *QueryBuilder) SanitizeString(input string) string {
return strings.TrimSpace(strings.ReplaceAll(input, "\x00", ""))
}
// QueryParser parses HTTP query parameters into DynamicQuery
type QueryParser struct {
defaultLimit int
maxLimit int
}
// NewQueryParser creates a new query parser
func NewQueryParser() *QueryParser {
return &QueryParser{defaultLimit: 10, maxLimit: 100}
}
// SetLimits sets default and maximum limits
func (qp *QueryParser) SetLimits(defaultLimit, maxLimit int) *QueryParser {
qp.defaultLimit = defaultLimit
qp.maxLimit = maxLimit
return qp
}
// ParseQuery parses URL query parameters into DynamicQuery
func (qp *QueryParser) ParseQuery(values url.Values) (DynamicQuery, error) {
query := DynamicQuery{Limit: qp.defaultLimit, Offset: 0}
// ... (implementation remains the same as before) ...
if fields := values.Get("fields"); fields != "" {
if fields == "*.*" || fields == "*" {
query.Fields = []string{"*"}
} else {
query.Fields = strings.Split(fields, ",")
for i, field := range query.Fields {
query.Fields[i] = strings.TrimSpace(field)
}
}
}
if limit := values.Get("limit"); limit != "" {
if l, err := strconv.Atoi(limit); err == nil {
if l > 0 && l <= qp.maxLimit {
query.Limit = l
}
}
}
if offset := values.Get("offset"); offset != "" {
if o, err := strconv.Atoi(offset); err == nil && o >= 0 {
query.Offset = o
}
}
filters, err := qp.parseFilters(values)
if err != nil {
return query, err
}
query.Filters = filters
sorts, err := qp.parseSorting(values)
if err != nil {
return query, err
}
query.Sort = sorts
if groupBy := values.Get("group"); groupBy != "" {
query.GroupBy = strings.Split(groupBy, ",")
for i, field := range query.GroupBy {
query.GroupBy[i] = strings.TrimSpace(field)
}
}
return query, nil
}
// parseFilters, parseSorting, ParseAdvancedFilters, parseDate, parseNumeric remain the same
func (qp *QueryParser) parseFilters(values url.Values) ([]FilterGroup, error) {
filterMap := make(map[string]map[string]string)
for key, vals := range values {
if strings.HasPrefix(key, "filter[") && strings.HasSuffix(key, "]") {
parts := strings.Split(key[7:len(key)-1], "][")
if len(parts) == 2 {
column, operator := parts[0], parts[1]
if filterMap[column] == nil {
filterMap[column] = make(map[string]string)
}
if len(vals) > 0 {
filterMap[column][operator] = vals[0]
}
}
}
}
if len(filterMap) == 0 {
return nil, nil
}
var filters []DynamicFilter
for column, operators := range filterMap {
for opStr, value := range operators {
operator := FilterOperator(opStr)
var parsedValue interface{}
switch operator {
case OpIn, OpNotIn:
if value != "" {
parsedValue = strings.Split(value, ",")
}
case OpBetween, OpNotBetween:
if value != "" {
parts := strings.Split(value, ",")
if len(parts) == 2 {
parsedValue = []interface{}{strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])}
}
}
case OpNull, OpNotNull:
parsedValue = nil
default:
parsedValue = value
}
filters = append(filters, DynamicFilter{Column: column, Operator: operator, Value: parsedValue})
}
}
if len(filters) == 0 {
return nil, nil
}
return []FilterGroup{{Filters: filters, LogicOp: "AND"}}, nil
}
func (qp *QueryParser) parseSorting(values url.Values) ([]SortField, error) {
sortParam := values.Get("sort")
if sortParam == "" {
return nil, nil
}
var sorts []SortField
fields := strings.Split(sortParam, ",")
for _, field := range fields {
field = strings.TrimSpace(field)
if field == "" {
continue
}
order, column := "ASC", field
if strings.HasPrefix(field, "-") {
order = "DESC"
column = field[1:]
} else if strings.HasPrefix(field, "+") {
column = field[1:]
}
sorts = append(sorts, SortField{Column: column, Order: order})
}
return sorts, nil
}
func (qp *QueryParser) ParseAdvancedFilters(filterParam string) ([]FilterGroup, error) {
return nil, nil
}
func parseDate(value string) (interface{}, error) {
formats := []string{"2006-01-02", "2006-01-02T15:04:05Z", "2006-01-02T15:04:05.000Z", "2006-01-02 15:04:05"}
for _, format := range formats {
if t, err := time.Parse(format, value); err == nil {
return t, nil
}
}
return value, nil
}
func parseNumeric(value string) interface{} {
if i, err := strconv.Atoi(value); err == nil {
return i
}
if f, err := strconv.ParseFloat(value, 64); err == nil {
return f
}
return value
}