first commit
This commit is contained in:
739
internal/config/config.go
Normal file
739
internal/config/config.go
Normal file
@@ -0,0 +1,739 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig
|
||||
Databases map[string]DatabaseConfig
|
||||
ReadReplicas map[string][]DatabaseConfig // For read replicas
|
||||
Keycloak KeycloakConfig
|
||||
Bpjs BpjsConfig
|
||||
SatuSehat SatuSehatConfig
|
||||
Swagger SwaggerConfig
|
||||
Validator *validator.Validate
|
||||
}
|
||||
|
||||
type SwaggerConfig struct {
|
||||
Title string
|
||||
Description string
|
||||
Version string
|
||||
TermsOfService string
|
||||
ContactName string
|
||||
ContactURL string
|
||||
ContactEmail string
|
||||
LicenseName string
|
||||
LicenseURL string
|
||||
Host string
|
||||
BasePath string
|
||||
Schemes []string
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Port int
|
||||
Mode string
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Name string
|
||||
Type string // postgres, mysql, sqlserver, sqlite, mongodb
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
Schema string
|
||||
SSLMode string
|
||||
Path string // For SQLite
|
||||
Options string // Additional connection options
|
||||
MaxOpenConns int // Max open connections
|
||||
MaxIdleConns int // Max idle connections
|
||||
ConnMaxLifetime time.Duration // Connection max lifetime
|
||||
}
|
||||
|
||||
type KeycloakConfig struct {
|
||||
Issuer string
|
||||
Audience string
|
||||
JwksURL string
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
type BpjsConfig struct {
|
||||
BaseURL string `json:"base_url"`
|
||||
ConsID string `json:"cons_id"`
|
||||
UserKey string `json:"user_key"`
|
||||
SecretKey string `json:"secret_key"`
|
||||
Timeout time.Duration `json:"timeout"`
|
||||
}
|
||||
|
||||
type SatuSehatConfig struct {
|
||||
OrgID string `json:"org_id"`
|
||||
FasyakesID string `json:"fasyakes_id"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
AuthURL string `json:"auth_url"`
|
||||
BaseURL string `json:"base_url"`
|
||||
ConsentURL string `json:"consent_url"`
|
||||
KFAURL string `json:"kfa_url"`
|
||||
Timeout time.Duration `json:"timeout"`
|
||||
}
|
||||
|
||||
// SetHeader generates required headers for BPJS VClaim API
|
||||
// func (cfg BpjsConfig) SetHeader() (string, string, string, string, string) {
|
||||
// timenow := time.Now().UTC()
|
||||
// t, err := time.Parse(time.RFC3339, "1970-01-01T00:00:00Z")
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
|
||||
// tstamp := timenow.Unix() - t.Unix()
|
||||
// secret := []byte(cfg.SecretKey)
|
||||
// message := []byte(cfg.ConsID + "&" + fmt.Sprint(tstamp))
|
||||
// hash := hmac.New(sha256.New, secret)
|
||||
// hash.Write(message)
|
||||
|
||||
// // to lowercase hexits
|
||||
// hex.EncodeToString(hash.Sum(nil))
|
||||
// // to base64
|
||||
// xSignature := base64.StdEncoding.EncodeToString(hash.Sum(nil))
|
||||
|
||||
// return cfg.ConsID, cfg.SecretKey, cfg.UserKey, fmt.Sprint(tstamp), xSignature
|
||||
// }
|
||||
func (cfg BpjsConfig) SetHeader() (string, string, string, string, string) {
|
||||
timenow := time.Now().UTC()
|
||||
t, err := time.Parse(time.RFC3339, "1970-01-01T00:00:00Z")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tstamp := timenow.Unix() - t.Unix()
|
||||
secret := []byte(cfg.SecretKey)
|
||||
message := []byte(cfg.ConsID + "&" + fmt.Sprint(tstamp))
|
||||
hash := hmac.New(sha256.New, secret)
|
||||
hash.Write(message)
|
||||
|
||||
// to lowercase hexits
|
||||
hex.EncodeToString(hash.Sum(nil))
|
||||
// to base64
|
||||
xSignature := base64.StdEncoding.EncodeToString(hash.Sum(nil))
|
||||
|
||||
return cfg.ConsID, cfg.SecretKey, cfg.UserKey, fmt.Sprint(tstamp), xSignature
|
||||
}
|
||||
|
||||
type ConfigBpjs struct {
|
||||
Cons_id string
|
||||
Secret_key string
|
||||
User_key string
|
||||
}
|
||||
|
||||
// SetHeader for backward compatibility
|
||||
func (cfg ConfigBpjs) SetHeader() (string, string, string, string, string) {
|
||||
bpjsConfig := BpjsConfig{
|
||||
ConsID: cfg.Cons_id,
|
||||
SecretKey: cfg.Secret_key,
|
||||
UserKey: cfg.User_key,
|
||||
}
|
||||
return bpjsConfig.SetHeader()
|
||||
}
|
||||
|
||||
func LoadConfig() *Config {
|
||||
config := &Config{
|
||||
Server: ServerConfig{
|
||||
Port: getEnvAsInt("PORT", 8080),
|
||||
Mode: getEnv("GIN_MODE", "debug"),
|
||||
},
|
||||
Databases: make(map[string]DatabaseConfig),
|
||||
ReadReplicas: make(map[string][]DatabaseConfig),
|
||||
Keycloak: KeycloakConfig{
|
||||
Issuer: getEnv("KEYCLOAK_ISSUER", "https://keycloak.example.com/auth/realms/yourrealm"),
|
||||
Audience: getEnv("KEYCLOAK_AUDIENCE", "your-client-id"),
|
||||
JwksURL: getEnv("KEYCLOAK_JWKS_URL", "https://keycloak.example.com/auth/realms/yourrealm/protocol/openid-connect/certs"),
|
||||
Enabled: getEnvAsBool("KEYCLOAK_ENABLED", true),
|
||||
},
|
||||
Bpjs: BpjsConfig{
|
||||
BaseURL: getEnv("BPJS_BASEURL", "https://apijkn.bpjs-kesehatan.go.id"),
|
||||
ConsID: getEnv("BPJS_CONSID", ""),
|
||||
UserKey: getEnv("BPJS_USERKEY", ""),
|
||||
SecretKey: getEnv("BPJS_SECRETKEY", ""),
|
||||
Timeout: parseDuration(getEnv("BPJS_TIMEOUT", "30s")),
|
||||
},
|
||||
SatuSehat: SatuSehatConfig{
|
||||
OrgID: getEnv("BRIDGING_SATUSEHAT_ORG_ID", ""),
|
||||
FasyakesID: getEnv("BRIDGING_SATUSEHAT_FASYAKES_ID", ""),
|
||||
ClientID: getEnv("BRIDGING_SATUSEHAT_CLIENT_ID", ""),
|
||||
ClientSecret: getEnv("BRIDGING_SATUSEHAT_CLIENT_SECRET", ""),
|
||||
AuthURL: getEnv("BRIDGING_SATUSEHAT_AUTH_URL", "https://api-satusehat.kemkes.go.id/oauth2/v1"),
|
||||
BaseURL: getEnv("BRIDGING_SATUSEHAT_BASE_URL", "https://api-satusehat.kemkes.go.id/fhir-r4/v1"),
|
||||
ConsentURL: getEnv("BRIDGING_SATUSEHAT_CONSENT_URL", "https://api-satusehat.dto.kemkes.go.id/consent/v1"),
|
||||
KFAURL: getEnv("BRIDGING_SATUSEHAT_KFA_URL", "https://api-satusehat.kemkes.go.id/kfa-v2"),
|
||||
Timeout: parseDuration(getEnv("BRIDGING_SATUSEHAT_TIMEOUT", "30s")),
|
||||
},
|
||||
Swagger: SwaggerConfig{
|
||||
Title: getEnv("SWAGGER_TITLE", "SERVICE API"),
|
||||
Description: getEnv("SWAGGER_DESCRIPTION", "CUSTUM SERVICE API"),
|
||||
Version: getEnv("SWAGGER_VERSION", "1.0.0"),
|
||||
TermsOfService: getEnv("SWAGGER_TERMS_OF_SERVICE", "http://swagger.io/terms/"),
|
||||
ContactName: getEnv("SWAGGER_CONTACT_NAME", "API Support"),
|
||||
ContactURL: getEnv("SWAGGER_CONTACT_URL", "http://rssa.example.com/support"),
|
||||
ContactEmail: getEnv("SWAGGER_CONTACT_EMAIL", "support@swagger.io"),
|
||||
LicenseName: getEnv("SWAGGER_LICENSE_NAME", "Apache 2.0"),
|
||||
LicenseURL: getEnv("SWAGGER_LICENSE_URL", "http://www.apache.org/licenses/LICENSE-2.0.html"),
|
||||
Host: getEnv("SWAGGER_HOST", "localhost:8080"),
|
||||
BasePath: getEnv("SWAGGER_BASE_PATH", "/api/v1"),
|
||||
Schemes: parseSchemes(getEnv("SWAGGER_SCHEMES", "http,https")),
|
||||
},
|
||||
}
|
||||
|
||||
// Initialize validator
|
||||
config.Validator = validator.New()
|
||||
|
||||
// Load database configurations
|
||||
config.loadDatabaseConfigs()
|
||||
|
||||
// Load read replica configurations
|
||||
config.loadReadReplicaConfigs()
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func (c *Config) loadDatabaseConfigs() {
|
||||
// Simplified approach: Directly load from environment variables
|
||||
// This ensures we get the exact values specified in .env
|
||||
|
||||
// Primary database configuration
|
||||
c.Databases["default"] = DatabaseConfig{
|
||||
Name: "default",
|
||||
Type: getEnv("DB_CONNECTION", "postgres"),
|
||||
Host: getEnv("DB_HOST", "localhost"),
|
||||
Port: getEnvAsInt("DB_PORT", 5432),
|
||||
Username: getEnv("DB_USERNAME", ""),
|
||||
Password: getEnv("DB_PASSWORD", ""),
|
||||
Database: getEnv("DB_DATABASE", "satu_db"),
|
||||
Schema: getEnv("DB_SCHEMA", "public"),
|
||||
SSLMode: getEnv("DB_SSLMODE", "disable"),
|
||||
MaxOpenConns: getEnvAsInt("DB_MAX_OPEN_CONNS", 25),
|
||||
MaxIdleConns: getEnvAsInt("DB_MAX_IDLE_CONNS", 25),
|
||||
ConnMaxLifetime: parseDuration(getEnv("DB_CONN_MAX_LIFETIME", "5m")),
|
||||
}
|
||||
|
||||
// SATUDATA database configuration
|
||||
c.addPostgreSQLConfigs()
|
||||
|
||||
// MongoDB database configuration
|
||||
c.addMongoDBConfigs()
|
||||
|
||||
// Legacy support for backward compatibility
|
||||
envVars := os.Environ()
|
||||
dbConfigs := make(map[string]map[string]string)
|
||||
|
||||
// Parse database configurations from environment variables
|
||||
for _, envVar := range envVars {
|
||||
parts := strings.SplitN(envVar, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := parts[0]
|
||||
value := parts[1]
|
||||
|
||||
// Parse specific database configurations
|
||||
if strings.HasSuffix(key, "_CONNECTION") || strings.HasSuffix(key, "_HOST") ||
|
||||
strings.HasSuffix(key, "_DATABASE") || strings.HasSuffix(key, "_USERNAME") ||
|
||||
strings.HasSuffix(key, "_PASSWORD") || strings.HasSuffix(key, "_PORT") ||
|
||||
strings.HasSuffix(key, "_NAME") {
|
||||
|
||||
segments := strings.Split(key, "_")
|
||||
if len(segments) >= 2 {
|
||||
dbName := strings.ToLower(strings.Join(segments[:len(segments)-1], "_"))
|
||||
property := strings.ToLower(segments[len(segments)-1])
|
||||
|
||||
if dbConfigs[dbName] == nil {
|
||||
dbConfigs[dbName] = make(map[string]string)
|
||||
}
|
||||
dbConfigs[dbName][property] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create DatabaseConfig from parsed configurations for additional databases
|
||||
for name, config := range dbConfigs {
|
||||
// Skip empty configurations or system configurations
|
||||
if name == "" || strings.Contains(name, "chrome_crashpad_pipe") || name == "primary" {
|
||||
continue
|
||||
}
|
||||
|
||||
dbConfig := DatabaseConfig{
|
||||
Name: name,
|
||||
Type: getEnvFromMap(config, "connection", getEnvFromMap(config, "type", "postgres")),
|
||||
Host: getEnvFromMap(config, "host", "localhost"),
|
||||
Port: getEnvAsIntFromMap(config, "port", 5432),
|
||||
Username: getEnvFromMap(config, "username", ""),
|
||||
Password: getEnvFromMap(config, "password", ""),
|
||||
Database: getEnvFromMap(config, "database", getEnvFromMap(config, "name", name)),
|
||||
Schema: getEnvFromMap(config, "schema", "public"),
|
||||
SSLMode: getEnvFromMap(config, "sslmode", "disable"),
|
||||
Path: getEnvFromMap(config, "path", ""),
|
||||
Options: getEnvFromMap(config, "options", ""),
|
||||
MaxOpenConns: getEnvAsIntFromMap(config, "max_open_conns", 25),
|
||||
MaxIdleConns: getEnvAsIntFromMap(config, "max_idle_conns", 25),
|
||||
ConnMaxLifetime: parseDuration(getEnvFromMap(config, "conn_max_lifetime", "5m")),
|
||||
}
|
||||
|
||||
// Skip if username is empty and it's not a system config
|
||||
if dbConfig.Username == "" && !strings.HasPrefix(name, "chrome") {
|
||||
continue
|
||||
}
|
||||
|
||||
c.Databases[name] = dbConfig
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) loadReadReplicaConfigs() {
|
||||
envVars := os.Environ()
|
||||
|
||||
for _, envVar := range envVars {
|
||||
parts := strings.SplitN(envVar, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := parts[0]
|
||||
value := parts[1]
|
||||
|
||||
// Parse read replica configurations (format: [DBNAME]_REPLICA_[INDEX]_[PROPERTY])
|
||||
if strings.Contains(key, "_REPLICA_") {
|
||||
segments := strings.Split(key, "_")
|
||||
if len(segments) >= 5 && strings.ToUpper(segments[2]) == "REPLICA" {
|
||||
dbName := strings.ToLower(segments[1])
|
||||
replicaIndex := segments[3]
|
||||
property := strings.ToLower(strings.Join(segments[4:], "_"))
|
||||
|
||||
replicaKey := dbName + "_replica_" + replicaIndex
|
||||
|
||||
if c.ReadReplicas[dbName] == nil {
|
||||
c.ReadReplicas[dbName] = []DatabaseConfig{}
|
||||
}
|
||||
|
||||
// Find or create replica config
|
||||
var replicaConfig *DatabaseConfig
|
||||
for i := range c.ReadReplicas[dbName] {
|
||||
if c.ReadReplicas[dbName][i].Name == replicaKey {
|
||||
replicaConfig = &c.ReadReplicas[dbName][i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if replicaConfig == nil {
|
||||
// Create new replica config
|
||||
newConfig := DatabaseConfig{
|
||||
Name: replicaKey,
|
||||
Type: c.Databases[dbName].Type,
|
||||
Host: getEnv("DB_"+strings.ToUpper(dbName)+"_REPLICA_"+replicaIndex+"_HOST", c.Databases[dbName].Host),
|
||||
Port: getEnvAsInt("DB_"+strings.ToUpper(dbName)+"_REPLICA_"+replicaIndex+"_PORT", c.Databases[dbName].Port),
|
||||
Username: getEnv("DB_"+strings.ToUpper(dbName)+"_REPLICA_"+replicaIndex+"_USERNAME", c.Databases[dbName].Username),
|
||||
Password: getEnv("DB_"+strings.ToUpper(dbName)+"_REPLICA_"+replicaIndex+"_PASSWORD", c.Databases[dbName].Password),
|
||||
Database: getEnv("DB_"+strings.ToUpper(dbName)+"_REPLICA_"+replicaIndex+"_DATABASE", c.Databases[dbName].Database),
|
||||
Schema: getEnv("DB_"+strings.ToUpper(dbName)+"_REPLICA_"+replicaIndex+"_SCHEMA", c.Databases[dbName].Schema),
|
||||
SSLMode: getEnv("DB_"+strings.ToUpper(dbName)+"_REPLICA_"+replicaIndex+"_SSLMODE", c.Databases[dbName].SSLMode),
|
||||
MaxOpenConns: getEnvAsInt("DB_"+strings.ToUpper(dbName)+"_REPLICA_"+replicaIndex+"_MAX_OPEN_CONNS", c.Databases[dbName].MaxOpenConns),
|
||||
MaxIdleConns: getEnvAsInt("DB_"+strings.ToUpper(dbName)+"_REPLICA_"+replicaIndex+"_MAX_IDLE_CONNS", c.Databases[dbName].MaxIdleConns),
|
||||
ConnMaxLifetime: parseDuration(getEnv("DB_"+strings.ToUpper(dbName)+"_REPLICA_"+replicaIndex+"_CONN_MAX_LIFETIME", "5m")),
|
||||
}
|
||||
c.ReadReplicas[dbName] = append(c.ReadReplicas[dbName], newConfig)
|
||||
replicaConfig = &c.ReadReplicas[dbName][len(c.ReadReplicas[dbName])-1]
|
||||
}
|
||||
|
||||
// Update the specific replica
|
||||
switch property {
|
||||
case "host":
|
||||
replicaConfig.Host = value
|
||||
case "port":
|
||||
replicaConfig.Port = getEnvAsInt(key, 5432)
|
||||
case "username":
|
||||
replicaConfig.Username = value
|
||||
case "password":
|
||||
replicaConfig.Password = value
|
||||
case "database":
|
||||
replicaConfig.Database = value
|
||||
case "schema":
|
||||
replicaConfig.Schema = value
|
||||
case "sslmode":
|
||||
replicaConfig.SSLMode = value
|
||||
case "max_open_conns":
|
||||
replicaConfig.MaxOpenConns = getEnvAsInt(key, 25)
|
||||
case "max_idle_conns":
|
||||
replicaConfig.MaxIdleConns = getEnvAsInt(key, 25)
|
||||
case "conn_max_lifetime":
|
||||
replicaConfig.ConnMaxLifetime = parseDuration(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) addSpecificDatabase(prefix, defaultType string) {
|
||||
connection := getEnv(strings.ToUpper(prefix)+"_CONNECTION", defaultType)
|
||||
host := getEnv(strings.ToUpper(prefix)+"_HOST", "")
|
||||
if host != "" {
|
||||
dbConfig := DatabaseConfig{
|
||||
Name: prefix,
|
||||
Type: connection,
|
||||
Host: host,
|
||||
Port: getEnvAsInt(strings.ToUpper(prefix)+"_PORT", 5432),
|
||||
Username: getEnv(strings.ToUpper(prefix)+"_USERNAME", ""),
|
||||
Password: getEnv(strings.ToUpper(prefix)+"_PASSWORD", ""),
|
||||
Database: getEnv(strings.ToUpper(prefix)+"_DATABASE", getEnv(strings.ToUpper(prefix)+"_NAME", prefix)),
|
||||
Schema: getEnv(strings.ToUpper(prefix)+"_SCHEMA", "public"),
|
||||
SSLMode: getEnv(strings.ToUpper(prefix)+"_SSLMODE", "disable"),
|
||||
MaxOpenConns: getEnvAsInt(strings.ToUpper(prefix)+"_MAX_OPEN_CONNS", 25),
|
||||
MaxIdleConns: getEnvAsInt(strings.ToUpper(prefix)+"_MAX_IDLE_CONNS", 25),
|
||||
ConnMaxLifetime: parseDuration(getEnv(strings.ToUpper(prefix)+"_CONN_MAX_LIFETIME", "5m")),
|
||||
}
|
||||
c.Databases[prefix] = dbConfig
|
||||
}
|
||||
}
|
||||
|
||||
// PostgreSQL database
|
||||
func (c *Config) addPostgreSQLConfigs() {
|
||||
// SATUDATA database configuration
|
||||
// defaultPOSTGRESHost := getEnv("POSTGRES_HOST", "localhost")
|
||||
// if defaultPOSTGRESHost != "" {
|
||||
// c.Databases["postgres"] = DatabaseConfig{
|
||||
// Name: "postgres",
|
||||
// Type: getEnv("POSTGRES_CONNECTION", "postgres"),
|
||||
// Host: defaultPOSTGRESHost,
|
||||
// Port: getEnvAsInt("POSTGRES_PORT", 5432),
|
||||
// Username: getEnv("POSTGRES_USERNAME", ""),
|
||||
// Password: getEnv("POSTGRES_PASSWORD", ""),
|
||||
// Database: getEnv("POSTGRES_DATABASE", "postgres"),
|
||||
// Schema: getEnv("POSTGRES_SCHEMA", "public"),
|
||||
// SSLMode: getEnv("POSTGRES_SSLMODE", "disable"),
|
||||
// MaxOpenConns: getEnvAsInt("POSTGRES_MAX_OPEN_CONNS", 25),
|
||||
// MaxIdleConns: getEnvAsInt("POSTGRES_MAX_IDLE_CONNS", 25),
|
||||
// ConnMaxLifetime: parseDuration(getEnv("POSTGRES_CONN_MAX_LIFETIME", "5m")),
|
||||
// }
|
||||
// }
|
||||
|
||||
// Support for custom PostgreSQL configurations with POSTGRES_ prefix
|
||||
envVars := os.Environ()
|
||||
for _, envVar := range envVars {
|
||||
parts := strings.SplitN(envVar, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := parts[0]
|
||||
// Parse PostgreSQL configurations (format: POSTGRES_[NAME]_[PROPERTY])
|
||||
if strings.HasPrefix(key, "POSTGRES_") && strings.Contains(key, "_") {
|
||||
segments := strings.Split(key, "_")
|
||||
if len(segments) >= 3 {
|
||||
dbName := strings.ToLower(strings.Join(segments[1:len(segments)-1], "_"))
|
||||
|
||||
// Skip if it's a standard PostgreSQL configuration
|
||||
if dbName == "connection" || dbName == "dev" || dbName == "default" || dbName == "satudata" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create or update PostgreSQL configuration
|
||||
if _, exists := c.Databases[dbName]; !exists {
|
||||
c.Databases[dbName] = DatabaseConfig{
|
||||
Name: dbName,
|
||||
Type: "postgres",
|
||||
Host: getEnv("POSTGRES_"+strings.ToUpper(dbName)+"_HOST", "localhost"),
|
||||
Port: getEnvAsInt("POSTGRES_"+strings.ToUpper(dbName)+"_PORT", 5432),
|
||||
Username: getEnv("POSTGRES_"+strings.ToUpper(dbName)+"_USERNAME", ""),
|
||||
Password: getEnv("POSTGRES_"+strings.ToUpper(dbName)+"_PASSWORD", ""),
|
||||
Database: getEnv("POSTGRES_"+strings.ToUpper(dbName)+"_DATABASE", dbName),
|
||||
Schema: getEnv("POSTGRES_"+strings.ToUpper(dbName)+"_SCHEMA", "public"),
|
||||
SSLMode: getEnv("POSTGRES_"+strings.ToUpper(dbName)+"_SSLMODE", "disable"),
|
||||
MaxOpenConns: getEnvAsInt("POSTGRES_MAX_OPEN_CONNS", 25),
|
||||
MaxIdleConns: getEnvAsInt("POSTGRES_MAX_IDLE_CONNS", 25),
|
||||
ConnMaxLifetime: parseDuration(getEnv("POSTGRES_CONN_MAX_LIFETIME", "5m")),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// addMYSQLConfigs adds MYSQL database
|
||||
func (c *Config) addMySQLConfigs() {
|
||||
// Primary MySQL configuration
|
||||
defaultMySQLHost := getEnv("MYSQL_HOST", "")
|
||||
if defaultMySQLHost != "" {
|
||||
c.Databases["mysql"] = DatabaseConfig{
|
||||
Name: "mysql",
|
||||
Type: getEnv("MYSQL_CONNECTION", "mysql"),
|
||||
Host: defaultMySQLHost,
|
||||
Port: getEnvAsInt("MYSQL_PORT", 3306),
|
||||
Username: getEnv("MYSQL_USERNAME", ""),
|
||||
Password: getEnv("MYSQL_PASSWORD", ""),
|
||||
Database: getEnv("MYSQL_DATABASE", "mysql"),
|
||||
SSLMode: getEnv("MYSQL_SSLMODE", "disable"),
|
||||
MaxOpenConns: getEnvAsInt("MYSQL_MAX_OPEN_CONNS", 25),
|
||||
MaxIdleConns: getEnvAsInt("MYSQL_MAX_IDLE_CONNS", 25),
|
||||
ConnMaxLifetime: parseDuration(getEnv("MYSQL_CONN_MAX_LIFETIME", "5m")),
|
||||
}
|
||||
}
|
||||
|
||||
// Support for custom MySQL configurations with MYSQL_ prefix
|
||||
envVars := os.Environ()
|
||||
for _, envVar := range envVars {
|
||||
parts := strings.SplitN(envVar, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := parts[0]
|
||||
// Parse MySQL configurations (format: MYSQL_[NAME]_[PROPERTY])
|
||||
if strings.HasPrefix(key, "MYSQL_") && strings.Contains(key, "_") {
|
||||
segments := strings.Split(key, "_")
|
||||
if len(segments) >= 3 {
|
||||
dbName := strings.ToLower(strings.Join(segments[1:len(segments)-1], "_"))
|
||||
|
||||
// Skip if it's a standard MySQL configuration
|
||||
if dbName == "connection" || dbName == "dev" || dbName == "max" || dbName == "conn" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create or update MySQL configuration
|
||||
if _, exists := c.Databases[dbName]; !exists {
|
||||
mysqlHost := getEnv("MYSQL_"+strings.ToUpper(dbName)+"_HOST", "")
|
||||
if mysqlHost != "" {
|
||||
c.Databases[dbName] = DatabaseConfig{
|
||||
Name: dbName,
|
||||
Type: getEnv("MYSQL_"+strings.ToUpper(dbName)+"_CONNECTION", "mysql"),
|
||||
Host: mysqlHost,
|
||||
Port: getEnvAsInt("MYSQL_"+strings.ToUpper(dbName)+"_PORT", 3306),
|
||||
Username: getEnv("MYSQL_"+strings.ToUpper(dbName)+"_USERNAME", ""),
|
||||
Password: getEnv("MYSQL_"+strings.ToUpper(dbName)+"_PASSWORD", ""),
|
||||
Database: getEnv("MYSQL_"+strings.ToUpper(dbName)+"_DATABASE", dbName),
|
||||
SSLMode: getEnv("MYSQL_"+strings.ToUpper(dbName)+"_SSLMODE", "disable"),
|
||||
MaxOpenConns: getEnvAsInt("MYSQL_MAX_OPEN_CONNS", 25),
|
||||
MaxIdleConns: getEnvAsInt("MYSQL_MAX_IDLE_CONNS", 25),
|
||||
ConnMaxLifetime: parseDuration(getEnv("MYSQL_CONN_MAX_LIFETIME", "5m")),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// addMongoDBConfigs adds MongoDB database configurations from environment variables
|
||||
func (c *Config) addMongoDBConfigs() {
|
||||
// Primary MongoDB configuration
|
||||
mongoHost := getEnv("MONGODB_HOST", "")
|
||||
if mongoHost != "" {
|
||||
c.Databases["mongodb"] = DatabaseConfig{
|
||||
Name: "mongodb",
|
||||
Type: getEnv("MONGODB_CONNECTION", "mongodb"),
|
||||
Host: mongoHost,
|
||||
Port: getEnvAsInt("MONGODB_PORT", 27017),
|
||||
Username: getEnv("MONGODB_USER", ""),
|
||||
Password: getEnv("MONGODB_PASS", ""),
|
||||
Database: getEnv("MONGODB_MASTER", "master"),
|
||||
SSLMode: getEnv("MONGODB_SSLMODE", "disable"),
|
||||
MaxOpenConns: getEnvAsInt("MONGODB_MAX_OPEN_CONNS", 100),
|
||||
MaxIdleConns: getEnvAsInt("MONGODB_MAX_IDLE_CONNS", 10),
|
||||
ConnMaxLifetime: parseDuration(getEnv("MONGODB_CONN_MAX_LIFETIME", "30m")),
|
||||
}
|
||||
}
|
||||
|
||||
// Additional MongoDB configurations for local database
|
||||
mongoLocalHost := getEnv("MONGODB_LOCAL_HOST", "")
|
||||
if mongoLocalHost != "" {
|
||||
c.Databases["mongodb_local"] = DatabaseConfig{
|
||||
Name: "mongodb_local",
|
||||
Type: getEnv("MONGODB_CONNECTION", "mongodb"),
|
||||
Host: mongoLocalHost,
|
||||
Port: getEnvAsInt("MONGODB_LOCAL_PORT", 27017),
|
||||
Username: getEnv("MONGODB_LOCAL_USER", ""),
|
||||
Password: getEnv("MONGODB_LOCAL_PASS", ""),
|
||||
Database: getEnv("MONGODB_LOCAL_DB", "local"),
|
||||
SSLMode: getEnv("MONGOD_SSLMODE", "disable"),
|
||||
MaxOpenConns: getEnvAsInt("MONGODB_MAX_OPEN_CONNS", 100),
|
||||
MaxIdleConns: getEnvAsInt("MONGODB_MAX_IDLE_CONNS", 10),
|
||||
ConnMaxLifetime: parseDuration(getEnv("MONGODB_CONN_MAX_LIFETIME", "30m")),
|
||||
}
|
||||
}
|
||||
|
||||
// Support for custom MongoDB configurations with MONGODB_ prefix
|
||||
envVars := os.Environ()
|
||||
for _, envVar := range envVars {
|
||||
parts := strings.SplitN(envVar, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := parts[0]
|
||||
// Parse MongoDB configurations (format: MONGODB_[NAME]_[PROPERTY])
|
||||
if strings.HasPrefix(key, "MONGODB_") && strings.Contains(key, "_") {
|
||||
segments := strings.Split(key, "_")
|
||||
if len(segments) >= 3 {
|
||||
dbName := strings.ToLower(strings.Join(segments[1:len(segments)-1], "_"))
|
||||
// Skip if it's a standard MongoDB configuration
|
||||
if dbName == "connection" || dbName == "dev" || dbName == "local" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create or update MongoDB configuration
|
||||
if _, exists := c.Databases[dbName]; !exists {
|
||||
c.Databases[dbName] = DatabaseConfig{
|
||||
Name: dbName,
|
||||
Type: "mongodb",
|
||||
Host: getEnv("MONGODB_"+strings.ToUpper(dbName)+"_HOST", "localhost"),
|
||||
Port: getEnvAsInt("MONGODB_"+strings.ToUpper(dbName)+"_PORT", 27017),
|
||||
Username: getEnv("MONGODB_"+strings.ToUpper(dbName)+"_USER", ""),
|
||||
Password: getEnv("MONGODB_"+strings.ToUpper(dbName)+"_PASS", ""),
|
||||
Database: getEnv("MONGODB_"+strings.ToUpper(dbName)+"_DB", dbName),
|
||||
SSLMode: getEnv("MONGOD_SSLMODE", "disable"),
|
||||
MaxOpenConns: getEnvAsInt("MONGODB_MAX_OPEN_CONNS", 100),
|
||||
MaxIdleConns: getEnvAsInt("MONGODB_MAX_IDLE_CONNS", 10),
|
||||
ConnMaxLifetime: parseDuration(getEnv("MONGODB_CONN_MAX_LIFETIME", "30m")),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getEnvFromMap(config map[string]string, key, defaultValue string) string {
|
||||
if value, exists := config[key]; exists {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvAsIntFromMap(config map[string]string, key string, defaultValue int) int {
|
||||
if value, exists := config[key]; exists {
|
||||
if intValue, err := strconv.Atoi(value); err == nil {
|
||||
return intValue
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func parseDuration(durationStr string) time.Duration {
|
||||
if duration, err := time.ParseDuration(durationStr); err == nil {
|
||||
return duration
|
||||
}
|
||||
return 5 * time.Minute
|
||||
}
|
||||
|
||||
func getEnv(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvAsInt(key string, defaultValue int) int {
|
||||
valueStr := getEnv(key, "")
|
||||
if value, err := strconv.Atoi(valueStr); err == nil {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvAsBool(key string, defaultValue bool) bool {
|
||||
valueStr := getEnv(key, "")
|
||||
if value, err := strconv.ParseBool(valueStr); err == nil {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// parseSchemes parses comma-separated schemes string into a slice
|
||||
func parseSchemes(schemesStr string) []string {
|
||||
if schemesStr == "" {
|
||||
return []string{"http"}
|
||||
}
|
||||
|
||||
schemes := strings.Split(schemesStr, ",")
|
||||
for i, scheme := range schemes {
|
||||
schemes[i] = strings.TrimSpace(scheme)
|
||||
}
|
||||
return schemes
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
if len(c.Databases) == 0 {
|
||||
log.Fatal("At least one database configuration is required")
|
||||
}
|
||||
|
||||
for name, db := range c.Databases {
|
||||
if db.Host == "" {
|
||||
log.Fatalf("Database host is required for %s", name)
|
||||
}
|
||||
if db.Username == "" {
|
||||
log.Fatalf("Database username is required for %s", name)
|
||||
}
|
||||
if db.Password == "" {
|
||||
log.Fatalf("Database password is required for %s", name)
|
||||
}
|
||||
if db.Database == "" {
|
||||
log.Fatalf("Database name is required for %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
if c.Bpjs.BaseURL == "" {
|
||||
log.Fatal("BPJS Base URL is required")
|
||||
}
|
||||
if c.Bpjs.ConsID == "" {
|
||||
log.Fatal("BPJS Consumer ID is required")
|
||||
}
|
||||
if c.Bpjs.UserKey == "" {
|
||||
log.Fatal("BPJS User Key is required")
|
||||
}
|
||||
if c.Bpjs.SecretKey == "" {
|
||||
log.Fatal("BPJS Secret Key is required")
|
||||
}
|
||||
|
||||
// Validate Keycloak configuration if enabled
|
||||
if c.Keycloak.Enabled {
|
||||
if c.Keycloak.Issuer == "" {
|
||||
log.Fatal("Keycloak issuer is required when Keycloak is enabled")
|
||||
}
|
||||
if c.Keycloak.Audience == "" {
|
||||
log.Fatal("Keycloak audience is required when Keycloak is enabled")
|
||||
}
|
||||
if c.Keycloak.JwksURL == "" {
|
||||
log.Fatal("Keycloak JWKS URL is required when Keycloak is enabled")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate SatuSehat configuration
|
||||
if c.SatuSehat.OrgID == "" {
|
||||
log.Fatal("SatuSehat Organization ID is required")
|
||||
}
|
||||
if c.SatuSehat.FasyakesID == "" {
|
||||
log.Fatal("SatuSehat Fasyankes ID is required")
|
||||
}
|
||||
if c.SatuSehat.ClientID == "" {
|
||||
log.Fatal("SatuSehat Client ID is required")
|
||||
}
|
||||
if c.SatuSehat.ClientSecret == "" {
|
||||
log.Fatal("SatuSehat Client Secret is required")
|
||||
}
|
||||
if c.SatuSehat.AuthURL == "" {
|
||||
log.Fatal("SatuSehat Auth URL is required")
|
||||
}
|
||||
if c.SatuSehat.BaseURL == "" {
|
||||
log.Fatal("SatuSehat Base URL is required")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
699
internal/database/database.go
Normal file
699
internal/database/database.go
Normal file
@@ -0,0 +1,699 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log" // Import runtime package
|
||||
|
||||
// Import debug package
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"api-service/internal/config"
|
||||
|
||||
_ "github.com/jackc/pgx/v5" // Import pgx driver
|
||||
"github.com/lib/pq"
|
||||
_ "gorm.io/driver/postgres" // Import GORM PostgreSQL driver
|
||||
|
||||
_ "github.com/go-sql-driver/mysql" // MySQL driver for database/sql
|
||||
_ "gorm.io/driver/mysql" // GORM MySQL driver
|
||||
_ "gorm.io/driver/sqlserver" // GORM SQL Server driver
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
// DatabaseType represents supported database types
|
||||
type DatabaseType string
|
||||
|
||||
const (
|
||||
Postgres DatabaseType = "postgres"
|
||||
MySQL DatabaseType = "mysql"
|
||||
SQLServer DatabaseType = "sqlserver"
|
||||
SQLite DatabaseType = "sqlite"
|
||||
MongoDB DatabaseType = "mongodb"
|
||||
)
|
||||
|
||||
// Service represents a service that interacts with multiple databases
|
||||
type Service interface {
|
||||
Health() map[string]map[string]string
|
||||
GetDB(name string) (*sql.DB, error)
|
||||
GetMongoClient(name string) (*mongo.Client, error)
|
||||
GetReadDB(name string) (*sql.DB, error) // For read replicas
|
||||
Close() error
|
||||
ListDBs() []string
|
||||
GetDBType(name string) (DatabaseType, error)
|
||||
// Tambahkan method untuk WebSocket notifications
|
||||
ListenForChanges(ctx context.Context, dbName string, channels []string, callback func(string, string)) error
|
||||
NotifyChange(dbName, channel, payload string) error
|
||||
GetPrimaryDB(name string) (*sql.DB, error) // Helper untuk get primary DB
|
||||
}
|
||||
|
||||
type service struct {
|
||||
sqlDatabases map[string]*sql.DB
|
||||
mongoClients map[string]*mongo.Client
|
||||
readReplicas map[string][]*sql.DB // Read replicas for load balancing
|
||||
configs map[string]config.DatabaseConfig
|
||||
readConfigs map[string][]config.DatabaseConfig
|
||||
mu sync.RWMutex
|
||||
readBalancer map[string]int // Round-robin counter for read replicas
|
||||
listeners map[string]*pq.Listener // Tambahkan untuk tracking listeners
|
||||
listenersMu sync.RWMutex
|
||||
}
|
||||
|
||||
var (
|
||||
dbManager *service
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
// New creates a new database service with multiple connections
|
||||
func New(cfg *config.Config) Service {
|
||||
once.Do(func() {
|
||||
dbManager = &service{
|
||||
sqlDatabases: make(map[string]*sql.DB),
|
||||
mongoClients: make(map[string]*mongo.Client),
|
||||
readReplicas: make(map[string][]*sql.DB),
|
||||
configs: make(map[string]config.DatabaseConfig),
|
||||
readConfigs: make(map[string][]config.DatabaseConfig),
|
||||
readBalancer: make(map[string]int),
|
||||
listeners: make(map[string]*pq.Listener),
|
||||
}
|
||||
|
||||
log.Println("Initializing database service...") // Log when the initialization starts
|
||||
// log.Printf("Current Goroutine ID: %d", runtime.NumGoroutine()) // Log the number of goroutines
|
||||
// log.Printf("Stack Trace: %s", debug.Stack()) // Log the stack trace
|
||||
dbManager.loadFromConfig(cfg)
|
||||
|
||||
// Initialize all databases
|
||||
for name, dbConfig := range dbManager.configs {
|
||||
if err := dbManager.addDatabase(name, dbConfig); err != nil {
|
||||
log.Printf("Failed to connect to database %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize read replicas
|
||||
for name, replicaConfigs := range dbManager.readConfigs {
|
||||
for i, replicaConfig := range replicaConfigs {
|
||||
if err := dbManager.addReadReplica(name, i, replicaConfig); err != nil {
|
||||
log.Printf("Failed to connect to read replica %s[%d]: %v", name, i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return dbManager
|
||||
}
|
||||
|
||||
func (s *service) loadFromConfig(cfg *config.Config) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Load primary databases
|
||||
for name, dbConfig := range cfg.Databases {
|
||||
s.configs[name] = dbConfig
|
||||
}
|
||||
|
||||
// Load read replicas
|
||||
for name, replicaConfigs := range cfg.ReadReplicas {
|
||||
s.readConfigs[name] = replicaConfigs
|
||||
}
|
||||
}
|
||||
|
||||
func (s *service) addDatabase(name string, config config.DatabaseConfig) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
log.Printf("=== Database Connection Debug ===")
|
||||
// log.Printf("Database: %s", name)
|
||||
// log.Printf("Type: %s", config.Type)
|
||||
// log.Printf("Host: %s", config.Host)
|
||||
// log.Printf("Port: %d", config.Port)
|
||||
// log.Printf("Database: %s", config.Database)
|
||||
// log.Printf("Username: %s", config.Username)
|
||||
// log.Printf("SSLMode: %s", config.SSLMode)
|
||||
|
||||
var db *sql.DB
|
||||
var err error
|
||||
|
||||
dbType := DatabaseType(config.Type)
|
||||
|
||||
switch dbType {
|
||||
case Postgres:
|
||||
db, err = s.openPostgresConnection(config)
|
||||
case MySQL:
|
||||
db, err = s.openMySQLConnection(config)
|
||||
case SQLServer:
|
||||
db, err = s.openSQLServerConnection(config)
|
||||
case SQLite:
|
||||
db, err = s.openSQLiteConnection(config)
|
||||
case MongoDB:
|
||||
return s.addMongoDB(name, config)
|
||||
default:
|
||||
return fmt.Errorf("unsupported database type: %s", config.Type)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Printf("❌ Error connecting to database %s: %v", name, err)
|
||||
log.Printf(" Database: %s@%s:%d/%s", config.Username, config.Host, config.Port, config.Database)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("✅ Successfully connected to database: %s", name)
|
||||
return s.configureSQLDB(name, db, config.MaxOpenConns, config.MaxIdleConns, config.ConnMaxLifetime)
|
||||
}
|
||||
|
||||
func (s *service) addReadReplica(name string, index int, config config.DatabaseConfig) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var db *sql.DB
|
||||
var err error
|
||||
|
||||
dbType := DatabaseType(config.Type)
|
||||
|
||||
switch dbType {
|
||||
case Postgres:
|
||||
db, err = s.openPostgresConnection(config)
|
||||
case MySQL:
|
||||
db, err = s.openMySQLConnection(config)
|
||||
case SQLServer:
|
||||
db, err = s.openSQLServerConnection(config)
|
||||
case SQLite:
|
||||
db, err = s.openSQLiteConnection(config)
|
||||
default:
|
||||
return fmt.Errorf("unsupported database type for read replica: %s", config.Type)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.readReplicas[name] == nil {
|
||||
s.readReplicas[name] = make([]*sql.DB, 0)
|
||||
}
|
||||
|
||||
// Ensure we have enough slots
|
||||
for len(s.readReplicas[name]) <= index {
|
||||
s.readReplicas[name] = append(s.readReplicas[name], nil)
|
||||
}
|
||||
|
||||
s.readReplicas[name][index] = db
|
||||
log.Printf("Successfully connected to read replica %s[%d]", name, index)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) openPostgresConnection(config config.DatabaseConfig) (*sql.DB, error) {
|
||||
connStr := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=%s",
|
||||
config.Username,
|
||||
config.Password,
|
||||
config.Host,
|
||||
config.Port,
|
||||
config.Database,
|
||||
config.SSLMode,
|
||||
)
|
||||
|
||||
if config.Schema != "" {
|
||||
connStr += "&search_path=" + config.Schema
|
||||
}
|
||||
|
||||
db, err := sql.Open("pgx", connStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open PostgreSQL connection: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func (s *service) openMySQLConnection(config config.DatabaseConfig) (*sql.DB, error) {
|
||||
connStr := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
config.Username,
|
||||
config.Password,
|
||||
config.Host,
|
||||
config.Port,
|
||||
config.Database,
|
||||
)
|
||||
|
||||
db, err := sql.Open("mysql", connStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open MySQL connection: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func (s *service) openSQLServerConnection(config config.DatabaseConfig) (*sql.DB, error) {
|
||||
connStr := fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s",
|
||||
config.Username,
|
||||
config.Password,
|
||||
config.Host,
|
||||
config.Port,
|
||||
config.Database,
|
||||
)
|
||||
|
||||
db, err := sql.Open("sqlserver", connStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open SQL Server connection: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func (s *service) openSQLiteConnection(config config.DatabaseConfig) (*sql.DB, error) {
|
||||
dbPath := config.Path
|
||||
if dbPath == "" {
|
||||
dbPath = fmt.Sprintf("./data/%s.db", config.Database)
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open SQLite connection: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func (s *service) addMongoDB(name string, config config.DatabaseConfig) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
uri := fmt.Sprintf("mongodb://%s:%s@%s:%d/%s",
|
||||
config.Username,
|
||||
config.Password,
|
||||
config.Host,
|
||||
config.Port,
|
||||
config.Database,
|
||||
)
|
||||
|
||||
client, err := mongo.Connect(ctx, options.Client().ApplyURI(uri))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to MongoDB: %w", err)
|
||||
}
|
||||
|
||||
s.mongoClients[name] = client
|
||||
log.Printf("Successfully connected to MongoDB: %s", name)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) configureSQLDB(name string, db *sql.DB, maxOpenConns, maxIdleConns int, connMaxLifetime time.Duration) error {
|
||||
db.SetMaxOpenConns(maxOpenConns)
|
||||
db.SetMaxIdleConns(maxIdleConns)
|
||||
db.SetConnMaxLifetime(connMaxLifetime)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
db.Close()
|
||||
return fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
s.sqlDatabases[name] = db
|
||||
log.Printf("Successfully connected to SQL database: %s", name)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Health checks the health of all database connections by pinging each database.
|
||||
func (s *service) Health() map[string]map[string]string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
result := make(map[string]map[string]string)
|
||||
|
||||
// Check SQL databases
|
||||
for name, db := range s.sqlDatabases {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
stats := make(map[string]string)
|
||||
|
||||
err := db.PingContext(ctx)
|
||||
if err != nil {
|
||||
stats["status"] = "down"
|
||||
stats["error"] = fmt.Sprintf("db down: %v", err)
|
||||
stats["type"] = "sql"
|
||||
stats["role"] = "primary"
|
||||
result[name] = stats
|
||||
continue
|
||||
}
|
||||
|
||||
stats["status"] = "up"
|
||||
stats["message"] = "It's healthy"
|
||||
stats["type"] = "sql"
|
||||
stats["role"] = "primary"
|
||||
|
||||
dbStats := db.Stats()
|
||||
stats["open_connections"] = strconv.Itoa(dbStats.OpenConnections)
|
||||
stats["in_use"] = strconv.Itoa(dbStats.InUse)
|
||||
stats["idle"] = strconv.Itoa(dbStats.Idle)
|
||||
stats["wait_count"] = strconv.FormatInt(dbStats.WaitCount, 10)
|
||||
stats["wait_duration"] = dbStats.WaitDuration.String()
|
||||
stats["max_idle_closed"] = strconv.FormatInt(dbStats.MaxIdleClosed, 10)
|
||||
stats["max_lifetime_closed"] = strconv.FormatInt(dbStats.MaxLifetimeClosed, 10)
|
||||
|
||||
if dbStats.OpenConnections > 40 {
|
||||
stats["message"] = "The database is experiencing heavy load."
|
||||
}
|
||||
|
||||
if dbStats.WaitCount > 1000 {
|
||||
stats["message"] = "The database has a high number of wait events, indicating potential bottlenecks."
|
||||
}
|
||||
|
||||
if dbStats.MaxIdleClosed > int64(dbStats.OpenConnections)/2 {
|
||||
stats["message"] = "Many idle connections are being closed, consider revising the connection pool settings."
|
||||
}
|
||||
|
||||
if dbStats.MaxLifetimeClosed > int64(dbStats.OpenConnections)/2 {
|
||||
stats["message"] = "Many connections are being closed due to max lifetime, consider increasing max lifetime or revising the connection usage pattern."
|
||||
}
|
||||
|
||||
result[name] = stats
|
||||
}
|
||||
|
||||
// Check read replicas
|
||||
for name, replicas := range s.readReplicas {
|
||||
for i, db := range replicas {
|
||||
if db == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
replicaName := fmt.Sprintf("%s_replica_%d", name, i)
|
||||
stats := make(map[string]string)
|
||||
|
||||
err := db.PingContext(ctx)
|
||||
if err != nil {
|
||||
stats["status"] = "down"
|
||||
stats["error"] = fmt.Sprintf("read replica down: %v", err)
|
||||
stats["type"] = "sql"
|
||||
stats["role"] = "replica"
|
||||
result[replicaName] = stats
|
||||
continue
|
||||
}
|
||||
|
||||
stats["status"] = "up"
|
||||
stats["message"] = "Read replica healthy"
|
||||
stats["type"] = "sql"
|
||||
stats["role"] = "replica"
|
||||
|
||||
dbStats := db.Stats()
|
||||
stats["open_connections"] = strconv.Itoa(dbStats.OpenConnections)
|
||||
stats["in_use"] = strconv.Itoa(dbStats.InUse)
|
||||
stats["idle"] = strconv.Itoa(dbStats.Idle)
|
||||
|
||||
result[replicaName] = stats
|
||||
}
|
||||
}
|
||||
|
||||
// Check MongoDB connections
|
||||
for name, client := range s.mongoClients {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
stats := make(map[string]string)
|
||||
|
||||
err := client.Ping(ctx, nil)
|
||||
if err != nil {
|
||||
stats["status"] = "down"
|
||||
stats["error"] = fmt.Sprintf("mongodb down: %v", err)
|
||||
stats["type"] = "mongodb"
|
||||
result[name] = stats
|
||||
continue
|
||||
}
|
||||
|
||||
stats["status"] = "up"
|
||||
stats["message"] = "It's healthy"
|
||||
stats["type"] = "mongodb"
|
||||
|
||||
result[name] = stats
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// GetDB returns a specific SQL database connection by name
|
||||
func (s *service) GetDB(name string) (*sql.DB, error) {
|
||||
log.Printf("Attempting to get database connection for: %s", name)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
db, exists := s.sqlDatabases[name]
|
||||
if !exists {
|
||||
log.Printf("Error: database %s not found", name) // Log the error
|
||||
return nil, fmt.Errorf("database %s not found", name)
|
||||
}
|
||||
|
||||
log.Printf("Current connection pool state for %s: Open: %d, In Use: %d, Idle: %d",
|
||||
name, db.Stats().OpenConnections, db.Stats().InUse, db.Stats().Idle)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// db, exists := s.sqlDatabases[name]
|
||||
// if !exists {
|
||||
// log.Printf("Error: database %s not found", name) // Log the error
|
||||
// return nil, fmt.Errorf("database %s not found", name)
|
||||
// }
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// GetReadDB returns a read replica connection using round-robin load balancing
|
||||
func (s *service) GetReadDB(name string) (*sql.DB, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
replicas, exists := s.readReplicas[name]
|
||||
if !exists || len(replicas) == 0 {
|
||||
// Fallback to primary if no replicas available
|
||||
return s.GetDB(name)
|
||||
}
|
||||
|
||||
// Round-robin load balancing
|
||||
s.readBalancer[name] = (s.readBalancer[name] + 1) % len(replicas)
|
||||
selected := replicas[s.readBalancer[name]]
|
||||
|
||||
if selected == nil {
|
||||
// Fallback to primary if replica is nil
|
||||
return s.GetDB(name)
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// GetMongoClient returns a specific MongoDB client by name
|
||||
func (s *service) GetMongoClient(name string) (*mongo.Client, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
client, exists := s.mongoClients[name]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("MongoDB client %s not found", name)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// ListDBs returns list of available database names
|
||||
func (s *service) ListDBs() []string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
names := make([]string, 0, len(s.sqlDatabases)+len(s.mongoClients))
|
||||
|
||||
for name := range s.sqlDatabases {
|
||||
names = append(names, name)
|
||||
}
|
||||
|
||||
for name := range s.mongoClients {
|
||||
names = append(names, name)
|
||||
}
|
||||
|
||||
return names
|
||||
}
|
||||
|
||||
// GetDBType returns the type of a specific database
|
||||
func (s *service) GetDBType(name string) (DatabaseType, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
config, exists := s.configs[name]
|
||||
if !exists {
|
||||
return "", fmt.Errorf("database %s not found", name)
|
||||
}
|
||||
|
||||
return DatabaseType(config.Type), nil
|
||||
}
|
||||
|
||||
// Close closes all database connections
|
||||
func (s *service) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var errs []error
|
||||
|
||||
for name, db := range s.sqlDatabases {
|
||||
if err := db.Close(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to close database %s: %w", name, err))
|
||||
} else {
|
||||
log.Printf("Disconnected from SQL database: %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
for name, replicas := range s.readReplicas {
|
||||
for i, db := range replicas {
|
||||
if db != nil {
|
||||
if err := db.Close(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to close read replica %s[%d]: %w", name, i, err))
|
||||
} else {
|
||||
log.Printf("Disconnected from read replica: %s[%d]", name, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for name, client := range s.mongoClients {
|
||||
if err := client.Disconnect(context.Background()); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to disconnect MongoDB client %s: %w", name, err))
|
||||
} else {
|
||||
log.Printf("Disconnected from MongoDB: %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
s.sqlDatabases = make(map[string]*sql.DB)
|
||||
s.mongoClients = make(map[string]*mongo.Client)
|
||||
s.readReplicas = make(map[string][]*sql.DB)
|
||||
s.configs = make(map[string]config.DatabaseConfig)
|
||||
s.readConfigs = make(map[string][]config.DatabaseConfig)
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("errors closing databases: %v", errs)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPrimaryDB returns primary database connection
|
||||
func (s *service) GetPrimaryDB(name string) (*sql.DB, error) {
|
||||
return s.GetDB(name)
|
||||
}
|
||||
|
||||
// ListenForChanges implements PostgreSQL LISTEN/NOTIFY for real-time updates
|
||||
func (s *service) ListenForChanges(ctx context.Context, dbName string, channels []string, callback func(string, string)) error {
|
||||
s.mu.RLock()
|
||||
config, exists := s.configs[dbName]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("database %s not found", dbName)
|
||||
}
|
||||
|
||||
// Only support PostgreSQL for LISTEN/NOTIFY
|
||||
if DatabaseType(config.Type) != Postgres {
|
||||
return fmt.Errorf("LISTEN/NOTIFY only supported for PostgreSQL databases")
|
||||
}
|
||||
|
||||
// Create connection string for listener
|
||||
connStr := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=%s",
|
||||
config.Username,
|
||||
config.Password,
|
||||
config.Host,
|
||||
config.Port,
|
||||
config.Database,
|
||||
config.SSLMode,
|
||||
)
|
||||
|
||||
// Create listener
|
||||
listener := pq.NewListener(
|
||||
connStr,
|
||||
10*time.Second,
|
||||
time.Minute,
|
||||
func(ev pq.ListenerEventType, err error) {
|
||||
if err != nil {
|
||||
log.Printf("Database listener (%s) error: %v", dbName, err)
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
// Store listener for cleanup
|
||||
s.listenersMu.Lock()
|
||||
s.listeners[dbName] = listener
|
||||
s.listenersMu.Unlock()
|
||||
|
||||
// Listen to specified channels
|
||||
for _, channel := range channels {
|
||||
err := listener.Listen(channel)
|
||||
if err != nil {
|
||||
listener.Close()
|
||||
return fmt.Errorf("failed to listen to channel %s: %w", channel, err)
|
||||
}
|
||||
log.Printf("Listening to database channel: %s on %s", channel, dbName)
|
||||
}
|
||||
|
||||
// Start listening loop
|
||||
go func() {
|
||||
defer func() {
|
||||
listener.Close()
|
||||
s.listenersMu.Lock()
|
||||
delete(s.listeners, dbName)
|
||||
s.listenersMu.Unlock()
|
||||
log.Printf("Database listener for %s stopped", dbName)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case n := <-listener.Notify:
|
||||
if n != nil {
|
||||
callback(n.Channel, n.Extra)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(90 * time.Second):
|
||||
// Send ping to keep connection alive
|
||||
go func() {
|
||||
if err := listener.Ping(); err != nil {
|
||||
log.Printf("Listener ping failed for %s: %v", dbName, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NotifyChange sends a notification to a PostgreSQL channel
|
||||
func (s *service) NotifyChange(dbName, channel, payload string) error {
|
||||
db, err := s.GetDB(dbName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get database %s: %w", dbName, err)
|
||||
}
|
||||
|
||||
// Check if it's PostgreSQL
|
||||
s.mu.RLock()
|
||||
config, exists := s.configs[dbName]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("database %s configuration not found", dbName)
|
||||
}
|
||||
|
||||
if DatabaseType(config.Type) != Postgres {
|
||||
return fmt.Errorf("NOTIFY only supported for PostgreSQL databases")
|
||||
}
|
||||
|
||||
// Execute NOTIFY
|
||||
query := "SELECT pg_notify($1, $2)"
|
||||
_, err = db.Exec(query, channel, payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send notification: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Sent notification to channel %s on %s: %s", channel, dbName, payload)
|
||||
return nil
|
||||
}
|
||||
132
internal/handlers/auth/auth.go
Normal file
132
internal/handlers/auth/auth.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
models "api-service/internal/models/auth"
|
||||
services "api-service/internal/services/auth"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AuthHandler handles authentication endpoints
|
||||
type AuthHandler struct {
|
||||
authService *services.AuthService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new authentication handler
|
||||
func NewAuthHandler(authService *services.AuthService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
authService: authService,
|
||||
}
|
||||
}
|
||||
|
||||
// Login godoc
|
||||
// @Summary Login user and get JWT token
|
||||
// @Description Authenticate user with username and password to receive JWT token
|
||||
// @Tags Authentication
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param login body models.LoginRequest true "Login credentials"
|
||||
// @Success 200 {object} models.TokenResponse
|
||||
// @Failure 400 {object} map[string]string "Bad request"
|
||||
// @Failure 401 {object} map[string]string "Unauthorized"
|
||||
// @Router /api/v1/auth/login [post]
|
||||
func (h *AuthHandler) Login(c *gin.Context) {
|
||||
var loginReq models.LoginRequest
|
||||
|
||||
// Bind JSON request
|
||||
if err := c.ShouldBindJSON(&loginReq); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate user
|
||||
tokenResponse, err := h.authService.Login(loginReq.Username, loginReq.Password)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, tokenResponse)
|
||||
}
|
||||
|
||||
// RefreshToken godoc
|
||||
// @Summary Refresh JWT token
|
||||
// @Description Refresh the JWT token using a valid refresh token
|
||||
// @Tags Authentication
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param refresh body map[string]string true "Refresh token"
|
||||
// @Success 200 {object} models.TokenResponse
|
||||
// @Failure 400 {object} map[string]string "Bad request"
|
||||
// @Failure 401 {object} map[string]string "Unauthorized"
|
||||
// @Router /api/v1/auth/refresh [post]
|
||||
func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
||||
// For now, this is a placeholder for refresh token functionality
|
||||
// In a real implementation, you would handle refresh tokens here
|
||||
c.JSON(http.StatusNotImplemented, gin.H{"error": "refresh token not implemented"})
|
||||
}
|
||||
|
||||
// Register godoc
|
||||
// @Summary Register new user
|
||||
// @Description Register a new user account
|
||||
// @Tags Authentication
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param register body map[string]string true "Registration data"
|
||||
// @Success 201 {object} map[string]string
|
||||
// @Failure 400 {object} map[string]string "Bad request"
|
||||
// @Router /api/v1/auth/register [post]
|
||||
func (h *AuthHandler) Register(c *gin.Context) {
|
||||
var registerReq struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Role string `json:"role" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(®isterReq); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
err := h.authService.RegisterUser(
|
||||
registerReq.Username,
|
||||
registerReq.Email,
|
||||
registerReq.Password,
|
||||
registerReq.Role,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{"message": "user registered successfully"})
|
||||
}
|
||||
|
||||
// Me godoc
|
||||
// @Summary Get current user info
|
||||
// @Description Get information about the currently authenticated user
|
||||
// @Tags Authentication
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} models.User
|
||||
// @Failure 401 {object} map[string]string "Unauthorized"
|
||||
// @Router /api/v1/auth/me [get]
|
||||
func (h *AuthHandler) Me(c *gin.Context) {
|
||||
// Get user info from context (set by middleware)
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "user not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
// In a real implementation, you would fetch user details from database
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"id": userID,
|
||||
"username": c.GetString("username"),
|
||||
"email": c.GetString("email"),
|
||||
"role": c.GetString("role"),
|
||||
})
|
||||
}
|
||||
95
internal/handlers/auth/token.go
Normal file
95
internal/handlers/auth/token.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
models "api-service/internal/models/auth"
|
||||
services "api-service/internal/services/auth"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// TokenHandler handles token generation endpoints
|
||||
type TokenHandler struct {
|
||||
authService *services.AuthService
|
||||
}
|
||||
|
||||
// NewTokenHandler creates a new token handler
|
||||
func NewTokenHandler(authService *services.AuthService) *TokenHandler {
|
||||
return &TokenHandler{
|
||||
authService: authService,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateToken godoc
|
||||
// @Summary Generate JWT token
|
||||
// @Description Generate a JWT token for a user
|
||||
// @Tags Token
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param token body models.LoginRequest true "User credentials"
|
||||
// @Success 200 {object} models.TokenResponse
|
||||
// @Failure 400 {object} map[string]string "Bad request"
|
||||
// @Failure 401 {object} map[string]string "Unauthorized"
|
||||
// @Router /api/v1/token/generate [post]
|
||||
func (h *TokenHandler) GenerateToken(c *gin.Context) {
|
||||
var loginReq models.LoginRequest
|
||||
|
||||
// Bind JSON request
|
||||
if err := c.ShouldBindJSON(&loginReq); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Generate token
|
||||
tokenResponse, err := h.authService.Login(loginReq.Username, loginReq.Password)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, tokenResponse)
|
||||
}
|
||||
|
||||
// GenerateTokenDirect godoc
|
||||
// @Summary Generate token directly
|
||||
// @Description Generate a JWT token directly without password verification (for testing)
|
||||
// @Tags Token
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param user body map[string]string true "User info"
|
||||
// @Success 200 {object} models.TokenResponse
|
||||
// @Failure 400 {object} map[string]string "Bad request"
|
||||
// @Router /api/v1/token/generate-direct [post]
|
||||
func (h *TokenHandler) GenerateTokenDirect(c *gin.Context) {
|
||||
var req struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Email string `json:"email" binding:"required"`
|
||||
Role string `json:"role" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Create a temporary user for token generation
|
||||
user := &models.User{
|
||||
ID: "temp-" + req.Username,
|
||||
Username: req.Username,
|
||||
Email: req.Email,
|
||||
Role: req.Role,
|
||||
}
|
||||
|
||||
// Generate token directly
|
||||
token, err := h.authService.GenerateTokenForUser(user)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.TokenResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
})
|
||||
}
|
||||
24
internal/handlers/healthcheck/healthcheck.go
Normal file
24
internal/handlers/healthcheck/healthcheck.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"api-service/internal/database"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// HealthCheckHandler handles health check requests
|
||||
type HealthCheckHandler struct {
|
||||
dbService database.Service
|
||||
}
|
||||
|
||||
// NewHealthCheckHandler creates a new HealthCheckHandler
|
||||
func NewHealthCheckHandler(dbService database.Service) *HealthCheckHandler {
|
||||
return &HealthCheckHandler{dbService: dbService}
|
||||
}
|
||||
|
||||
// CheckHealth checks the health of the application
|
||||
func (h *HealthCheckHandler) CheckHealth(c *gin.Context) {
|
||||
healthStatus := h.dbService.Health() // Call the health check function from the database service
|
||||
c.JSON(http.StatusOK, healthStatus)
|
||||
}
|
||||
1401
internal/handlers/retribusi/retribusi.go
Normal file
1401
internal/handlers/retribusi/retribusi.go
Normal file
File diff suppressed because it is too large
Load Diff
111
internal/handlers/websocket/broadcast.go
Normal file
111
internal/handlers/websocket/broadcast.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// WebSocketBroadcaster defines the interface for broadcasting messages
|
||||
type WebSocketBroadcaster interface {
|
||||
BroadcastMessage(messageType string, data interface{})
|
||||
}
|
||||
|
||||
// Broadcaster handles server-initiated broadcasts to WebSocket clients
|
||||
type Broadcaster struct {
|
||||
handler WebSocketBroadcaster
|
||||
tickers []*time.Ticker
|
||||
quit chan struct{}
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewBroadcaster creates a new Broadcaster instance
|
||||
func NewBroadcaster(handler WebSocketBroadcaster) *Broadcaster {
|
||||
return &Broadcaster{
|
||||
handler: handler,
|
||||
tickers: make([]*time.Ticker, 0),
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// StartHeartbeat starts sending periodic heartbeat messages to all clients
|
||||
func (b *Broadcaster) StartHeartbeat(interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
b.tickers = append(b.tickers, ticker)
|
||||
go func() {
|
||||
defer func() {
|
||||
// Remove ticker from slice when done
|
||||
for i, t := range b.tickers {
|
||||
if t == ticker {
|
||||
b.tickers = append(b.tickers[:i], b.tickers[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
b.handler.BroadcastMessage("heartbeat", map[string]interface{}{
|
||||
"message": "Server heartbeat",
|
||||
"timestamp": time.Now().Format(time.RFC3339),
|
||||
})
|
||||
case <-b.quit:
|
||||
ticker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Stop stops the broadcaster
|
||||
func (b *Broadcaster) Stop() {
|
||||
close(b.quit)
|
||||
for _, ticker := range b.tickers {
|
||||
if ticker != nil {
|
||||
ticker.Stop()
|
||||
}
|
||||
}
|
||||
b.tickers = nil
|
||||
}
|
||||
|
||||
// BroadcastNotification sends a notification message to all clients
|
||||
func (b *Broadcaster) BroadcastNotification(title, message, level string) {
|
||||
b.handler.BroadcastMessage("notification", map[string]interface{}{
|
||||
"title": title,
|
||||
"message": message,
|
||||
"level": level,
|
||||
"time": time.Now().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
// SimulateDataStream simulates streaming data to clients (useful for demos)
|
||||
func (b *Broadcaster) SimulateDataStream() {
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
b.tickers = append(b.tickers, ticker)
|
||||
go func() {
|
||||
defer func() {
|
||||
// Remove ticker from slice when done
|
||||
for i, t := range b.tickers {
|
||||
if t == ticker {
|
||||
b.tickers = append(b.tickers[:i], b.tickers[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
counter := 0
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
counter++
|
||||
b.handler.BroadcastMessage("data_stream", map[string]interface{}{
|
||||
"id": counter,
|
||||
"value": counter * 10,
|
||||
"timestamp": time.Now().Format(time.RFC3339),
|
||||
"type": "simulated_data",
|
||||
})
|
||||
case <-b.quit:
|
||||
ticker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
251
internal/handlers/websocket/broadcast_test.go
Normal file
251
internal/handlers/websocket/broadcast_test.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MockWebSocketHandler is a mock implementation for testing
|
||||
type MockWebSocketHandler struct {
|
||||
mu sync.Mutex
|
||||
messages []map[string]interface{}
|
||||
broadcasts []string
|
||||
}
|
||||
|
||||
func (m *MockWebSocketHandler) BroadcastMessage(messageType string, data interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.broadcasts = append(m.broadcasts, messageType)
|
||||
m.messages = append(m.messages, map[string]interface{}{
|
||||
"type": messageType,
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *MockWebSocketHandler) GetMessages() []map[string]interface{} {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
result := make([]map[string]interface{}, len(m.messages))
|
||||
copy(result, m.messages)
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *MockWebSocketHandler) GetBroadcasts() []string {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
result := make([]string, len(m.broadcasts))
|
||||
copy(result, m.broadcasts)
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *MockWebSocketHandler) Clear() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.messages = make([]map[string]interface{}, 0)
|
||||
m.broadcasts = make([]string, 0)
|
||||
}
|
||||
|
||||
func NewMockWebSocketHandler() *MockWebSocketHandler {
|
||||
return &MockWebSocketHandler{
|
||||
messages: make([]map[string]interface{}, 0),
|
||||
broadcasts: make([]string, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcaster_StartHeartbeat(t *testing.T) {
|
||||
mockHandler := NewMockWebSocketHandler()
|
||||
broadcaster := NewBroadcaster(mockHandler)
|
||||
|
||||
// Start heartbeat with short interval for testing
|
||||
broadcaster.StartHeartbeat(100 * time.Millisecond)
|
||||
|
||||
// Wait for a few heartbeats
|
||||
time.Sleep(350 * time.Millisecond)
|
||||
|
||||
// Stop the broadcaster
|
||||
broadcaster.Stop()
|
||||
|
||||
// Check if heartbeats were sent
|
||||
messages := mockHandler.GetMessages()
|
||||
if len(messages) == 0 {
|
||||
t.Error("Expected heartbeat messages, but got none")
|
||||
}
|
||||
|
||||
// Check that all messages are heartbeat type
|
||||
broadcasts := mockHandler.GetBroadcasts()
|
||||
for _, msgType := range broadcasts {
|
||||
if msgType != "heartbeat" {
|
||||
t.Errorf("Expected heartbeat message type, got %s", msgType)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Received %d heartbeat messages", len(messages))
|
||||
}
|
||||
|
||||
func TestBroadcaster_BroadcastNotification(t *testing.T) {
|
||||
mockHandler := NewMockWebSocketHandler()
|
||||
broadcaster := NewBroadcaster(mockHandler)
|
||||
|
||||
// Send a notification
|
||||
broadcaster.BroadcastNotification("Test Title", "Test Message", "info")
|
||||
|
||||
// Check if notification was sent
|
||||
messages := mockHandler.GetMessages()
|
||||
if len(messages) != 1 {
|
||||
t.Errorf("Expected 1 message, got %d", len(messages))
|
||||
return
|
||||
}
|
||||
|
||||
msg := messages[0]
|
||||
if msg["type"] != "notification" {
|
||||
t.Errorf("Expected message type 'notification', got %s", msg["type"])
|
||||
}
|
||||
|
||||
data := msg["data"].(map[string]interface{})
|
||||
if data["title"] != "Test Title" {
|
||||
t.Errorf("Expected title 'Test Title', got %s", data["title"])
|
||||
}
|
||||
if data["message"] != "Test Message" {
|
||||
t.Errorf("Expected message 'Test Message', got %s", data["message"])
|
||||
}
|
||||
if data["level"] != "info" {
|
||||
t.Errorf("Expected level 'info', got %s", data["level"])
|
||||
}
|
||||
|
||||
t.Logf("Notification sent successfully: %+v", data)
|
||||
}
|
||||
|
||||
func TestBroadcaster_SimulateDataStream(t *testing.T) {
|
||||
mockHandler := NewMockWebSocketHandler()
|
||||
broadcaster := NewBroadcaster(mockHandler)
|
||||
|
||||
// Start data stream with short interval for testing
|
||||
broadcaster.SimulateDataStream()
|
||||
|
||||
// Wait for a few data points
|
||||
time.Sleep(550 * time.Millisecond)
|
||||
|
||||
// Stop the broadcaster
|
||||
broadcaster.Stop()
|
||||
|
||||
// Check if data stream messages were sent
|
||||
messages := mockHandler.GetMessages()
|
||||
if len(messages) == 0 {
|
||||
t.Error("Expected data stream messages, but got none")
|
||||
}
|
||||
|
||||
// Check that all messages are data_stream type
|
||||
broadcasts := mockHandler.GetBroadcasts()
|
||||
for _, msgType := range broadcasts {
|
||||
if msgType != "data_stream" {
|
||||
t.Errorf("Expected data_stream message type, got %s", msgType)
|
||||
}
|
||||
}
|
||||
|
||||
// Check data structure
|
||||
for i, msg := range messages {
|
||||
data := msg["data"].(map[string]interface{})
|
||||
if data["type"] != "simulated_data" {
|
||||
t.Errorf("Expected data type 'simulated_data', got %s", data["type"])
|
||||
}
|
||||
if id, ok := data["id"].(int); ok {
|
||||
if id != i+1 {
|
||||
t.Errorf("Expected id %d, got %d", i+1, id)
|
||||
}
|
||||
}
|
||||
if value, ok := data["value"].(int); ok {
|
||||
expectedValue := (i + 1) * 10
|
||||
if value != expectedValue {
|
||||
t.Errorf("Expected value %d, got %d", expectedValue, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Received %d data stream messages", len(messages))
|
||||
}
|
||||
|
||||
func TestBroadcaster_Stop(t *testing.T) {
|
||||
mockHandler := NewMockWebSocketHandler()
|
||||
broadcaster := NewBroadcaster(mockHandler)
|
||||
|
||||
// Start heartbeat
|
||||
broadcaster.StartHeartbeat(50 * time.Millisecond)
|
||||
|
||||
// Wait a bit
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Stop the broadcaster
|
||||
broadcaster.Stop()
|
||||
|
||||
// Clear previous messages
|
||||
mockHandler.Clear()
|
||||
|
||||
// Wait a bit more to ensure no new messages are sent
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Check that no new messages were sent after stopping
|
||||
messages := mockHandler.GetMessages()
|
||||
if len(messages) > 0 {
|
||||
t.Errorf("Expected no messages after stopping, but got %d", len(messages))
|
||||
}
|
||||
|
||||
// Clear quit channel to allow reuse in tests
|
||||
broadcaster.quit = make(chan struct{})
|
||||
|
||||
t.Log("Broadcaster stopped successfully")
|
||||
}
|
||||
|
||||
func TestBroadcaster_MultipleOperations(t *testing.T) {
|
||||
mockHandler := NewMockWebSocketHandler()
|
||||
broadcaster := NewBroadcaster(mockHandler)
|
||||
|
||||
// Start heartbeat
|
||||
broadcaster.StartHeartbeat(100 * time.Millisecond)
|
||||
|
||||
// Send notification
|
||||
broadcaster.BroadcastNotification("Test", "Message", "warning")
|
||||
|
||||
// Start data stream
|
||||
broadcaster.SimulateDataStream()
|
||||
|
||||
// Wait for some activity
|
||||
time.Sleep(350 * time.Millisecond)
|
||||
|
||||
// Stop everything
|
||||
broadcaster.Stop()
|
||||
|
||||
// Check results
|
||||
messages := mockHandler.GetMessages()
|
||||
if len(messages) == 0 {
|
||||
t.Error("Expected messages from multiple operations, but got none")
|
||||
}
|
||||
|
||||
broadcasts := mockHandler.GetBroadcasts()
|
||||
hasHeartbeat := false
|
||||
hasNotification := false
|
||||
hasDataStream := false
|
||||
|
||||
for _, msgType := range broadcasts {
|
||||
switch msgType {
|
||||
case "heartbeat":
|
||||
hasHeartbeat = true
|
||||
case "notification":
|
||||
hasNotification = true
|
||||
case "data_stream":
|
||||
hasDataStream = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasHeartbeat {
|
||||
t.Error("Expected heartbeat messages")
|
||||
}
|
||||
if !hasNotification {
|
||||
t.Error("Expected notification message")
|
||||
}
|
||||
if !hasDataStream {
|
||||
t.Error("Expected data stream messages")
|
||||
}
|
||||
|
||||
t.Logf("Multiple operations test passed: %d total messages", len(messages))
|
||||
}
|
||||
1621
internal/handlers/websocket/websocket.go
Normal file
1621
internal/handlers/websocket/websocket.go
Normal file
File diff suppressed because it is too large
Load Diff
59
internal/middleware/auth_middleware.go
Normal file
59
internal/middleware/auth_middleware.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"api-service/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ConfigurableAuthMiddleware provides flexible authentication based on configuration
|
||||
func ConfigurableAuthMiddleware(cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Skip authentication for development/testing if explicitly disabled
|
||||
if !cfg.Keycloak.Enabled {
|
||||
fmt.Println("Authentication is disabled - allowing all requests")
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Use Keycloak authentication when enabled
|
||||
AuthMiddleware()(c)
|
||||
}
|
||||
}
|
||||
|
||||
// StrictAuthMiddleware enforces authentication regardless of Keycloak.Enabled setting
|
||||
func StrictAuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if appConfig == nil {
|
||||
fmt.Println("AuthMiddleware: Config not initialized")
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "authentication service not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
// Always enforce authentication
|
||||
AuthMiddleware()(c)
|
||||
}
|
||||
}
|
||||
|
||||
// OptionalKeycloakAuthMiddleware allows requests but adds authentication info if available
|
||||
func OptionalKeycloakAuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if appConfig == nil || !appConfig.Keycloak.Enabled {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
// No token provided, but continue
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Try to validate token, but don't fail if invalid
|
||||
AuthMiddleware()(c)
|
||||
}
|
||||
}
|
||||
54
internal/middleware/error_handler.go
Normal file
54
internal/middleware/error_handler.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
models "api-service/internal/models"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ErrorHandler handles errors globally
|
||||
func ErrorHandler() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
|
||||
if len(c.Errors) > 0 {
|
||||
err := c.Errors.Last()
|
||||
status := http.StatusInternalServerError
|
||||
|
||||
// Determine status code based on error type
|
||||
switch err.Type {
|
||||
case gin.ErrorTypeBind:
|
||||
status = http.StatusBadRequest
|
||||
case gin.ErrorTypeRender:
|
||||
status = http.StatusUnprocessableEntity
|
||||
case gin.ErrorTypePrivate:
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
|
||||
response := models.ErrorResponse{
|
||||
Error: "internal_error",
|
||||
Message: err.Error(),
|
||||
Code: status,
|
||||
}
|
||||
|
||||
c.JSON(status, response)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CORS middleware configuration
|
||||
func CORSConfig() gin.HandlerFunc {
|
||||
return gin.HandlerFunc(func(c *gin.Context) {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS, PATCH")
|
||||
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
77
internal/middleware/jwt_middleware.go
Normal file
77
internal/middleware/jwt_middleware.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
services "api-service/internal/services/auth"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// JWTAuthMiddleware validates JWT tokens generated by our auth service
|
||||
func JWTAuthMiddleware(authService *services.AuthService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header missing"})
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"})
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
|
||||
// Validate token
|
||||
claims, err := authService.ValidateToken(tokenString)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Set user info in context
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("email", claims.Email)
|
||||
c.Set("role", claims.Role)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// OptionalAuthMiddleware allows both authenticated and unauthenticated requests
|
||||
func OptionalAuthMiddleware(authService *services.AuthService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
// No token provided, but continue
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
claims, err := authService.ValidateToken(tokenString)
|
||||
if err != nil {
|
||||
// Invalid token, but continue (don't abort)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Set user info in context
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("email", claims.Email)
|
||||
c.Set("role", claims.Role)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
254
internal/middleware/keycloak_middleware.go
Normal file
254
internal/middleware/keycloak_middleware.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package middleware
|
||||
|
||||
/** Keycloak Auth Middleware **/
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"api-service/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
)
|
||||
|
||||
// JwksCache caches JWKS keys with expiration
|
||||
type JwksCache struct {
|
||||
mu sync.RWMutex
|
||||
keys map[string]*rsa.PublicKey
|
||||
expiresAt time.Time
|
||||
sfGroup singleflight.Group
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func NewJwksCache(cfg *config.Config) *JwksCache {
|
||||
return &JwksCache{
|
||||
keys: make(map[string]*rsa.PublicKey),
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *JwksCache) GetKey(kid string) (*rsa.PublicKey, error) {
|
||||
c.mu.RLock()
|
||||
if key, ok := c.keys[kid]; ok && time.Now().Before(c.expiresAt) {
|
||||
c.mu.RUnlock()
|
||||
return key, nil
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
// Fetch keys with singleflight to avoid concurrent fetches
|
||||
v, err, _ := c.sfGroup.Do("fetch_jwks", func() (interface{}, error) {
|
||||
return c.fetchKeys()
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys := v.(map[string]*rsa.PublicKey)
|
||||
|
||||
c.mu.Lock()
|
||||
c.keys = keys
|
||||
c.expiresAt = time.Now().Add(1 * time.Hour) // cache for 1 hour
|
||||
c.mu.Unlock()
|
||||
|
||||
key, ok := keys[kid]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("key with kid %s not found", kid)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (c *JwksCache) fetchKeys() (map[string]*rsa.PublicKey, error) {
|
||||
if !c.config.Keycloak.Enabled {
|
||||
return nil, fmt.Errorf("keycloak authentication is disabled")
|
||||
}
|
||||
|
||||
jwksURL := c.config.Keycloak.JwksURL
|
||||
if jwksURL == "" {
|
||||
// Construct JWKS URL from issuer if not explicitly provided
|
||||
jwksURL = c.config.Keycloak.Issuer + "/protocol/openid-connect/certs"
|
||||
}
|
||||
|
||||
resp, err := http.Get(jwksURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var jwksData struct {
|
||||
Keys []struct {
|
||||
Kid string `json:"kid"`
|
||||
Kty string `json:"kty"`
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
} `json:"keys"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&jwksData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys := make(map[string]*rsa.PublicKey)
|
||||
for _, key := range jwksData.Keys {
|
||||
if key.Kty != "RSA" {
|
||||
continue
|
||||
}
|
||||
pubKey, err := parseRSAPublicKey(key.N, key.E)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
keys[key.Kid] = pubKey
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// parseRSAPublicKey parses RSA public key components from base64url strings
|
||||
func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
|
||||
nBytes, err := base64UrlDecode(nStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eBytes, err := base64UrlDecode(eStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var eInt int
|
||||
for _, b := range eBytes {
|
||||
eInt = eInt<<8 + int(b)
|
||||
}
|
||||
|
||||
pubKey := &rsa.PublicKey{
|
||||
N: new(big.Int).SetBytes(nBytes),
|
||||
E: eInt,
|
||||
}
|
||||
return pubKey, nil
|
||||
}
|
||||
|
||||
func base64UrlDecode(s string) ([]byte, error) {
|
||||
// Add padding if missing
|
||||
if m := len(s) % 4; m != 0 {
|
||||
s += strings.Repeat("=", 4-m)
|
||||
}
|
||||
return base64.URLEncoding.DecodeString(s)
|
||||
}
|
||||
|
||||
// Global config instance
|
||||
var appConfig *config.Config
|
||||
var jwksCacheInstance *JwksCache
|
||||
|
||||
// InitializeAuth initializes the auth middleware with config
|
||||
func InitializeAuth(cfg *config.Config) {
|
||||
appConfig = cfg
|
||||
jwksCacheInstance = NewJwksCache(cfg)
|
||||
}
|
||||
|
||||
// AuthMiddleware validates Bearer token as Keycloak JWT token
|
||||
func AuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if appConfig == nil {
|
||||
fmt.Println("AuthMiddleware: Config not initialized")
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "authentication service not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
if !appConfig.Keycloak.Enabled {
|
||||
// Skip authentication if Keycloak is disabled but log for debugging
|
||||
fmt.Println("AuthMiddleware: Keycloak authentication is disabled - allowing all requests")
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("AuthMiddleware: Checking Authorization header") // Debug log
|
||||
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
fmt.Println("AuthMiddleware: Authorization header missing") // Debug log
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header missing"})
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
|
||||
fmt.Println("AuthMiddleware: Invalid Authorization header format") // Debug log
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"})
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
// Verify signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
fmt.Printf("AuthMiddleware: Unexpected signing method: %v\n", token.Header["alg"]) // Debug log
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
||||
kid, ok := token.Header["kid"].(string)
|
||||
if !ok {
|
||||
fmt.Println("AuthMiddleware: kid header not found") // Debug log
|
||||
return nil, errors.New("kid header not found")
|
||||
}
|
||||
|
||||
return jwksCacheInstance.GetKey(kid)
|
||||
}, jwt.WithIssuer(appConfig.Keycloak.Issuer), jwt.WithAudience(appConfig.Keycloak.Audience))
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
fmt.Printf("AuthMiddleware: Invalid or expired token: %v\n", err) // Debug log
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired token"})
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("AuthMiddleware: Token valid, proceeding") // Debug log
|
||||
// Token is valid, proceed
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
/** JWT Bearer authentication middleware */
|
||||
// import (
|
||||
// "net/http"
|
||||
// "strings"
|
||||
|
||||
// "github.com/gin-gonic/gin"
|
||||
// )
|
||||
|
||||
// AuthMiddleware validates Bearer token in Authorization header
|
||||
func AuthJWTMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header missing"})
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"})
|
||||
return
|
||||
}
|
||||
|
||||
token := parts[1]
|
||||
// For now, use a static token for validation. Replace with your logic.
|
||||
const validToken = "your-static-token"
|
||||
|
||||
if token != validToken {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
31
internal/models/auth/auth.go
Normal file
31
internal/models/auth/auth.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package models
|
||||
|
||||
// LoginRequest represents the login request payload
|
||||
type LoginRequest struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
// TokenResponse represents the token response
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
}
|
||||
|
||||
// JWTClaims represents the JWT claims
|
||||
type JWTClaims struct {
|
||||
UserID string `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
// User represents a user for authentication
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Password string `json:"-"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
221
internal/models/models.go
Normal file
221
internal/models/models.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NullableInt32 - your existing implementation
|
||||
type NullableInt32 struct {
|
||||
Int32 int32 `json:"int32,omitempty"`
|
||||
Valid bool `json:"valid"`
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface for NullableInt32
|
||||
func (n *NullableInt32) Scan(value interface{}) error {
|
||||
var ni sql.NullInt32
|
||||
if err := ni.Scan(value); err != nil {
|
||||
return err
|
||||
}
|
||||
n.Int32 = ni.Int32
|
||||
n.Valid = ni.Valid
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface for NullableInt32
|
||||
func (n NullableInt32) Value() (driver.Value, error) {
|
||||
if !n.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return n.Int32, nil
|
||||
}
|
||||
|
||||
// NullableString provides consistent nullable string handling
|
||||
type NullableString struct {
|
||||
String string `json:"string,omitempty"`
|
||||
Valid bool `json:"valid"`
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface for NullableString
|
||||
func (n *NullableString) Scan(value interface{}) error {
|
||||
var ns sql.NullString
|
||||
if err := ns.Scan(value); err != nil {
|
||||
return err
|
||||
}
|
||||
n.String = ns.String
|
||||
n.Valid = ns.Valid
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface for NullableString
|
||||
func (n NullableString) Value() (driver.Value, error) {
|
||||
if !n.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return n.String, nil
|
||||
}
|
||||
|
||||
// NullableTime provides consistent nullable time handling
|
||||
type NullableTime struct {
|
||||
Time time.Time `json:"time,omitempty"`
|
||||
Valid bool `json:"valid"`
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface for NullableTime
|
||||
func (n *NullableTime) Scan(value interface{}) error {
|
||||
var nt sql.NullTime
|
||||
if err := nt.Scan(value); err != nil {
|
||||
return err
|
||||
}
|
||||
n.Time = nt.Time
|
||||
n.Valid = nt.Valid
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface for NullableTime
|
||||
func (n NullableTime) Value() (driver.Value, error) {
|
||||
if !n.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return n.Time, nil
|
||||
}
|
||||
|
||||
// Metadata untuk pagination - dioptimalkan
|
||||
type MetaResponse struct {
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
Total int `json:"total"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
CurrentPage int `json:"current_page"`
|
||||
HasNext bool `json:"has_next"`
|
||||
HasPrev bool `json:"has_prev"`
|
||||
}
|
||||
|
||||
// Aggregate data untuk summary
|
||||
type AggregateData struct {
|
||||
TotalActive int `json:"total_active"`
|
||||
TotalDraft int `json:"total_draft"`
|
||||
TotalInactive int `json:"total_inactive"`
|
||||
ByStatus map[string]int `json:"by_status"`
|
||||
ByDinas map[string]int `json:"by_dinas,omitempty"`
|
||||
ByJenis map[string]int `json:"by_jenis,omitempty"`
|
||||
LastUpdated *time.Time `json:"last_updated,omitempty"`
|
||||
CreatedToday int `json:"created_today"`
|
||||
UpdatedToday int `json:"updated_today"`
|
||||
}
|
||||
|
||||
// Error response yang konsisten
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// BaseRequest contains common fields for all BPJS requests
|
||||
type BaseRequest struct {
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp,omitempty"`
|
||||
}
|
||||
|
||||
// BaseResponse contains common response fields
|
||||
type BaseResponse struct {
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message,omitempty"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
Timestamp string `json:"timestamp,omitempty"`
|
||||
}
|
||||
|
||||
// ErrorResponse represents error response structure
|
||||
type ErrorResponseBpjs struct {
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
Errors map[string]interface{} `json:"errors,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
}
|
||||
|
||||
// PaginationRequest contains pagination parameters
|
||||
type PaginationRequest struct {
|
||||
Page int `json:"page" validate:"min=1"`
|
||||
Limit int `json:"limit" validate:"min=1,max=100"`
|
||||
SortBy string `json:"sort_by,omitempty"`
|
||||
SortDir string `json:"sort_dir,omitempty" validate:"omitempty,oneof=asc desc"`
|
||||
}
|
||||
|
||||
// PaginationResponse contains pagination metadata
|
||||
type PaginationResponse struct {
|
||||
CurrentPage int `json:"current_page"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
TotalItems int64 `json:"total_items"`
|
||||
ItemsPerPage int `json:"items_per_page"`
|
||||
HasNext bool `json:"has_next"`
|
||||
HasPrev bool `json:"has_previous"`
|
||||
}
|
||||
|
||||
// MetaInfo contains additional metadata
|
||||
type MetaInfo struct {
|
||||
Version string `json:"version"`
|
||||
Environment string `json:"environment"`
|
||||
ServerTime string `json:"server_time"`
|
||||
}
|
||||
|
||||
func GetStatusCodeFromMeta(metaCode interface{}) int {
|
||||
statusCode := http.StatusOK
|
||||
|
||||
if metaCode != nil {
|
||||
switch v := metaCode.(type) {
|
||||
case string:
|
||||
if code, err := strconv.Atoi(v); err == nil {
|
||||
if code >= 100 && code <= 599 {
|
||||
statusCode = code
|
||||
} else {
|
||||
statusCode = http.StatusInternalServerError
|
||||
}
|
||||
} else {
|
||||
statusCode = http.StatusInternalServerError
|
||||
}
|
||||
case int:
|
||||
if v >= 100 && v <= 599 {
|
||||
statusCode = v
|
||||
} else {
|
||||
statusCode = http.StatusInternalServerError
|
||||
}
|
||||
case float64:
|
||||
code := int(v)
|
||||
if code >= 100 && code <= 599 {
|
||||
statusCode = code
|
||||
} else {
|
||||
statusCode = http.StatusInternalServerError
|
||||
}
|
||||
default:
|
||||
statusCode = http.StatusInternalServerError
|
||||
}
|
||||
}
|
||||
|
||||
return statusCode
|
||||
}
|
||||
|
||||
// Validation constants
|
||||
const (
|
||||
StatusDraft = "draft"
|
||||
StatusActive = "active"
|
||||
StatusInactive = "inactive"
|
||||
StatusDeleted = "deleted"
|
||||
)
|
||||
|
||||
// ValidStatuses untuk validasi
|
||||
var ValidStatuses = []string{StatusDraft, StatusActive, StatusInactive}
|
||||
|
||||
// IsValidStatus helper function
|
||||
func IsValidStatus(status string) bool {
|
||||
for _, validStatus := range ValidStatuses {
|
||||
if status == validStatus {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
228
internal/models/retribusi/retribusi.go
Normal file
228
internal/models/retribusi/retribusi.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package retribusi
|
||||
|
||||
import (
|
||||
"api-service/internal/models"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Retribusi represents the data structure for the retribusi table
|
||||
// with proper null handling and optimized JSON marshaling
|
||||
type Retribusi struct {
|
||||
ID string `json:"id" db:"id"`
|
||||
Status string `json:"status" db:"status"`
|
||||
Sort models.NullableInt32 `json:"sort,omitempty" db:"sort"`
|
||||
UserCreated models.NullableString `json:"user_created,omitempty" db:"user_created"`
|
||||
DateCreated models.NullableTime `json:"date_created,omitempty" db:"date_created"`
|
||||
UserUpdated models.NullableString `json:"user_updated,omitempty" db:"user_updated"`
|
||||
DateUpdated models.NullableTime `json:"date_updated,omitempty" db:"date_updated"`
|
||||
Jenis models.NullableString `json:"jenis,omitempty" db:"Jenis"`
|
||||
Pelayanan models.NullableString `json:"pelayanan,omitempty" db:"Pelayanan"`
|
||||
Dinas models.NullableString `json:"dinas,omitempty" db:"Dinas"`
|
||||
KelompokObyek models.NullableString `json:"kelompok_obyek,omitempty" db:"Kelompok_obyek"`
|
||||
KodeTarif models.NullableString `json:"kode_tarif,omitempty" db:"Kode_tarif"`
|
||||
Tarif models.NullableString `json:"tarif,omitempty" db:"Tarif"`
|
||||
Satuan models.NullableString `json:"satuan,omitempty" db:"Satuan"`
|
||||
TarifOvertime models.NullableString `json:"tarif_overtime,omitempty" db:"Tarif_overtime"`
|
||||
SatuanOvertime models.NullableString `json:"satuan_overtime,omitempty" db:"Satuan_overtime"`
|
||||
RekeningPokok models.NullableString `json:"rekening_pokok,omitempty" db:"Rekening_pokok"`
|
||||
RekeningDenda models.NullableString `json:"rekening_denda,omitempty" db:"Rekening_denda"`
|
||||
Uraian1 models.NullableString `json:"uraian_1,omitempty" db:"Uraian_1"`
|
||||
Uraian2 models.NullableString `json:"uraian_2,omitempty" db:"Uraian_2"`
|
||||
Uraian3 models.NullableString `json:"uraian_3,omitempty" db:"Uraian_3"`
|
||||
}
|
||||
|
||||
// Custom JSON marshaling untuk Retribusi agar NULL values tidak muncul di response
|
||||
func (r Retribusi) MarshalJSON() ([]byte, error) {
|
||||
type Alias Retribusi
|
||||
aux := &struct {
|
||||
Sort *int `json:"sort,omitempty"`
|
||||
UserCreated *string `json:"user_created,omitempty"`
|
||||
DateCreated *time.Time `json:"date_created,omitempty"`
|
||||
UserUpdated *string `json:"user_updated,omitempty"`
|
||||
DateUpdated *time.Time `json:"date_updated,omitempty"`
|
||||
Jenis *string `json:"jenis,omitempty"`
|
||||
Pelayanan *string `json:"pelayanan,omitempty"`
|
||||
Dinas *string `json:"dinas,omitempty"`
|
||||
KelompokObyek *string `json:"kelompok_obyek,omitempty"`
|
||||
KodeTarif *string `json:"kode_tarif,omitempty"`
|
||||
Tarif *string `json:"tarif,omitempty"`
|
||||
Satuan *string `json:"satuan,omitempty"`
|
||||
TarifOvertime *string `json:"tarif_overtime,omitempty"`
|
||||
SatuanOvertime *string `json:"satuan_overtime,omitempty"`
|
||||
RekeningPokok *string `json:"rekening_pokok,omitempty"`
|
||||
RekeningDenda *string `json:"rekening_denda,omitempty"`
|
||||
Uraian1 *string `json:"uraian_1,omitempty"`
|
||||
Uraian2 *string `json:"uraian_2,omitempty"`
|
||||
Uraian3 *string `json:"uraian_3,omitempty"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(&r),
|
||||
}
|
||||
|
||||
// Convert NullableInt32 to pointer
|
||||
if r.Sort.Valid {
|
||||
sort := int(r.Sort.Int32)
|
||||
aux.Sort = &sort
|
||||
}
|
||||
if r.UserCreated.Valid {
|
||||
aux.UserCreated = &r.UserCreated.String
|
||||
}
|
||||
if r.DateCreated.Valid {
|
||||
aux.DateCreated = &r.DateCreated.Time
|
||||
}
|
||||
if r.UserUpdated.Valid {
|
||||
aux.UserUpdated = &r.UserUpdated.String
|
||||
}
|
||||
if r.DateUpdated.Valid {
|
||||
aux.DateUpdated = &r.DateUpdated.Time
|
||||
}
|
||||
if r.Jenis.Valid {
|
||||
aux.Jenis = &r.Jenis.String
|
||||
}
|
||||
if r.Pelayanan.Valid {
|
||||
aux.Pelayanan = &r.Pelayanan.String
|
||||
}
|
||||
if r.Dinas.Valid {
|
||||
aux.Dinas = &r.Dinas.String
|
||||
}
|
||||
if r.KelompokObyek.Valid {
|
||||
aux.KelompokObyek = &r.KelompokObyek.String
|
||||
}
|
||||
if r.KodeTarif.Valid {
|
||||
aux.KodeTarif = &r.KodeTarif.String
|
||||
}
|
||||
if r.Tarif.Valid {
|
||||
aux.Tarif = &r.Tarif.String
|
||||
}
|
||||
if r.Satuan.Valid {
|
||||
aux.Satuan = &r.Satuan.String
|
||||
}
|
||||
if r.TarifOvertime.Valid {
|
||||
aux.TarifOvertime = &r.TarifOvertime.String
|
||||
}
|
||||
if r.SatuanOvertime.Valid {
|
||||
aux.SatuanOvertime = &r.SatuanOvertime.String
|
||||
}
|
||||
if r.RekeningPokok.Valid {
|
||||
aux.RekeningPokok = &r.RekeningPokok.String
|
||||
}
|
||||
if r.RekeningDenda.Valid {
|
||||
aux.RekeningDenda = &r.RekeningDenda.String
|
||||
}
|
||||
if r.Uraian1.Valid {
|
||||
aux.Uraian1 = &r.Uraian1.String
|
||||
}
|
||||
if r.Uraian2.Valid {
|
||||
aux.Uraian2 = &r.Uraian2.String
|
||||
}
|
||||
if r.Uraian3.Valid {
|
||||
aux.Uraian3 = &r.Uraian3.String
|
||||
}
|
||||
|
||||
return json.Marshal(aux)
|
||||
}
|
||||
|
||||
// Helper methods untuk mendapatkan nilai yang aman
|
||||
func (r *Retribusi) GetJenis() string {
|
||||
if r.Jenis.Valid {
|
||||
return r.Jenis.String
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *Retribusi) GetDinas() string {
|
||||
if r.Dinas.Valid {
|
||||
return r.Dinas.String
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *Retribusi) GetTarif() string {
|
||||
if r.Tarif.Valid {
|
||||
return r.Tarif.String
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Response struct untuk GET by ID - diperbaiki struktur
|
||||
type RetribusiGetByIDResponse struct {
|
||||
Message string `json:"message"`
|
||||
Data *Retribusi `json:"data"`
|
||||
}
|
||||
|
||||
// Request struct untuk create - dioptimalkan dengan validasi
|
||||
type RetribusiCreateRequest struct {
|
||||
Status string `json:"status" validate:"required,oneof=draft active inactive"`
|
||||
Jenis *string `json:"jenis,omitempty" validate:"omitempty,min=1,max=255"`
|
||||
Pelayanan *string `json:"pelayanan,omitempty" validate:"omitempty,min=1,max=255"`
|
||||
Dinas *string `json:"dinas,omitempty" validate:"omitempty,min=1,max=255"`
|
||||
KelompokObyek *string `json:"kelompok_obyek,omitempty" validate:"omitempty,min=1,max=255"`
|
||||
KodeTarif *string `json:"kode_tarif,omitempty" validate:"omitempty,min=1,max=255"`
|
||||
Uraian1 *string `json:"uraian_1,omitempty"`
|
||||
Uraian2 *string `json:"uraian_2,omitempty"`
|
||||
Uraian3 *string `json:"uraian_3,omitempty"`
|
||||
Tarif *string `json:"tarif,omitempty" validate:"omitempty,numeric"`
|
||||
Satuan *string `json:"satuan,omitempty" validate:"omitempty,min=1,max=255"`
|
||||
TarifOvertime *string `json:"tarif_overtime,omitempty" validate:"omitempty,numeric"`
|
||||
SatuanOvertime *string `json:"satuan_overtime,omitempty" validate:"omitempty,min=1,max=255"`
|
||||
RekeningPokok *string `json:"rekening_pokok,omitempty" validate:"omitempty,min=1,max=255"`
|
||||
RekeningDenda *string `json:"rekening_denda,omitempty" validate:"omitempty,min=1,max=255"`
|
||||
}
|
||||
|
||||
// Response struct untuk create
|
||||
type RetribusiCreateResponse struct {
|
||||
Message string `json:"message"`
|
||||
Data *Retribusi `json:"data"`
|
||||
}
|
||||
|
||||
// Update request - sama seperti create tapi dengan ID
|
||||
type RetribusiUpdateRequest struct {
|
||||
ID string `json:"-" validate:"required,uuid4"` // ID dari URL path
|
||||
Status string `json:"status" validate:"required,oneof=draft active inactive"`
|
||||
Jenis *string `json:"jenis,omitempty" validate:"omitempty,min=1,max=255"`
|
||||
Pelayanan *string `json:"pelayanan,omitempty" validate:"omitempty,min=1,max=255"`
|
||||
Dinas *string `json:"dinas,omitempty" validate:"omitempty,min=1,max=255"`
|
||||
KelompokObyek *string `json:"kelompok_obyek,omitempty" validate:"omitempty,min=1,max=255"`
|
||||
KodeTarif *string `json:"kode_tarif,omitempty" validate:"omitempty,min=1,max=255"`
|
||||
Uraian1 *string `json:"uraian_1,omitempty"`
|
||||
Uraian2 *string `json:"uraian_2,omitempty"`
|
||||
Uraian3 *string `json:"uraian_3,omitempty"`
|
||||
Tarif *string `json:"tarif,omitempty" validate:"omitempty,numeric"`
|
||||
Satuan *string `json:"satuan,omitempty" validate:"omitempty,min=1,max=255"`
|
||||
TarifOvertime *string `json:"tarif_overtime,omitempty" validate:"omitempty,numeric"`
|
||||
SatuanOvertime *string `json:"satuan_overtime,omitempty" validate:"omitempty,min=1,max=255"`
|
||||
RekeningPokok *string `json:"rekening_pokok,omitempty" validate:"omitempty,min=1,max=255"`
|
||||
RekeningDenda *string `json:"rekening_denda,omitempty" validate:"omitempty,min=1,max=255"`
|
||||
}
|
||||
|
||||
// Response struct untuk update
|
||||
type RetribusiUpdateResponse struct {
|
||||
Message string `json:"message"`
|
||||
Data *Retribusi `json:"data"`
|
||||
}
|
||||
|
||||
// Response struct untuk delete
|
||||
type RetribusiDeleteResponse struct {
|
||||
Message string `json:"message"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// Enhanced GET response dengan pagination dan aggregation
|
||||
type RetribusiGetResponse struct {
|
||||
Message string `json:"message"`
|
||||
Data []Retribusi `json:"data"`
|
||||
Meta models.MetaResponse `json:"meta"`
|
||||
Summary *models.AggregateData `json:"summary,omitempty"`
|
||||
}
|
||||
|
||||
// Filter struct untuk query parameters
|
||||
type RetribusiFilter struct {
|
||||
Status *string `json:"status,omitempty" form:"status"`
|
||||
Jenis *string `json:"jenis,omitempty" form:"jenis"`
|
||||
Dinas *string `json:"dinas,omitempty" form:"dinas"`
|
||||
KelompokObyek *string `json:"kelompok_obyek,omitempty" form:"kelompok_obyek"`
|
||||
Search *string `json:"search,omitempty" form:"search"`
|
||||
DateFrom *time.Time `json:"date_from,omitempty" form:"date_from"`
|
||||
DateTo *time.Time `json:"date_to,omitempty" form:"date_to"`
|
||||
}
|
||||
106
internal/models/validation.go
Normal file
106
internal/models/validation.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// CustomValidator wraps the validator
|
||||
type CustomValidator struct {
|
||||
Validator *validator.Validate
|
||||
}
|
||||
|
||||
// Validate validates struct
|
||||
func (cv *CustomValidator) Validate(i interface{}) error {
|
||||
return cv.Validator.Struct(i)
|
||||
}
|
||||
|
||||
// RegisterCustomValidations registers custom validation rules
|
||||
func RegisterCustomValidations(v *validator.Validate) {
|
||||
// Validate Indonesian phone number
|
||||
v.RegisterValidation("indonesian_phone", validateIndonesianPhone)
|
||||
|
||||
// Validate BPJS card number format
|
||||
v.RegisterValidation("bpjs_card", validateBPJSCard)
|
||||
|
||||
// Validate Indonesian NIK
|
||||
v.RegisterValidation("indonesian_nik", validateIndonesianNIK)
|
||||
|
||||
// Validate date format YYYY-MM-DD
|
||||
v.RegisterValidation("date_format", validateDateFormat)
|
||||
|
||||
// Validate ICD-10 code format
|
||||
v.RegisterValidation("icd10", validateICD10)
|
||||
|
||||
// Validate ICD-9-CM procedure code
|
||||
v.RegisterValidation("icd9cm", validateICD9CM)
|
||||
}
|
||||
|
||||
func validateIndonesianPhone(fl validator.FieldLevel) bool {
|
||||
phone := fl.Field().String()
|
||||
if phone == "" {
|
||||
return true // Optional field
|
||||
}
|
||||
|
||||
// Indonesian phone number pattern: +62, 62, 08, or 8
|
||||
pattern := `^(\+?62|0?8)[1-9][0-9]{7,11}$`
|
||||
matched, _ := regexp.MatchString(pattern, phone)
|
||||
return matched
|
||||
}
|
||||
|
||||
func validateBPJSCard(fl validator.FieldLevel) bool {
|
||||
card := fl.Field().String()
|
||||
if len(card) != 13 {
|
||||
return false
|
||||
}
|
||||
|
||||
// BPJS card should be numeric
|
||||
pattern := `^\d{13}$`
|
||||
matched, _ := regexp.MatchString(pattern, card)
|
||||
return matched
|
||||
}
|
||||
|
||||
func validateIndonesianNIK(fl validator.FieldLevel) bool {
|
||||
nik := fl.Field().String()
|
||||
if len(nik) != 16 {
|
||||
return false
|
||||
}
|
||||
|
||||
// NIK should be numeric
|
||||
pattern := `^\d{16}$`
|
||||
matched, _ := regexp.MatchString(pattern, nik)
|
||||
return matched
|
||||
}
|
||||
|
||||
func validateDateFormat(fl validator.FieldLevel) bool {
|
||||
dateStr := fl.Field().String()
|
||||
_, err := time.Parse("2006-01-02", dateStr)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func validateICD10(fl validator.FieldLevel) bool {
|
||||
code := fl.Field().String()
|
||||
if code == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Basic ICD-10 pattern: Letter followed by 2 digits, optional dot and more digits
|
||||
pattern := `^[A-Z]\d{2}(\.\d+)?$`
|
||||
matched, _ := regexp.MatchString(pattern, strings.ToUpper(code))
|
||||
return matched
|
||||
}
|
||||
|
||||
func validateICD9CM(fl validator.FieldLevel) bool {
|
||||
code := fl.Field().String()
|
||||
if code == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Basic ICD-9-CM procedure pattern: 2-4 digits with optional decimal
|
||||
pattern := `^\d{2,4}(\.\d+)?$`
|
||||
matched, _ := regexp.MatchString(pattern, code)
|
||||
return matched
|
||||
}
|
||||
774
internal/routes/v1/routes.go
Normal file
774
internal/routes/v1/routes.go
Normal file
@@ -0,0 +1,774 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"api-service/internal/config"
|
||||
"api-service/internal/database"
|
||||
authHandlers "api-service/internal/handlers/auth"
|
||||
healthcheckHandlers "api-service/internal/handlers/healthcheck"
|
||||
retribusiHandlers "api-service/internal/handlers/retribusi"
|
||||
"api-service/internal/handlers/websocket"
|
||||
websocketHandlers "api-service/internal/handlers/websocket"
|
||||
"api-service/internal/middleware"
|
||||
services "api-service/internal/services/auth"
|
||||
"api-service/pkg/logger"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
swaggerFiles "github.com/swaggo/files"
|
||||
ginSwagger "github.com/swaggo/gin-swagger"
|
||||
)
|
||||
|
||||
func RegisterRoutes(cfg *config.Config) *gin.Engine {
|
||||
router := gin.New()
|
||||
|
||||
// Initialize auth middleware configuration
|
||||
middleware.InitializeAuth(cfg)
|
||||
|
||||
// Add global middleware
|
||||
router.Use(middleware.CORSConfig())
|
||||
router.Use(middleware.ErrorHandler())
|
||||
router.Use(logger.RequestLoggerMiddleware(logger.Default()))
|
||||
router.Use(gin.Recovery())
|
||||
|
||||
// Initialize services with error handling
|
||||
authService := services.NewAuthService(cfg)
|
||||
if authService == nil {
|
||||
logger.Fatal("Failed to initialize auth service")
|
||||
}
|
||||
|
||||
// Initialize database service
|
||||
dbService := database.New(cfg)
|
||||
|
||||
// Initialize WebSocket handler with enhanced features
|
||||
websocketHandler := websocketHandlers.NewWebSocketHandler(cfg, dbService)
|
||||
|
||||
// =============================================================================
|
||||
// HEALTH CHECK & SYSTEM ROUTES
|
||||
// =============================================================================
|
||||
|
||||
healthCheckHandler := healthcheckHandlers.NewHealthCheckHandler(dbService)
|
||||
sistem := router.Group("/api/sistem")
|
||||
{
|
||||
sistem.GET("/health", healthCheckHandler.CheckHealth)
|
||||
sistem.GET("/databases", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{
|
||||
"databases": dbService.ListDBs(),
|
||||
"health": dbService.Health(),
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
sistem.GET("/info", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{
|
||||
"service": "API Service v1.0.0",
|
||||
"websocket_active": true,
|
||||
"connected_clients": websocketHandler.GetConnectedClients(),
|
||||
"databases": dbService.ListDBs(),
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SWAGGER DOCUMENTATION
|
||||
// =============================================================================
|
||||
|
||||
router.GET("/swagger/*any", ginSwagger.WrapHandler(
|
||||
swaggerFiles.Handler,
|
||||
ginSwagger.DefaultModelsExpandDepth(-1),
|
||||
ginSwagger.DeepLinking(true),
|
||||
))
|
||||
|
||||
// =============================================================================
|
||||
// WEBSOCKET TEST CLIENT
|
||||
// =============================================================================
|
||||
|
||||
// router.GET("/websocket-test", func(c *gin.Context) {
|
||||
// c.Header("Content-Type", "text/html")
|
||||
// c.String(http.StatusOK, getWebSocketTestHTML())
|
||||
// })
|
||||
|
||||
// =============================================================================
|
||||
// API v1 GROUP
|
||||
// =============================================================================
|
||||
|
||||
v1 := router.Group("/api/v1")
|
||||
|
||||
// =============================================================================
|
||||
// PUBLIC ROUTES (No Authentication Required)
|
||||
// =============================================================================
|
||||
|
||||
// Authentication routes
|
||||
authHandler := authHandlers.NewAuthHandler(authService)
|
||||
tokenHandler := authHandlers.NewTokenHandler(authService)
|
||||
|
||||
// Basic auth routes
|
||||
v1.POST("/auth/login", authHandler.Login)
|
||||
v1.POST("/auth/register", authHandler.Register)
|
||||
v1.POST("/auth/refresh", authHandler.RefreshToken)
|
||||
|
||||
// Token generation routes
|
||||
v1.POST("/token/generate", tokenHandler.GenerateToken)
|
||||
v1.POST("/token/generate-direct", tokenHandler.GenerateTokenDirect)
|
||||
|
||||
// =============================================================================
|
||||
// WEBSOCKET ROUTES
|
||||
// =============================================================================
|
||||
|
||||
// Main WebSocket endpoint with enhanced features
|
||||
v1.GET("/ws", websocketHandler.HandleWebSocket)
|
||||
|
||||
// WebSocket management API
|
||||
wsAPI := router.Group("/api/websocket")
|
||||
{
|
||||
// =============================================================================
|
||||
// BASIC BROADCASTING
|
||||
// =============================================================================
|
||||
|
||||
wsAPI.POST("/broadcast", func(c *gin.Context) {
|
||||
var req struct {
|
||||
Type string `json:"type"`
|
||||
Message interface{} `json:"message"`
|
||||
Database string `json:"database,omitempty"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(400, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
websocketHandler.BroadcastMessage(req.Type, req.Message)
|
||||
c.JSON(200, gin.H{
|
||||
"status": "broadcast sent",
|
||||
"clients_count": websocketHandler.GetConnectedClients(),
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
wsAPI.POST("/broadcast/room/:room", func(c *gin.Context) {
|
||||
room := c.Param("room")
|
||||
var req struct {
|
||||
Type string `json:"type"`
|
||||
Message interface{} `json:"message"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(400, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
websocketHandler.BroadcastToRoom(room, req.Type, req.Message)
|
||||
c.JSON(200, gin.H{
|
||||
"status": "room broadcast sent",
|
||||
"room": room,
|
||||
"clients_count": websocketHandler.GetRoomClientCount(room), // Fix: gunakan GetRoomClientCount
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
// =============================================================================
|
||||
// ENHANCED CLIENT TARGETING
|
||||
// =============================================================================
|
||||
|
||||
wsAPI.POST("/send/:clientId", func(c *gin.Context) {
|
||||
clientID := c.Param("clientId")
|
||||
var req struct {
|
||||
Type string `json:"type"`
|
||||
Message interface{} `json:"message"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(400, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
websocketHandler.SendToClient(clientID, req.Type, req.Message)
|
||||
c.JSON(200, gin.H{
|
||||
"status": "message sent",
|
||||
"client_id": clientID,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
// Send to client by static ID
|
||||
wsAPI.POST("/send/static/:staticId", func(c *gin.Context) {
|
||||
staticID := c.Param("staticId")
|
||||
logger.Infof("Sending message to static client: %s", staticID)
|
||||
var req struct {
|
||||
Type string `json:"type"`
|
||||
Message interface{} `json:"message"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(400, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
success := websocketHandler.SendToClientByStaticID(staticID, req.Type, req.Message)
|
||||
if success {
|
||||
c.JSON(200, gin.H{
|
||||
"status": "message sent to static client",
|
||||
"static_id": staticID,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
} else {
|
||||
c.JSON(404, gin.H{
|
||||
"error": "static client not found",
|
||||
"static_id": staticID,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Broadcast to all clients from specific IP
|
||||
wsAPI.POST("/broadcast/ip/:ipAddress", func(c *gin.Context) {
|
||||
ipAddress := c.Param("ipAddress")
|
||||
var req struct {
|
||||
Type string `json:"type"`
|
||||
Message interface{} `json:"message"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(400, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
count := websocketHandler.BroadcastToIP(ipAddress, req.Type, req.Message)
|
||||
c.JSON(200, gin.H{
|
||||
"status": "ip broadcast sent",
|
||||
"ip_address": ipAddress,
|
||||
"clients_count": count,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
// =============================================================================
|
||||
// CLIENT INFORMATION & STATISTICS
|
||||
// =============================================================================
|
||||
|
||||
wsAPI.GET("/stats", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{
|
||||
"connected_clients": websocketHandler.GetConnectedClients(),
|
||||
"databases": dbService.ListDBs(),
|
||||
"database_health": dbService.Health(),
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
wsAPI.GET("/stats/detailed", func(c *gin.Context) {
|
||||
stats := websocketHandler.GetDetailedStats()
|
||||
c.JSON(200, gin.H{
|
||||
"stats": stats,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
wsAPI.GET("/clients", func(c *gin.Context) {
|
||||
clients := websocketHandler.GetAllClients()
|
||||
c.JSON(200, gin.H{
|
||||
"clients": clients,
|
||||
"count": len(clients),
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
// Fix: Perbaiki GetClientsByIP untuk menggunakan ClientInfo
|
||||
wsAPI.GET("/clients/by-ip/:ipAddress", func(c *gin.Context) {
|
||||
ipAddress := c.Param("ipAddress")
|
||||
client := websocketHandler.GetClientsByIP(ipAddress)
|
||||
if client == nil {
|
||||
c.JSON(404, gin.H{
|
||||
"error": "client not found",
|
||||
"ip_address": ipAddress,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Use ClientInfo struct instead of direct field access
|
||||
clientInfo := websocketHandler.GetAllClients()
|
||||
var targetClientInfo *websocket.ClientInfo
|
||||
for i := range clientInfo {
|
||||
if clientInfo[i].ID == ipAddress {
|
||||
targetClientInfo = &clientInfo[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if targetClientInfo == nil {
|
||||
c.JSON(404, gin.H{
|
||||
"error": "ipAddress not found",
|
||||
"client_id": ipAddress,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"client": map[string]interface{}{
|
||||
"id": targetClientInfo.ID,
|
||||
"static_id": targetClientInfo.StaticID,
|
||||
"ip_address": targetClientInfo.IPAddress,
|
||||
"user_id": targetClientInfo.UserID,
|
||||
"room": targetClientInfo.Room,
|
||||
"connected_at": targetClientInfo.ConnectedAt.Unix(), // Fixed: use exported field
|
||||
"last_ping": targetClientInfo.LastPing.Unix(), // Fixed: use exported field
|
||||
},
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
// Fix: Perbaiki GetClientByID response
|
||||
wsAPI.GET("/client/:clientId", func(c *gin.Context) {
|
||||
clientID := c.Param("clientId")
|
||||
client := websocketHandler.GetClientByID(clientID)
|
||||
|
||||
if client == nil {
|
||||
c.JSON(404, gin.H{
|
||||
"error": "client not found",
|
||||
"client_id": clientID,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Use ClientInfo struct instead of direct field access
|
||||
clientInfo := websocketHandler.GetAllClients()
|
||||
var targetClientInfo *websocket.ClientInfo
|
||||
for i := range clientInfo {
|
||||
if clientInfo[i].ID == clientID {
|
||||
targetClientInfo = &clientInfo[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if targetClientInfo == nil {
|
||||
c.JSON(404, gin.H{
|
||||
"error": "client not found",
|
||||
"client_id": clientID,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"client": map[string]interface{}{
|
||||
"id": targetClientInfo.ID,
|
||||
"static_id": targetClientInfo.StaticID,
|
||||
"ip_address": targetClientInfo.IPAddress,
|
||||
"user_id": targetClientInfo.UserID,
|
||||
"room": targetClientInfo.Room,
|
||||
"connected_at": targetClientInfo.ConnectedAt.Unix(), // Fixed: use exported field
|
||||
"last_ping": targetClientInfo.LastPing.Unix(), // Fixed: use exported field
|
||||
},
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
// Fix: Perbaiki GetClientByStaticID response
|
||||
wsAPI.GET("/client/static/:staticId", func(c *gin.Context) {
|
||||
staticID := c.Param("staticId")
|
||||
client := websocketHandler.GetClientByStaticID(staticID)
|
||||
|
||||
if client == nil {
|
||||
c.JSON(404, gin.H{
|
||||
"error": "static client not found",
|
||||
"static_id": staticID,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Use ClientInfo struct instead of direct field access
|
||||
clientInfo := websocketHandler.GetAllClients()
|
||||
var targetClientInfo *websocket.ClientInfo
|
||||
for i := range clientInfo {
|
||||
if clientInfo[i].StaticID == staticID {
|
||||
targetClientInfo = &clientInfo[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if targetClientInfo == nil {
|
||||
c.JSON(404, gin.H{
|
||||
"error": "static client not found",
|
||||
"static_id": staticID,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"client": map[string]interface{}{
|
||||
"id": targetClientInfo.ID,
|
||||
"static_id": targetClientInfo.StaticID,
|
||||
"ip_address": targetClientInfo.IPAddress,
|
||||
"user_id": targetClientInfo.UserID,
|
||||
"room": targetClientInfo.Room,
|
||||
"connected_at": targetClientInfo.ConnectedAt.Unix(), // Fixed: use exported field
|
||||
"last_ping": targetClientInfo.LastPing.Unix(), // Fixed: use exported field
|
||||
},
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
// =============================================================================
|
||||
// ACTIVE CLIENTS & CLEANUP
|
||||
// =============================================================================
|
||||
|
||||
// Tambahkan endpoint untuk active clients
|
||||
wsAPI.GET("/clients/active", func(c *gin.Context) {
|
||||
// Default: clients active dalam 5 menit terakhir
|
||||
minutes := c.DefaultQuery("minutes", "5")
|
||||
minutesInt, err := strconv.Atoi(minutes)
|
||||
if err != nil {
|
||||
minutesInt = 5
|
||||
}
|
||||
|
||||
activeClients := websocketHandler.GetActiveClients(time.Duration(minutesInt) * time.Minute)
|
||||
c.JSON(200, gin.H{
|
||||
"active_clients": activeClients,
|
||||
"count": len(activeClients),
|
||||
"threshold_minutes": minutesInt,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
// Tambahkan endpoint untuk cleanup inactive clients
|
||||
wsAPI.POST("/cleanup/inactive", func(c *gin.Context) {
|
||||
var req struct {
|
||||
InactiveMinutes int `json:"inactive_minutes"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
req.InactiveMinutes = 30 // Default 30 minutes
|
||||
}
|
||||
|
||||
if req.InactiveMinutes <= 0 {
|
||||
req.InactiveMinutes = 30
|
||||
}
|
||||
|
||||
cleanedCount := websocketHandler.CleanupInactiveClients(time.Duration(req.InactiveMinutes) * time.Minute)
|
||||
c.JSON(200, gin.H{
|
||||
"status": "cleanup completed",
|
||||
"cleaned_clients": cleanedCount,
|
||||
"inactive_minutes": req.InactiveMinutes,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
// =============================================================================
|
||||
// DATABASE NOTIFICATIONS
|
||||
// =============================================================================
|
||||
|
||||
wsAPI.POST("/notify/:database/:channel", func(c *gin.Context) {
|
||||
database := c.Param("database")
|
||||
channel := c.Param("channel")
|
||||
|
||||
var req struct {
|
||||
Payload interface{} `json:"payload"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(400, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
payloadJSON, _ := json.Marshal(req.Payload)
|
||||
err := dbService.NotifyChange(database, channel, string(payloadJSON))
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{
|
||||
"error": err.Error(),
|
||||
"database": database,
|
||||
"channel": channel,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": "notification sent",
|
||||
"database": database,
|
||||
"channel": channel,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
// Test database notification
|
||||
wsAPI.POST("/test-notification", func(c *gin.Context) {
|
||||
var req struct {
|
||||
Database string `json:"database"`
|
||||
Channel string `json:"channel"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(400, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Default values
|
||||
if req.Database == "" {
|
||||
req.Database = "default"
|
||||
}
|
||||
if req.Channel == "" {
|
||||
req.Channel = "system_changes"
|
||||
}
|
||||
if req.Message == "" {
|
||||
req.Message = "Test notification from API"
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"operation": "API_TEST",
|
||||
"table": "manual_test",
|
||||
"data": map[string]interface{}{
|
||||
"message": req.Message,
|
||||
"test_data": req.Data,
|
||||
"timestamp": time.Now().Unix(),
|
||||
},
|
||||
}
|
||||
|
||||
payloadJSON, _ := json.Marshal(payload)
|
||||
err := dbService.NotifyChange(req.Database, req.Channel, string(payloadJSON))
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": "test notification sent",
|
||||
"database": req.Database,
|
||||
"channel": req.Channel,
|
||||
"payload": payload,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
// =============================================================================
|
||||
// ROOM MANAGEMENT
|
||||
// =============================================================================
|
||||
|
||||
wsAPI.GET("/rooms", func(c *gin.Context) {
|
||||
rooms := websocketHandler.GetAllRooms()
|
||||
c.JSON(200, gin.H{
|
||||
"rooms": rooms,
|
||||
"count": len(rooms),
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
wsAPI.GET("/room/:room/clients", func(c *gin.Context) {
|
||||
room := c.Param("room")
|
||||
clientCount := websocketHandler.GetRoomClientCount(room)
|
||||
|
||||
// Get detailed room info
|
||||
allRooms := websocketHandler.GetAllRooms()
|
||||
roomClients := allRooms[room]
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"room": room,
|
||||
"client_count": clientCount,
|
||||
"clients": roomClients,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
// =============================================================================
|
||||
// MONITORING & DEBUGGING
|
||||
// =============================================================================
|
||||
|
||||
wsAPI.GET("/monitor", func(c *gin.Context) {
|
||||
monitor := websocketHandler.GetMonitoringData()
|
||||
c.JSON(200, monitor)
|
||||
})
|
||||
|
||||
wsAPI.POST("/ping-client/:clientId", func(c *gin.Context) {
|
||||
clientID := c.Param("clientId")
|
||||
websocketHandler.SendToClient(clientID, "server_ping", map[string]interface{}{
|
||||
"message": "Ping from server",
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
c.JSON(200, gin.H{
|
||||
"status": "ping sent",
|
||||
"client_id": clientID,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
// Disconnect specific client
|
||||
wsAPI.POST("/disconnect/:clientId", func(c *gin.Context) {
|
||||
clientID := c.Param("clientId")
|
||||
success := websocketHandler.DisconnectClient(clientID)
|
||||
if success {
|
||||
c.JSON(200, gin.H{
|
||||
"status": "client disconnected",
|
||||
"client_id": clientID,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
} else {
|
||||
c.JSON(404, gin.H{
|
||||
"error": "client not found",
|
||||
"client_id": clientID,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// =============================================================================
|
||||
// BULK OPERATIONS
|
||||
// =============================================================================
|
||||
|
||||
// Broadcast to multiple clients
|
||||
wsAPI.POST("/broadcast/bulk", func(c *gin.Context) {
|
||||
var req struct {
|
||||
ClientIDs []string `json:"client_ids"`
|
||||
Type string `json:"type"`
|
||||
Message interface{} `json:"message"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(400, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
successCount := 0
|
||||
for _, clientID := range req.ClientIDs {
|
||||
websocketHandler.SendToClient(clientID, req.Type, req.Message)
|
||||
successCount++
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": "bulk broadcast sent",
|
||||
"total_clients": len(req.ClientIDs),
|
||||
"success_count": successCount,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
// Disconnect multiple clients
|
||||
wsAPI.POST("/disconnect/bulk", func(c *gin.Context) {
|
||||
var req struct {
|
||||
ClientIDs []string `json:"client_ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(400, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
successCount := 0
|
||||
for _, clientID := range req.ClientIDs {
|
||||
if websocketHandler.DisconnectClient(clientID) {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": "bulk disconnect completed",
|
||||
"total_clients": len(req.ClientIDs),
|
||||
"success_count": successCount,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// PUBLISHED ROUTES
|
||||
// =============================================================================
|
||||
|
||||
// Retribusi endpoints with WebSocket notifications
|
||||
retribusiHandler := retribusiHandlers.NewRetribusiHandler()
|
||||
retribusiGroup := v1.Group("/retribusi")
|
||||
{
|
||||
retribusiGroup.GET("", retribusiHandler.GetRetribusi)
|
||||
retribusiGroup.GET("/dynamic", retribusiHandler.GetRetribusiDynamic)
|
||||
retribusiGroup.GET("/search", retribusiHandler.SearchRetribusiAdvanced)
|
||||
retribusiGroup.GET("/id/:id", retribusiHandler.GetRetribusiByID)
|
||||
|
||||
// POST/PUT/DELETE with automatic WebSocket notifications
|
||||
retribusiGroup.POST("", func(c *gin.Context) {
|
||||
retribusiHandler.CreateRetribusi(c)
|
||||
|
||||
// Trigger WebSocket notification after successful creation
|
||||
if c.Writer.Status() == 200 || c.Writer.Status() == 201 {
|
||||
websocketHandler.BroadcastMessage("retribusi_created", map[string]interface{}{
|
||||
"message": "New retribusi record created",
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
retribusiGroup.PUT("/id/:id", func(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
retribusiHandler.UpdateRetribusi(c)
|
||||
|
||||
// Trigger WebSocket notification after successful update
|
||||
if c.Writer.Status() == 200 {
|
||||
websocketHandler.BroadcastMessage("retribusi_updated", map[string]interface{}{
|
||||
"message": "Retribusi record updated",
|
||||
"id": id,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
retribusiGroup.DELETE("/id/:id", func(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
retribusiHandler.DeleteRetribusi(c)
|
||||
|
||||
// Trigger WebSocket notification after successful deletion
|
||||
if c.Writer.Status() == 200 {
|
||||
websocketHandler.BroadcastMessage("retribusi_deleted", map[string]interface{}{
|
||||
"message": "Retribusi record deleted",
|
||||
"id": id,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// PROTECTED ROUTES (Authentication Required)
|
||||
// =============================================================================
|
||||
|
||||
protected := v1.Group("/")
|
||||
protected.Use(middleware.ConfigurableAuthMiddleware(cfg))
|
||||
|
||||
// Protected WebSocket management (optional)
|
||||
protectedWS := protected.Group("/ws-admin")
|
||||
{
|
||||
protectedWS.GET("/stats", func(c *gin.Context) {
|
||||
detailedStats := websocketHandler.GetDetailedStats()
|
||||
c.JSON(200, gin.H{
|
||||
"admin_stats": detailedStats,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
protectedWS.POST("/force-disconnect/:clientId", func(c *gin.Context) {
|
||||
clientID := c.Param("clientId")
|
||||
success := websocketHandler.DisconnectClient(clientID)
|
||||
c.JSON(200, gin.H{
|
||||
"status": "force disconnect attempted",
|
||||
"client_id": clientID,
|
||||
"success": success,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
protectedWS.POST("/cleanup/force", func(c *gin.Context) {
|
||||
var req struct {
|
||||
InactiveMinutes int `json:"inactive_minutes"`
|
||||
Force bool `json:"force"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
req.InactiveMinutes = 10
|
||||
req.Force = false
|
||||
}
|
||||
|
||||
cleanedCount := websocketHandler.CleanupInactiveClients(time.Duration(req.InactiveMinutes) * time.Minute)
|
||||
c.JSON(200, gin.H{
|
||||
"status": "admin cleanup completed",
|
||||
"cleaned_clients": cleanedCount,
|
||||
"inactive_minutes": req.InactiveMinutes,
|
||||
"force": req.Force,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
return router
|
||||
}
|
||||
53
internal/server/server.go
Normal file
53
internal/server/server.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
_ "github.com/joho/godotenv/autoload"
|
||||
|
||||
"api-service/internal/config"
|
||||
"api-service/internal/database"
|
||||
v1 "api-service/internal/routes/v1"
|
||||
)
|
||||
|
||||
var dbService database.Service // Global variable to hold the database service instance
|
||||
|
||||
type Server struct {
|
||||
port int
|
||||
db database.Service
|
||||
}
|
||||
|
||||
func NewServer() *http.Server {
|
||||
// Load configuration
|
||||
cfg := config.LoadConfig()
|
||||
cfg.Validate()
|
||||
|
||||
port, _ := strconv.Atoi(os.Getenv("PORT"))
|
||||
if port == 0 {
|
||||
port = cfg.Server.Port
|
||||
}
|
||||
|
||||
if dbService == nil { // Check if the database service is already initialized
|
||||
dbService = database.New(cfg) // Initialize only once
|
||||
}
|
||||
|
||||
NewServer := &Server{
|
||||
port: port,
|
||||
db: dbService, // Use the global database service instance
|
||||
}
|
||||
|
||||
// Declare Server config
|
||||
server := &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", NewServer.port),
|
||||
Handler: v1.RegisterRoutes(cfg),
|
||||
IdleTimeout: time.Minute,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
return server
|
||||
}
|
||||
169
internal/services/auth/auth.go
Normal file
169
internal/services/auth/auth.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"api-service/internal/config"
|
||||
models "api-service/internal/models/auth"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// AuthService handles authentication logic
|
||||
type AuthService struct {
|
||||
config *config.Config
|
||||
users map[string]*models.User // In-memory user store for demo
|
||||
}
|
||||
|
||||
// NewAuthService creates a new authentication service
|
||||
func NewAuthService(cfg *config.Config) *AuthService {
|
||||
// Initialize with demo users
|
||||
users := make(map[string]*models.User)
|
||||
|
||||
// Add demo users
|
||||
users["admin"] = &models.User{
|
||||
ID: "1",
|
||||
Username: "admin",
|
||||
Email: "admin@example.com",
|
||||
Password: "$2a$10$92IXUNpkjO0rOQ5byMi.Ye4oKoEa3Ro9llC/.og/at2.uheWG/igi", // password
|
||||
Role: "admin",
|
||||
}
|
||||
|
||||
users["user"] = &models.User{
|
||||
ID: "2",
|
||||
Username: "user",
|
||||
Email: "user@example.com",
|
||||
Password: "$2a$10$92IXUNpkjO0rOQ5byMi.Ye4oKoEa3Ro9llC/.og/at2.uheWG/igi", // password
|
||||
Role: "user",
|
||||
}
|
||||
|
||||
return &AuthService{
|
||||
config: cfg,
|
||||
users: users,
|
||||
}
|
||||
}
|
||||
|
||||
// Login authenticates user and generates JWT token
|
||||
func (s *AuthService) Login(username, password string) (*models.TokenResponse, error) {
|
||||
user, exists := s.users[username]
|
||||
if !exists {
|
||||
return nil, errors.New("invalid credentials")
|
||||
}
|
||||
|
||||
// Verify password
|
||||
err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid credentials")
|
||||
}
|
||||
|
||||
// Generate JWT token
|
||||
token, err := s.generateToken(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &models.TokenResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600, // 1 hour
|
||||
}, nil
|
||||
}
|
||||
|
||||
// generateToken creates a new JWT token for the user
|
||||
func (s *AuthService) generateToken(user *models.User) (string, error) {
|
||||
// Create claims
|
||||
claims := jwt.MapClaims{
|
||||
"user_id": user.ID,
|
||||
"username": user.Username,
|
||||
"email": user.Email,
|
||||
"role": user.Role,
|
||||
"exp": time.Now().Add(time.Hour * 1).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
// Create token
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
|
||||
// Sign token with secret key
|
||||
secretKey := []byte(s.getJWTSecret())
|
||||
return token.SignedString(secretKey)
|
||||
}
|
||||
|
||||
// GenerateTokenForUser generates a JWT token for a specific user
|
||||
func (s *AuthService) GenerateTokenForUser(user *models.User) (string, error) {
|
||||
// Create claims
|
||||
claims := jwt.MapClaims{
|
||||
"user_id": user.ID,
|
||||
"username": user.Username,
|
||||
"email": user.Email,
|
||||
"role": user.Role,
|
||||
"exp": time.Now().Add(time.Hour * 1).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
// Create token
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
|
||||
// Sign token with secret key
|
||||
secretKey := []byte(s.getJWTSecret())
|
||||
return token.SignedString(secretKey)
|
||||
}
|
||||
|
||||
// ValidateToken validates the JWT token
|
||||
func (s *AuthService) ValidateToken(tokenString string) (*models.JWTClaims, error) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, errors.New("unexpected signing method")
|
||||
}
|
||||
return []byte(s.getJWTSecret()), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid claims")
|
||||
}
|
||||
|
||||
return &models.JWTClaims{
|
||||
UserID: claims["user_id"].(string),
|
||||
Username: claims["username"].(string),
|
||||
Email: claims["email"].(string),
|
||||
Role: claims["role"].(string),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getJWTSecret returns the JWT secret key
|
||||
func (s *AuthService) getJWTSecret() string {
|
||||
// In production, this should come from environment variables
|
||||
return "your-secret-key-change-this-in-production"
|
||||
}
|
||||
|
||||
// RegisterUser registers a new user (for demo purposes)
|
||||
func (s *AuthService) RegisterUser(username, email, password, role string) error {
|
||||
if _, exists := s.users[username]; exists {
|
||||
return errors.New("username already exists")
|
||||
}
|
||||
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.users[username] = &models.User{
|
||||
ID: string(rune(len(s.users) + 1)),
|
||||
Username: username,
|
||||
Email: email,
|
||||
Password: string(hashedPassword),
|
||||
Role: role,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
593
internal/utils/filters/dynamic_filter.go
Normal file
593
internal/utils/filters/dynamic_filter.go
Normal file
@@ -0,0 +1,593 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// FilterOperator represents supported filter operators
|
||||
type FilterOperator string
|
||||
|
||||
const (
|
||||
OpEqual FilterOperator = "_eq"
|
||||
OpNotEqual FilterOperator = "_neq"
|
||||
OpLike FilterOperator = "_like"
|
||||
OpILike FilterOperator = "_ilike"
|
||||
OpIn FilterOperator = "_in"
|
||||
OpNotIn FilterOperator = "_nin"
|
||||
OpGreaterThan FilterOperator = "_gt"
|
||||
OpGreaterThanEqual FilterOperator = "_gte"
|
||||
OpLessThan FilterOperator = "_lt"
|
||||
OpLessThanEqual FilterOperator = "_lte"
|
||||
OpBetween FilterOperator = "_between"
|
||||
OpNotBetween FilterOperator = "_nbetween"
|
||||
OpNull FilterOperator = "_null"
|
||||
OpNotNull FilterOperator = "_nnull"
|
||||
OpContains FilterOperator = "_contains"
|
||||
OpNotContains FilterOperator = "_ncontains"
|
||||
OpStartsWith FilterOperator = "_starts_with"
|
||||
OpEndsWith FilterOperator = "_ends_with"
|
||||
)
|
||||
|
||||
// DynamicFilter represents a single filter condition
|
||||
type DynamicFilter struct {
|
||||
Column string `json:"column"`
|
||||
Operator FilterOperator `json:"operator"`
|
||||
Value interface{} `json:"value"`
|
||||
LogicOp string `json:"logic_op,omitempty"` // AND, OR
|
||||
}
|
||||
|
||||
// FilterGroup represents a group of filters
|
||||
type FilterGroup struct {
|
||||
Filters []DynamicFilter `json:"filters"`
|
||||
LogicOp string `json:"logic_op"` // AND, OR
|
||||
}
|
||||
|
||||
// DynamicQuery represents the complete query structure
|
||||
type DynamicQuery struct {
|
||||
Fields []string `json:"fields,omitempty"`
|
||||
Filters []FilterGroup `json:"filters,omitempty"`
|
||||
Sort []SortField `json:"sort,omitempty"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
GroupBy []string `json:"group_by,omitempty"`
|
||||
Having []FilterGroup `json:"having,omitempty"`
|
||||
}
|
||||
|
||||
// SortField represents sorting configuration
|
||||
type SortField struct {
|
||||
Column string `json:"column"`
|
||||
Order string `json:"order"` // ASC, DESC
|
||||
}
|
||||
|
||||
// QueryBuilder builds SQL queries from dynamic filters
|
||||
type QueryBuilder struct {
|
||||
tableName string
|
||||
columnMapping map[string]string // Maps API field names to DB column names
|
||||
allowedColumns map[string]bool // Security: only allow specified columns
|
||||
paramCounter int
|
||||
mu *sync.RWMutex
|
||||
}
|
||||
|
||||
// NewQueryBuilder creates a new query builder instance
|
||||
func NewQueryBuilder(tableName string) *QueryBuilder {
|
||||
return &QueryBuilder{
|
||||
tableName: tableName,
|
||||
columnMapping: make(map[string]string),
|
||||
allowedColumns: make(map[string]bool),
|
||||
paramCounter: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// SetColumnMapping sets the mapping between API field names and database column names
|
||||
func (qb *QueryBuilder) SetColumnMapping(mapping map[string]string) *QueryBuilder {
|
||||
qb.columnMapping = mapping
|
||||
return qb
|
||||
}
|
||||
|
||||
// SetAllowedColumns sets the list of allowed columns for security
|
||||
func (qb *QueryBuilder) SetAllowedColumns(columns []string) *QueryBuilder {
|
||||
qb.allowedColumns = make(map[string]bool)
|
||||
for _, col := range columns {
|
||||
qb.allowedColumns[col] = true
|
||||
}
|
||||
return qb
|
||||
}
|
||||
|
||||
// BuildQuery builds the complete SQL query
|
||||
func (qb *QueryBuilder) BuildQuery(query DynamicQuery) (string, []interface{}, error) {
|
||||
qb.paramCounter = 0
|
||||
|
||||
// Build SELECT clause
|
||||
selectClause := qb.buildSelectClause(query.Fields)
|
||||
|
||||
// Build FROM clause
|
||||
fromClause := fmt.Sprintf("FROM %s", qb.tableName)
|
||||
|
||||
// Build WHERE clause
|
||||
whereClause, whereArgs, err := qb.buildWhereClause(query.Filters)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// Build ORDER BY clause
|
||||
orderClause := qb.buildOrderClause(query.Sort)
|
||||
|
||||
// Build GROUP BY clause
|
||||
groupClause := qb.buildGroupByClause(query.GroupBy)
|
||||
|
||||
// Build HAVING clause
|
||||
havingClause, havingArgs, err := qb.buildHavingClause(query.Having)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// Combine all parts
|
||||
sqlParts := []string{selectClause, fromClause}
|
||||
args := []interface{}{}
|
||||
|
||||
if whereClause != "" {
|
||||
sqlParts = append(sqlParts, "WHERE "+whereClause)
|
||||
args = append(args, whereArgs...)
|
||||
}
|
||||
|
||||
if groupClause != "" {
|
||||
sqlParts = append(sqlParts, groupClause)
|
||||
}
|
||||
|
||||
if havingClause != "" {
|
||||
sqlParts = append(sqlParts, "HAVING "+havingClause)
|
||||
args = append(args, havingArgs...)
|
||||
}
|
||||
|
||||
if orderClause != "" {
|
||||
sqlParts = append(sqlParts, orderClause)
|
||||
}
|
||||
|
||||
// Add pagination
|
||||
if query.Limit > 0 {
|
||||
qb.paramCounter++
|
||||
sqlParts = append(sqlParts, fmt.Sprintf("LIMIT $%d", qb.paramCounter))
|
||||
args = append(args, query.Limit)
|
||||
}
|
||||
|
||||
if query.Offset > 0 {
|
||||
qb.paramCounter++
|
||||
sqlParts = append(sqlParts, fmt.Sprintf("OFFSET $%d", qb.paramCounter))
|
||||
args = append(args, query.Offset)
|
||||
}
|
||||
|
||||
sql := strings.Join(sqlParts, " ")
|
||||
return sql, args, nil
|
||||
}
|
||||
|
||||
// buildSelectClause builds the SELECT part of the query
|
||||
func (qb *QueryBuilder) buildSelectClause(fields []string) string {
|
||||
if len(fields) == 0 || (len(fields) == 1 && fields[0] == "*") {
|
||||
return "SELECT *"
|
||||
}
|
||||
|
||||
var selectedFields []string
|
||||
for _, field := range fields {
|
||||
if field == "*.*" || field == "*" {
|
||||
selectedFields = append(selectedFields, "*")
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if it's an expression (contains spaces, parentheses, etc.)
|
||||
if strings.Contains(field, " ") || strings.Contains(field, "(") || strings.Contains(field, ")") {
|
||||
// Expression, add as is
|
||||
selectedFields = append(selectedFields, field)
|
||||
continue
|
||||
}
|
||||
|
||||
// Security check: only allow specified columns (check original field name)
|
||||
if len(qb.allowedColumns) > 0 && !qb.allowedColumns[field] {
|
||||
continue
|
||||
}
|
||||
|
||||
// Map field name if mapping exists
|
||||
if mappedCol, exists := qb.columnMapping[field]; exists {
|
||||
field = mappedCol
|
||||
}
|
||||
|
||||
selectedFields = append(selectedFields, fmt.Sprintf(`"%s"`, field))
|
||||
}
|
||||
|
||||
if len(selectedFields) == 0 {
|
||||
return "SELECT *"
|
||||
}
|
||||
|
||||
return "SELECT " + strings.Join(selectedFields, ", ")
|
||||
}
|
||||
|
||||
// buildWhereClause builds the WHERE part of the query
|
||||
func (qb *QueryBuilder) buildWhereClause(filterGroups []FilterGroup) (string, []interface{}, error) {
|
||||
if len(filterGroups) == 0 {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
for i, group := range filterGroups {
|
||||
groupCondition, groupArgs, err := qb.buildFilterGroup(group)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if groupCondition != "" {
|
||||
if i > 0 {
|
||||
logicOp := "AND"
|
||||
if group.LogicOp != "" {
|
||||
logicOp = strings.ToUpper(group.LogicOp)
|
||||
}
|
||||
conditions = append(conditions, logicOp)
|
||||
}
|
||||
|
||||
conditions = append(conditions, groupCondition)
|
||||
args = append(args, groupArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(conditions, " "), args, nil
|
||||
}
|
||||
|
||||
// buildFilterGroup builds conditions for a filter group
|
||||
func (qb *QueryBuilder) buildFilterGroup(group FilterGroup) (string, []interface{}, error) {
|
||||
if len(group.Filters) == 0 {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
for i, filter := range group.Filters {
|
||||
condition, filterArgs, err := qb.buildFilterCondition(filter)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if condition != "" {
|
||||
if i > 0 {
|
||||
logicOp := "AND"
|
||||
if filter.LogicOp != "" {
|
||||
logicOp = strings.ToUpper(filter.LogicOp)
|
||||
} else if group.LogicOp != "" {
|
||||
logicOp = strings.ToUpper(group.LogicOp)
|
||||
}
|
||||
conditions = append(conditions, logicOp)
|
||||
}
|
||||
|
||||
conditions = append(conditions, condition)
|
||||
args = append(args, filterArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(conditions, " "), args, nil
|
||||
}
|
||||
|
||||
// buildFilterCondition builds a single filter condition
|
||||
func (qb *QueryBuilder) buildFilterCondition(filter DynamicFilter) (string, []interface{}, error) {
|
||||
// Security check (check original field name)
|
||||
if len(qb.allowedColumns) > 0 && !qb.allowedColumns[filter.Column] {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
// Map column name if mapping exists
|
||||
column := filter.Column
|
||||
if mappedCol, exists := qb.columnMapping[column]; exists {
|
||||
column = mappedCol
|
||||
}
|
||||
|
||||
// Wrap column name in quotes for PostgreSQL
|
||||
column = fmt.Sprintf(`"%s"`, column)
|
||||
|
||||
switch filter.Operator {
|
||||
case OpEqual:
|
||||
if filter.Value == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
qb.paramCounter++
|
||||
return fmt.Sprintf("%s = $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
|
||||
|
||||
case OpNotEqual:
|
||||
if filter.Value == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
qb.paramCounter++
|
||||
return fmt.Sprintf("%s != $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
|
||||
|
||||
case OpLike:
|
||||
if filter.Value == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
qb.paramCounter++
|
||||
return fmt.Sprintf("%s LIKE $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
|
||||
|
||||
case OpILike:
|
||||
if filter.Value == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
qb.paramCounter++
|
||||
return fmt.Sprintf("%s ILIKE $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
|
||||
|
||||
case OpIn:
|
||||
values := qb.parseArrayValue(filter.Value)
|
||||
if len(values) == 0 {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
var placeholders []string
|
||||
var args []interface{}
|
||||
for _, val := range values {
|
||||
qb.paramCounter++
|
||||
placeholders = append(placeholders, fmt.Sprintf("$%d", qb.paramCounter))
|
||||
args = append(args, val)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s IN (%s)", column, strings.Join(placeholders, ", ")), args, nil
|
||||
|
||||
case OpNotIn:
|
||||
values := qb.parseArrayValue(filter.Value)
|
||||
if len(values) == 0 {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
var placeholders []string
|
||||
var args []interface{}
|
||||
for _, val := range values {
|
||||
qb.paramCounter++
|
||||
placeholders = append(placeholders, fmt.Sprintf("$%d", qb.paramCounter))
|
||||
args = append(args, val)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s NOT IN (%s)", column, strings.Join(placeholders, ", ")), args, nil
|
||||
|
||||
case OpGreaterThan:
|
||||
if filter.Value == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
qb.paramCounter++
|
||||
return fmt.Sprintf("%s > $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
|
||||
|
||||
case OpGreaterThanEqual:
|
||||
if filter.Value == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
qb.paramCounter++
|
||||
return fmt.Sprintf("%s >= $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
|
||||
|
||||
case OpLessThan:
|
||||
if filter.Value == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
qb.paramCounter++
|
||||
return fmt.Sprintf("%s < $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
|
||||
|
||||
case OpLessThanEqual:
|
||||
if filter.Value == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
qb.paramCounter++
|
||||
return fmt.Sprintf("%s <= $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
|
||||
|
||||
case OpBetween:
|
||||
if filter.Value == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
values := qb.parseArrayValue(filter.Value)
|
||||
if len(values) != 2 {
|
||||
return "", nil, fmt.Errorf("between operator requires exactly 2 values")
|
||||
}
|
||||
qb.paramCounter++
|
||||
param1 := qb.paramCounter
|
||||
qb.paramCounter++
|
||||
param2 := qb.paramCounter
|
||||
return fmt.Sprintf("%s BETWEEN $%d AND $%d", column, param1, param2), []interface{}{values[0], values[1]}, nil
|
||||
|
||||
case OpNotBetween:
|
||||
if filter.Value == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
values := qb.parseArrayValue(filter.Value)
|
||||
if len(values) != 2 {
|
||||
return "", nil, fmt.Errorf("not between operator requires exactly 2 values")
|
||||
}
|
||||
qb.paramCounter++
|
||||
param1 := qb.paramCounter
|
||||
qb.paramCounter++
|
||||
param2 := qb.paramCounter
|
||||
return fmt.Sprintf("%s NOT BETWEEN $%d AND $%d", column, param1, param2), []interface{}{values[0], values[1]}, nil
|
||||
|
||||
case OpNull:
|
||||
return fmt.Sprintf("%s IS NULL", column), nil, nil
|
||||
|
||||
case OpNotNull:
|
||||
return fmt.Sprintf("%s IS NOT NULL", column), nil, nil
|
||||
|
||||
case OpContains:
|
||||
if filter.Value == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
qb.paramCounter++
|
||||
value := fmt.Sprintf("%%%v%%", filter.Value)
|
||||
return fmt.Sprintf("%s ILIKE $%d", column, qb.paramCounter), []interface{}{value}, nil
|
||||
|
||||
case OpNotContains:
|
||||
if filter.Value == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
qb.paramCounter++
|
||||
value := fmt.Sprintf("%%%v%%", filter.Value)
|
||||
return fmt.Sprintf("%s NOT ILIKE $%d", column, qb.paramCounter), []interface{}{value}, nil
|
||||
|
||||
case OpStartsWith:
|
||||
if filter.Value == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
qb.paramCounter++
|
||||
value := fmt.Sprintf("%v%%", filter.Value)
|
||||
return fmt.Sprintf("%s ILIKE $%d", column, qb.paramCounter), []interface{}{value}, nil
|
||||
|
||||
case OpEndsWith:
|
||||
if filter.Value == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
qb.paramCounter++
|
||||
value := fmt.Sprintf("%%%v", filter.Value)
|
||||
return fmt.Sprintf("%s ILIKE $%d", column, qb.paramCounter), []interface{}{value}, nil
|
||||
|
||||
default:
|
||||
return "", nil, fmt.Errorf("unsupported operator: %s", filter.Operator)
|
||||
}
|
||||
}
|
||||
|
||||
// parseArrayValue parses array values from various formats
|
||||
func (qb *QueryBuilder) parseArrayValue(value interface{}) []interface{} {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If it's already a slice
|
||||
if reflect.TypeOf(value).Kind() == reflect.Slice {
|
||||
v := reflect.ValueOf(value)
|
||||
result := make([]interface{}, v.Len())
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
result[i] = v.Index(i).Interface()
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// If it's a string, try to split by comma
|
||||
if str, ok := value.(string); ok {
|
||||
if strings.Contains(str, ",") {
|
||||
parts := strings.Split(str, ",")
|
||||
result := make([]interface{}, len(parts))
|
||||
for i, part := range parts {
|
||||
result[i] = strings.TrimSpace(part)
|
||||
}
|
||||
return result
|
||||
}
|
||||
return []interface{}{str}
|
||||
}
|
||||
|
||||
return []interface{}{value}
|
||||
}
|
||||
|
||||
// buildOrderClause builds the ORDER BY clause
|
||||
func (qb *QueryBuilder) buildOrderClause(sortFields []SortField) string {
|
||||
if len(sortFields) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var orderParts []string
|
||||
for _, sort := range sortFields {
|
||||
column := sort.Column
|
||||
|
||||
// Security check (check original field name)
|
||||
if len(qb.allowedColumns) > 0 && !qb.allowedColumns[column] {
|
||||
continue
|
||||
}
|
||||
|
||||
if mappedCol, exists := qb.columnMapping[column]; exists {
|
||||
column = mappedCol
|
||||
}
|
||||
|
||||
order := "ASC"
|
||||
if sort.Order != "" {
|
||||
order = strings.ToUpper(sort.Order)
|
||||
}
|
||||
|
||||
orderParts = append(orderParts, fmt.Sprintf(`"%s" %s`, column, order))
|
||||
}
|
||||
|
||||
if len(orderParts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return "ORDER BY " + strings.Join(orderParts, ", ")
|
||||
}
|
||||
|
||||
// buildGroupByClause builds the GROUP BY clause
|
||||
func (qb *QueryBuilder) buildGroupByClause(groupFields []string) string {
|
||||
if len(groupFields) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var groupParts []string
|
||||
for _, field := range groupFields {
|
||||
column := field
|
||||
if mappedCol, exists := qb.columnMapping[column]; exists {
|
||||
column = mappedCol
|
||||
}
|
||||
|
||||
// Security check
|
||||
if len(qb.allowedColumns) > 0 && !qb.allowedColumns[column] {
|
||||
continue
|
||||
}
|
||||
|
||||
groupParts = append(groupParts, fmt.Sprintf(`"%s"`, column))
|
||||
}
|
||||
|
||||
if len(groupParts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return "GROUP BY " + strings.Join(groupParts, ", ")
|
||||
}
|
||||
|
||||
// buildHavingClause builds the HAVING clause
|
||||
func (qb *QueryBuilder) buildHavingClause(havingGroups []FilterGroup) (string, []interface{}, error) {
|
||||
if len(havingGroups) == 0 {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
return qb.buildWhereClause(havingGroups)
|
||||
}
|
||||
|
||||
// BuildCountQuery builds a count query
|
||||
func (qb *QueryBuilder) BuildCountQuery(query DynamicQuery) (string, []interface{}, error) {
|
||||
qb.paramCounter = 0
|
||||
|
||||
// Build FROM clause
|
||||
fromClause := fmt.Sprintf("FROM %s", qb.tableName)
|
||||
|
||||
// Build WHERE clause
|
||||
whereClause, whereArgs, err := qb.buildWhereClause(query.Filters)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// Build GROUP BY clause
|
||||
groupClause := qb.buildGroupByClause(query.GroupBy)
|
||||
|
||||
// Build HAVING clause
|
||||
havingClause, havingArgs, err := qb.buildHavingClause(query.Having)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// Combine parts
|
||||
sqlParts := []string{"SELECT COUNT(*)", fromClause}
|
||||
args := []interface{}{}
|
||||
|
||||
if whereClause != "" {
|
||||
sqlParts = append(sqlParts, "WHERE "+whereClause)
|
||||
args = append(args, whereArgs...)
|
||||
}
|
||||
|
||||
if groupClause != "" {
|
||||
sqlParts = append(sqlParts, groupClause)
|
||||
}
|
||||
|
||||
if havingClause != "" {
|
||||
sqlParts = append(sqlParts, "HAVING "+havingClause)
|
||||
args = append(args, havingArgs...)
|
||||
}
|
||||
|
||||
sql := strings.Join(sqlParts, " ")
|
||||
return sql, args, nil
|
||||
}
|
||||
241
internal/utils/filters/query_parser.go
Normal file
241
internal/utils/filters/query_parser.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// QueryParser parses HTTP query parameters into DynamicQuery
|
||||
type QueryParser struct {
|
||||
defaultLimit int
|
||||
maxLimit int
|
||||
}
|
||||
|
||||
// NewQueryParser creates a new query parser
|
||||
func NewQueryParser() *QueryParser {
|
||||
return &QueryParser{
|
||||
defaultLimit: 10,
|
||||
maxLimit: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// SetLimits sets default and maximum limits
|
||||
func (qp *QueryParser) SetLimits(defaultLimit, maxLimit int) *QueryParser {
|
||||
qp.defaultLimit = defaultLimit
|
||||
qp.maxLimit = maxLimit
|
||||
return qp
|
||||
}
|
||||
|
||||
// ParseQuery parses URL query parameters into DynamicQuery
|
||||
func (qp *QueryParser) ParseQuery(values url.Values) (DynamicQuery, error) {
|
||||
query := DynamicQuery{
|
||||
Limit: qp.defaultLimit,
|
||||
Offset: 0,
|
||||
}
|
||||
|
||||
// Parse fields
|
||||
if fields := values.Get("fields"); fields != "" {
|
||||
if fields == "*.*" || fields == "*" {
|
||||
query.Fields = []string{"*"}
|
||||
} else {
|
||||
query.Fields = strings.Split(fields, ",")
|
||||
for i, field := range query.Fields {
|
||||
query.Fields[i] = strings.TrimSpace(field)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse pagination
|
||||
if limit := values.Get("limit"); limit != "" {
|
||||
if l, err := strconv.Atoi(limit); err == nil {
|
||||
if l > 0 && l <= qp.maxLimit {
|
||||
query.Limit = l
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if offset := values.Get("offset"); offset != "" {
|
||||
if o, err := strconv.Atoi(offset); err == nil && o >= 0 {
|
||||
query.Offset = o
|
||||
}
|
||||
}
|
||||
|
||||
// Parse filters
|
||||
filters, err := qp.parseFilters(values)
|
||||
if err != nil {
|
||||
return query, err
|
||||
}
|
||||
query.Filters = filters
|
||||
|
||||
// Parse sorting
|
||||
sorts, err := qp.parseSorting(values)
|
||||
if err != nil {
|
||||
return query, err
|
||||
}
|
||||
query.Sort = sorts
|
||||
|
||||
// Parse group by
|
||||
if groupBy := values.Get("group"); groupBy != "" {
|
||||
query.GroupBy = strings.Split(groupBy, ",")
|
||||
for i, field := range query.GroupBy {
|
||||
query.GroupBy[i] = strings.TrimSpace(field)
|
||||
}
|
||||
}
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
// parseFilters parses filter parameters
|
||||
// Supports format: filter[column][operator]=value
|
||||
func (qp *QueryParser) parseFilters(values url.Values) ([]FilterGroup, error) {
|
||||
filterMap := make(map[string]map[string]string)
|
||||
|
||||
// Group filters by column
|
||||
for key, vals := range values {
|
||||
if strings.HasPrefix(key, "filter[") && strings.HasSuffix(key, "]") {
|
||||
// Parse filter[column][operator] format
|
||||
parts := strings.Split(key[7:len(key)-1], "][")
|
||||
if len(parts) == 2 {
|
||||
column := parts[0]
|
||||
operator := parts[1]
|
||||
|
||||
if filterMap[column] == nil {
|
||||
filterMap[column] = make(map[string]string)
|
||||
}
|
||||
|
||||
if len(vals) > 0 {
|
||||
filterMap[column][operator] = vals[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(filterMap) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Convert to FilterGroup
|
||||
var filters []DynamicFilter
|
||||
|
||||
for column, operators := range filterMap {
|
||||
for opStr, value := range operators {
|
||||
operator := FilterOperator(opStr)
|
||||
|
||||
// Parse value based on operator
|
||||
var parsedValue interface{}
|
||||
switch operator {
|
||||
case OpIn, OpNotIn:
|
||||
if value != "" {
|
||||
parsedValue = strings.Split(value, ",")
|
||||
}
|
||||
case OpBetween, OpNotBetween:
|
||||
if value != "" {
|
||||
parts := strings.Split(value, ",")
|
||||
if len(parts) == 2 {
|
||||
parsedValue = []interface{}{strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])}
|
||||
}
|
||||
}
|
||||
case OpNull, OpNotNull:
|
||||
parsedValue = nil
|
||||
default:
|
||||
parsedValue = value
|
||||
}
|
||||
|
||||
filters = append(filters, DynamicFilter{
|
||||
Column: column,
|
||||
Operator: operator,
|
||||
Value: parsedValue,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(filters) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return []FilterGroup{{
|
||||
Filters: filters,
|
||||
LogicOp: "AND",
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// parseSorting parses sort parameters
|
||||
// Supports format: sort=column1,-column2 (- for DESC)
|
||||
func (qp *QueryParser) parseSorting(values url.Values) ([]SortField, error) {
|
||||
sortParam := values.Get("sort")
|
||||
if sortParam == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var sorts []SortField
|
||||
fields := strings.Split(sortParam, ",")
|
||||
|
||||
for _, field := range fields {
|
||||
field = strings.TrimSpace(field)
|
||||
if field == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
order := "ASC"
|
||||
column := field
|
||||
|
||||
if strings.HasPrefix(field, "-") {
|
||||
order = "DESC"
|
||||
column = field[1:]
|
||||
} else if strings.HasPrefix(field, "+") {
|
||||
column = field[1:]
|
||||
}
|
||||
|
||||
sorts = append(sorts, SortField{
|
||||
Column: column,
|
||||
Order: order,
|
||||
})
|
||||
}
|
||||
|
||||
return sorts, nil
|
||||
}
|
||||
|
||||
// ParseAdvancedFilters parses complex filter structures
|
||||
// Supports nested filters and logic operators
|
||||
func (qp *QueryParser) ParseAdvancedFilters(filterParam string) ([]FilterGroup, error) {
|
||||
// This would be for more complex JSON-based filters
|
||||
// Implementation depends on your specific needs
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Helper function to parse date values
|
||||
func parseDate(value string) (interface{}, error) {
|
||||
// Try different date formats
|
||||
formats := []string{
|
||||
"2006-01-02",
|
||||
"2006-01-02T15:04:05Z",
|
||||
"2006-01-02T15:04:05.000Z",
|
||||
"2006-01-02 15:04:05",
|
||||
}
|
||||
|
||||
for _, format := range formats {
|
||||
if t, err := time.Parse(format, value); err == nil {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Helper function to parse numeric values
|
||||
func parseNumeric(value string) interface{} {
|
||||
// Try integer first
|
||||
if i, err := strconv.Atoi(value); err == nil {
|
||||
return i
|
||||
}
|
||||
|
||||
// Try float
|
||||
if f, err := strconv.ParseFloat(value, 64); err == nil {
|
||||
return f
|
||||
}
|
||||
|
||||
// Return as string
|
||||
return value
|
||||
}
|
||||
141
internal/utils/validation/duplicate_validator.go
Normal file
141
internal/utils/validation/duplicate_validator.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ValidationConfig holds configuration for duplicate validation
|
||||
type ValidationConfig struct {
|
||||
TableName string
|
||||
IDColumn string
|
||||
StatusColumn string
|
||||
DateColumn string
|
||||
ActiveStatuses []string
|
||||
AdditionalFields map[string]interface{}
|
||||
}
|
||||
|
||||
// DuplicateValidator provides methods for validating duplicate entries
|
||||
type DuplicateValidator struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewDuplicateValidator creates a new instance of DuplicateValidator
|
||||
func NewDuplicateValidator(db *sql.DB) *DuplicateValidator {
|
||||
return &DuplicateValidator{db: db}
|
||||
}
|
||||
|
||||
// ValidateDuplicate checks for duplicate entries based on the provided configuration
|
||||
func (dv *DuplicateValidator) ValidateDuplicate(ctx context.Context, config ValidationConfig, identifier interface{}) error {
|
||||
query := fmt.Sprintf(`
|
||||
SELECT COUNT(*)
|
||||
FROM %s
|
||||
WHERE %s = $1
|
||||
AND %s = ANY($2)
|
||||
AND DATE(%s) = CURRENT_DATE
|
||||
`, config.TableName, config.IDColumn, config.StatusColumn, config.DateColumn)
|
||||
|
||||
var count int
|
||||
err := dv.db.QueryRowContext(ctx, query, identifier, config.ActiveStatuses).Scan(&count)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check duplicate: %w", err)
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
return fmt.Errorf("data with ID %v already exists with active status today", identifier)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateDuplicateWithCustomFields checks for duplicates with additional custom fields
|
||||
func (dv *DuplicateValidator) ValidateDuplicateWithCustomFields(ctx context.Context, config ValidationConfig, fields map[string]interface{}) error {
|
||||
whereClause := fmt.Sprintf("%s = ANY($1) AND DATE(%s) = CURRENT_DATE", config.StatusColumn, config.DateColumn)
|
||||
args := []interface{}{config.ActiveStatuses}
|
||||
argIndex := 2
|
||||
|
||||
// Add additional field conditions
|
||||
for fieldName, fieldValue := range config.AdditionalFields {
|
||||
whereClause += fmt.Sprintf(" AND %s = $%d", fieldName, argIndex)
|
||||
args = append(args, fieldValue)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
// Add dynamic fields
|
||||
for fieldName, fieldValue := range fields {
|
||||
whereClause += fmt.Sprintf(" AND %s = $%d", fieldName, argIndex)
|
||||
args = append(args, fieldValue)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s", config.TableName, whereClause)
|
||||
|
||||
var count int
|
||||
err := dv.db.QueryRowContext(ctx, query, args...).Scan(&count)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check duplicate with custom fields: %w", err)
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
return fmt.Errorf("duplicate entry found with the specified criteria")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateOncePerDay ensures only one submission per day for a given identifier
|
||||
func (dv *DuplicateValidator) ValidateOncePerDay(ctx context.Context, tableName, idColumn, dateColumn string, identifier interface{}) error {
|
||||
query := fmt.Sprintf(`
|
||||
SELECT COUNT(*)
|
||||
FROM %s
|
||||
WHERE %s = $1
|
||||
AND DATE(%s) = CURRENT_DATE
|
||||
`, tableName, idColumn, dateColumn)
|
||||
|
||||
var count int
|
||||
err := dv.db.QueryRowContext(ctx, query, identifier).Scan(&count)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check daily submission: %w", err)
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
return fmt.Errorf("only one submission allowed per day for ID %v", identifier)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLastSubmissionTime returns the last submission time for a given identifier
|
||||
func (dv *DuplicateValidator) GetLastSubmissionTime(ctx context.Context, tableName, idColumn, dateColumn string, identifier interface{}) (*time.Time, error) {
|
||||
query := fmt.Sprintf(`
|
||||
SELECT %s
|
||||
FROM %s
|
||||
WHERE %s = $1
|
||||
ORDER BY %s DESC
|
||||
LIMIT 1
|
||||
`, dateColumn, tableName, idColumn, dateColumn)
|
||||
|
||||
var lastTime time.Time
|
||||
err := dv.db.QueryRowContext(ctx, query, identifier).Scan(&lastTime)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil // No previous submission
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get last submission time: %w", err)
|
||||
}
|
||||
|
||||
return &lastTime, nil
|
||||
}
|
||||
|
||||
// DefaultRetribusiConfig returns default configuration for retribusi validation
|
||||
func DefaultRetribusiConfig() ValidationConfig {
|
||||
return ValidationConfig{
|
||||
TableName: "data_retribusi",
|
||||
IDColumn: "id",
|
||||
StatusColumn: "status",
|
||||
DateColumn: "date_created",
|
||||
ActiveStatuses: []string{"active", "draft"},
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user