first commit

This commit is contained in:
2025-09-24 18:42:16 +07:00
commit daffbc67dc
72 changed files with 40710 additions and 0 deletions

739
internal/config/config.go Normal file
View File

@@ -0,0 +1,739 @@
package config
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"log"
"os"
"strconv"
"strings"
"time"
"github.com/go-playground/validator/v10"
)
type Config struct {
Server ServerConfig
Databases map[string]DatabaseConfig
ReadReplicas map[string][]DatabaseConfig // For read replicas
Keycloak KeycloakConfig
Bpjs BpjsConfig
SatuSehat SatuSehatConfig
Swagger SwaggerConfig
Validator *validator.Validate
}
type SwaggerConfig struct {
Title string
Description string
Version string
TermsOfService string
ContactName string
ContactURL string
ContactEmail string
LicenseName string
LicenseURL string
Host string
BasePath string
Schemes []string
}
type ServerConfig struct {
Port int
Mode string
}
type DatabaseConfig struct {
Name string
Type string // postgres, mysql, sqlserver, sqlite, mongodb
Host string
Port int
Username string
Password string
Database string
Schema string
SSLMode string
Path string // For SQLite
Options string // Additional connection options
MaxOpenConns int // Max open connections
MaxIdleConns int // Max idle connections
ConnMaxLifetime time.Duration // Connection max lifetime
}
type KeycloakConfig struct {
Issuer string
Audience string
JwksURL string
Enabled bool
}
type BpjsConfig struct {
BaseURL string `json:"base_url"`
ConsID string `json:"cons_id"`
UserKey string `json:"user_key"`
SecretKey string `json:"secret_key"`
Timeout time.Duration `json:"timeout"`
}
type SatuSehatConfig struct {
OrgID string `json:"org_id"`
FasyakesID string `json:"fasyakes_id"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
AuthURL string `json:"auth_url"`
BaseURL string `json:"base_url"`
ConsentURL string `json:"consent_url"`
KFAURL string `json:"kfa_url"`
Timeout time.Duration `json:"timeout"`
}
// SetHeader generates required headers for BPJS VClaim API
// func (cfg BpjsConfig) SetHeader() (string, string, string, string, string) {
// timenow := time.Now().UTC()
// t, err := time.Parse(time.RFC3339, "1970-01-01T00:00:00Z")
// if err != nil {
// log.Fatal(err)
// }
// tstamp := timenow.Unix() - t.Unix()
// secret := []byte(cfg.SecretKey)
// message := []byte(cfg.ConsID + "&" + fmt.Sprint(tstamp))
// hash := hmac.New(sha256.New, secret)
// hash.Write(message)
// // to lowercase hexits
// hex.EncodeToString(hash.Sum(nil))
// // to base64
// xSignature := base64.StdEncoding.EncodeToString(hash.Sum(nil))
// return cfg.ConsID, cfg.SecretKey, cfg.UserKey, fmt.Sprint(tstamp), xSignature
// }
func (cfg BpjsConfig) SetHeader() (string, string, string, string, string) {
timenow := time.Now().UTC()
t, err := time.Parse(time.RFC3339, "1970-01-01T00:00:00Z")
if err != nil {
log.Fatal(err)
}
tstamp := timenow.Unix() - t.Unix()
secret := []byte(cfg.SecretKey)
message := []byte(cfg.ConsID + "&" + fmt.Sprint(tstamp))
hash := hmac.New(sha256.New, secret)
hash.Write(message)
// to lowercase hexits
hex.EncodeToString(hash.Sum(nil))
// to base64
xSignature := base64.StdEncoding.EncodeToString(hash.Sum(nil))
return cfg.ConsID, cfg.SecretKey, cfg.UserKey, fmt.Sprint(tstamp), xSignature
}
type ConfigBpjs struct {
Cons_id string
Secret_key string
User_key string
}
// SetHeader for backward compatibility
func (cfg ConfigBpjs) SetHeader() (string, string, string, string, string) {
bpjsConfig := BpjsConfig{
ConsID: cfg.Cons_id,
SecretKey: cfg.Secret_key,
UserKey: cfg.User_key,
}
return bpjsConfig.SetHeader()
}
func LoadConfig() *Config {
config := &Config{
Server: ServerConfig{
Port: getEnvAsInt("PORT", 8080),
Mode: getEnv("GIN_MODE", "debug"),
},
Databases: make(map[string]DatabaseConfig),
ReadReplicas: make(map[string][]DatabaseConfig),
Keycloak: KeycloakConfig{
Issuer: getEnv("KEYCLOAK_ISSUER", "https://keycloak.example.com/auth/realms/yourrealm"),
Audience: getEnv("KEYCLOAK_AUDIENCE", "your-client-id"),
JwksURL: getEnv("KEYCLOAK_JWKS_URL", "https://keycloak.example.com/auth/realms/yourrealm/protocol/openid-connect/certs"),
Enabled: getEnvAsBool("KEYCLOAK_ENABLED", true),
},
Bpjs: BpjsConfig{
BaseURL: getEnv("BPJS_BASEURL", "https://apijkn.bpjs-kesehatan.go.id"),
ConsID: getEnv("BPJS_CONSID", ""),
UserKey: getEnv("BPJS_USERKEY", ""),
SecretKey: getEnv("BPJS_SECRETKEY", ""),
Timeout: parseDuration(getEnv("BPJS_TIMEOUT", "30s")),
},
SatuSehat: SatuSehatConfig{
OrgID: getEnv("BRIDGING_SATUSEHAT_ORG_ID", ""),
FasyakesID: getEnv("BRIDGING_SATUSEHAT_FASYAKES_ID", ""),
ClientID: getEnv("BRIDGING_SATUSEHAT_CLIENT_ID", ""),
ClientSecret: getEnv("BRIDGING_SATUSEHAT_CLIENT_SECRET", ""),
AuthURL: getEnv("BRIDGING_SATUSEHAT_AUTH_URL", "https://api-satusehat.kemkes.go.id/oauth2/v1"),
BaseURL: getEnv("BRIDGING_SATUSEHAT_BASE_URL", "https://api-satusehat.kemkes.go.id/fhir-r4/v1"),
ConsentURL: getEnv("BRIDGING_SATUSEHAT_CONSENT_URL", "https://api-satusehat.dto.kemkes.go.id/consent/v1"),
KFAURL: getEnv("BRIDGING_SATUSEHAT_KFA_URL", "https://api-satusehat.kemkes.go.id/kfa-v2"),
Timeout: parseDuration(getEnv("BRIDGING_SATUSEHAT_TIMEOUT", "30s")),
},
Swagger: SwaggerConfig{
Title: getEnv("SWAGGER_TITLE", "SERVICE API"),
Description: getEnv("SWAGGER_DESCRIPTION", "CUSTUM SERVICE API"),
Version: getEnv("SWAGGER_VERSION", "1.0.0"),
TermsOfService: getEnv("SWAGGER_TERMS_OF_SERVICE", "http://swagger.io/terms/"),
ContactName: getEnv("SWAGGER_CONTACT_NAME", "API Support"),
ContactURL: getEnv("SWAGGER_CONTACT_URL", "http://rssa.example.com/support"),
ContactEmail: getEnv("SWAGGER_CONTACT_EMAIL", "support@swagger.io"),
LicenseName: getEnv("SWAGGER_LICENSE_NAME", "Apache 2.0"),
LicenseURL: getEnv("SWAGGER_LICENSE_URL", "http://www.apache.org/licenses/LICENSE-2.0.html"),
Host: getEnv("SWAGGER_HOST", "localhost:8080"),
BasePath: getEnv("SWAGGER_BASE_PATH", "/api/v1"),
Schemes: parseSchemes(getEnv("SWAGGER_SCHEMES", "http,https")),
},
}
// Initialize validator
config.Validator = validator.New()
// Load database configurations
config.loadDatabaseConfigs()
// Load read replica configurations
config.loadReadReplicaConfigs()
return config
}
func (c *Config) loadDatabaseConfigs() {
// Simplified approach: Directly load from environment variables
// This ensures we get the exact values specified in .env
// Primary database configuration
c.Databases["default"] = DatabaseConfig{
Name: "default",
Type: getEnv("DB_CONNECTION", "postgres"),
Host: getEnv("DB_HOST", "localhost"),
Port: getEnvAsInt("DB_PORT", 5432),
Username: getEnv("DB_USERNAME", ""),
Password: getEnv("DB_PASSWORD", ""),
Database: getEnv("DB_DATABASE", "satu_db"),
Schema: getEnv("DB_SCHEMA", "public"),
SSLMode: getEnv("DB_SSLMODE", "disable"),
MaxOpenConns: getEnvAsInt("DB_MAX_OPEN_CONNS", 25),
MaxIdleConns: getEnvAsInt("DB_MAX_IDLE_CONNS", 25),
ConnMaxLifetime: parseDuration(getEnv("DB_CONN_MAX_LIFETIME", "5m")),
}
// SATUDATA database configuration
c.addPostgreSQLConfigs()
// MongoDB database configuration
c.addMongoDBConfigs()
// Legacy support for backward compatibility
envVars := os.Environ()
dbConfigs := make(map[string]map[string]string)
// Parse database configurations from environment variables
for _, envVar := range envVars {
parts := strings.SplitN(envVar, "=", 2)
if len(parts) != 2 {
continue
}
key := parts[0]
value := parts[1]
// Parse specific database configurations
if strings.HasSuffix(key, "_CONNECTION") || strings.HasSuffix(key, "_HOST") ||
strings.HasSuffix(key, "_DATABASE") || strings.HasSuffix(key, "_USERNAME") ||
strings.HasSuffix(key, "_PASSWORD") || strings.HasSuffix(key, "_PORT") ||
strings.HasSuffix(key, "_NAME") {
segments := strings.Split(key, "_")
if len(segments) >= 2 {
dbName := strings.ToLower(strings.Join(segments[:len(segments)-1], "_"))
property := strings.ToLower(segments[len(segments)-1])
if dbConfigs[dbName] == nil {
dbConfigs[dbName] = make(map[string]string)
}
dbConfigs[dbName][property] = value
}
}
}
// Create DatabaseConfig from parsed configurations for additional databases
for name, config := range dbConfigs {
// Skip empty configurations or system configurations
if name == "" || strings.Contains(name, "chrome_crashpad_pipe") || name == "primary" {
continue
}
dbConfig := DatabaseConfig{
Name: name,
Type: getEnvFromMap(config, "connection", getEnvFromMap(config, "type", "postgres")),
Host: getEnvFromMap(config, "host", "localhost"),
Port: getEnvAsIntFromMap(config, "port", 5432),
Username: getEnvFromMap(config, "username", ""),
Password: getEnvFromMap(config, "password", ""),
Database: getEnvFromMap(config, "database", getEnvFromMap(config, "name", name)),
Schema: getEnvFromMap(config, "schema", "public"),
SSLMode: getEnvFromMap(config, "sslmode", "disable"),
Path: getEnvFromMap(config, "path", ""),
Options: getEnvFromMap(config, "options", ""),
MaxOpenConns: getEnvAsIntFromMap(config, "max_open_conns", 25),
MaxIdleConns: getEnvAsIntFromMap(config, "max_idle_conns", 25),
ConnMaxLifetime: parseDuration(getEnvFromMap(config, "conn_max_lifetime", "5m")),
}
// Skip if username is empty and it's not a system config
if dbConfig.Username == "" && !strings.HasPrefix(name, "chrome") {
continue
}
c.Databases[name] = dbConfig
}
}
func (c *Config) loadReadReplicaConfigs() {
envVars := os.Environ()
for _, envVar := range envVars {
parts := strings.SplitN(envVar, "=", 2)
if len(parts) != 2 {
continue
}
key := parts[0]
value := parts[1]
// Parse read replica configurations (format: [DBNAME]_REPLICA_[INDEX]_[PROPERTY])
if strings.Contains(key, "_REPLICA_") {
segments := strings.Split(key, "_")
if len(segments) >= 5 && strings.ToUpper(segments[2]) == "REPLICA" {
dbName := strings.ToLower(segments[1])
replicaIndex := segments[3]
property := strings.ToLower(strings.Join(segments[4:], "_"))
replicaKey := dbName + "_replica_" + replicaIndex
if c.ReadReplicas[dbName] == nil {
c.ReadReplicas[dbName] = []DatabaseConfig{}
}
// Find or create replica config
var replicaConfig *DatabaseConfig
for i := range c.ReadReplicas[dbName] {
if c.ReadReplicas[dbName][i].Name == replicaKey {
replicaConfig = &c.ReadReplicas[dbName][i]
break
}
}
if replicaConfig == nil {
// Create new replica config
newConfig := DatabaseConfig{
Name: replicaKey,
Type: c.Databases[dbName].Type,
Host: getEnv("DB_"+strings.ToUpper(dbName)+"_REPLICA_"+replicaIndex+"_HOST", c.Databases[dbName].Host),
Port: getEnvAsInt("DB_"+strings.ToUpper(dbName)+"_REPLICA_"+replicaIndex+"_PORT", c.Databases[dbName].Port),
Username: getEnv("DB_"+strings.ToUpper(dbName)+"_REPLICA_"+replicaIndex+"_USERNAME", c.Databases[dbName].Username),
Password: getEnv("DB_"+strings.ToUpper(dbName)+"_REPLICA_"+replicaIndex+"_PASSWORD", c.Databases[dbName].Password),
Database: getEnv("DB_"+strings.ToUpper(dbName)+"_REPLICA_"+replicaIndex+"_DATABASE", c.Databases[dbName].Database),
Schema: getEnv("DB_"+strings.ToUpper(dbName)+"_REPLICA_"+replicaIndex+"_SCHEMA", c.Databases[dbName].Schema),
SSLMode: getEnv("DB_"+strings.ToUpper(dbName)+"_REPLICA_"+replicaIndex+"_SSLMODE", c.Databases[dbName].SSLMode),
MaxOpenConns: getEnvAsInt("DB_"+strings.ToUpper(dbName)+"_REPLICA_"+replicaIndex+"_MAX_OPEN_CONNS", c.Databases[dbName].MaxOpenConns),
MaxIdleConns: getEnvAsInt("DB_"+strings.ToUpper(dbName)+"_REPLICA_"+replicaIndex+"_MAX_IDLE_CONNS", c.Databases[dbName].MaxIdleConns),
ConnMaxLifetime: parseDuration(getEnv("DB_"+strings.ToUpper(dbName)+"_REPLICA_"+replicaIndex+"_CONN_MAX_LIFETIME", "5m")),
}
c.ReadReplicas[dbName] = append(c.ReadReplicas[dbName], newConfig)
replicaConfig = &c.ReadReplicas[dbName][len(c.ReadReplicas[dbName])-1]
}
// Update the specific replica
switch property {
case "host":
replicaConfig.Host = value
case "port":
replicaConfig.Port = getEnvAsInt(key, 5432)
case "username":
replicaConfig.Username = value
case "password":
replicaConfig.Password = value
case "database":
replicaConfig.Database = value
case "schema":
replicaConfig.Schema = value
case "sslmode":
replicaConfig.SSLMode = value
case "max_open_conns":
replicaConfig.MaxOpenConns = getEnvAsInt(key, 25)
case "max_idle_conns":
replicaConfig.MaxIdleConns = getEnvAsInt(key, 25)
case "conn_max_lifetime":
replicaConfig.ConnMaxLifetime = parseDuration(value)
}
}
}
}
}
func (c *Config) addSpecificDatabase(prefix, defaultType string) {
connection := getEnv(strings.ToUpper(prefix)+"_CONNECTION", defaultType)
host := getEnv(strings.ToUpper(prefix)+"_HOST", "")
if host != "" {
dbConfig := DatabaseConfig{
Name: prefix,
Type: connection,
Host: host,
Port: getEnvAsInt(strings.ToUpper(prefix)+"_PORT", 5432),
Username: getEnv(strings.ToUpper(prefix)+"_USERNAME", ""),
Password: getEnv(strings.ToUpper(prefix)+"_PASSWORD", ""),
Database: getEnv(strings.ToUpper(prefix)+"_DATABASE", getEnv(strings.ToUpper(prefix)+"_NAME", prefix)),
Schema: getEnv(strings.ToUpper(prefix)+"_SCHEMA", "public"),
SSLMode: getEnv(strings.ToUpper(prefix)+"_SSLMODE", "disable"),
MaxOpenConns: getEnvAsInt(strings.ToUpper(prefix)+"_MAX_OPEN_CONNS", 25),
MaxIdleConns: getEnvAsInt(strings.ToUpper(prefix)+"_MAX_IDLE_CONNS", 25),
ConnMaxLifetime: parseDuration(getEnv(strings.ToUpper(prefix)+"_CONN_MAX_LIFETIME", "5m")),
}
c.Databases[prefix] = dbConfig
}
}
// PostgreSQL database
func (c *Config) addPostgreSQLConfigs() {
// SATUDATA database configuration
// defaultPOSTGRESHost := getEnv("POSTGRES_HOST", "localhost")
// if defaultPOSTGRESHost != "" {
// c.Databases["postgres"] = DatabaseConfig{
// Name: "postgres",
// Type: getEnv("POSTGRES_CONNECTION", "postgres"),
// Host: defaultPOSTGRESHost,
// Port: getEnvAsInt("POSTGRES_PORT", 5432),
// Username: getEnv("POSTGRES_USERNAME", ""),
// Password: getEnv("POSTGRES_PASSWORD", ""),
// Database: getEnv("POSTGRES_DATABASE", "postgres"),
// Schema: getEnv("POSTGRES_SCHEMA", "public"),
// SSLMode: getEnv("POSTGRES_SSLMODE", "disable"),
// MaxOpenConns: getEnvAsInt("POSTGRES_MAX_OPEN_CONNS", 25),
// MaxIdleConns: getEnvAsInt("POSTGRES_MAX_IDLE_CONNS", 25),
// ConnMaxLifetime: parseDuration(getEnv("POSTGRES_CONN_MAX_LIFETIME", "5m")),
// }
// }
// Support for custom PostgreSQL configurations with POSTGRES_ prefix
envVars := os.Environ()
for _, envVar := range envVars {
parts := strings.SplitN(envVar, "=", 2)
if len(parts) != 2 {
continue
}
key := parts[0]
// Parse PostgreSQL configurations (format: POSTGRES_[NAME]_[PROPERTY])
if strings.HasPrefix(key, "POSTGRES_") && strings.Contains(key, "_") {
segments := strings.Split(key, "_")
if len(segments) >= 3 {
dbName := strings.ToLower(strings.Join(segments[1:len(segments)-1], "_"))
// Skip if it's a standard PostgreSQL configuration
if dbName == "connection" || dbName == "dev" || dbName == "default" || dbName == "satudata" {
continue
}
// Create or update PostgreSQL configuration
if _, exists := c.Databases[dbName]; !exists {
c.Databases[dbName] = DatabaseConfig{
Name: dbName,
Type: "postgres",
Host: getEnv("POSTGRES_"+strings.ToUpper(dbName)+"_HOST", "localhost"),
Port: getEnvAsInt("POSTGRES_"+strings.ToUpper(dbName)+"_PORT", 5432),
Username: getEnv("POSTGRES_"+strings.ToUpper(dbName)+"_USERNAME", ""),
Password: getEnv("POSTGRES_"+strings.ToUpper(dbName)+"_PASSWORD", ""),
Database: getEnv("POSTGRES_"+strings.ToUpper(dbName)+"_DATABASE", dbName),
Schema: getEnv("POSTGRES_"+strings.ToUpper(dbName)+"_SCHEMA", "public"),
SSLMode: getEnv("POSTGRES_"+strings.ToUpper(dbName)+"_SSLMODE", "disable"),
MaxOpenConns: getEnvAsInt("POSTGRES_MAX_OPEN_CONNS", 25),
MaxIdleConns: getEnvAsInt("POSTGRES_MAX_IDLE_CONNS", 25),
ConnMaxLifetime: parseDuration(getEnv("POSTGRES_CONN_MAX_LIFETIME", "5m")),
}
}
}
}
}
}
// addMYSQLConfigs adds MYSQL database
func (c *Config) addMySQLConfigs() {
// Primary MySQL configuration
defaultMySQLHost := getEnv("MYSQL_HOST", "")
if defaultMySQLHost != "" {
c.Databases["mysql"] = DatabaseConfig{
Name: "mysql",
Type: getEnv("MYSQL_CONNECTION", "mysql"),
Host: defaultMySQLHost,
Port: getEnvAsInt("MYSQL_PORT", 3306),
Username: getEnv("MYSQL_USERNAME", ""),
Password: getEnv("MYSQL_PASSWORD", ""),
Database: getEnv("MYSQL_DATABASE", "mysql"),
SSLMode: getEnv("MYSQL_SSLMODE", "disable"),
MaxOpenConns: getEnvAsInt("MYSQL_MAX_OPEN_CONNS", 25),
MaxIdleConns: getEnvAsInt("MYSQL_MAX_IDLE_CONNS", 25),
ConnMaxLifetime: parseDuration(getEnv("MYSQL_CONN_MAX_LIFETIME", "5m")),
}
}
// Support for custom MySQL configurations with MYSQL_ prefix
envVars := os.Environ()
for _, envVar := range envVars {
parts := strings.SplitN(envVar, "=", 2)
if len(parts) != 2 {
continue
}
key := parts[0]
// Parse MySQL configurations (format: MYSQL_[NAME]_[PROPERTY])
if strings.HasPrefix(key, "MYSQL_") && strings.Contains(key, "_") {
segments := strings.Split(key, "_")
if len(segments) >= 3 {
dbName := strings.ToLower(strings.Join(segments[1:len(segments)-1], "_"))
// Skip if it's a standard MySQL configuration
if dbName == "connection" || dbName == "dev" || dbName == "max" || dbName == "conn" {
continue
}
// Create or update MySQL configuration
if _, exists := c.Databases[dbName]; !exists {
mysqlHost := getEnv("MYSQL_"+strings.ToUpper(dbName)+"_HOST", "")
if mysqlHost != "" {
c.Databases[dbName] = DatabaseConfig{
Name: dbName,
Type: getEnv("MYSQL_"+strings.ToUpper(dbName)+"_CONNECTION", "mysql"),
Host: mysqlHost,
Port: getEnvAsInt("MYSQL_"+strings.ToUpper(dbName)+"_PORT", 3306),
Username: getEnv("MYSQL_"+strings.ToUpper(dbName)+"_USERNAME", ""),
Password: getEnv("MYSQL_"+strings.ToUpper(dbName)+"_PASSWORD", ""),
Database: getEnv("MYSQL_"+strings.ToUpper(dbName)+"_DATABASE", dbName),
SSLMode: getEnv("MYSQL_"+strings.ToUpper(dbName)+"_SSLMODE", "disable"),
MaxOpenConns: getEnvAsInt("MYSQL_MAX_OPEN_CONNS", 25),
MaxIdleConns: getEnvAsInt("MYSQL_MAX_IDLE_CONNS", 25),
ConnMaxLifetime: parseDuration(getEnv("MYSQL_CONN_MAX_LIFETIME", "5m")),
}
}
}
}
}
}
}
// addMongoDBConfigs adds MongoDB database configurations from environment variables
func (c *Config) addMongoDBConfigs() {
// Primary MongoDB configuration
mongoHost := getEnv("MONGODB_HOST", "")
if mongoHost != "" {
c.Databases["mongodb"] = DatabaseConfig{
Name: "mongodb",
Type: getEnv("MONGODB_CONNECTION", "mongodb"),
Host: mongoHost,
Port: getEnvAsInt("MONGODB_PORT", 27017),
Username: getEnv("MONGODB_USER", ""),
Password: getEnv("MONGODB_PASS", ""),
Database: getEnv("MONGODB_MASTER", "master"),
SSLMode: getEnv("MONGODB_SSLMODE", "disable"),
MaxOpenConns: getEnvAsInt("MONGODB_MAX_OPEN_CONNS", 100),
MaxIdleConns: getEnvAsInt("MONGODB_MAX_IDLE_CONNS", 10),
ConnMaxLifetime: parseDuration(getEnv("MONGODB_CONN_MAX_LIFETIME", "30m")),
}
}
// Additional MongoDB configurations for local database
mongoLocalHost := getEnv("MONGODB_LOCAL_HOST", "")
if mongoLocalHost != "" {
c.Databases["mongodb_local"] = DatabaseConfig{
Name: "mongodb_local",
Type: getEnv("MONGODB_CONNECTION", "mongodb"),
Host: mongoLocalHost,
Port: getEnvAsInt("MONGODB_LOCAL_PORT", 27017),
Username: getEnv("MONGODB_LOCAL_USER", ""),
Password: getEnv("MONGODB_LOCAL_PASS", ""),
Database: getEnv("MONGODB_LOCAL_DB", "local"),
SSLMode: getEnv("MONGOD_SSLMODE", "disable"),
MaxOpenConns: getEnvAsInt("MONGODB_MAX_OPEN_CONNS", 100),
MaxIdleConns: getEnvAsInt("MONGODB_MAX_IDLE_CONNS", 10),
ConnMaxLifetime: parseDuration(getEnv("MONGODB_CONN_MAX_LIFETIME", "30m")),
}
}
// Support for custom MongoDB configurations with MONGODB_ prefix
envVars := os.Environ()
for _, envVar := range envVars {
parts := strings.SplitN(envVar, "=", 2)
if len(parts) != 2 {
continue
}
key := parts[0]
// Parse MongoDB configurations (format: MONGODB_[NAME]_[PROPERTY])
if strings.HasPrefix(key, "MONGODB_") && strings.Contains(key, "_") {
segments := strings.Split(key, "_")
if len(segments) >= 3 {
dbName := strings.ToLower(strings.Join(segments[1:len(segments)-1], "_"))
// Skip if it's a standard MongoDB configuration
if dbName == "connection" || dbName == "dev" || dbName == "local" {
continue
}
// Create or update MongoDB configuration
if _, exists := c.Databases[dbName]; !exists {
c.Databases[dbName] = DatabaseConfig{
Name: dbName,
Type: "mongodb",
Host: getEnv("MONGODB_"+strings.ToUpper(dbName)+"_HOST", "localhost"),
Port: getEnvAsInt("MONGODB_"+strings.ToUpper(dbName)+"_PORT", 27017),
Username: getEnv("MONGODB_"+strings.ToUpper(dbName)+"_USER", ""),
Password: getEnv("MONGODB_"+strings.ToUpper(dbName)+"_PASS", ""),
Database: getEnv("MONGODB_"+strings.ToUpper(dbName)+"_DB", dbName),
SSLMode: getEnv("MONGOD_SSLMODE", "disable"),
MaxOpenConns: getEnvAsInt("MONGODB_MAX_OPEN_CONNS", 100),
MaxIdleConns: getEnvAsInt("MONGODB_MAX_IDLE_CONNS", 10),
ConnMaxLifetime: parseDuration(getEnv("MONGODB_CONN_MAX_LIFETIME", "30m")),
}
}
}
}
}
}
func getEnvFromMap(config map[string]string, key, defaultValue string) string {
if value, exists := config[key]; exists {
return value
}
return defaultValue
}
func getEnvAsIntFromMap(config map[string]string, key string, defaultValue int) int {
if value, exists := config[key]; exists {
if intValue, err := strconv.Atoi(value); err == nil {
return intValue
}
}
return defaultValue
}
func parseDuration(durationStr string) time.Duration {
if duration, err := time.ParseDuration(durationStr); err == nil {
return duration
}
return 5 * time.Minute
}
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
func getEnvAsInt(key string, defaultValue int) int {
valueStr := getEnv(key, "")
if value, err := strconv.Atoi(valueStr); err == nil {
return value
}
return defaultValue
}
func getEnvAsBool(key string, defaultValue bool) bool {
valueStr := getEnv(key, "")
if value, err := strconv.ParseBool(valueStr); err == nil {
return value
}
return defaultValue
}
// parseSchemes parses comma-separated schemes string into a slice
func parseSchemes(schemesStr string) []string {
if schemesStr == "" {
return []string{"http"}
}
schemes := strings.Split(schemesStr, ",")
for i, scheme := range schemes {
schemes[i] = strings.TrimSpace(scheme)
}
return schemes
}
func (c *Config) Validate() error {
if len(c.Databases) == 0 {
log.Fatal("At least one database configuration is required")
}
for name, db := range c.Databases {
if db.Host == "" {
log.Fatalf("Database host is required for %s", name)
}
if db.Username == "" {
log.Fatalf("Database username is required for %s", name)
}
if db.Password == "" {
log.Fatalf("Database password is required for %s", name)
}
if db.Database == "" {
log.Fatalf("Database name is required for %s", name)
}
}
if c.Bpjs.BaseURL == "" {
log.Fatal("BPJS Base URL is required")
}
if c.Bpjs.ConsID == "" {
log.Fatal("BPJS Consumer ID is required")
}
if c.Bpjs.UserKey == "" {
log.Fatal("BPJS User Key is required")
}
if c.Bpjs.SecretKey == "" {
log.Fatal("BPJS Secret Key is required")
}
// Validate Keycloak configuration if enabled
if c.Keycloak.Enabled {
if c.Keycloak.Issuer == "" {
log.Fatal("Keycloak issuer is required when Keycloak is enabled")
}
if c.Keycloak.Audience == "" {
log.Fatal("Keycloak audience is required when Keycloak is enabled")
}
if c.Keycloak.JwksURL == "" {
log.Fatal("Keycloak JWKS URL is required when Keycloak is enabled")
}
}
// Validate SatuSehat configuration
if c.SatuSehat.OrgID == "" {
log.Fatal("SatuSehat Organization ID is required")
}
if c.SatuSehat.FasyakesID == "" {
log.Fatal("SatuSehat Fasyankes ID is required")
}
if c.SatuSehat.ClientID == "" {
log.Fatal("SatuSehat Client ID is required")
}
if c.SatuSehat.ClientSecret == "" {
log.Fatal("SatuSehat Client Secret is required")
}
if c.SatuSehat.AuthURL == "" {
log.Fatal("SatuSehat Auth URL is required")
}
if c.SatuSehat.BaseURL == "" {
log.Fatal("SatuSehat Base URL is required")
}
return nil
}

View File

@@ -0,0 +1,699 @@
package database
import (
"context"
"database/sql"
"fmt"
"log" // Import runtime package
// Import debug package
"strconv"
"sync"
"time"
"api-service/internal/config"
_ "github.com/jackc/pgx/v5" // Import pgx driver
"github.com/lib/pq"
_ "gorm.io/driver/postgres" // Import GORM PostgreSQL driver
_ "github.com/go-sql-driver/mysql" // MySQL driver for database/sql
_ "gorm.io/driver/mysql" // GORM MySQL driver
_ "gorm.io/driver/sqlserver" // GORM SQL Server driver
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// DatabaseType represents supported database types
type DatabaseType string
const (
Postgres DatabaseType = "postgres"
MySQL DatabaseType = "mysql"
SQLServer DatabaseType = "sqlserver"
SQLite DatabaseType = "sqlite"
MongoDB DatabaseType = "mongodb"
)
// Service represents a service that interacts with multiple databases
type Service interface {
Health() map[string]map[string]string
GetDB(name string) (*sql.DB, error)
GetMongoClient(name string) (*mongo.Client, error)
GetReadDB(name string) (*sql.DB, error) // For read replicas
Close() error
ListDBs() []string
GetDBType(name string) (DatabaseType, error)
// Tambahkan method untuk WebSocket notifications
ListenForChanges(ctx context.Context, dbName string, channels []string, callback func(string, string)) error
NotifyChange(dbName, channel, payload string) error
GetPrimaryDB(name string) (*sql.DB, error) // Helper untuk get primary DB
}
type service struct {
sqlDatabases map[string]*sql.DB
mongoClients map[string]*mongo.Client
readReplicas map[string][]*sql.DB // Read replicas for load balancing
configs map[string]config.DatabaseConfig
readConfigs map[string][]config.DatabaseConfig
mu sync.RWMutex
readBalancer map[string]int // Round-robin counter for read replicas
listeners map[string]*pq.Listener // Tambahkan untuk tracking listeners
listenersMu sync.RWMutex
}
var (
dbManager *service
once sync.Once
)
// New creates a new database service with multiple connections
func New(cfg *config.Config) Service {
once.Do(func() {
dbManager = &service{
sqlDatabases: make(map[string]*sql.DB),
mongoClients: make(map[string]*mongo.Client),
readReplicas: make(map[string][]*sql.DB),
configs: make(map[string]config.DatabaseConfig),
readConfigs: make(map[string][]config.DatabaseConfig),
readBalancer: make(map[string]int),
listeners: make(map[string]*pq.Listener),
}
log.Println("Initializing database service...") // Log when the initialization starts
// log.Printf("Current Goroutine ID: %d", runtime.NumGoroutine()) // Log the number of goroutines
// log.Printf("Stack Trace: %s", debug.Stack()) // Log the stack trace
dbManager.loadFromConfig(cfg)
// Initialize all databases
for name, dbConfig := range dbManager.configs {
if err := dbManager.addDatabase(name, dbConfig); err != nil {
log.Printf("Failed to connect to database %s: %v", name, err)
}
}
// Initialize read replicas
for name, replicaConfigs := range dbManager.readConfigs {
for i, replicaConfig := range replicaConfigs {
if err := dbManager.addReadReplica(name, i, replicaConfig); err != nil {
log.Printf("Failed to connect to read replica %s[%d]: %v", name, i, err)
}
}
}
})
return dbManager
}
func (s *service) loadFromConfig(cfg *config.Config) {
s.mu.Lock()
defer s.mu.Unlock()
// Load primary databases
for name, dbConfig := range cfg.Databases {
s.configs[name] = dbConfig
}
// Load read replicas
for name, replicaConfigs := range cfg.ReadReplicas {
s.readConfigs[name] = replicaConfigs
}
}
func (s *service) addDatabase(name string, config config.DatabaseConfig) error {
s.mu.Lock()
defer s.mu.Unlock()
log.Printf("=== Database Connection Debug ===")
// log.Printf("Database: %s", name)
// log.Printf("Type: %s", config.Type)
// log.Printf("Host: %s", config.Host)
// log.Printf("Port: %d", config.Port)
// log.Printf("Database: %s", config.Database)
// log.Printf("Username: %s", config.Username)
// log.Printf("SSLMode: %s", config.SSLMode)
var db *sql.DB
var err error
dbType := DatabaseType(config.Type)
switch dbType {
case Postgres:
db, err = s.openPostgresConnection(config)
case MySQL:
db, err = s.openMySQLConnection(config)
case SQLServer:
db, err = s.openSQLServerConnection(config)
case SQLite:
db, err = s.openSQLiteConnection(config)
case MongoDB:
return s.addMongoDB(name, config)
default:
return fmt.Errorf("unsupported database type: %s", config.Type)
}
if err != nil {
log.Printf("❌ Error connecting to database %s: %v", name, err)
log.Printf(" Database: %s@%s:%d/%s", config.Username, config.Host, config.Port, config.Database)
return err
}
log.Printf("✅ Successfully connected to database: %s", name)
return s.configureSQLDB(name, db, config.MaxOpenConns, config.MaxIdleConns, config.ConnMaxLifetime)
}
func (s *service) addReadReplica(name string, index int, config config.DatabaseConfig) error {
s.mu.Lock()
defer s.mu.Unlock()
var db *sql.DB
var err error
dbType := DatabaseType(config.Type)
switch dbType {
case Postgres:
db, err = s.openPostgresConnection(config)
case MySQL:
db, err = s.openMySQLConnection(config)
case SQLServer:
db, err = s.openSQLServerConnection(config)
case SQLite:
db, err = s.openSQLiteConnection(config)
default:
return fmt.Errorf("unsupported database type for read replica: %s", config.Type)
}
if err != nil {
return err
}
if s.readReplicas[name] == nil {
s.readReplicas[name] = make([]*sql.DB, 0)
}
// Ensure we have enough slots
for len(s.readReplicas[name]) <= index {
s.readReplicas[name] = append(s.readReplicas[name], nil)
}
s.readReplicas[name][index] = db
log.Printf("Successfully connected to read replica %s[%d]", name, index)
return nil
}
func (s *service) openPostgresConnection(config config.DatabaseConfig) (*sql.DB, error) {
connStr := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=%s",
config.Username,
config.Password,
config.Host,
config.Port,
config.Database,
config.SSLMode,
)
if config.Schema != "" {
connStr += "&search_path=" + config.Schema
}
db, err := sql.Open("pgx", connStr)
if err != nil {
return nil, fmt.Errorf("failed to open PostgreSQL connection: %w", err)
}
return db, nil
}
func (s *service) openMySQLConnection(config config.DatabaseConfig) (*sql.DB, error) {
connStr := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
config.Username,
config.Password,
config.Host,
config.Port,
config.Database,
)
db, err := sql.Open("mysql", connStr)
if err != nil {
return nil, fmt.Errorf("failed to open MySQL connection: %w", err)
}
return db, nil
}
func (s *service) openSQLServerConnection(config config.DatabaseConfig) (*sql.DB, error) {
connStr := fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s",
config.Username,
config.Password,
config.Host,
config.Port,
config.Database,
)
db, err := sql.Open("sqlserver", connStr)
if err != nil {
return nil, fmt.Errorf("failed to open SQL Server connection: %w", err)
}
return db, nil
}
func (s *service) openSQLiteConnection(config config.DatabaseConfig) (*sql.DB, error) {
dbPath := config.Path
if dbPath == "" {
dbPath = fmt.Sprintf("./data/%s.db", config.Database)
}
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, fmt.Errorf("failed to open SQLite connection: %w", err)
}
return db, nil
}
func (s *service) addMongoDB(name string, config config.DatabaseConfig) error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
uri := fmt.Sprintf("mongodb://%s:%s@%s:%d/%s",
config.Username,
config.Password,
config.Host,
config.Port,
config.Database,
)
client, err := mongo.Connect(ctx, options.Client().ApplyURI(uri))
if err != nil {
return fmt.Errorf("failed to connect to MongoDB: %w", err)
}
s.mongoClients[name] = client
log.Printf("Successfully connected to MongoDB: %s", name)
return nil
}
func (s *service) configureSQLDB(name string, db *sql.DB, maxOpenConns, maxIdleConns int, connMaxLifetime time.Duration) error {
db.SetMaxOpenConns(maxOpenConns)
db.SetMaxIdleConns(maxIdleConns)
db.SetConnMaxLifetime(connMaxLifetime)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
db.Close()
return fmt.Errorf("failed to ping database: %w", err)
}
s.sqlDatabases[name] = db
log.Printf("Successfully connected to SQL database: %s", name)
return nil
}
// Health checks the health of all database connections by pinging each database.
func (s *service) Health() map[string]map[string]string {
s.mu.RLock()
defer s.mu.RUnlock()
result := make(map[string]map[string]string)
// Check SQL databases
for name, db := range s.sqlDatabases {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
stats := make(map[string]string)
err := db.PingContext(ctx)
if err != nil {
stats["status"] = "down"
stats["error"] = fmt.Sprintf("db down: %v", err)
stats["type"] = "sql"
stats["role"] = "primary"
result[name] = stats
continue
}
stats["status"] = "up"
stats["message"] = "It's healthy"
stats["type"] = "sql"
stats["role"] = "primary"
dbStats := db.Stats()
stats["open_connections"] = strconv.Itoa(dbStats.OpenConnections)
stats["in_use"] = strconv.Itoa(dbStats.InUse)
stats["idle"] = strconv.Itoa(dbStats.Idle)
stats["wait_count"] = strconv.FormatInt(dbStats.WaitCount, 10)
stats["wait_duration"] = dbStats.WaitDuration.String()
stats["max_idle_closed"] = strconv.FormatInt(dbStats.MaxIdleClosed, 10)
stats["max_lifetime_closed"] = strconv.FormatInt(dbStats.MaxLifetimeClosed, 10)
if dbStats.OpenConnections > 40 {
stats["message"] = "The database is experiencing heavy load."
}
if dbStats.WaitCount > 1000 {
stats["message"] = "The database has a high number of wait events, indicating potential bottlenecks."
}
if dbStats.MaxIdleClosed > int64(dbStats.OpenConnections)/2 {
stats["message"] = "Many idle connections are being closed, consider revising the connection pool settings."
}
if dbStats.MaxLifetimeClosed > int64(dbStats.OpenConnections)/2 {
stats["message"] = "Many connections are being closed due to max lifetime, consider increasing max lifetime or revising the connection usage pattern."
}
result[name] = stats
}
// Check read replicas
for name, replicas := range s.readReplicas {
for i, db := range replicas {
if db == nil {
continue
}
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
replicaName := fmt.Sprintf("%s_replica_%d", name, i)
stats := make(map[string]string)
err := db.PingContext(ctx)
if err != nil {
stats["status"] = "down"
stats["error"] = fmt.Sprintf("read replica down: %v", err)
stats["type"] = "sql"
stats["role"] = "replica"
result[replicaName] = stats
continue
}
stats["status"] = "up"
stats["message"] = "Read replica healthy"
stats["type"] = "sql"
stats["role"] = "replica"
dbStats := db.Stats()
stats["open_connections"] = strconv.Itoa(dbStats.OpenConnections)
stats["in_use"] = strconv.Itoa(dbStats.InUse)
stats["idle"] = strconv.Itoa(dbStats.Idle)
result[replicaName] = stats
}
}
// Check MongoDB connections
for name, client := range s.mongoClients {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
stats := make(map[string]string)
err := client.Ping(ctx, nil)
if err != nil {
stats["status"] = "down"
stats["error"] = fmt.Sprintf("mongodb down: %v", err)
stats["type"] = "mongodb"
result[name] = stats
continue
}
stats["status"] = "up"
stats["message"] = "It's healthy"
stats["type"] = "mongodb"
result[name] = stats
}
return result
}
// GetDB returns a specific SQL database connection by name
func (s *service) GetDB(name string) (*sql.DB, error) {
log.Printf("Attempting to get database connection for: %s", name)
s.mu.RLock()
defer s.mu.RUnlock()
db, exists := s.sqlDatabases[name]
if !exists {
log.Printf("Error: database %s not found", name) // Log the error
return nil, fmt.Errorf("database %s not found", name)
}
log.Printf("Current connection pool state for %s: Open: %d, In Use: %d, Idle: %d",
name, db.Stats().OpenConnections, db.Stats().InUse, db.Stats().Idle)
s.mu.RLock()
defer s.mu.RUnlock()
// db, exists := s.sqlDatabases[name]
// if !exists {
// log.Printf("Error: database %s not found", name) // Log the error
// return nil, fmt.Errorf("database %s not found", name)
// }
return db, nil
}
// GetReadDB returns a read replica connection using round-robin load balancing
func (s *service) GetReadDB(name string) (*sql.DB, error) {
s.mu.RLock()
defer s.mu.RUnlock()
replicas, exists := s.readReplicas[name]
if !exists || len(replicas) == 0 {
// Fallback to primary if no replicas available
return s.GetDB(name)
}
// Round-robin load balancing
s.readBalancer[name] = (s.readBalancer[name] + 1) % len(replicas)
selected := replicas[s.readBalancer[name]]
if selected == nil {
// Fallback to primary if replica is nil
return s.GetDB(name)
}
return selected, nil
}
// GetMongoClient returns a specific MongoDB client by name
func (s *service) GetMongoClient(name string) (*mongo.Client, error) {
s.mu.RLock()
defer s.mu.RUnlock()
client, exists := s.mongoClients[name]
if !exists {
return nil, fmt.Errorf("MongoDB client %s not found", name)
}
return client, nil
}
// ListDBs returns list of available database names
func (s *service) ListDBs() []string {
s.mu.RLock()
defer s.mu.RUnlock()
names := make([]string, 0, len(s.sqlDatabases)+len(s.mongoClients))
for name := range s.sqlDatabases {
names = append(names, name)
}
for name := range s.mongoClients {
names = append(names, name)
}
return names
}
// GetDBType returns the type of a specific database
func (s *service) GetDBType(name string) (DatabaseType, error) {
s.mu.RLock()
defer s.mu.RUnlock()
config, exists := s.configs[name]
if !exists {
return "", fmt.Errorf("database %s not found", name)
}
return DatabaseType(config.Type), nil
}
// Close closes all database connections
func (s *service) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
var errs []error
for name, db := range s.sqlDatabases {
if err := db.Close(); err != nil {
errs = append(errs, fmt.Errorf("failed to close database %s: %w", name, err))
} else {
log.Printf("Disconnected from SQL database: %s", name)
}
}
for name, replicas := range s.readReplicas {
for i, db := range replicas {
if db != nil {
if err := db.Close(); err != nil {
errs = append(errs, fmt.Errorf("failed to close read replica %s[%d]: %w", name, i, err))
} else {
log.Printf("Disconnected from read replica: %s[%d]", name, i)
}
}
}
}
for name, client := range s.mongoClients {
if err := client.Disconnect(context.Background()); err != nil {
errs = append(errs, fmt.Errorf("failed to disconnect MongoDB client %s: %w", name, err))
} else {
log.Printf("Disconnected from MongoDB: %s", name)
}
}
s.sqlDatabases = make(map[string]*sql.DB)
s.mongoClients = make(map[string]*mongo.Client)
s.readReplicas = make(map[string][]*sql.DB)
s.configs = make(map[string]config.DatabaseConfig)
s.readConfigs = make(map[string][]config.DatabaseConfig)
if len(errs) > 0 {
return fmt.Errorf("errors closing databases: %v", errs)
}
return nil
}
// GetPrimaryDB returns primary database connection
func (s *service) GetPrimaryDB(name string) (*sql.DB, error) {
return s.GetDB(name)
}
// ListenForChanges implements PostgreSQL LISTEN/NOTIFY for real-time updates
func (s *service) ListenForChanges(ctx context.Context, dbName string, channels []string, callback func(string, string)) error {
s.mu.RLock()
config, exists := s.configs[dbName]
s.mu.RUnlock()
if !exists {
return fmt.Errorf("database %s not found", dbName)
}
// Only support PostgreSQL for LISTEN/NOTIFY
if DatabaseType(config.Type) != Postgres {
return fmt.Errorf("LISTEN/NOTIFY only supported for PostgreSQL databases")
}
// Create connection string for listener
connStr := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=%s",
config.Username,
config.Password,
config.Host,
config.Port,
config.Database,
config.SSLMode,
)
// Create listener
listener := pq.NewListener(
connStr,
10*time.Second,
time.Minute,
func(ev pq.ListenerEventType, err error) {
if err != nil {
log.Printf("Database listener (%s) error: %v", dbName, err)
}
},
)
// Store listener for cleanup
s.listenersMu.Lock()
s.listeners[dbName] = listener
s.listenersMu.Unlock()
// Listen to specified channels
for _, channel := range channels {
err := listener.Listen(channel)
if err != nil {
listener.Close()
return fmt.Errorf("failed to listen to channel %s: %w", channel, err)
}
log.Printf("Listening to database channel: %s on %s", channel, dbName)
}
// Start listening loop
go func() {
defer func() {
listener.Close()
s.listenersMu.Lock()
delete(s.listeners, dbName)
s.listenersMu.Unlock()
log.Printf("Database listener for %s stopped", dbName)
}()
for {
select {
case n := <-listener.Notify:
if n != nil {
callback(n.Channel, n.Extra)
}
case <-ctx.Done():
return
case <-time.After(90 * time.Second):
// Send ping to keep connection alive
go func() {
if err := listener.Ping(); err != nil {
log.Printf("Listener ping failed for %s: %v", dbName, err)
}
}()
}
}
}()
return nil
}
// NotifyChange sends a notification to a PostgreSQL channel
func (s *service) NotifyChange(dbName, channel, payload string) error {
db, err := s.GetDB(dbName)
if err != nil {
return fmt.Errorf("failed to get database %s: %w", dbName, err)
}
// Check if it's PostgreSQL
s.mu.RLock()
config, exists := s.configs[dbName]
s.mu.RUnlock()
if !exists {
return fmt.Errorf("database %s configuration not found", dbName)
}
if DatabaseType(config.Type) != Postgres {
return fmt.Errorf("NOTIFY only supported for PostgreSQL databases")
}
// Execute NOTIFY
query := "SELECT pg_notify($1, $2)"
_, err = db.Exec(query, channel, payload)
if err != nil {
return fmt.Errorf("failed to send notification: %w", err)
}
log.Printf("Sent notification to channel %s on %s: %s", channel, dbName, payload)
return nil
}

View File

@@ -0,0 +1,132 @@
package handlers
import (
models "api-service/internal/models/auth"
services "api-service/internal/services/auth"
"net/http"
"github.com/gin-gonic/gin"
)
// AuthHandler handles authentication endpoints
type AuthHandler struct {
authService *services.AuthService
}
// NewAuthHandler creates a new authentication handler
func NewAuthHandler(authService *services.AuthService) *AuthHandler {
return &AuthHandler{
authService: authService,
}
}
// Login godoc
// @Summary Login user and get JWT token
// @Description Authenticate user with username and password to receive JWT token
// @Tags Authentication
// @Accept json
// @Produce json
// @Param login body models.LoginRequest true "Login credentials"
// @Success 200 {object} models.TokenResponse
// @Failure 400 {object} map[string]string "Bad request"
// @Failure 401 {object} map[string]string "Unauthorized"
// @Router /api/v1/auth/login [post]
func (h *AuthHandler) Login(c *gin.Context) {
var loginReq models.LoginRequest
// Bind JSON request
if err := c.ShouldBindJSON(&loginReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Authenticate user
tokenResponse, err := h.authService.Login(loginReq.Username, loginReq.Password)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, tokenResponse)
}
// RefreshToken godoc
// @Summary Refresh JWT token
// @Description Refresh the JWT token using a valid refresh token
// @Tags Authentication
// @Accept json
// @Produce json
// @Param refresh body map[string]string true "Refresh token"
// @Success 200 {object} models.TokenResponse
// @Failure 400 {object} map[string]string "Bad request"
// @Failure 401 {object} map[string]string "Unauthorized"
// @Router /api/v1/auth/refresh [post]
func (h *AuthHandler) RefreshToken(c *gin.Context) {
// For now, this is a placeholder for refresh token functionality
// In a real implementation, you would handle refresh tokens here
c.JSON(http.StatusNotImplemented, gin.H{"error": "refresh token not implemented"})
}
// Register godoc
// @Summary Register new user
// @Description Register a new user account
// @Tags Authentication
// @Accept json
// @Produce json
// @Param register body map[string]string true "Registration data"
// @Success 201 {object} map[string]string
// @Failure 400 {object} map[string]string "Bad request"
// @Router /api/v1/auth/register [post]
func (h *AuthHandler) Register(c *gin.Context) {
var registerReq struct {
Username string `json:"username" binding:"required"`
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
Role string `json:"role" binding:"required"`
}
if err := c.ShouldBindJSON(&registerReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
err := h.authService.RegisterUser(
registerReq.Username,
registerReq.Email,
registerReq.Password,
registerReq.Role,
)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, gin.H{"message": "user registered successfully"})
}
// Me godoc
// @Summary Get current user info
// @Description Get information about the currently authenticated user
// @Tags Authentication
// @Produce json
// @Security Bearer
// @Success 200 {object} models.User
// @Failure 401 {object} map[string]string "Unauthorized"
// @Router /api/v1/auth/me [get]
func (h *AuthHandler) Me(c *gin.Context) {
// Get user info from context (set by middleware)
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "user not authenticated"})
return
}
// In a real implementation, you would fetch user details from database
c.JSON(http.StatusOK, gin.H{
"id": userID,
"username": c.GetString("username"),
"email": c.GetString("email"),
"role": c.GetString("role"),
})
}

View File

@@ -0,0 +1,95 @@
package handlers
import (
models "api-service/internal/models/auth"
services "api-service/internal/services/auth"
"net/http"
"github.com/gin-gonic/gin"
)
// TokenHandler handles token generation endpoints
type TokenHandler struct {
authService *services.AuthService
}
// NewTokenHandler creates a new token handler
func NewTokenHandler(authService *services.AuthService) *TokenHandler {
return &TokenHandler{
authService: authService,
}
}
// GenerateToken godoc
// @Summary Generate JWT token
// @Description Generate a JWT token for a user
// @Tags Token
// @Accept json
// @Produce json
// @Param token body models.LoginRequest true "User credentials"
// @Success 200 {object} models.TokenResponse
// @Failure 400 {object} map[string]string "Bad request"
// @Failure 401 {object} map[string]string "Unauthorized"
// @Router /api/v1/token/generate [post]
func (h *TokenHandler) GenerateToken(c *gin.Context) {
var loginReq models.LoginRequest
// Bind JSON request
if err := c.ShouldBindJSON(&loginReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Generate token
tokenResponse, err := h.authService.Login(loginReq.Username, loginReq.Password)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, tokenResponse)
}
// GenerateTokenDirect godoc
// @Summary Generate token directly
// @Description Generate a JWT token directly without password verification (for testing)
// @Tags Token
// @Accept json
// @Produce json
// @Param user body map[string]string true "User info"
// @Success 200 {object} models.TokenResponse
// @Failure 400 {object} map[string]string "Bad request"
// @Router /api/v1/token/generate-direct [post]
func (h *TokenHandler) GenerateTokenDirect(c *gin.Context) {
var req struct {
Username string `json:"username" binding:"required"`
Email string `json:"email" binding:"required"`
Role string `json:"role" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Create a temporary user for token generation
user := &models.User{
ID: "temp-" + req.Username,
Username: req.Username,
Email: req.Email,
Role: req.Role,
}
// Generate token directly
token, err := h.authService.GenerateTokenForUser(user)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, models.TokenResponse{
AccessToken: token,
TokenType: "Bearer",
ExpiresIn: 3600,
})
}

View File

@@ -0,0 +1,24 @@
package healthcheck
import (
"api-service/internal/database"
"net/http"
"github.com/gin-gonic/gin"
)
// HealthCheckHandler handles health check requests
type HealthCheckHandler struct {
dbService database.Service
}
// NewHealthCheckHandler creates a new HealthCheckHandler
func NewHealthCheckHandler(dbService database.Service) *HealthCheckHandler {
return &HealthCheckHandler{dbService: dbService}
}
// CheckHealth checks the health of the application
func (h *HealthCheckHandler) CheckHealth(c *gin.Context) {
healthStatus := h.dbService.Health() // Call the health check function from the database service
c.JSON(http.StatusOK, healthStatus)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,111 @@
package websocket
import (
"sync"
"time"
)
// WebSocketBroadcaster defines the interface for broadcasting messages
type WebSocketBroadcaster interface {
BroadcastMessage(messageType string, data interface{})
}
// Broadcaster handles server-initiated broadcasts to WebSocket clients
type Broadcaster struct {
handler WebSocketBroadcaster
tickers []*time.Ticker
quit chan struct{}
mu sync.Mutex
}
// NewBroadcaster creates a new Broadcaster instance
func NewBroadcaster(handler WebSocketBroadcaster) *Broadcaster {
return &Broadcaster{
handler: handler,
tickers: make([]*time.Ticker, 0),
quit: make(chan struct{}),
}
}
// StartHeartbeat starts sending periodic heartbeat messages to all clients
func (b *Broadcaster) StartHeartbeat(interval time.Duration) {
ticker := time.NewTicker(interval)
b.tickers = append(b.tickers, ticker)
go func() {
defer func() {
// Remove ticker from slice when done
for i, t := range b.tickers {
if t == ticker {
b.tickers = append(b.tickers[:i], b.tickers[i+1:]...)
break
}
}
}()
for {
select {
case <-ticker.C:
b.handler.BroadcastMessage("heartbeat", map[string]interface{}{
"message": "Server heartbeat",
"timestamp": time.Now().Format(time.RFC3339),
})
case <-b.quit:
ticker.Stop()
return
}
}
}()
}
// Stop stops the broadcaster
func (b *Broadcaster) Stop() {
close(b.quit)
for _, ticker := range b.tickers {
if ticker != nil {
ticker.Stop()
}
}
b.tickers = nil
}
// BroadcastNotification sends a notification message to all clients
func (b *Broadcaster) BroadcastNotification(title, message, level string) {
b.handler.BroadcastMessage("notification", map[string]interface{}{
"title": title,
"message": message,
"level": level,
"time": time.Now().Format(time.RFC3339),
})
}
// SimulateDataStream simulates streaming data to clients (useful for demos)
func (b *Broadcaster) SimulateDataStream() {
ticker := time.NewTicker(100 * time.Millisecond)
b.tickers = append(b.tickers, ticker)
go func() {
defer func() {
// Remove ticker from slice when done
for i, t := range b.tickers {
if t == ticker {
b.tickers = append(b.tickers[:i], b.tickers[i+1:]...)
break
}
}
}()
counter := 0
for {
select {
case <-ticker.C:
counter++
b.handler.BroadcastMessage("data_stream", map[string]interface{}{
"id": counter,
"value": counter * 10,
"timestamp": time.Now().Format(time.RFC3339),
"type": "simulated_data",
})
case <-b.quit:
ticker.Stop()
return
}
}
}()
}

View File

@@ -0,0 +1,251 @@
package websocket
import (
"sync"
"testing"
"time"
)
// MockWebSocketHandler is a mock implementation for testing
type MockWebSocketHandler struct {
mu sync.Mutex
messages []map[string]interface{}
broadcasts []string
}
func (m *MockWebSocketHandler) BroadcastMessage(messageType string, data interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.broadcasts = append(m.broadcasts, messageType)
m.messages = append(m.messages, map[string]interface{}{
"type": messageType,
"data": data,
})
}
func (m *MockWebSocketHandler) GetMessages() []map[string]interface{} {
m.mu.Lock()
defer m.mu.Unlock()
result := make([]map[string]interface{}, len(m.messages))
copy(result, m.messages)
return result
}
func (m *MockWebSocketHandler) GetBroadcasts() []string {
m.mu.Lock()
defer m.mu.Unlock()
result := make([]string, len(m.broadcasts))
copy(result, m.broadcasts)
return result
}
func (m *MockWebSocketHandler) Clear() {
m.mu.Lock()
defer m.mu.Unlock()
m.messages = make([]map[string]interface{}, 0)
m.broadcasts = make([]string, 0)
}
func NewMockWebSocketHandler() *MockWebSocketHandler {
return &MockWebSocketHandler{
messages: make([]map[string]interface{}, 0),
broadcasts: make([]string, 0),
}
}
func TestBroadcaster_StartHeartbeat(t *testing.T) {
mockHandler := NewMockWebSocketHandler()
broadcaster := NewBroadcaster(mockHandler)
// Start heartbeat with short interval for testing
broadcaster.StartHeartbeat(100 * time.Millisecond)
// Wait for a few heartbeats
time.Sleep(350 * time.Millisecond)
// Stop the broadcaster
broadcaster.Stop()
// Check if heartbeats were sent
messages := mockHandler.GetMessages()
if len(messages) == 0 {
t.Error("Expected heartbeat messages, but got none")
}
// Check that all messages are heartbeat type
broadcasts := mockHandler.GetBroadcasts()
for _, msgType := range broadcasts {
if msgType != "heartbeat" {
t.Errorf("Expected heartbeat message type, got %s", msgType)
}
}
t.Logf("Received %d heartbeat messages", len(messages))
}
func TestBroadcaster_BroadcastNotification(t *testing.T) {
mockHandler := NewMockWebSocketHandler()
broadcaster := NewBroadcaster(mockHandler)
// Send a notification
broadcaster.BroadcastNotification("Test Title", "Test Message", "info")
// Check if notification was sent
messages := mockHandler.GetMessages()
if len(messages) != 1 {
t.Errorf("Expected 1 message, got %d", len(messages))
return
}
msg := messages[0]
if msg["type"] != "notification" {
t.Errorf("Expected message type 'notification', got %s", msg["type"])
}
data := msg["data"].(map[string]interface{})
if data["title"] != "Test Title" {
t.Errorf("Expected title 'Test Title', got %s", data["title"])
}
if data["message"] != "Test Message" {
t.Errorf("Expected message 'Test Message', got %s", data["message"])
}
if data["level"] != "info" {
t.Errorf("Expected level 'info', got %s", data["level"])
}
t.Logf("Notification sent successfully: %+v", data)
}
func TestBroadcaster_SimulateDataStream(t *testing.T) {
mockHandler := NewMockWebSocketHandler()
broadcaster := NewBroadcaster(mockHandler)
// Start data stream with short interval for testing
broadcaster.SimulateDataStream()
// Wait for a few data points
time.Sleep(550 * time.Millisecond)
// Stop the broadcaster
broadcaster.Stop()
// Check if data stream messages were sent
messages := mockHandler.GetMessages()
if len(messages) == 0 {
t.Error("Expected data stream messages, but got none")
}
// Check that all messages are data_stream type
broadcasts := mockHandler.GetBroadcasts()
for _, msgType := range broadcasts {
if msgType != "data_stream" {
t.Errorf("Expected data_stream message type, got %s", msgType)
}
}
// Check data structure
for i, msg := range messages {
data := msg["data"].(map[string]interface{})
if data["type"] != "simulated_data" {
t.Errorf("Expected data type 'simulated_data', got %s", data["type"])
}
if id, ok := data["id"].(int); ok {
if id != i+1 {
t.Errorf("Expected id %d, got %d", i+1, id)
}
}
if value, ok := data["value"].(int); ok {
expectedValue := (i + 1) * 10
if value != expectedValue {
t.Errorf("Expected value %d, got %d", expectedValue, value)
}
}
}
t.Logf("Received %d data stream messages", len(messages))
}
func TestBroadcaster_Stop(t *testing.T) {
mockHandler := NewMockWebSocketHandler()
broadcaster := NewBroadcaster(mockHandler)
// Start heartbeat
broadcaster.StartHeartbeat(50 * time.Millisecond)
// Wait a bit
time.Sleep(100 * time.Millisecond)
// Stop the broadcaster
broadcaster.Stop()
// Clear previous messages
mockHandler.Clear()
// Wait a bit more to ensure no new messages are sent
time.Sleep(200 * time.Millisecond)
// Check that no new messages were sent after stopping
messages := mockHandler.GetMessages()
if len(messages) > 0 {
t.Errorf("Expected no messages after stopping, but got %d", len(messages))
}
// Clear quit channel to allow reuse in tests
broadcaster.quit = make(chan struct{})
t.Log("Broadcaster stopped successfully")
}
func TestBroadcaster_MultipleOperations(t *testing.T) {
mockHandler := NewMockWebSocketHandler()
broadcaster := NewBroadcaster(mockHandler)
// Start heartbeat
broadcaster.StartHeartbeat(100 * time.Millisecond)
// Send notification
broadcaster.BroadcastNotification("Test", "Message", "warning")
// Start data stream
broadcaster.SimulateDataStream()
// Wait for some activity
time.Sleep(350 * time.Millisecond)
// Stop everything
broadcaster.Stop()
// Check results
messages := mockHandler.GetMessages()
if len(messages) == 0 {
t.Error("Expected messages from multiple operations, but got none")
}
broadcasts := mockHandler.GetBroadcasts()
hasHeartbeat := false
hasNotification := false
hasDataStream := false
for _, msgType := range broadcasts {
switch msgType {
case "heartbeat":
hasHeartbeat = true
case "notification":
hasNotification = true
case "data_stream":
hasDataStream = true
}
}
if !hasHeartbeat {
t.Error("Expected heartbeat messages")
}
if !hasNotification {
t.Error("Expected notification message")
}
if !hasDataStream {
t.Error("Expected data stream messages")
}
t.Logf("Multiple operations test passed: %d total messages", len(messages))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,59 @@
package middleware
import (
"fmt"
"net/http"
"api-service/internal/config"
"github.com/gin-gonic/gin"
)
// ConfigurableAuthMiddleware provides flexible authentication based on configuration
func ConfigurableAuthMiddleware(cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
// Skip authentication for development/testing if explicitly disabled
if !cfg.Keycloak.Enabled {
fmt.Println("Authentication is disabled - allowing all requests")
c.Next()
return
}
// Use Keycloak authentication when enabled
AuthMiddleware()(c)
}
}
// StrictAuthMiddleware enforces authentication regardless of Keycloak.Enabled setting
func StrictAuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if appConfig == nil {
fmt.Println("AuthMiddleware: Config not initialized")
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "authentication service not configured"})
return
}
// Always enforce authentication
AuthMiddleware()(c)
}
}
// OptionalKeycloakAuthMiddleware allows requests but adds authentication info if available
func OptionalKeycloakAuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if appConfig == nil || !appConfig.Keycloak.Enabled {
c.Next()
return
}
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
// No token provided, but continue
c.Next()
return
}
// Try to validate token, but don't fail if invalid
AuthMiddleware()(c)
}
}

View File

@@ -0,0 +1,54 @@
package middleware
import (
models "api-service/internal/models"
"net/http"
"github.com/gin-gonic/gin"
)
// ErrorHandler handles errors globally
func ErrorHandler() gin.HandlerFunc {
return func(c *gin.Context) {
c.Next()
if len(c.Errors) > 0 {
err := c.Errors.Last()
status := http.StatusInternalServerError
// Determine status code based on error type
switch err.Type {
case gin.ErrorTypeBind:
status = http.StatusBadRequest
case gin.ErrorTypeRender:
status = http.StatusUnprocessableEntity
case gin.ErrorTypePrivate:
status = http.StatusInternalServerError
}
response := models.ErrorResponse{
Error: "internal_error",
Message: err.Error(),
Code: status,
}
c.JSON(status, response)
}
}
}
// CORS middleware configuration
func CORSConfig() gin.HandlerFunc {
return gin.HandlerFunc(func(c *gin.Context) {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS, PATCH")
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
})
}

View File

@@ -0,0 +1,77 @@
package middleware
import (
services "api-service/internal/services/auth"
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
// JWTAuthMiddleware validates JWT tokens generated by our auth service
func JWTAuthMiddleware(authService *services.AuthService) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header missing"})
return
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"})
return
}
tokenString := parts[1]
// Validate token
claims, err := authService.ValidateToken(tokenString)
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
// Set user info in context
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("email", claims.Email)
c.Set("role", claims.Role)
c.Next()
}
}
// OptionalAuthMiddleware allows both authenticated and unauthenticated requests
func OptionalAuthMiddleware(authService *services.AuthService) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
// No token provided, but continue
c.Next()
return
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
c.Next()
return
}
tokenString := parts[1]
claims, err := authService.ValidateToken(tokenString)
if err != nil {
// Invalid token, but continue (don't abort)
c.Next()
return
}
// Set user info in context
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("email", claims.Email)
c.Set("role", claims.Role)
c.Next()
}
}

View File

@@ -0,0 +1,254 @@
package middleware
/** Keycloak Auth Middleware **/
import (
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/big"
"net/http"
"strings"
"sync"
"time"
"api-service/internal/config"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/sync/singleflight"
)
var (
ErrInvalidToken = errors.New("invalid token")
)
// JwksCache caches JWKS keys with expiration
type JwksCache struct {
mu sync.RWMutex
keys map[string]*rsa.PublicKey
expiresAt time.Time
sfGroup singleflight.Group
config *config.Config
}
func NewJwksCache(cfg *config.Config) *JwksCache {
return &JwksCache{
keys: make(map[string]*rsa.PublicKey),
config: cfg,
}
}
func (c *JwksCache) GetKey(kid string) (*rsa.PublicKey, error) {
c.mu.RLock()
if key, ok := c.keys[kid]; ok && time.Now().Before(c.expiresAt) {
c.mu.RUnlock()
return key, nil
}
c.mu.RUnlock()
// Fetch keys with singleflight to avoid concurrent fetches
v, err, _ := c.sfGroup.Do("fetch_jwks", func() (interface{}, error) {
return c.fetchKeys()
})
if err != nil {
return nil, err
}
keys := v.(map[string]*rsa.PublicKey)
c.mu.Lock()
c.keys = keys
c.expiresAt = time.Now().Add(1 * time.Hour) // cache for 1 hour
c.mu.Unlock()
key, ok := keys[kid]
if !ok {
return nil, fmt.Errorf("key with kid %s not found", kid)
}
return key, nil
}
func (c *JwksCache) fetchKeys() (map[string]*rsa.PublicKey, error) {
if !c.config.Keycloak.Enabled {
return nil, fmt.Errorf("keycloak authentication is disabled")
}
jwksURL := c.config.Keycloak.JwksURL
if jwksURL == "" {
// Construct JWKS URL from issuer if not explicitly provided
jwksURL = c.config.Keycloak.Issuer + "/protocol/openid-connect/certs"
}
resp, err := http.Get(jwksURL)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var jwksData struct {
Keys []struct {
Kid string `json:"kid"`
Kty string `json:"kty"`
N string `json:"n"`
E string `json:"e"`
} `json:"keys"`
}
if err := json.NewDecoder(resp.Body).Decode(&jwksData); err != nil {
return nil, err
}
keys := make(map[string]*rsa.PublicKey)
for _, key := range jwksData.Keys {
if key.Kty != "RSA" {
continue
}
pubKey, err := parseRSAPublicKey(key.N, key.E)
if err != nil {
continue
}
keys[key.Kid] = pubKey
}
return keys, nil
}
// parseRSAPublicKey parses RSA public key components from base64url strings
func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
nBytes, err := base64UrlDecode(nStr)
if err != nil {
return nil, err
}
eBytes, err := base64UrlDecode(eStr)
if err != nil {
return nil, err
}
var eInt int
for _, b := range eBytes {
eInt = eInt<<8 + int(b)
}
pubKey := &rsa.PublicKey{
N: new(big.Int).SetBytes(nBytes),
E: eInt,
}
return pubKey, nil
}
func base64UrlDecode(s string) ([]byte, error) {
// Add padding if missing
if m := len(s) % 4; m != 0 {
s += strings.Repeat("=", 4-m)
}
return base64.URLEncoding.DecodeString(s)
}
// Global config instance
var appConfig *config.Config
var jwksCacheInstance *JwksCache
// InitializeAuth initializes the auth middleware with config
func InitializeAuth(cfg *config.Config) {
appConfig = cfg
jwksCacheInstance = NewJwksCache(cfg)
}
// AuthMiddleware validates Bearer token as Keycloak JWT token
func AuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if appConfig == nil {
fmt.Println("AuthMiddleware: Config not initialized")
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "authentication service not configured"})
return
}
if !appConfig.Keycloak.Enabled {
// Skip authentication if Keycloak is disabled but log for debugging
fmt.Println("AuthMiddleware: Keycloak authentication is disabled - allowing all requests")
c.Next()
return
}
fmt.Println("AuthMiddleware: Checking Authorization header") // Debug log
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
fmt.Println("AuthMiddleware: Authorization header missing") // Debug log
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header missing"})
return
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
fmt.Println("AuthMiddleware: Invalid Authorization header format") // Debug log
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"})
return
}
tokenString := parts[1]
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Verify signing method
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
fmt.Printf("AuthMiddleware: Unexpected signing method: %v\n", token.Header["alg"]) // Debug log
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
kid, ok := token.Header["kid"].(string)
if !ok {
fmt.Println("AuthMiddleware: kid header not found") // Debug log
return nil, errors.New("kid header not found")
}
return jwksCacheInstance.GetKey(kid)
}, jwt.WithIssuer(appConfig.Keycloak.Issuer), jwt.WithAudience(appConfig.Keycloak.Audience))
if err != nil || !token.Valid {
fmt.Printf("AuthMiddleware: Invalid or expired token: %v\n", err) // Debug log
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired token"})
return
}
fmt.Println("AuthMiddleware: Token valid, proceeding") // Debug log
// Token is valid, proceed
c.Next()
}
}
/** JWT Bearer authentication middleware */
// import (
// "net/http"
// "strings"
// "github.com/gin-gonic/gin"
// )
// AuthMiddleware validates Bearer token in Authorization header
func AuthJWTMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header missing"})
return
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"})
return
}
token := parts[1]
// For now, use a static token for validation. Replace with your logic.
const validToken = "your-static-token"
if token != validToken {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
return
}
c.Next()
}
}

View File

@@ -0,0 +1,31 @@
package models
// LoginRequest represents the login request payload
type LoginRequest struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}
// TokenResponse represents the token response
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
}
// JWTClaims represents the JWT claims
type JWTClaims struct {
UserID string `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
Role string `json:"role"`
}
// User represents a user for authentication
type User struct {
ID string `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"-"`
Role string `json:"role"`
}

221
internal/models/models.go Normal file
View File

@@ -0,0 +1,221 @@
package models
import (
"database/sql"
"database/sql/driver"
"net/http"
"strconv"
"time"
)
// NullableInt32 - your existing implementation
type NullableInt32 struct {
Int32 int32 `json:"int32,omitempty"`
Valid bool `json:"valid"`
}
// Scan implements the sql.Scanner interface for NullableInt32
func (n *NullableInt32) Scan(value interface{}) error {
var ni sql.NullInt32
if err := ni.Scan(value); err != nil {
return err
}
n.Int32 = ni.Int32
n.Valid = ni.Valid
return nil
}
// Value implements the driver.Valuer interface for NullableInt32
func (n NullableInt32) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Int32, nil
}
// NullableString provides consistent nullable string handling
type NullableString struct {
String string `json:"string,omitempty"`
Valid bool `json:"valid"`
}
// Scan implements the sql.Scanner interface for NullableString
func (n *NullableString) Scan(value interface{}) error {
var ns sql.NullString
if err := ns.Scan(value); err != nil {
return err
}
n.String = ns.String
n.Valid = ns.Valid
return nil
}
// Value implements the driver.Valuer interface for NullableString
func (n NullableString) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.String, nil
}
// NullableTime provides consistent nullable time handling
type NullableTime struct {
Time time.Time `json:"time,omitempty"`
Valid bool `json:"valid"`
}
// Scan implements the sql.Scanner interface for NullableTime
func (n *NullableTime) Scan(value interface{}) error {
var nt sql.NullTime
if err := nt.Scan(value); err != nil {
return err
}
n.Time = nt.Time
n.Valid = nt.Valid
return nil
}
// Value implements the driver.Valuer interface for NullableTime
func (n NullableTime) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Time, nil
}
// Metadata untuk pagination - dioptimalkan
type MetaResponse struct {
Limit int `json:"limit"`
Offset int `json:"offset"`
Total int `json:"total"`
TotalPages int `json:"total_pages"`
CurrentPage int `json:"current_page"`
HasNext bool `json:"has_next"`
HasPrev bool `json:"has_prev"`
}
// Aggregate data untuk summary
type AggregateData struct {
TotalActive int `json:"total_active"`
TotalDraft int `json:"total_draft"`
TotalInactive int `json:"total_inactive"`
ByStatus map[string]int `json:"by_status"`
ByDinas map[string]int `json:"by_dinas,omitempty"`
ByJenis map[string]int `json:"by_jenis,omitempty"`
LastUpdated *time.Time `json:"last_updated,omitempty"`
CreatedToday int `json:"created_today"`
UpdatedToday int `json:"updated_today"`
}
// Error response yang konsisten
type ErrorResponse struct {
Error string `json:"error"`
Code int `json:"code"`
Message string `json:"message"`
Timestamp time.Time `json:"timestamp"`
}
// BaseRequest contains common fields for all BPJS requests
type BaseRequest struct {
RequestID string `json:"request_id,omitempty"`
Timestamp time.Time `json:"timestamp,omitempty"`
}
// BaseResponse contains common response fields
type BaseResponse struct {
Status string `json:"status"`
Message string `json:"message,omitempty"`
RequestID string `json:"request_id,omitempty"`
Timestamp string `json:"timestamp,omitempty"`
}
// ErrorResponse represents error response structure
type ErrorResponseBpjs struct {
Status string `json:"status"`
Message string `json:"message"`
RequestID string `json:"request_id,omitempty"`
Errors map[string]interface{} `json:"errors,omitempty"`
Code string `json:"code,omitempty"`
}
// PaginationRequest contains pagination parameters
type PaginationRequest struct {
Page int `json:"page" validate:"min=1"`
Limit int `json:"limit" validate:"min=1,max=100"`
SortBy string `json:"sort_by,omitempty"`
SortDir string `json:"sort_dir,omitempty" validate:"omitempty,oneof=asc desc"`
}
// PaginationResponse contains pagination metadata
type PaginationResponse struct {
CurrentPage int `json:"current_page"`
TotalPages int `json:"total_pages"`
TotalItems int64 `json:"total_items"`
ItemsPerPage int `json:"items_per_page"`
HasNext bool `json:"has_next"`
HasPrev bool `json:"has_previous"`
}
// MetaInfo contains additional metadata
type MetaInfo struct {
Version string `json:"version"`
Environment string `json:"environment"`
ServerTime string `json:"server_time"`
}
func GetStatusCodeFromMeta(metaCode interface{}) int {
statusCode := http.StatusOK
if metaCode != nil {
switch v := metaCode.(type) {
case string:
if code, err := strconv.Atoi(v); err == nil {
if code >= 100 && code <= 599 {
statusCode = code
} else {
statusCode = http.StatusInternalServerError
}
} else {
statusCode = http.StatusInternalServerError
}
case int:
if v >= 100 && v <= 599 {
statusCode = v
} else {
statusCode = http.StatusInternalServerError
}
case float64:
code := int(v)
if code >= 100 && code <= 599 {
statusCode = code
} else {
statusCode = http.StatusInternalServerError
}
default:
statusCode = http.StatusInternalServerError
}
}
return statusCode
}
// Validation constants
const (
StatusDraft = "draft"
StatusActive = "active"
StatusInactive = "inactive"
StatusDeleted = "deleted"
)
// ValidStatuses untuk validasi
var ValidStatuses = []string{StatusDraft, StatusActive, StatusInactive}
// IsValidStatus helper function
func IsValidStatus(status string) bool {
for _, validStatus := range ValidStatuses {
if status == validStatus {
return true
}
}
return false
}

View File

@@ -0,0 +1,228 @@
package retribusi
import (
"api-service/internal/models"
"encoding/json"
"time"
)
// Retribusi represents the data structure for the retribusi table
// with proper null handling and optimized JSON marshaling
type Retribusi struct {
ID string `json:"id" db:"id"`
Status string `json:"status" db:"status"`
Sort models.NullableInt32 `json:"sort,omitempty" db:"sort"`
UserCreated models.NullableString `json:"user_created,omitempty" db:"user_created"`
DateCreated models.NullableTime `json:"date_created,omitempty" db:"date_created"`
UserUpdated models.NullableString `json:"user_updated,omitempty" db:"user_updated"`
DateUpdated models.NullableTime `json:"date_updated,omitempty" db:"date_updated"`
Jenis models.NullableString `json:"jenis,omitempty" db:"Jenis"`
Pelayanan models.NullableString `json:"pelayanan,omitempty" db:"Pelayanan"`
Dinas models.NullableString `json:"dinas,omitempty" db:"Dinas"`
KelompokObyek models.NullableString `json:"kelompok_obyek,omitempty" db:"Kelompok_obyek"`
KodeTarif models.NullableString `json:"kode_tarif,omitempty" db:"Kode_tarif"`
Tarif models.NullableString `json:"tarif,omitempty" db:"Tarif"`
Satuan models.NullableString `json:"satuan,omitempty" db:"Satuan"`
TarifOvertime models.NullableString `json:"tarif_overtime,omitempty" db:"Tarif_overtime"`
SatuanOvertime models.NullableString `json:"satuan_overtime,omitempty" db:"Satuan_overtime"`
RekeningPokok models.NullableString `json:"rekening_pokok,omitempty" db:"Rekening_pokok"`
RekeningDenda models.NullableString `json:"rekening_denda,omitempty" db:"Rekening_denda"`
Uraian1 models.NullableString `json:"uraian_1,omitempty" db:"Uraian_1"`
Uraian2 models.NullableString `json:"uraian_2,omitempty" db:"Uraian_2"`
Uraian3 models.NullableString `json:"uraian_3,omitempty" db:"Uraian_3"`
}
// Custom JSON marshaling untuk Retribusi agar NULL values tidak muncul di response
func (r Retribusi) MarshalJSON() ([]byte, error) {
type Alias Retribusi
aux := &struct {
Sort *int `json:"sort,omitempty"`
UserCreated *string `json:"user_created,omitempty"`
DateCreated *time.Time `json:"date_created,omitempty"`
UserUpdated *string `json:"user_updated,omitempty"`
DateUpdated *time.Time `json:"date_updated,omitempty"`
Jenis *string `json:"jenis,omitempty"`
Pelayanan *string `json:"pelayanan,omitempty"`
Dinas *string `json:"dinas,omitempty"`
KelompokObyek *string `json:"kelompok_obyek,omitempty"`
KodeTarif *string `json:"kode_tarif,omitempty"`
Tarif *string `json:"tarif,omitempty"`
Satuan *string `json:"satuan,omitempty"`
TarifOvertime *string `json:"tarif_overtime,omitempty"`
SatuanOvertime *string `json:"satuan_overtime,omitempty"`
RekeningPokok *string `json:"rekening_pokok,omitempty"`
RekeningDenda *string `json:"rekening_denda,omitempty"`
Uraian1 *string `json:"uraian_1,omitempty"`
Uraian2 *string `json:"uraian_2,omitempty"`
Uraian3 *string `json:"uraian_3,omitempty"`
*Alias
}{
Alias: (*Alias)(&r),
}
// Convert NullableInt32 to pointer
if r.Sort.Valid {
sort := int(r.Sort.Int32)
aux.Sort = &sort
}
if r.UserCreated.Valid {
aux.UserCreated = &r.UserCreated.String
}
if r.DateCreated.Valid {
aux.DateCreated = &r.DateCreated.Time
}
if r.UserUpdated.Valid {
aux.UserUpdated = &r.UserUpdated.String
}
if r.DateUpdated.Valid {
aux.DateUpdated = &r.DateUpdated.Time
}
if r.Jenis.Valid {
aux.Jenis = &r.Jenis.String
}
if r.Pelayanan.Valid {
aux.Pelayanan = &r.Pelayanan.String
}
if r.Dinas.Valid {
aux.Dinas = &r.Dinas.String
}
if r.KelompokObyek.Valid {
aux.KelompokObyek = &r.KelompokObyek.String
}
if r.KodeTarif.Valid {
aux.KodeTarif = &r.KodeTarif.String
}
if r.Tarif.Valid {
aux.Tarif = &r.Tarif.String
}
if r.Satuan.Valid {
aux.Satuan = &r.Satuan.String
}
if r.TarifOvertime.Valid {
aux.TarifOvertime = &r.TarifOvertime.String
}
if r.SatuanOvertime.Valid {
aux.SatuanOvertime = &r.SatuanOvertime.String
}
if r.RekeningPokok.Valid {
aux.RekeningPokok = &r.RekeningPokok.String
}
if r.RekeningDenda.Valid {
aux.RekeningDenda = &r.RekeningDenda.String
}
if r.Uraian1.Valid {
aux.Uraian1 = &r.Uraian1.String
}
if r.Uraian2.Valid {
aux.Uraian2 = &r.Uraian2.String
}
if r.Uraian3.Valid {
aux.Uraian3 = &r.Uraian3.String
}
return json.Marshal(aux)
}
// Helper methods untuk mendapatkan nilai yang aman
func (r *Retribusi) GetJenis() string {
if r.Jenis.Valid {
return r.Jenis.String
}
return ""
}
func (r *Retribusi) GetDinas() string {
if r.Dinas.Valid {
return r.Dinas.String
}
return ""
}
func (r *Retribusi) GetTarif() string {
if r.Tarif.Valid {
return r.Tarif.String
}
return ""
}
// Response struct untuk GET by ID - diperbaiki struktur
type RetribusiGetByIDResponse struct {
Message string `json:"message"`
Data *Retribusi `json:"data"`
}
// Request struct untuk create - dioptimalkan dengan validasi
type RetribusiCreateRequest struct {
Status string `json:"status" validate:"required,oneof=draft active inactive"`
Jenis *string `json:"jenis,omitempty" validate:"omitempty,min=1,max=255"`
Pelayanan *string `json:"pelayanan,omitempty" validate:"omitempty,min=1,max=255"`
Dinas *string `json:"dinas,omitempty" validate:"omitempty,min=1,max=255"`
KelompokObyek *string `json:"kelompok_obyek,omitempty" validate:"omitempty,min=1,max=255"`
KodeTarif *string `json:"kode_tarif,omitempty" validate:"omitempty,min=1,max=255"`
Uraian1 *string `json:"uraian_1,omitempty"`
Uraian2 *string `json:"uraian_2,omitempty"`
Uraian3 *string `json:"uraian_3,omitempty"`
Tarif *string `json:"tarif,omitempty" validate:"omitempty,numeric"`
Satuan *string `json:"satuan,omitempty" validate:"omitempty,min=1,max=255"`
TarifOvertime *string `json:"tarif_overtime,omitempty" validate:"omitempty,numeric"`
SatuanOvertime *string `json:"satuan_overtime,omitempty" validate:"omitempty,min=1,max=255"`
RekeningPokok *string `json:"rekening_pokok,omitempty" validate:"omitempty,min=1,max=255"`
RekeningDenda *string `json:"rekening_denda,omitempty" validate:"omitempty,min=1,max=255"`
}
// Response struct untuk create
type RetribusiCreateResponse struct {
Message string `json:"message"`
Data *Retribusi `json:"data"`
}
// Update request - sama seperti create tapi dengan ID
type RetribusiUpdateRequest struct {
ID string `json:"-" validate:"required,uuid4"` // ID dari URL path
Status string `json:"status" validate:"required,oneof=draft active inactive"`
Jenis *string `json:"jenis,omitempty" validate:"omitempty,min=1,max=255"`
Pelayanan *string `json:"pelayanan,omitempty" validate:"omitempty,min=1,max=255"`
Dinas *string `json:"dinas,omitempty" validate:"omitempty,min=1,max=255"`
KelompokObyek *string `json:"kelompok_obyek,omitempty" validate:"omitempty,min=1,max=255"`
KodeTarif *string `json:"kode_tarif,omitempty" validate:"omitempty,min=1,max=255"`
Uraian1 *string `json:"uraian_1,omitempty"`
Uraian2 *string `json:"uraian_2,omitempty"`
Uraian3 *string `json:"uraian_3,omitempty"`
Tarif *string `json:"tarif,omitempty" validate:"omitempty,numeric"`
Satuan *string `json:"satuan,omitempty" validate:"omitempty,min=1,max=255"`
TarifOvertime *string `json:"tarif_overtime,omitempty" validate:"omitempty,numeric"`
SatuanOvertime *string `json:"satuan_overtime,omitempty" validate:"omitempty,min=1,max=255"`
RekeningPokok *string `json:"rekening_pokok,omitempty" validate:"omitempty,min=1,max=255"`
RekeningDenda *string `json:"rekening_denda,omitempty" validate:"omitempty,min=1,max=255"`
}
// Response struct untuk update
type RetribusiUpdateResponse struct {
Message string `json:"message"`
Data *Retribusi `json:"data"`
}
// Response struct untuk delete
type RetribusiDeleteResponse struct {
Message string `json:"message"`
ID string `json:"id"`
}
// Enhanced GET response dengan pagination dan aggregation
type RetribusiGetResponse struct {
Message string `json:"message"`
Data []Retribusi `json:"data"`
Meta models.MetaResponse `json:"meta"`
Summary *models.AggregateData `json:"summary,omitempty"`
}
// Filter struct untuk query parameters
type RetribusiFilter struct {
Status *string `json:"status,omitempty" form:"status"`
Jenis *string `json:"jenis,omitempty" form:"jenis"`
Dinas *string `json:"dinas,omitempty" form:"dinas"`
KelompokObyek *string `json:"kelompok_obyek,omitempty" form:"kelompok_obyek"`
Search *string `json:"search,omitempty" form:"search"`
DateFrom *time.Time `json:"date_from,omitempty" form:"date_from"`
DateTo *time.Time `json:"date_to,omitempty" form:"date_to"`
}

View File

@@ -0,0 +1,106 @@
package models
import (
"regexp"
"strings"
"time"
"github.com/go-playground/validator/v10"
)
// CustomValidator wraps the validator
type CustomValidator struct {
Validator *validator.Validate
}
// Validate validates struct
func (cv *CustomValidator) Validate(i interface{}) error {
return cv.Validator.Struct(i)
}
// RegisterCustomValidations registers custom validation rules
func RegisterCustomValidations(v *validator.Validate) {
// Validate Indonesian phone number
v.RegisterValidation("indonesian_phone", validateIndonesianPhone)
// Validate BPJS card number format
v.RegisterValidation("bpjs_card", validateBPJSCard)
// Validate Indonesian NIK
v.RegisterValidation("indonesian_nik", validateIndonesianNIK)
// Validate date format YYYY-MM-DD
v.RegisterValidation("date_format", validateDateFormat)
// Validate ICD-10 code format
v.RegisterValidation("icd10", validateICD10)
// Validate ICD-9-CM procedure code
v.RegisterValidation("icd9cm", validateICD9CM)
}
func validateIndonesianPhone(fl validator.FieldLevel) bool {
phone := fl.Field().String()
if phone == "" {
return true // Optional field
}
// Indonesian phone number pattern: +62, 62, 08, or 8
pattern := `^(\+?62|0?8)[1-9][0-9]{7,11}$`
matched, _ := regexp.MatchString(pattern, phone)
return matched
}
func validateBPJSCard(fl validator.FieldLevel) bool {
card := fl.Field().String()
if len(card) != 13 {
return false
}
// BPJS card should be numeric
pattern := `^\d{13}$`
matched, _ := regexp.MatchString(pattern, card)
return matched
}
func validateIndonesianNIK(fl validator.FieldLevel) bool {
nik := fl.Field().String()
if len(nik) != 16 {
return false
}
// NIK should be numeric
pattern := `^\d{16}$`
matched, _ := regexp.MatchString(pattern, nik)
return matched
}
func validateDateFormat(fl validator.FieldLevel) bool {
dateStr := fl.Field().String()
_, err := time.Parse("2006-01-02", dateStr)
return err == nil
}
func validateICD10(fl validator.FieldLevel) bool {
code := fl.Field().String()
if code == "" {
return true
}
// Basic ICD-10 pattern: Letter followed by 2 digits, optional dot and more digits
pattern := `^[A-Z]\d{2}(\.\d+)?$`
matched, _ := regexp.MatchString(pattern, strings.ToUpper(code))
return matched
}
func validateICD9CM(fl validator.FieldLevel) bool {
code := fl.Field().String()
if code == "" {
return true
}
// Basic ICD-9-CM procedure pattern: 2-4 digits with optional decimal
pattern := `^\d{2,4}(\.\d+)?$`
matched, _ := regexp.MatchString(pattern, code)
return matched
}

View File

@@ -0,0 +1,774 @@
package v1
import (
"api-service/internal/config"
"api-service/internal/database"
authHandlers "api-service/internal/handlers/auth"
healthcheckHandlers "api-service/internal/handlers/healthcheck"
retribusiHandlers "api-service/internal/handlers/retribusi"
"api-service/internal/handlers/websocket"
websocketHandlers "api-service/internal/handlers/websocket"
"api-service/internal/middleware"
services "api-service/internal/services/auth"
"api-service/pkg/logger"
"encoding/json"
"strconv"
"time"
"github.com/gin-gonic/gin"
swaggerFiles "github.com/swaggo/files"
ginSwagger "github.com/swaggo/gin-swagger"
)
func RegisterRoutes(cfg *config.Config) *gin.Engine {
router := gin.New()
// Initialize auth middleware configuration
middleware.InitializeAuth(cfg)
// Add global middleware
router.Use(middleware.CORSConfig())
router.Use(middleware.ErrorHandler())
router.Use(logger.RequestLoggerMiddleware(logger.Default()))
router.Use(gin.Recovery())
// Initialize services with error handling
authService := services.NewAuthService(cfg)
if authService == nil {
logger.Fatal("Failed to initialize auth service")
}
// Initialize database service
dbService := database.New(cfg)
// Initialize WebSocket handler with enhanced features
websocketHandler := websocketHandlers.NewWebSocketHandler(cfg, dbService)
// =============================================================================
// HEALTH CHECK & SYSTEM ROUTES
// =============================================================================
healthCheckHandler := healthcheckHandlers.NewHealthCheckHandler(dbService)
sistem := router.Group("/api/sistem")
{
sistem.GET("/health", healthCheckHandler.CheckHealth)
sistem.GET("/databases", func(c *gin.Context) {
c.JSON(200, gin.H{
"databases": dbService.ListDBs(),
"health": dbService.Health(),
"timestamp": time.Now().Unix(),
})
})
sistem.GET("/info", func(c *gin.Context) {
c.JSON(200, gin.H{
"service": "API Service v1.0.0",
"websocket_active": true,
"connected_clients": websocketHandler.GetConnectedClients(),
"databases": dbService.ListDBs(),
"timestamp": time.Now().Unix(),
})
})
}
// =============================================================================
// SWAGGER DOCUMENTATION
// =============================================================================
router.GET("/swagger/*any", ginSwagger.WrapHandler(
swaggerFiles.Handler,
ginSwagger.DefaultModelsExpandDepth(-1),
ginSwagger.DeepLinking(true),
))
// =============================================================================
// WEBSOCKET TEST CLIENT
// =============================================================================
// router.GET("/websocket-test", func(c *gin.Context) {
// c.Header("Content-Type", "text/html")
// c.String(http.StatusOK, getWebSocketTestHTML())
// })
// =============================================================================
// API v1 GROUP
// =============================================================================
v1 := router.Group("/api/v1")
// =============================================================================
// PUBLIC ROUTES (No Authentication Required)
// =============================================================================
// Authentication routes
authHandler := authHandlers.NewAuthHandler(authService)
tokenHandler := authHandlers.NewTokenHandler(authService)
// Basic auth routes
v1.POST("/auth/login", authHandler.Login)
v1.POST("/auth/register", authHandler.Register)
v1.POST("/auth/refresh", authHandler.RefreshToken)
// Token generation routes
v1.POST("/token/generate", tokenHandler.GenerateToken)
v1.POST("/token/generate-direct", tokenHandler.GenerateTokenDirect)
// =============================================================================
// WEBSOCKET ROUTES
// =============================================================================
// Main WebSocket endpoint with enhanced features
v1.GET("/ws", websocketHandler.HandleWebSocket)
// WebSocket management API
wsAPI := router.Group("/api/websocket")
{
// =============================================================================
// BASIC BROADCASTING
// =============================================================================
wsAPI.POST("/broadcast", func(c *gin.Context) {
var req struct {
Type string `json:"type"`
Message interface{} `json:"message"`
Database string `json:"database,omitempty"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
websocketHandler.BroadcastMessage(req.Type, req.Message)
c.JSON(200, gin.H{
"status": "broadcast sent",
"clients_count": websocketHandler.GetConnectedClients(),
"timestamp": time.Now().Unix(),
})
})
wsAPI.POST("/broadcast/room/:room", func(c *gin.Context) {
room := c.Param("room")
var req struct {
Type string `json:"type"`
Message interface{} `json:"message"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
websocketHandler.BroadcastToRoom(room, req.Type, req.Message)
c.JSON(200, gin.H{
"status": "room broadcast sent",
"room": room,
"clients_count": websocketHandler.GetRoomClientCount(room), // Fix: gunakan GetRoomClientCount
"timestamp": time.Now().Unix(),
})
})
// =============================================================================
// ENHANCED CLIENT TARGETING
// =============================================================================
wsAPI.POST("/send/:clientId", func(c *gin.Context) {
clientID := c.Param("clientId")
var req struct {
Type string `json:"type"`
Message interface{} `json:"message"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
websocketHandler.SendToClient(clientID, req.Type, req.Message)
c.JSON(200, gin.H{
"status": "message sent",
"client_id": clientID,
"timestamp": time.Now().Unix(),
})
})
// Send to client by static ID
wsAPI.POST("/send/static/:staticId", func(c *gin.Context) {
staticID := c.Param("staticId")
logger.Infof("Sending message to static client: %s", staticID)
var req struct {
Type string `json:"type"`
Message interface{} `json:"message"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
success := websocketHandler.SendToClientByStaticID(staticID, req.Type, req.Message)
if success {
c.JSON(200, gin.H{
"status": "message sent to static client",
"static_id": staticID,
"timestamp": time.Now().Unix(),
})
} else {
c.JSON(404, gin.H{
"error": "static client not found",
"static_id": staticID,
"timestamp": time.Now().Unix(),
})
}
})
// Broadcast to all clients from specific IP
wsAPI.POST("/broadcast/ip/:ipAddress", func(c *gin.Context) {
ipAddress := c.Param("ipAddress")
var req struct {
Type string `json:"type"`
Message interface{} `json:"message"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
count := websocketHandler.BroadcastToIP(ipAddress, req.Type, req.Message)
c.JSON(200, gin.H{
"status": "ip broadcast sent",
"ip_address": ipAddress,
"clients_count": count,
"timestamp": time.Now().Unix(),
})
})
// =============================================================================
// CLIENT INFORMATION & STATISTICS
// =============================================================================
wsAPI.GET("/stats", func(c *gin.Context) {
c.JSON(200, gin.H{
"connected_clients": websocketHandler.GetConnectedClients(),
"databases": dbService.ListDBs(),
"database_health": dbService.Health(),
"timestamp": time.Now().Unix(),
})
})
wsAPI.GET("/stats/detailed", func(c *gin.Context) {
stats := websocketHandler.GetDetailedStats()
c.JSON(200, gin.H{
"stats": stats,
"timestamp": time.Now().Unix(),
})
})
wsAPI.GET("/clients", func(c *gin.Context) {
clients := websocketHandler.GetAllClients()
c.JSON(200, gin.H{
"clients": clients,
"count": len(clients),
"timestamp": time.Now().Unix(),
})
})
// Fix: Perbaiki GetClientsByIP untuk menggunakan ClientInfo
wsAPI.GET("/clients/by-ip/:ipAddress", func(c *gin.Context) {
ipAddress := c.Param("ipAddress")
client := websocketHandler.GetClientsByIP(ipAddress)
if client == nil {
c.JSON(404, gin.H{
"error": "client not found",
"ip_address": ipAddress,
"timestamp": time.Now().Unix(),
})
return
}
// Use ClientInfo struct instead of direct field access
clientInfo := websocketHandler.GetAllClients()
var targetClientInfo *websocket.ClientInfo
for i := range clientInfo {
if clientInfo[i].ID == ipAddress {
targetClientInfo = &clientInfo[i]
break
}
}
if targetClientInfo == nil {
c.JSON(404, gin.H{
"error": "ipAddress not found",
"client_id": ipAddress,
"timestamp": time.Now().Unix(),
})
return
}
c.JSON(200, gin.H{
"client": map[string]interface{}{
"id": targetClientInfo.ID,
"static_id": targetClientInfo.StaticID,
"ip_address": targetClientInfo.IPAddress,
"user_id": targetClientInfo.UserID,
"room": targetClientInfo.Room,
"connected_at": targetClientInfo.ConnectedAt.Unix(), // Fixed: use exported field
"last_ping": targetClientInfo.LastPing.Unix(), // Fixed: use exported field
},
"timestamp": time.Now().Unix(),
})
})
// Fix: Perbaiki GetClientByID response
wsAPI.GET("/client/:clientId", func(c *gin.Context) {
clientID := c.Param("clientId")
client := websocketHandler.GetClientByID(clientID)
if client == nil {
c.JSON(404, gin.H{
"error": "client not found",
"client_id": clientID,
"timestamp": time.Now().Unix(),
})
return
}
// Use ClientInfo struct instead of direct field access
clientInfo := websocketHandler.GetAllClients()
var targetClientInfo *websocket.ClientInfo
for i := range clientInfo {
if clientInfo[i].ID == clientID {
targetClientInfo = &clientInfo[i]
break
}
}
if targetClientInfo == nil {
c.JSON(404, gin.H{
"error": "client not found",
"client_id": clientID,
"timestamp": time.Now().Unix(),
})
return
}
c.JSON(200, gin.H{
"client": map[string]interface{}{
"id": targetClientInfo.ID,
"static_id": targetClientInfo.StaticID,
"ip_address": targetClientInfo.IPAddress,
"user_id": targetClientInfo.UserID,
"room": targetClientInfo.Room,
"connected_at": targetClientInfo.ConnectedAt.Unix(), // Fixed: use exported field
"last_ping": targetClientInfo.LastPing.Unix(), // Fixed: use exported field
},
"timestamp": time.Now().Unix(),
})
})
// Fix: Perbaiki GetClientByStaticID response
wsAPI.GET("/client/static/:staticId", func(c *gin.Context) {
staticID := c.Param("staticId")
client := websocketHandler.GetClientByStaticID(staticID)
if client == nil {
c.JSON(404, gin.H{
"error": "static client not found",
"static_id": staticID,
"timestamp": time.Now().Unix(),
})
return
}
// Use ClientInfo struct instead of direct field access
clientInfo := websocketHandler.GetAllClients()
var targetClientInfo *websocket.ClientInfo
for i := range clientInfo {
if clientInfo[i].StaticID == staticID {
targetClientInfo = &clientInfo[i]
break
}
}
if targetClientInfo == nil {
c.JSON(404, gin.H{
"error": "static client not found",
"static_id": staticID,
"timestamp": time.Now().Unix(),
})
return
}
c.JSON(200, gin.H{
"client": map[string]interface{}{
"id": targetClientInfo.ID,
"static_id": targetClientInfo.StaticID,
"ip_address": targetClientInfo.IPAddress,
"user_id": targetClientInfo.UserID,
"room": targetClientInfo.Room,
"connected_at": targetClientInfo.ConnectedAt.Unix(), // Fixed: use exported field
"last_ping": targetClientInfo.LastPing.Unix(), // Fixed: use exported field
},
"timestamp": time.Now().Unix(),
})
})
// =============================================================================
// ACTIVE CLIENTS & CLEANUP
// =============================================================================
// Tambahkan endpoint untuk active clients
wsAPI.GET("/clients/active", func(c *gin.Context) {
// Default: clients active dalam 5 menit terakhir
minutes := c.DefaultQuery("minutes", "5")
minutesInt, err := strconv.Atoi(minutes)
if err != nil {
minutesInt = 5
}
activeClients := websocketHandler.GetActiveClients(time.Duration(minutesInt) * time.Minute)
c.JSON(200, gin.H{
"active_clients": activeClients,
"count": len(activeClients),
"threshold_minutes": minutesInt,
"timestamp": time.Now().Unix(),
})
})
// Tambahkan endpoint untuk cleanup inactive clients
wsAPI.POST("/cleanup/inactive", func(c *gin.Context) {
var req struct {
InactiveMinutes int `json:"inactive_minutes"`
}
if err := c.ShouldBindJSON(&req); err != nil {
req.InactiveMinutes = 30 // Default 30 minutes
}
if req.InactiveMinutes <= 0 {
req.InactiveMinutes = 30
}
cleanedCount := websocketHandler.CleanupInactiveClients(time.Duration(req.InactiveMinutes) * time.Minute)
c.JSON(200, gin.H{
"status": "cleanup completed",
"cleaned_clients": cleanedCount,
"inactive_minutes": req.InactiveMinutes,
"timestamp": time.Now().Unix(),
})
})
// =============================================================================
// DATABASE NOTIFICATIONS
// =============================================================================
wsAPI.POST("/notify/:database/:channel", func(c *gin.Context) {
database := c.Param("database")
channel := c.Param("channel")
var req struct {
Payload interface{} `json:"payload"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
payloadJSON, _ := json.Marshal(req.Payload)
err := dbService.NotifyChange(database, channel, string(payloadJSON))
if err != nil {
c.JSON(500, gin.H{
"error": err.Error(),
"database": database,
"channel": channel,
"timestamp": time.Now().Unix(),
})
return
}
c.JSON(200, gin.H{
"status": "notification sent",
"database": database,
"channel": channel,
"timestamp": time.Now().Unix(),
})
})
// Test database notification
wsAPI.POST("/test-notification", func(c *gin.Context) {
var req struct {
Database string `json:"database"`
Channel string `json:"channel"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
// Default values
if req.Database == "" {
req.Database = "default"
}
if req.Channel == "" {
req.Channel = "system_changes"
}
if req.Message == "" {
req.Message = "Test notification from API"
}
payload := map[string]interface{}{
"operation": "API_TEST",
"table": "manual_test",
"data": map[string]interface{}{
"message": req.Message,
"test_data": req.Data,
"timestamp": time.Now().Unix(),
},
}
payloadJSON, _ := json.Marshal(payload)
err := dbService.NotifyChange(req.Database, req.Channel, string(payloadJSON))
if err != nil {
c.JSON(500, gin.H{"error": err.Error()})
return
}
c.JSON(200, gin.H{
"status": "test notification sent",
"database": req.Database,
"channel": req.Channel,
"payload": payload,
"timestamp": time.Now().Unix(),
})
})
// =============================================================================
// ROOM MANAGEMENT
// =============================================================================
wsAPI.GET("/rooms", func(c *gin.Context) {
rooms := websocketHandler.GetAllRooms()
c.JSON(200, gin.H{
"rooms": rooms,
"count": len(rooms),
"timestamp": time.Now().Unix(),
})
})
wsAPI.GET("/room/:room/clients", func(c *gin.Context) {
room := c.Param("room")
clientCount := websocketHandler.GetRoomClientCount(room)
// Get detailed room info
allRooms := websocketHandler.GetAllRooms()
roomClients := allRooms[room]
c.JSON(200, gin.H{
"room": room,
"client_count": clientCount,
"clients": roomClients,
"timestamp": time.Now().Unix(),
})
})
// =============================================================================
// MONITORING & DEBUGGING
// =============================================================================
wsAPI.GET("/monitor", func(c *gin.Context) {
monitor := websocketHandler.GetMonitoringData()
c.JSON(200, monitor)
})
wsAPI.POST("/ping-client/:clientId", func(c *gin.Context) {
clientID := c.Param("clientId")
websocketHandler.SendToClient(clientID, "server_ping", map[string]interface{}{
"message": "Ping from server",
"timestamp": time.Now().Unix(),
})
c.JSON(200, gin.H{
"status": "ping sent",
"client_id": clientID,
"timestamp": time.Now().Unix(),
})
})
// Disconnect specific client
wsAPI.POST("/disconnect/:clientId", func(c *gin.Context) {
clientID := c.Param("clientId")
success := websocketHandler.DisconnectClient(clientID)
if success {
c.JSON(200, gin.H{
"status": "client disconnected",
"client_id": clientID,
"timestamp": time.Now().Unix(),
})
} else {
c.JSON(404, gin.H{
"error": "client not found",
"client_id": clientID,
"timestamp": time.Now().Unix(),
})
}
})
// =============================================================================
// BULK OPERATIONS
// =============================================================================
// Broadcast to multiple clients
wsAPI.POST("/broadcast/bulk", func(c *gin.Context) {
var req struct {
ClientIDs []string `json:"client_ids"`
Type string `json:"type"`
Message interface{} `json:"message"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
successCount := 0
for _, clientID := range req.ClientIDs {
websocketHandler.SendToClient(clientID, req.Type, req.Message)
successCount++
}
c.JSON(200, gin.H{
"status": "bulk broadcast sent",
"total_clients": len(req.ClientIDs),
"success_count": successCount,
"timestamp": time.Now().Unix(),
})
})
// Disconnect multiple clients
wsAPI.POST("/disconnect/bulk", func(c *gin.Context) {
var req struct {
ClientIDs []string `json:"client_ids"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
successCount := 0
for _, clientID := range req.ClientIDs {
if websocketHandler.DisconnectClient(clientID) {
successCount++
}
}
c.JSON(200, gin.H{
"status": "bulk disconnect completed",
"total_clients": len(req.ClientIDs),
"success_count": successCount,
"timestamp": time.Now().Unix(),
})
})
}
// =============================================================================
// PUBLISHED ROUTES
// =============================================================================
// Retribusi endpoints with WebSocket notifications
retribusiHandler := retribusiHandlers.NewRetribusiHandler()
retribusiGroup := v1.Group("/retribusi")
{
retribusiGroup.GET("", retribusiHandler.GetRetribusi)
retribusiGroup.GET("/dynamic", retribusiHandler.GetRetribusiDynamic)
retribusiGroup.GET("/search", retribusiHandler.SearchRetribusiAdvanced)
retribusiGroup.GET("/id/:id", retribusiHandler.GetRetribusiByID)
// POST/PUT/DELETE with automatic WebSocket notifications
retribusiGroup.POST("", func(c *gin.Context) {
retribusiHandler.CreateRetribusi(c)
// Trigger WebSocket notification after successful creation
if c.Writer.Status() == 200 || c.Writer.Status() == 201 {
websocketHandler.BroadcastMessage("retribusi_created", map[string]interface{}{
"message": "New retribusi record created",
"timestamp": time.Now().Unix(),
})
}
})
retribusiGroup.PUT("/id/:id", func(c *gin.Context) {
id := c.Param("id")
retribusiHandler.UpdateRetribusi(c)
// Trigger WebSocket notification after successful update
if c.Writer.Status() == 200 {
websocketHandler.BroadcastMessage("retribusi_updated", map[string]interface{}{
"message": "Retribusi record updated",
"id": id,
"timestamp": time.Now().Unix(),
})
}
})
retribusiGroup.DELETE("/id/:id", func(c *gin.Context) {
id := c.Param("id")
retribusiHandler.DeleteRetribusi(c)
// Trigger WebSocket notification after successful deletion
if c.Writer.Status() == 200 {
websocketHandler.BroadcastMessage("retribusi_deleted", map[string]interface{}{
"message": "Retribusi record deleted",
"id": id,
"timestamp": time.Now().Unix(),
})
}
})
}
// =============================================================================
// PROTECTED ROUTES (Authentication Required)
// =============================================================================
protected := v1.Group("/")
protected.Use(middleware.ConfigurableAuthMiddleware(cfg))
// Protected WebSocket management (optional)
protectedWS := protected.Group("/ws-admin")
{
protectedWS.GET("/stats", func(c *gin.Context) {
detailedStats := websocketHandler.GetDetailedStats()
c.JSON(200, gin.H{
"admin_stats": detailedStats,
"timestamp": time.Now().Unix(),
})
})
protectedWS.POST("/force-disconnect/:clientId", func(c *gin.Context) {
clientID := c.Param("clientId")
success := websocketHandler.DisconnectClient(clientID)
c.JSON(200, gin.H{
"status": "force disconnect attempted",
"client_id": clientID,
"success": success,
"timestamp": time.Now().Unix(),
})
})
protectedWS.POST("/cleanup/force", func(c *gin.Context) {
var req struct {
InactiveMinutes int `json:"inactive_minutes"`
Force bool `json:"force"`
}
if err := c.ShouldBindJSON(&req); err != nil {
req.InactiveMinutes = 10
req.Force = false
}
cleanedCount := websocketHandler.CleanupInactiveClients(time.Duration(req.InactiveMinutes) * time.Minute)
c.JSON(200, gin.H{
"status": "admin cleanup completed",
"cleaned_clients": cleanedCount,
"inactive_minutes": req.InactiveMinutes,
"force": req.Force,
"timestamp": time.Now().Unix(),
})
})
}
return router
}

53
internal/server/server.go Normal file
View File

@@ -0,0 +1,53 @@
package server
import (
"fmt"
"net/http"
"os"
"strconv"
"time"
_ "github.com/joho/godotenv/autoload"
"api-service/internal/config"
"api-service/internal/database"
v1 "api-service/internal/routes/v1"
)
var dbService database.Service // Global variable to hold the database service instance
type Server struct {
port int
db database.Service
}
func NewServer() *http.Server {
// Load configuration
cfg := config.LoadConfig()
cfg.Validate()
port, _ := strconv.Atoi(os.Getenv("PORT"))
if port == 0 {
port = cfg.Server.Port
}
if dbService == nil { // Check if the database service is already initialized
dbService = database.New(cfg) // Initialize only once
}
NewServer := &Server{
port: port,
db: dbService, // Use the global database service instance
}
// Declare Server config
server := &http.Server{
Addr: fmt.Sprintf(":%d", NewServer.port),
Handler: v1.RegisterRoutes(cfg),
IdleTimeout: time.Minute,
ReadTimeout: 10 * time.Second,
WriteTimeout: 30 * time.Second,
}
return server
}

View File

@@ -0,0 +1,169 @@
package services
import (
"api-service/internal/config"
models "api-service/internal/models/auth"
"errors"
"time"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
)
// AuthService handles authentication logic
type AuthService struct {
config *config.Config
users map[string]*models.User // In-memory user store for demo
}
// NewAuthService creates a new authentication service
func NewAuthService(cfg *config.Config) *AuthService {
// Initialize with demo users
users := make(map[string]*models.User)
// Add demo users
users["admin"] = &models.User{
ID: "1",
Username: "admin",
Email: "admin@example.com",
Password: "$2a$10$92IXUNpkjO0rOQ5byMi.Ye4oKoEa3Ro9llC/.og/at2.uheWG/igi", // password
Role: "admin",
}
users["user"] = &models.User{
ID: "2",
Username: "user",
Email: "user@example.com",
Password: "$2a$10$92IXUNpkjO0rOQ5byMi.Ye4oKoEa3Ro9llC/.og/at2.uheWG/igi", // password
Role: "user",
}
return &AuthService{
config: cfg,
users: users,
}
}
// Login authenticates user and generates JWT token
func (s *AuthService) Login(username, password string) (*models.TokenResponse, error) {
user, exists := s.users[username]
if !exists {
return nil, errors.New("invalid credentials")
}
// Verify password
err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
if err != nil {
return nil, errors.New("invalid credentials")
}
// Generate JWT token
token, err := s.generateToken(user)
if err != nil {
return nil, err
}
return &models.TokenResponse{
AccessToken: token,
TokenType: "Bearer",
ExpiresIn: 3600, // 1 hour
}, nil
}
// generateToken creates a new JWT token for the user
func (s *AuthService) generateToken(user *models.User) (string, error) {
// Create claims
claims := jwt.MapClaims{
"user_id": user.ID,
"username": user.Username,
"email": user.Email,
"role": user.Role,
"exp": time.Now().Add(time.Hour * 1).Unix(),
"iat": time.Now().Unix(),
}
// Create token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
// Sign token with secret key
secretKey := []byte(s.getJWTSecret())
return token.SignedString(secretKey)
}
// GenerateTokenForUser generates a JWT token for a specific user
func (s *AuthService) GenerateTokenForUser(user *models.User) (string, error) {
// Create claims
claims := jwt.MapClaims{
"user_id": user.ID,
"username": user.Username,
"email": user.Email,
"role": user.Role,
"exp": time.Now().Add(time.Hour * 1).Unix(),
"iat": time.Now().Unix(),
}
// Create token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
// Sign token with secret key
secretKey := []byte(s.getJWTSecret())
return token.SignedString(secretKey)
}
// ValidateToken validates the JWT token
func (s *AuthService) ValidateToken(tokenString string) (*models.JWTClaims, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, errors.New("unexpected signing method")
}
return []byte(s.getJWTSecret()), nil
})
if err != nil {
return nil, err
}
if !token.Valid {
return nil, errors.New("invalid token")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, errors.New("invalid claims")
}
return &models.JWTClaims{
UserID: claims["user_id"].(string),
Username: claims["username"].(string),
Email: claims["email"].(string),
Role: claims["role"].(string),
}, nil
}
// getJWTSecret returns the JWT secret key
func (s *AuthService) getJWTSecret() string {
// In production, this should come from environment variables
return "your-secret-key-change-this-in-production"
}
// RegisterUser registers a new user (for demo purposes)
func (s *AuthService) RegisterUser(username, email, password, role string) error {
if _, exists := s.users[username]; exists {
return errors.New("username already exists")
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
s.users[username] = &models.User{
ID: string(rune(len(s.users) + 1)),
Username: username,
Email: email,
Password: string(hashedPassword),
Role: role,
}
return nil
}

View File

@@ -0,0 +1,593 @@
package utils
import (
"fmt"
"reflect"
"strings"
"sync"
)
// FilterOperator represents supported filter operators
type FilterOperator string
const (
OpEqual FilterOperator = "_eq"
OpNotEqual FilterOperator = "_neq"
OpLike FilterOperator = "_like"
OpILike FilterOperator = "_ilike"
OpIn FilterOperator = "_in"
OpNotIn FilterOperator = "_nin"
OpGreaterThan FilterOperator = "_gt"
OpGreaterThanEqual FilterOperator = "_gte"
OpLessThan FilterOperator = "_lt"
OpLessThanEqual FilterOperator = "_lte"
OpBetween FilterOperator = "_between"
OpNotBetween FilterOperator = "_nbetween"
OpNull FilterOperator = "_null"
OpNotNull FilterOperator = "_nnull"
OpContains FilterOperator = "_contains"
OpNotContains FilterOperator = "_ncontains"
OpStartsWith FilterOperator = "_starts_with"
OpEndsWith FilterOperator = "_ends_with"
)
// DynamicFilter represents a single filter condition
type DynamicFilter struct {
Column string `json:"column"`
Operator FilterOperator `json:"operator"`
Value interface{} `json:"value"`
LogicOp string `json:"logic_op,omitempty"` // AND, OR
}
// FilterGroup represents a group of filters
type FilterGroup struct {
Filters []DynamicFilter `json:"filters"`
LogicOp string `json:"logic_op"` // AND, OR
}
// DynamicQuery represents the complete query structure
type DynamicQuery struct {
Fields []string `json:"fields,omitempty"`
Filters []FilterGroup `json:"filters,omitempty"`
Sort []SortField `json:"sort,omitempty"`
Limit int `json:"limit"`
Offset int `json:"offset"`
GroupBy []string `json:"group_by,omitempty"`
Having []FilterGroup `json:"having,omitempty"`
}
// SortField represents sorting configuration
type SortField struct {
Column string `json:"column"`
Order string `json:"order"` // ASC, DESC
}
// QueryBuilder builds SQL queries from dynamic filters
type QueryBuilder struct {
tableName string
columnMapping map[string]string // Maps API field names to DB column names
allowedColumns map[string]bool // Security: only allow specified columns
paramCounter int
mu *sync.RWMutex
}
// NewQueryBuilder creates a new query builder instance
func NewQueryBuilder(tableName string) *QueryBuilder {
return &QueryBuilder{
tableName: tableName,
columnMapping: make(map[string]string),
allowedColumns: make(map[string]bool),
paramCounter: 0,
}
}
// SetColumnMapping sets the mapping between API field names and database column names
func (qb *QueryBuilder) SetColumnMapping(mapping map[string]string) *QueryBuilder {
qb.columnMapping = mapping
return qb
}
// SetAllowedColumns sets the list of allowed columns for security
func (qb *QueryBuilder) SetAllowedColumns(columns []string) *QueryBuilder {
qb.allowedColumns = make(map[string]bool)
for _, col := range columns {
qb.allowedColumns[col] = true
}
return qb
}
// BuildQuery builds the complete SQL query
func (qb *QueryBuilder) BuildQuery(query DynamicQuery) (string, []interface{}, error) {
qb.paramCounter = 0
// Build SELECT clause
selectClause := qb.buildSelectClause(query.Fields)
// Build FROM clause
fromClause := fmt.Sprintf("FROM %s", qb.tableName)
// Build WHERE clause
whereClause, whereArgs, err := qb.buildWhereClause(query.Filters)
if err != nil {
return "", nil, err
}
// Build ORDER BY clause
orderClause := qb.buildOrderClause(query.Sort)
// Build GROUP BY clause
groupClause := qb.buildGroupByClause(query.GroupBy)
// Build HAVING clause
havingClause, havingArgs, err := qb.buildHavingClause(query.Having)
if err != nil {
return "", nil, err
}
// Combine all parts
sqlParts := []string{selectClause, fromClause}
args := []interface{}{}
if whereClause != "" {
sqlParts = append(sqlParts, "WHERE "+whereClause)
args = append(args, whereArgs...)
}
if groupClause != "" {
sqlParts = append(sqlParts, groupClause)
}
if havingClause != "" {
sqlParts = append(sqlParts, "HAVING "+havingClause)
args = append(args, havingArgs...)
}
if orderClause != "" {
sqlParts = append(sqlParts, orderClause)
}
// Add pagination
if query.Limit > 0 {
qb.paramCounter++
sqlParts = append(sqlParts, fmt.Sprintf("LIMIT $%d", qb.paramCounter))
args = append(args, query.Limit)
}
if query.Offset > 0 {
qb.paramCounter++
sqlParts = append(sqlParts, fmt.Sprintf("OFFSET $%d", qb.paramCounter))
args = append(args, query.Offset)
}
sql := strings.Join(sqlParts, " ")
return sql, args, nil
}
// buildSelectClause builds the SELECT part of the query
func (qb *QueryBuilder) buildSelectClause(fields []string) string {
if len(fields) == 0 || (len(fields) == 1 && fields[0] == "*") {
return "SELECT *"
}
var selectedFields []string
for _, field := range fields {
if field == "*.*" || field == "*" {
selectedFields = append(selectedFields, "*")
continue
}
// Check if it's an expression (contains spaces, parentheses, etc.)
if strings.Contains(field, " ") || strings.Contains(field, "(") || strings.Contains(field, ")") {
// Expression, add as is
selectedFields = append(selectedFields, field)
continue
}
// Security check: only allow specified columns (check original field name)
if len(qb.allowedColumns) > 0 && !qb.allowedColumns[field] {
continue
}
// Map field name if mapping exists
if mappedCol, exists := qb.columnMapping[field]; exists {
field = mappedCol
}
selectedFields = append(selectedFields, fmt.Sprintf(`"%s"`, field))
}
if len(selectedFields) == 0 {
return "SELECT *"
}
return "SELECT " + strings.Join(selectedFields, ", ")
}
// buildWhereClause builds the WHERE part of the query
func (qb *QueryBuilder) buildWhereClause(filterGroups []FilterGroup) (string, []interface{}, error) {
if len(filterGroups) == 0 {
return "", nil, nil
}
var conditions []string
var args []interface{}
for i, group := range filterGroups {
groupCondition, groupArgs, err := qb.buildFilterGroup(group)
if err != nil {
return "", nil, err
}
if groupCondition != "" {
if i > 0 {
logicOp := "AND"
if group.LogicOp != "" {
logicOp = strings.ToUpper(group.LogicOp)
}
conditions = append(conditions, logicOp)
}
conditions = append(conditions, groupCondition)
args = append(args, groupArgs...)
}
}
return strings.Join(conditions, " "), args, nil
}
// buildFilterGroup builds conditions for a filter group
func (qb *QueryBuilder) buildFilterGroup(group FilterGroup) (string, []interface{}, error) {
if len(group.Filters) == 0 {
return "", nil, nil
}
var conditions []string
var args []interface{}
for i, filter := range group.Filters {
condition, filterArgs, err := qb.buildFilterCondition(filter)
if err != nil {
return "", nil, err
}
if condition != "" {
if i > 0 {
logicOp := "AND"
if filter.LogicOp != "" {
logicOp = strings.ToUpper(filter.LogicOp)
} else if group.LogicOp != "" {
logicOp = strings.ToUpper(group.LogicOp)
}
conditions = append(conditions, logicOp)
}
conditions = append(conditions, condition)
args = append(args, filterArgs...)
}
}
return strings.Join(conditions, " "), args, nil
}
// buildFilterCondition builds a single filter condition
func (qb *QueryBuilder) buildFilterCondition(filter DynamicFilter) (string, []interface{}, error) {
// Security check (check original field name)
if len(qb.allowedColumns) > 0 && !qb.allowedColumns[filter.Column] {
return "", nil, nil
}
// Map column name if mapping exists
column := filter.Column
if mappedCol, exists := qb.columnMapping[column]; exists {
column = mappedCol
}
// Wrap column name in quotes for PostgreSQL
column = fmt.Sprintf(`"%s"`, column)
switch filter.Operator {
case OpEqual:
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
return fmt.Sprintf("%s = $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
case OpNotEqual:
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
return fmt.Sprintf("%s != $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
case OpLike:
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
return fmt.Sprintf("%s LIKE $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
case OpILike:
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
return fmt.Sprintf("%s ILIKE $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
case OpIn:
values := qb.parseArrayValue(filter.Value)
if len(values) == 0 {
return "", nil, nil
}
var placeholders []string
var args []interface{}
for _, val := range values {
qb.paramCounter++
placeholders = append(placeholders, fmt.Sprintf("$%d", qb.paramCounter))
args = append(args, val)
}
return fmt.Sprintf("%s IN (%s)", column, strings.Join(placeholders, ", ")), args, nil
case OpNotIn:
values := qb.parseArrayValue(filter.Value)
if len(values) == 0 {
return "", nil, nil
}
var placeholders []string
var args []interface{}
for _, val := range values {
qb.paramCounter++
placeholders = append(placeholders, fmt.Sprintf("$%d", qb.paramCounter))
args = append(args, val)
}
return fmt.Sprintf("%s NOT IN (%s)", column, strings.Join(placeholders, ", ")), args, nil
case OpGreaterThan:
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
return fmt.Sprintf("%s > $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
case OpGreaterThanEqual:
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
return fmt.Sprintf("%s >= $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
case OpLessThan:
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
return fmt.Sprintf("%s < $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
case OpLessThanEqual:
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
return fmt.Sprintf("%s <= $%d", column, qb.paramCounter), []interface{}{filter.Value}, nil
case OpBetween:
if filter.Value == nil {
return "", nil, nil
}
values := qb.parseArrayValue(filter.Value)
if len(values) != 2 {
return "", nil, fmt.Errorf("between operator requires exactly 2 values")
}
qb.paramCounter++
param1 := qb.paramCounter
qb.paramCounter++
param2 := qb.paramCounter
return fmt.Sprintf("%s BETWEEN $%d AND $%d", column, param1, param2), []interface{}{values[0], values[1]}, nil
case OpNotBetween:
if filter.Value == nil {
return "", nil, nil
}
values := qb.parseArrayValue(filter.Value)
if len(values) != 2 {
return "", nil, fmt.Errorf("not between operator requires exactly 2 values")
}
qb.paramCounter++
param1 := qb.paramCounter
qb.paramCounter++
param2 := qb.paramCounter
return fmt.Sprintf("%s NOT BETWEEN $%d AND $%d", column, param1, param2), []interface{}{values[0], values[1]}, nil
case OpNull:
return fmt.Sprintf("%s IS NULL", column), nil, nil
case OpNotNull:
return fmt.Sprintf("%s IS NOT NULL", column), nil, nil
case OpContains:
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
value := fmt.Sprintf("%%%v%%", filter.Value)
return fmt.Sprintf("%s ILIKE $%d", column, qb.paramCounter), []interface{}{value}, nil
case OpNotContains:
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
value := fmt.Sprintf("%%%v%%", filter.Value)
return fmt.Sprintf("%s NOT ILIKE $%d", column, qb.paramCounter), []interface{}{value}, nil
case OpStartsWith:
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
value := fmt.Sprintf("%v%%", filter.Value)
return fmt.Sprintf("%s ILIKE $%d", column, qb.paramCounter), []interface{}{value}, nil
case OpEndsWith:
if filter.Value == nil {
return "", nil, nil
}
qb.paramCounter++
value := fmt.Sprintf("%%%v", filter.Value)
return fmt.Sprintf("%s ILIKE $%d", column, qb.paramCounter), []interface{}{value}, nil
default:
return "", nil, fmt.Errorf("unsupported operator: %s", filter.Operator)
}
}
// parseArrayValue parses array values from various formats
func (qb *QueryBuilder) parseArrayValue(value interface{}) []interface{} {
if value == nil {
return nil
}
// If it's already a slice
if reflect.TypeOf(value).Kind() == reflect.Slice {
v := reflect.ValueOf(value)
result := make([]interface{}, v.Len())
for i := 0; i < v.Len(); i++ {
result[i] = v.Index(i).Interface()
}
return result
}
// If it's a string, try to split by comma
if str, ok := value.(string); ok {
if strings.Contains(str, ",") {
parts := strings.Split(str, ",")
result := make([]interface{}, len(parts))
for i, part := range parts {
result[i] = strings.TrimSpace(part)
}
return result
}
return []interface{}{str}
}
return []interface{}{value}
}
// buildOrderClause builds the ORDER BY clause
func (qb *QueryBuilder) buildOrderClause(sortFields []SortField) string {
if len(sortFields) == 0 {
return ""
}
var orderParts []string
for _, sort := range sortFields {
column := sort.Column
// Security check (check original field name)
if len(qb.allowedColumns) > 0 && !qb.allowedColumns[column] {
continue
}
if mappedCol, exists := qb.columnMapping[column]; exists {
column = mappedCol
}
order := "ASC"
if sort.Order != "" {
order = strings.ToUpper(sort.Order)
}
orderParts = append(orderParts, fmt.Sprintf(`"%s" %s`, column, order))
}
if len(orderParts) == 0 {
return ""
}
return "ORDER BY " + strings.Join(orderParts, ", ")
}
// buildGroupByClause builds the GROUP BY clause
func (qb *QueryBuilder) buildGroupByClause(groupFields []string) string {
if len(groupFields) == 0 {
return ""
}
var groupParts []string
for _, field := range groupFields {
column := field
if mappedCol, exists := qb.columnMapping[column]; exists {
column = mappedCol
}
// Security check
if len(qb.allowedColumns) > 0 && !qb.allowedColumns[column] {
continue
}
groupParts = append(groupParts, fmt.Sprintf(`"%s"`, column))
}
if len(groupParts) == 0 {
return ""
}
return "GROUP BY " + strings.Join(groupParts, ", ")
}
// buildHavingClause builds the HAVING clause
func (qb *QueryBuilder) buildHavingClause(havingGroups []FilterGroup) (string, []interface{}, error) {
if len(havingGroups) == 0 {
return "", nil, nil
}
return qb.buildWhereClause(havingGroups)
}
// BuildCountQuery builds a count query
func (qb *QueryBuilder) BuildCountQuery(query DynamicQuery) (string, []interface{}, error) {
qb.paramCounter = 0
// Build FROM clause
fromClause := fmt.Sprintf("FROM %s", qb.tableName)
// Build WHERE clause
whereClause, whereArgs, err := qb.buildWhereClause(query.Filters)
if err != nil {
return "", nil, err
}
// Build GROUP BY clause
groupClause := qb.buildGroupByClause(query.GroupBy)
// Build HAVING clause
havingClause, havingArgs, err := qb.buildHavingClause(query.Having)
if err != nil {
return "", nil, err
}
// Combine parts
sqlParts := []string{"SELECT COUNT(*)", fromClause}
args := []interface{}{}
if whereClause != "" {
sqlParts = append(sqlParts, "WHERE "+whereClause)
args = append(args, whereArgs...)
}
if groupClause != "" {
sqlParts = append(sqlParts, groupClause)
}
if havingClause != "" {
sqlParts = append(sqlParts, "HAVING "+havingClause)
args = append(args, havingArgs...)
}
sql := strings.Join(sqlParts, " ")
return sql, args, nil
}

View File

@@ -0,0 +1,241 @@
package utils
import (
"net/url"
"strconv"
"strings"
"time"
)
// QueryParser parses HTTP query parameters into DynamicQuery
type QueryParser struct {
defaultLimit int
maxLimit int
}
// NewQueryParser creates a new query parser
func NewQueryParser() *QueryParser {
return &QueryParser{
defaultLimit: 10,
maxLimit: 100,
}
}
// SetLimits sets default and maximum limits
func (qp *QueryParser) SetLimits(defaultLimit, maxLimit int) *QueryParser {
qp.defaultLimit = defaultLimit
qp.maxLimit = maxLimit
return qp
}
// ParseQuery parses URL query parameters into DynamicQuery
func (qp *QueryParser) ParseQuery(values url.Values) (DynamicQuery, error) {
query := DynamicQuery{
Limit: qp.defaultLimit,
Offset: 0,
}
// Parse fields
if fields := values.Get("fields"); fields != "" {
if fields == "*.*" || fields == "*" {
query.Fields = []string{"*"}
} else {
query.Fields = strings.Split(fields, ",")
for i, field := range query.Fields {
query.Fields[i] = strings.TrimSpace(field)
}
}
}
// Parse pagination
if limit := values.Get("limit"); limit != "" {
if l, err := strconv.Atoi(limit); err == nil {
if l > 0 && l <= qp.maxLimit {
query.Limit = l
}
}
}
if offset := values.Get("offset"); offset != "" {
if o, err := strconv.Atoi(offset); err == nil && o >= 0 {
query.Offset = o
}
}
// Parse filters
filters, err := qp.parseFilters(values)
if err != nil {
return query, err
}
query.Filters = filters
// Parse sorting
sorts, err := qp.parseSorting(values)
if err != nil {
return query, err
}
query.Sort = sorts
// Parse group by
if groupBy := values.Get("group"); groupBy != "" {
query.GroupBy = strings.Split(groupBy, ",")
for i, field := range query.GroupBy {
query.GroupBy[i] = strings.TrimSpace(field)
}
}
return query, nil
}
// parseFilters parses filter parameters
// Supports format: filter[column][operator]=value
func (qp *QueryParser) parseFilters(values url.Values) ([]FilterGroup, error) {
filterMap := make(map[string]map[string]string)
// Group filters by column
for key, vals := range values {
if strings.HasPrefix(key, "filter[") && strings.HasSuffix(key, "]") {
// Parse filter[column][operator] format
parts := strings.Split(key[7:len(key)-1], "][")
if len(parts) == 2 {
column := parts[0]
operator := parts[1]
if filterMap[column] == nil {
filterMap[column] = make(map[string]string)
}
if len(vals) > 0 {
filterMap[column][operator] = vals[0]
}
}
}
}
if len(filterMap) == 0 {
return nil, nil
}
// Convert to FilterGroup
var filters []DynamicFilter
for column, operators := range filterMap {
for opStr, value := range operators {
operator := FilterOperator(opStr)
// Parse value based on operator
var parsedValue interface{}
switch operator {
case OpIn, OpNotIn:
if value != "" {
parsedValue = strings.Split(value, ",")
}
case OpBetween, OpNotBetween:
if value != "" {
parts := strings.Split(value, ",")
if len(parts) == 2 {
parsedValue = []interface{}{strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])}
}
}
case OpNull, OpNotNull:
parsedValue = nil
default:
parsedValue = value
}
filters = append(filters, DynamicFilter{
Column: column,
Operator: operator,
Value: parsedValue,
})
}
}
if len(filters) == 0 {
return nil, nil
}
return []FilterGroup{{
Filters: filters,
LogicOp: "AND",
}}, nil
}
// parseSorting parses sort parameters
// Supports format: sort=column1,-column2 (- for DESC)
func (qp *QueryParser) parseSorting(values url.Values) ([]SortField, error) {
sortParam := values.Get("sort")
if sortParam == "" {
return nil, nil
}
var sorts []SortField
fields := strings.Split(sortParam, ",")
for _, field := range fields {
field = strings.TrimSpace(field)
if field == "" {
continue
}
order := "ASC"
column := field
if strings.HasPrefix(field, "-") {
order = "DESC"
column = field[1:]
} else if strings.HasPrefix(field, "+") {
column = field[1:]
}
sorts = append(sorts, SortField{
Column: column,
Order: order,
})
}
return sorts, nil
}
// ParseAdvancedFilters parses complex filter structures
// Supports nested filters and logic operators
func (qp *QueryParser) ParseAdvancedFilters(filterParam string) ([]FilterGroup, error) {
// This would be for more complex JSON-based filters
// Implementation depends on your specific needs
return nil, nil
}
// Helper function to parse date values
func parseDate(value string) (interface{}, error) {
// Try different date formats
formats := []string{
"2006-01-02",
"2006-01-02T15:04:05Z",
"2006-01-02T15:04:05.000Z",
"2006-01-02 15:04:05",
}
for _, format := range formats {
if t, err := time.Parse(format, value); err == nil {
return t, nil
}
}
return value, nil
}
// Helper function to parse numeric values
func parseNumeric(value string) interface{} {
// Try integer first
if i, err := strconv.Atoi(value); err == nil {
return i
}
// Try float
if f, err := strconv.ParseFloat(value, 64); err == nil {
return f
}
// Return as string
return value
}

View File

@@ -0,0 +1,141 @@
package validation
import (
"context"
"database/sql"
"fmt"
"time"
)
// ValidationConfig holds configuration for duplicate validation
type ValidationConfig struct {
TableName string
IDColumn string
StatusColumn string
DateColumn string
ActiveStatuses []string
AdditionalFields map[string]interface{}
}
// DuplicateValidator provides methods for validating duplicate entries
type DuplicateValidator struct {
db *sql.DB
}
// NewDuplicateValidator creates a new instance of DuplicateValidator
func NewDuplicateValidator(db *sql.DB) *DuplicateValidator {
return &DuplicateValidator{db: db}
}
// ValidateDuplicate checks for duplicate entries based on the provided configuration
func (dv *DuplicateValidator) ValidateDuplicate(ctx context.Context, config ValidationConfig, identifier interface{}) error {
query := fmt.Sprintf(`
SELECT COUNT(*)
FROM %s
WHERE %s = $1
AND %s = ANY($2)
AND DATE(%s) = CURRENT_DATE
`, config.TableName, config.IDColumn, config.StatusColumn, config.DateColumn)
var count int
err := dv.db.QueryRowContext(ctx, query, identifier, config.ActiveStatuses).Scan(&count)
if err != nil {
return fmt.Errorf("failed to check duplicate: %w", err)
}
if count > 0 {
return fmt.Errorf("data with ID %v already exists with active status today", identifier)
}
return nil
}
// ValidateDuplicateWithCustomFields checks for duplicates with additional custom fields
func (dv *DuplicateValidator) ValidateDuplicateWithCustomFields(ctx context.Context, config ValidationConfig, fields map[string]interface{}) error {
whereClause := fmt.Sprintf("%s = ANY($1) AND DATE(%s) = CURRENT_DATE", config.StatusColumn, config.DateColumn)
args := []interface{}{config.ActiveStatuses}
argIndex := 2
// Add additional field conditions
for fieldName, fieldValue := range config.AdditionalFields {
whereClause += fmt.Sprintf(" AND %s = $%d", fieldName, argIndex)
args = append(args, fieldValue)
argIndex++
}
// Add dynamic fields
for fieldName, fieldValue := range fields {
whereClause += fmt.Sprintf(" AND %s = $%d", fieldName, argIndex)
args = append(args, fieldValue)
argIndex++
}
query := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s", config.TableName, whereClause)
var count int
err := dv.db.QueryRowContext(ctx, query, args...).Scan(&count)
if err != nil {
return fmt.Errorf("failed to check duplicate with custom fields: %w", err)
}
if count > 0 {
return fmt.Errorf("duplicate entry found with the specified criteria")
}
return nil
}
// ValidateOncePerDay ensures only one submission per day for a given identifier
func (dv *DuplicateValidator) ValidateOncePerDay(ctx context.Context, tableName, idColumn, dateColumn string, identifier interface{}) error {
query := fmt.Sprintf(`
SELECT COUNT(*)
FROM %s
WHERE %s = $1
AND DATE(%s) = CURRENT_DATE
`, tableName, idColumn, dateColumn)
var count int
err := dv.db.QueryRowContext(ctx, query, identifier).Scan(&count)
if err != nil {
return fmt.Errorf("failed to check daily submission: %w", err)
}
if count > 0 {
return fmt.Errorf("only one submission allowed per day for ID %v", identifier)
}
return nil
}
// GetLastSubmissionTime returns the last submission time for a given identifier
func (dv *DuplicateValidator) GetLastSubmissionTime(ctx context.Context, tableName, idColumn, dateColumn string, identifier interface{}) (*time.Time, error) {
query := fmt.Sprintf(`
SELECT %s
FROM %s
WHERE %s = $1
ORDER BY %s DESC
LIMIT 1
`, dateColumn, tableName, idColumn, dateColumn)
var lastTime time.Time
err := dv.db.QueryRowContext(ctx, query, identifier).Scan(&lastTime)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil // No previous submission
}
return nil, fmt.Errorf("failed to get last submission time: %w", err)
}
return &lastTime, nil
}
// DefaultRetribusiConfig returns default configuration for retribusi validation
func DefaultRetribusiConfig() ValidationConfig {
return ValidationConfig{
TableName: "data_retribusi",
IDColumn: "id",
StatusColumn: "status",
DateColumn: "date_created",
ActiveStatuses: []string{"active", "draft"},
}
}