Files
api_antrean/internal/database/database.go
2025-11-02 03:08:38 +00:00

857 lines
23 KiB
Go

package database
import (
"context"
"crypto/tls"
"database/sql"
"fmt"
"log"
"strconv"
"sync"
"time"
"api-service/internal/config"
_ "github.com/jackc/pgx/v5"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
_ "gorm.io/driver/mysql"
_ "gorm.io/driver/postgres"
_ "gorm.io/driver/sqlserver"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// DatabaseType represents supported database types
type DatabaseType string
const (
Postgres DatabaseType = "postgres"
MySQL DatabaseType = "mysql"
SQLServer DatabaseType = "sqlserver"
SQLite DatabaseType = "sqlite"
MongoDB DatabaseType = "mongodb"
)
// Service represents a service that interacts with multiple databases
type Service interface {
Health() map[string]map[string]string
GetDB(name string) (*sql.DB, error)
GetSQLXDB(name string) (*sqlx.DB, error) // Tambahkan metode ini
GetMongoClient(name string) (*mongo.Client, error)
GetReadDB(name string) (*sql.DB, error)
Close() error
ListDBs() []string
GetDBType(name string) (DatabaseType, error)
ListenForChanges(ctx context.Context, dbName string, channels []string, callback func(string, string)) error
NotifyChange(dbName, channel, payload string) error
GetPrimaryDB(name string) (*sql.DB, error)
ExecuteQuery(ctx context.Context, dbName string, query string, args ...interface{}) (*sql.Rows, error)
ExecuteQueryRow(ctx context.Context, dbName string, query string, args ...interface{}) *sql.Row
Exec(ctx context.Context, dbName string, query string, args ...interface{}) (sql.Result, error)
}
type service struct {
sqlDatabases map[string]*sql.DB
sqlxDatabases map[string]*sqlx.DB // Tambahkan map untuk sqlx.DB
mongoClients map[string]*mongo.Client
readReplicas map[string][]*sql.DB
configs map[string]config.DatabaseConfig
readConfigs map[string][]config.DatabaseConfig
mu sync.RWMutex
readBalancer map[string]int
listeners map[string]*pq.Listener
listenersMu sync.RWMutex
}
var (
dbManager *service
once sync.Once
)
// New creates a new database service with multiple connections
func New(cfg *config.Config) Service {
once.Do(func() {
dbManager = &service{
sqlDatabases: make(map[string]*sql.DB),
sqlxDatabases: make(map[string]*sqlx.DB), // Inisialisasi map sqlx
mongoClients: make(map[string]*mongo.Client),
readReplicas: make(map[string][]*sql.DB),
configs: make(map[string]config.DatabaseConfig),
readConfigs: make(map[string][]config.DatabaseConfig),
readBalancer: make(map[string]int),
listeners: make(map[string]*pq.Listener),
}
log.Println("Initializing database service...")
dbManager.loadFromConfig(cfg)
// Initialize all databases
for name, dbConfig := range dbManager.configs {
if err := dbManager.addDatabase(name, dbConfig); err != nil {
log.Printf("Failed to connect to database %s: %v", name, err)
}
}
// Initialize read replicas
for name, replicaConfigs := range dbManager.readConfigs {
for i, replicaConfig := range replicaConfigs {
if err := dbManager.addReadReplica(name, i, replicaConfig); err != nil {
log.Printf("Failed to connect to read replica %s[%d]: %v", name, i, err)
}
}
}
})
return dbManager
}
func (s *service) loadFromConfig(cfg *config.Config) {
s.mu.Lock()
defer s.mu.Unlock()
// Load primary databases
for name, dbConfig := range cfg.Databases {
s.configs[name] = dbConfig
}
// Load read replicas
for name, replicaConfigs := range cfg.ReadReplicas {
s.readConfigs[name] = replicaConfigs
}
}
func (s *service) addDatabase(name string, config config.DatabaseConfig) error {
s.mu.Lock()
defer s.mu.Unlock()
// Check for duplicate database connections
for existingName, existingConfig := range s.configs {
if existingName != name &&
existingConfig.Host == config.Host &&
existingConfig.Port == config.Port &&
existingConfig.Database == config.Database &&
existingConfig.Type == config.Type {
log.Printf("⚠️ Database %s appears to be a duplicate of %s (same host:port:database), skipping connection", name, existingName)
return nil
}
}
var db *sql.DB
var err error
dbType := DatabaseType(config.Type)
switch dbType {
case Postgres:
db, err = s.openPostgresConnection(config)
case MySQL:
db, err = s.openMySQLConnection(config)
case SQLServer:
db, err = s.openSQLServerConnection(config)
case SQLite:
db, err = s.openSQLiteConnection(config)
case MongoDB:
return s.addMongoDB(name, config)
default:
return fmt.Errorf("unsupported database type: %s", config.Type)
}
if err != nil {
log.Printf("❌ Error connecting to database %s: %v", name, err)
return err
}
log.Printf("✅ Successfully connected to database: %s", name)
return s.configureSQLDB(name, db, config)
}
func (s *service) addReadReplica(name string, index int, config config.DatabaseConfig) error {
s.mu.Lock()
defer s.mu.Unlock()
var db *sql.DB
var err error
dbType := DatabaseType(config.Type)
switch dbType {
case Postgres:
db, err = s.openPostgresConnection(config)
case MySQL:
db, err = s.openMySQLConnection(config)
case SQLServer:
db, err = s.openSQLServerConnection(config)
case SQLite:
db, err = s.openSQLiteConnection(config)
default:
return fmt.Errorf("unsupported database type for read replica: %s", config.Type)
}
if err != nil {
return err
}
if s.readReplicas[name] == nil {
s.readReplicas[name] = make([]*sql.DB, 0)
}
// Ensure we have enough slots
for len(s.readReplicas[name]) <= index {
s.readReplicas[name] = append(s.readReplicas[name], nil)
}
s.readReplicas[name][index] = db
log.Printf("Successfully connected to read replica %s[%d]", name, index)
return nil
}
func (s *service) openPostgresConnection(config config.DatabaseConfig) (*sql.DB, error) {
// Build connection string with security parameters
// Convert timeout durations to seconds for pgx
connectTimeoutSec := int(config.ConnectTimeout.Seconds())
statementTimeoutSec := int(config.StatementTimeout.Seconds())
connStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=%d statement_timeout=%d",
config.Host,
config.Port,
config.Username,
config.Password,
config.Database,
config.SSLMode,
connectTimeoutSec,
statementTimeoutSec,
)
if config.Schema != "" {
connStr += " search_path=" + config.Schema
}
// Add SSL configuration if required
if config.RequireSSL {
connStr += " sslcert=" + config.SSLCert + " sslkey=" + config.SSLKey + " sslrootcert=" + config.SSLRootCert
}
// Open connection using standard database/sql interface
db, err := sql.Open("pgx", connStr)
if err != nil {
return nil, fmt.Errorf("failed to open PostgreSQL connection: %w", err)
}
return db, nil
}
func (s *service) openMySQLConnection(config config.DatabaseConfig) (*sql.DB, error) {
// Build connection string with security parameters
connStr := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=%s&readTimeout=%s&writeTimeout=%s",
config.Username,
config.Password,
config.Host,
config.Port,
config.Database,
config.Timeout,
config.ReadTimeout,
config.WriteTimeout,
)
// Add SSL configuration if required
if config.RequireSSL {
connStr += "&tls=true"
if config.SSLRootCert != "" {
connStr += "&ssl-ca=" + config.SSLRootCert
}
if config.SSLCert != "" {
connStr += "&ssl-cert=" + config.SSLCert
}
if config.SSLKey != "" {
connStr += "&ssl-key=" + config.SSLKey
}
}
// Open connection
db, err := sql.Open("mysql", connStr)
if err != nil {
return nil, fmt.Errorf("failed to open MySQL connection: %w", err)
}
return db, nil
}
func (s *service) openSQLServerConnection(config config.DatabaseConfig) (*sql.DB, error) {
// Build connection string with security parameters
// Convert timeout to seconds for SQL Server
connectTimeoutSec := int(config.ConnectTimeout.Seconds())
connStr := fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s&connection timeout=%d",
config.Username,
config.Password,
config.Host,
config.Port,
config.Database,
connectTimeoutSec,
)
// Add SSL configuration if required
if config.RequireSSL {
connStr += "&encrypt=true"
if config.SSLRootCert != "" {
connStr += "&trustServerCertificate=false"
} else {
connStr += "&trustServerCertificate=true"
}
}
// Open connection
db, err := sql.Open("sqlserver", connStr)
if err != nil {
return nil, fmt.Errorf("failed to open SQL Server connection: %w", err)
}
return db, nil
}
func (s *service) openSQLiteConnection(config config.DatabaseConfig) (*sql.DB, error) {
// Open connection
db, err := sql.Open("sqlite3", config.Path)
if err != nil {
return nil, fmt.Errorf("failed to open SQLite connection: %w", err)
}
// Enable foreign key constraints and WAL mode for better security and performance
_, err = db.Exec("PRAGMA foreign_keys = ON; PRAGMA journal_mode = WAL;")
if err != nil {
return nil, fmt.Errorf("failed to configure SQLite: %w", err)
}
return db, nil
}
func (s *service) addMongoDB(name string, config config.DatabaseConfig) error {
ctx, cancel := context.WithTimeout(context.Background(), config.Timeout)
defer cancel()
// Build MongoDB URI with authentication and TLS options
uri := fmt.Sprintf("mongodb://%s:%s@%s:%d/%s",
config.Username,
config.Password,
config.Host,
config.Port,
config.Database,
)
// Configure client options with security settings
clientOptions := options.Client().ApplyURI(uri)
// Set TLS configuration if needed
if config.RequireSSL {
clientOptions.SetTLSConfig(&tls.Config{
InsecureSkipVerify: config.SSLMode == "require",
MinVersion: tls.VersionTLS12,
})
}
// Set connection timeout
clientOptions.SetConnectTimeout(config.ConnectTimeout)
clientOptions.SetServerSelectionTimeout(config.Timeout)
client, err := mongo.Connect(ctx, clientOptions)
if err != nil {
return fmt.Errorf("failed to connect to MongoDB: %w", err)
}
// Ping to verify connection
if err := client.Ping(ctx, nil); err != nil {
return fmt.Errorf("failed to ping MongoDB: %w", err)
}
s.mongoClients[name] = client
log.Printf("Successfully connected to MongoDB: %s", name)
return nil
}
func (s *service) configureSQLDB(name string, db *sql.DB, config config.DatabaseConfig) error {
// Set connection pool limits
db.SetMaxOpenConns(config.MaxOpenConns)
db.SetMaxIdleConns(config.MaxIdleConns)
db.SetConnMaxLifetime(config.ConnMaxLifetime)
db.SetConnMaxIdleTime(config.MaxIdleTime)
ctx, cancel := context.WithTimeout(context.Background(), config.Timeout)
defer cancel()
if err := db.PingContext(ctx); err != nil {
db.Close()
return fmt.Errorf("failed to ping database: %w", err)
}
s.sqlDatabases[name] = db
// PERUBAHAN: Tambahkan pembuatan sqlx.DB dari sql.DB yang sudah ada
dbType := DatabaseType(config.Type)
var driverName string
switch dbType {
case Postgres:
driverName = "pgx"
case MySQL:
driverName = "mysql"
case SQLServer:
driverName = "sqlserver"
case SQLite:
driverName = "sqlite3"
default:
return fmt.Errorf("unsupported database type for sqlx: %s", config.Type)
}
// Buat sqlx.DB dari sql.DB yang sudah ada
sqlxDB := sqlx.NewDb(db, driverName)
s.sqlxDatabases[name] = sqlxDB
log.Printf("Successfully connected to SQL database: %s", name)
return nil
}
// Health checks the health of all database connections by pinging each database.
func (s *service) Health() map[string]map[string]string {
s.mu.RLock()
defer s.mu.RUnlock()
result := make(map[string]map[string]string)
// Check SQL databases
for name, db := range s.sqlDatabases {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
stats := make(map[string]string)
err := db.PingContext(ctx)
if err != nil {
stats["status"] = "down"
stats["error"] = fmt.Sprintf("db down: %v", err)
stats["type"] = "sql"
stats["role"] = "primary"
result[name] = stats
continue
}
stats["status"] = "up"
stats["message"] = "It's healthy"
stats["type"] = "sql"
stats["role"] = "primary"
dbStats := db.Stats()
stats["open_connections"] = strconv.Itoa(dbStats.OpenConnections)
stats["in_use"] = strconv.Itoa(dbStats.InUse)
stats["idle"] = strconv.Itoa(dbStats.Idle)
stats["wait_count"] = strconv.FormatInt(dbStats.WaitCount, 10)
stats["wait_duration"] = dbStats.WaitDuration.String()
stats["max_idle_closed"] = strconv.FormatInt(dbStats.MaxIdleClosed, 10)
stats["max_lifetime_closed"] = strconv.FormatInt(dbStats.MaxLifetimeClosed, 10)
if dbStats.OpenConnections > 40 {
stats["message"] = "The database is experiencing heavy load."
}
if dbStats.WaitCount > 1000 {
stats["message"] = "The database has a high number of wait events, indicating potential bottlenecks."
}
if dbStats.MaxIdleClosed > int64(dbStats.OpenConnections)/2 {
stats["message"] = "Many idle connections are being closed, consider revising the connection pool settings."
}
if dbStats.MaxLifetimeClosed > int64(dbStats.OpenConnections)/2 {
stats["message"] = "Many connections are being closed due to max lifetime, consider increasing max lifetime or revising the connection usage pattern."
}
result[name] = stats
}
// Check read replicas
for name, replicas := range s.readReplicas {
for i, db := range replicas {
if db == nil {
continue
}
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
replicaName := fmt.Sprintf("%s_replica_%d", name, i)
stats := make(map[string]string)
err := db.PingContext(ctx)
if err != nil {
stats["status"] = "down"
stats["error"] = fmt.Sprintf("read replica down: %v", err)
stats["type"] = "sql"
stats["role"] = "replica"
result[replicaName] = stats
continue
}
stats["status"] = "up"
stats["message"] = "Read replica healthy"
stats["type"] = "sql"
stats["role"] = "replica"
dbStats := db.Stats()
stats["open_connections"] = strconv.Itoa(dbStats.OpenConnections)
stats["in_use"] = strconv.Itoa(dbStats.InUse)
stats["idle"] = strconv.Itoa(dbStats.Idle)
result[replicaName] = stats
}
}
// Check MongoDB connections
for name, client := range s.mongoClients {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
stats := make(map[string]string)
err := client.Ping(ctx, nil)
if err != nil {
stats["status"] = "down"
stats["error"] = fmt.Sprintf("mongodb down: %v", err)
stats["type"] = "mongodb"
result[name] = stats
continue
}
stats["status"] = "up"
stats["message"] = "It's healthy"
stats["type"] = "mongodb"
result[name] = stats
}
return result
}
// GetDB returns a specific SQL database connection by name
func (s *service) GetDB(name string) (*sql.DB, error) {
s.mu.RLock()
defer s.mu.RUnlock()
db, exists := s.sqlDatabases[name]
if !exists {
return nil, fmt.Errorf("database %s not found", name)
}
return db, nil
}
// PERUBAHAN: Tambahkan metode GetSQLXDB
// GetSQLXDB returns a specific SQLX database connection by name
func (s *service) GetSQLXDB(name string) (*sqlx.DB, error) {
s.mu.RLock()
defer s.mu.RUnlock()
db, exists := s.sqlxDatabases[name]
if !exists {
return nil, fmt.Errorf("database %s not found", name)
}
return db, nil
}
// GetReadDB returns a read replica connection using round-robin load balancing
func (s *service) GetReadDB(name string) (*sql.DB, error) {
s.mu.RLock()
defer s.mu.RUnlock()
replicas, exists := s.readReplicas[name]
if !exists || len(replicas) == 0 {
// Fallback to primary if no replicas available
return s.GetDB(name)
}
// Round-robin load balancing
s.readBalancer[name] = (s.readBalancer[name] + 1) % len(replicas)
selected := replicas[s.readBalancer[name]]
if selected == nil {
// Fallback to primary if replica is nil
return s.GetDB(name)
}
return selected, nil
}
// GetMongoClient returns a specific MongoDB client by name
func (s *service) GetMongoClient(name string) (*mongo.Client, error) {
s.mu.RLock()
defer s.mu.RUnlock()
client, exists := s.mongoClients[name]
if !exists {
return nil, fmt.Errorf("MongoDB client %s not found", name)
}
return client, nil
}
// ListDBs returns list of available database names
func (s *service) ListDBs() []string {
s.mu.RLock()
defer s.mu.RUnlock()
names := make([]string, 0, len(s.sqlDatabases)+len(s.mongoClients))
for name := range s.sqlDatabases {
names = append(names, name)
}
for name := range s.mongoClients {
names = append(names, name)
}
return names
}
// GetDBType returns the type of a specific database
func (s *service) GetDBType(name string) (DatabaseType, error) {
s.mu.RLock()
defer s.mu.RUnlock()
config, exists := s.configs[name]
if !exists {
return "", fmt.Errorf("database %s not found", name)
}
return DatabaseType(config.Type), nil
}
// Close closes all database connections
func (s *service) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
var errs []error
// Close listeners first
for name, listener := range s.listeners {
if err := listener.Close(); err != nil {
errs = append(errs, fmt.Errorf("failed to close listener for %s: %w", name, err))
}
}
for name, db := range s.sqlDatabases {
if err := db.Close(); err != nil {
errs = append(errs, fmt.Errorf("failed to close database %s: %w", name, err))
} else {
log.Printf("Disconnected from SQL database: %s", name)
}
}
for name, replicas := range s.readReplicas {
for i, db := range replicas {
if db != nil {
if err := db.Close(); err != nil {
errs = append(errs, fmt.Errorf("failed to close read replica %s[%d]: %w", name, i, err))
} else {
log.Printf("Disconnected from read replica: %s[%d]", name, i)
}
}
}
}
for name, client := range s.mongoClients {
if err := client.Disconnect(context.Background()); err != nil {
errs = append(errs, fmt.Errorf("failed to disconnect MongoDB client %s: %w", name, err))
} else {
log.Printf("Disconnected from MongoDB: %s", name)
}
}
s.sqlDatabases = make(map[string]*sql.DB)
s.sqlxDatabases = make(map[string]*sqlx.DB) // Reset map sqlx
s.mongoClients = make(map[string]*mongo.Client)
s.readReplicas = make(map[string][]*sql.DB)
s.configs = make(map[string]config.DatabaseConfig)
s.readConfigs = make(map[string][]config.DatabaseConfig)
s.listeners = make(map[string]*pq.Listener)
if len(errs) > 0 {
return fmt.Errorf("errors closing databases: %v", errs)
}
return nil
}
// GetPrimaryDB returns primary database connection
func (s *service) GetPrimaryDB(name string) (*sql.DB, error) {
return s.GetDB(name)
}
// ExecuteQuery executes a query with parameters and returns rows
func (s *service) ExecuteQuery(ctx context.Context, dbName string, query string, args ...interface{}) (*sql.Rows, error) {
db, err := s.GetDB(dbName)
if err != nil {
return nil, fmt.Errorf("failed to get database %s: %w", dbName, err)
}
// Use parameterized queries to prevent SQL injection
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to execute query: %w", err)
}
return rows, nil
}
// ExecuteQueryRow executes a query with parameters and returns a single row
func (s *service) ExecuteQueryRow(ctx context.Context, dbName string, query string, args ...interface{}) *sql.Row {
db, err := s.GetDB(dbName)
if err != nil {
// Return an empty row with error
row := &sql.Row{}
return row
}
// Use parameterized queries to prevent SQL injection
return db.QueryRowContext(ctx, query, args...)
}
// Exec executes a query with parameters and returns the result
func (s *service) Exec(ctx context.Context, dbName string, query string, args ...interface{}) (sql.Result, error) {
db, err := s.GetDB(dbName)
if err != nil {
return nil, fmt.Errorf("failed to get database %s: %w", dbName, err)
}
// Use parameterized queries to prevent SQL injection
result, err := db.ExecContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to execute query: %w", err)
}
return result, nil
}
// ListenForChanges implements PostgreSQL LISTEN/NOTIFY for real-time updates
func (s *service) ListenForChanges(ctx context.Context, dbName string, channels []string, callback func(string, string)) error {
s.mu.RLock()
config, exists := s.configs[dbName]
s.mu.RUnlock()
if !exists {
return fmt.Errorf("database %s not found", dbName)
}
// Only support PostgreSQL for LISTEN/NOTIFY
if DatabaseType(config.Type) != Postgres {
return fmt.Errorf("LISTEN/NOTIFY only supported for PostgreSQL databases")
}
// Create connection string for listener
// Convert timeout to seconds for pq
connectTimeoutSec := int(config.ConnectTimeout.Seconds())
connStr := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=%s&connect_timeout=%d",
config.Username,
config.Password,
config.Host,
config.Port,
config.Database,
config.SSLMode,
connectTimeoutSec,
)
// Create listener
listener := pq.NewListener(
connStr,
10*time.Second,
time.Minute,
func(ev pq.ListenerEventType, err error) {
if err != nil {
log.Printf("Database listener (%s) error: %v", dbName, err)
}
},
)
// Store listener for cleanup
s.listenersMu.Lock()
s.listeners[dbName] = listener
s.listenersMu.Unlock()
// Listen to specified channels
for _, channel := range channels {
err := listener.Listen(channel)
if err != nil {
listener.Close()
return fmt.Errorf("failed to listen to channel %s: %w", channel, err)
}
log.Printf("Listening to database channel: %s on %s", channel, dbName)
}
// Start listening loop
go func() {
defer func() {
listener.Close()
s.listenersMu.Lock()
delete(s.listeners, dbName)
s.listenersMu.Unlock()
log.Printf("Database listener for %s stopped", dbName)
}()
for {
select {
case n := <-listener.Notify:
if n != nil {
callback(n.Channel, n.Extra)
}
case <-ctx.Done():
return
case <-time.After(90 * time.Second):
// Send ping to keep connection alive
go func() {
if err := listener.Ping(); err != nil {
log.Printf("Listener ping failed for %s: %v", dbName, err)
}
}()
}
}
}()
return nil
}
// NotifyChange sends a notification to a PostgreSQL channel
func (s *service) NotifyChange(dbName, channel, payload string) error {
db, err := s.GetDB(dbName)
if err != nil {
return fmt.Errorf("failed to get database %s: %w", dbName, err)
}
// Check if it's PostgreSQL
s.mu.RLock()
config, exists := s.configs[dbName]
s.mu.RUnlock()
if !exists {
return fmt.Errorf("database %s configuration not found", dbName)
}
if DatabaseType(config.Type) != Postgres {
return fmt.Errorf("NOTIFY only supported for PostgreSQL databases")
}
// Execute NOTIFY with parameterized query to prevent SQL injection
query := "SELECT pg_notify($1, $2)"
_, err = db.Exec(query, channel, payload)
if err != nil {
return fmt.Errorf("failed to send notification: %w", err)
}
log.Printf("Sent notification to channel %s on %s: %s", channel, dbName, payload)
return nil
}