package db
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"os"
"path/filepath"
"reflect"
"strings"
"sync"
"time"
gonanoid "github.com/matoous/go-nanoid/v2"
)
type LoggedTx struct {
tx *sql.Tx
queries []QueryRecord
mu sync.Mutex
toolName string
startTime time.Time
}
type QueryRecord struct {
SQL string `json:"sql"`
Parameters []any `json:"parameters"`
}
type TransactionEvent struct {
ID string `json:"id"`
Timestamp time.Time `json:"timestamp"`
Tool string `json:"tool,omitempty"`
Queries []QueryRecord `json:"queries"`
Success bool `json:"success"`
Duration int64 `json:"duration_ms"`
}
type LoggedStmt struct {
stmt *sql.Stmt
tx *LoggedTx
sql string
}
type EventLogConfig struct {
Enabled bool
Path string
}
var (
eventLogConfig EventLogConfig
eventLogMu sync.Mutex
eventLogFile *os.File
eventLogEnc *json.Encoder
)
func SetEventLogConfig(cfg EventLogConfig) {
eventLogMu.Lock()
defer eventLogMu.Unlock()
if eventLogFile != nil && eventLogConfig.Path != cfg.Path {
_ = eventLogFile.Close()
eventLogFile = nil
eventLogEnc = nil
}
eventLogConfig = cfg
}
func BeginLoggedTx(ctx context.Context, db *sql.DB, toolName string) (*LoggedTx, error) {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
return &LoggedTx{
tx: tx,
queries: make([]QueryRecord, 0),
toolName: toolName,
startTime: time.Now(),
}, nil
}
func (l *LoggedTx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
result, err := l.tx.ExecContext(ctx, query, args...)
if err == nil && isMutation(query) {
l.mu.Lock()
l.queries = append(l.queries, QueryRecord{
SQL: query,
Parameters: args,
})
l.mu.Unlock()
}
return result, err
}
func (l *LoggedTx) Exec(query string, args ...any) (sql.Result, error) {
return l.ExecContext(context.Background(), query, args...)
}
func (l *LoggedTx) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
return l.tx.QueryRowContext(ctx, query, args...)
}
func (l *LoggedTx) QueryRow(query string, args ...any) *sql.Row {
return l.tx.QueryRow(query, args...)
}
func (l *LoggedTx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
return l.tx.QueryContext(ctx, query, args...)
}
func (l *LoggedTx) Query(query string, args ...any) (*sql.Rows, error) {
return l.tx.Query(query, args...)
}
func (l *LoggedTx) UnderlyingTx() *sql.Tx {
return l.tx
}
func (l *LoggedTx) PrepareContext(ctx context.Context, query string) (*LoggedStmt, error) {
stmt, err := l.tx.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
return &LoggedStmt{stmt: stmt, tx: l, sql: query}, nil
}
func (l *LoggedTx) Prepare(query string) (*LoggedStmt, error) {
return l.PrepareContext(context.Background(), query)
}
func (l *LoggedTx) Rollback() error {
l.mu.Lock()
l.queries = nil l.mu.Unlock()
return l.tx.Rollback()
}
func (l *LoggedTx) Commit() error {
err := l.tx.Commit()
if err != nil {
return err
}
l.mu.Lock()
queries := l.queries
l.mu.Unlock()
if len(queries) > 0 && eventLogConfig.Enabled {
l.writeEvent(queries)
}
return nil
}
func (l *LoggedTx) writeEvent(queries []QueryRecord) {
eventLogMu.Lock()
defer eventLogMu.Unlock()
if !eventLogConfig.Enabled {
return
}
if err := ensureEventLogFile(); err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to open event log: %v\n", err)
return
}
id, err := gonanoid.New(21)
if err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to generate event ID: %v\n", err)
return
}
event := TransactionEvent{
ID: id,
Timestamp: time.Now(),
Tool: l.toolName,
Queries: queries,
Success: true,
Duration: time.Since(l.startTime).Milliseconds(),
}
if err := eventLogEnc.Encode(event); err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to write event log: %v\n", err)
}
}
func (s *LoggedStmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) {
result, err := s.stmt.ExecContext(ctx, args...)
if err == nil && isMutation(s.sql) {
s.tx.mu.Lock()
s.tx.queries = append(s.tx.queries, QueryRecord{
SQL: s.sql,
Parameters: args,
})
s.tx.mu.Unlock()
}
return result, err
}
func (s *LoggedStmt) Exec(args ...any) (sql.Result, error) {
return s.ExecContext(context.Background(), args...)
}
func (s *LoggedStmt) QueryRowContext(ctx context.Context, args ...any) *sql.Row {
return s.stmt.QueryRowContext(ctx, args...)
}
func (s *LoggedStmt) QueryRow(args ...any) *sql.Row {
return s.stmt.QueryRow(args...)
}
func (s *LoggedStmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) {
return s.stmt.QueryContext(ctx, args...)
}
func (s *LoggedStmt) Query(args ...any) (*sql.Rows, error) {
return s.stmt.Query(args...)
}
func (s *LoggedStmt) Close() error {
return s.stmt.Close()
}
func isMutation(sqlStr string) bool {
upper := strings.ToUpper(strings.TrimSpace(sqlStr))
if strings.HasPrefix(upper, "WITH") {
return strings.Contains(upper, "INSERT") ||
strings.Contains(upper, "UPDATE") ||
strings.Contains(upper, "DELETE")
}
return strings.HasPrefix(upper, "INSERT") ||
strings.HasPrefix(upper, "UPDATE") ||
strings.HasPrefix(upper, "DELETE")
}
func ensureEventLogFile() error {
if eventLogFile != nil {
return nil
}
dir := filepath.Dir(eventLogConfig.Path)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create event log directory: %w", err)
}
f, err := os.OpenFile(eventLogConfig.Path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return fmt.Errorf("failed to open event log file: %w", err)
}
eventLogFile = f
eventLogEnc = json.NewEncoder(f)
eventLogEnc.SetEscapeHTML(false)
return nil
}
func CloseEventLog() error {
eventLogMu.Lock()
defer eventLogMu.Unlock()
eventLogConfig.Enabled = false
if eventLogFile != nil {
err := eventLogFile.Close()
eventLogFile = nil
eventLogEnc = nil
return err
}
return nil
}
func (q QueryRecord) MarshalJSON() ([]byte, error) {
type QueryRecordJSON struct {
SQL string `json:"sql"`
Parameters []any `json:"parameters"`
}
result := QueryRecordJSON{
SQL: q.SQL,
Parameters: make([]any, len(q.Parameters)),
}
for i, param := range q.Parameters {
result.Parameters[i] = marshalParam(param)
}
return json.Marshal(result)
}
func marshalParam(param any) any {
if param == nil {
return nil
}
rv := reflect.ValueOf(param)
if rv.Kind() == reflect.Pointer {
if rv.IsNil() {
return nil
}
return marshalParam(rv.Elem().Interface())
}
switch v := param.(type) {
case time.Time:
return v.Format(time.RFC3339Nano)
case string:
return v
case int, int8, int16, int32, int64,
uint, uint8, uint16, uint32, uint64,
float32, float64, bool:
return v
case []byte:
return v
default:
return fmt.Sprintf("%v", v)
}
}