package db

import (
	"context"
	"database/sql"
	"encoding/json"
	"fmt"
	"os"
	"path/filepath"
	"reflect"
	"strings"
	"sync"
	"time"

	gonanoid "github.com/matoous/go-nanoid/v2"
)

// LoggedTx wraps *sql.Tx and records all Exec/ExecContext calls for mutation logging
type LoggedTx struct {
	tx        *sql.Tx
	queries   []QueryRecord
	mu        sync.Mutex
	toolName  string
	startTime time.Time
}

// QueryRecord represents a single SQL statement with parameters
type QueryRecord struct {
	SQL        string `json:"sql"`
	Parameters []any  `json:"parameters"`
}

// TransactionEvent represents a complete transaction for the event log
type TransactionEvent struct {
	ID        string        `json:"id"`
	Timestamp time.Time     `json:"timestamp"`
	Tool      string        `json:"tool,omitempty"`
	Queries   []QueryRecord `json:"queries"`
	Success   bool          `json:"success"`
	Duration  int64         `json:"duration_ms"`
}

// LoggedStmt wraps *sql.Stmt to intercept Exec calls on prepared statements
type LoggedStmt struct {
	stmt *sql.Stmt
	tx   *LoggedTx
	sql  string
}

// EventLogConfig holds configuration for event logging
type EventLogConfig struct {
	Enabled bool
	Path    string
}

var (
	eventLogConfig EventLogConfig
	eventLogMu     sync.Mutex
	eventLogFile   *os.File
	eventLogEnc    *json.Encoder
)

// SetEventLogConfig configures event logging globally
func SetEventLogConfig(cfg EventLogConfig) {
	eventLogMu.Lock()
	defer eventLogMu.Unlock()

	// Close existing file if path changed
	if eventLogFile != nil && eventLogConfig.Path != cfg.Path {
		_ = eventLogFile.Close()
		eventLogFile = nil
		eventLogEnc = nil
	}

	eventLogConfig = cfg
}

// BeginLoggedTx starts a new transaction that logs all mutations
// toolName is optional and identifies which tool initiated the transaction
func BeginLoggedTx(ctx context.Context, db *sql.DB, toolName string) (*LoggedTx, error) {
	tx, err := db.BeginTx(ctx, nil)
	if err != nil {
		return nil, err
	}

	return &LoggedTx{
		tx:        tx,
		queries:   make([]QueryRecord, 0),
		toolName:  toolName,
		startTime: time.Now(),
	}, nil
}

// ExecContext executes and records the SQL statement if it's a mutation
func (l *LoggedTx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
	result, err := l.tx.ExecContext(ctx, query, args...)
	if err == nil && isMutation(query) {
		l.mu.Lock()
		l.queries = append(l.queries, QueryRecord{
			SQL:        query,
			Parameters: args,
		})
		l.mu.Unlock()
	}
	return result, err
}

// Exec executes and records the SQL statement if it's a mutation
func (l *LoggedTx) Exec(query string, args ...any) (sql.Result, error) {
	return l.ExecContext(context.Background(), query, args...)
}

// QueryRowContext delegates to underlying tx (not logged - read operation)
func (l *LoggedTx) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
	return l.tx.QueryRowContext(ctx, query, args...)
}

// QueryRow delegates to underlying tx (not logged - read operation)
func (l *LoggedTx) QueryRow(query string, args ...any) *sql.Row {
	return l.tx.QueryRow(query, args...)
}

// QueryContext delegates to underlying tx (not logged - read operation)
func (l *LoggedTx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
	return l.tx.QueryContext(ctx, query, args...)
}

// Query delegates to underlying tx (not logged - read operation)
func (l *LoggedTx) Query(query string, args ...any) (*sql.Rows, error) {
	return l.tx.Query(query, args...)
}

// UnderlyingTx returns the underlying *sql.Tx for use by packages that
// should not import db (e.g. utils). Prefer using LoggedTx methods directly
// when event logging is desired.
func (l *LoggedTx) UnderlyingTx() *sql.Tx {
	return l.tx
}

// PrepareContext creates a logged prepared statement
func (l *LoggedTx) PrepareContext(ctx context.Context, query string) (*LoggedStmt, error) {
	stmt, err := l.tx.PrepareContext(ctx, query)
	if err != nil {
		return nil, err
	}
	return &LoggedStmt{stmt: stmt, tx: l, sql: query}, nil
}

// Prepare creates a logged prepared statement
func (l *LoggedTx) Prepare(query string) (*LoggedStmt, error) {
	return l.PrepareContext(context.Background(), query)
}

// Rollback rolls back the transaction (discards recorded queries)
func (l *LoggedTx) Rollback() error {
	l.mu.Lock()
	l.queries = nil // Discard recorded queries
	l.mu.Unlock()
	return l.tx.Rollback()
}

// Commit commits the transaction and logs all recorded queries on success
func (l *LoggedTx) Commit() error {
	err := l.tx.Commit()
	if err != nil {
		return err
	}

	// Log on success only
	l.mu.Lock()
	queries := l.queries
	l.mu.Unlock()

	if len(queries) > 0 && eventLogConfig.Enabled {
		l.writeEvent(queries)
	}

	return nil
}

// writeEvent writes the transaction to the event log
func (l *LoggedTx) writeEvent(queries []QueryRecord) {
	eventLogMu.Lock()
	defer eventLogMu.Unlock()

	if !eventLogConfig.Enabled {
		return
	}

	// Ensure file is open
	if err := ensureEventLogFile(); err != nil {
		// Log to stderr but don't fail the commit
		fmt.Fprintf(os.Stderr, "Warning: failed to open event log: %v\n", err)
		return
	}

	id, err := gonanoid.New(21)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Warning: failed to generate event ID: %v\n", err)
		return
	}

	event := TransactionEvent{
		ID:        id,
		Timestamp: time.Now(),
		Tool:      l.toolName,
		Queries:   queries,
		Success:   true,
		Duration:  time.Since(l.startTime).Milliseconds(),
	}

	if err := eventLogEnc.Encode(event); err != nil {
		fmt.Fprintf(os.Stderr, "Warning: failed to write event log: %v\n", err)
	}
}

// LoggedStmt methods

// ExecContext executes the prepared statement and logs if it's a mutation
func (s *LoggedStmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) {
	result, err := s.stmt.ExecContext(ctx, args...)
	if err == nil && isMutation(s.sql) {
		s.tx.mu.Lock()
		s.tx.queries = append(s.tx.queries, QueryRecord{
			SQL:        s.sql,
			Parameters: args,
		})
		s.tx.mu.Unlock()
	}
	return result, err
}

// Exec executes the prepared statement and logs if it's a mutation
func (s *LoggedStmt) Exec(args ...any) (sql.Result, error) {
	return s.ExecContext(context.Background(), args...)
}

// QueryRowContext delegates to underlying statement
func (s *LoggedStmt) QueryRowContext(ctx context.Context, args ...any) *sql.Row {
	return s.stmt.QueryRowContext(ctx, args...)
}

// QueryRow delegates to underlying statement
func (s *LoggedStmt) QueryRow(args ...any) *sql.Row {
	return s.stmt.QueryRow(args...)
}

// QueryContext delegates to underlying statement
func (s *LoggedStmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) {
	return s.stmt.QueryContext(ctx, args...)
}

// Query delegates to underlying statement
func (s *LoggedStmt) Query(args ...any) (*sql.Rows, error) {
	return s.stmt.Query(args...)
}

// Close closes the prepared statement
func (s *LoggedStmt) Close() error {
	return s.stmt.Close()
}

// isMutation returns true if the SQL is a mutation (INSERT, UPDATE, DELETE)
func isMutation(sqlStr string) bool {
	upper := strings.ToUpper(strings.TrimSpace(sqlStr))
	// Handle WITH clauses (CTEs) that may contain mutations
	if strings.HasPrefix(upper, "WITH") {
		// Check for INSERT/UPDATE/DELETE within the query
		return strings.Contains(upper, "INSERT") ||
			strings.Contains(upper, "UPDATE") ||
			strings.Contains(upper, "DELETE")
	}
	return strings.HasPrefix(upper, "INSERT") ||
		strings.HasPrefix(upper, "UPDATE") ||
		strings.HasPrefix(upper, "DELETE")
}

// ensureEventLogFile opens the event log file if not already open
func ensureEventLogFile() error {
	if eventLogFile != nil {
		return nil
	}

	dir := filepath.Dir(eventLogConfig.Path)
	if err := os.MkdirAll(dir, 0755); err != nil {
		return fmt.Errorf("failed to create event log directory: %w", err)
	}

	f, err := os.OpenFile(eventLogConfig.Path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
	if err != nil {
		return fmt.Errorf("failed to open event log file: %w", err)
	}

	eventLogFile = f
	eventLogEnc = json.NewEncoder(f)
	eventLogEnc.SetEscapeHTML(false)

	return nil
}

// CloseEventLog closes the event log file
func CloseEventLog() error {
	eventLogMu.Lock()
	defer eventLogMu.Unlock()

	// Disable logging before closing
	eventLogConfig.Enabled = false

	if eventLogFile != nil {
		err := eventLogFile.Close()
		eventLogFile = nil
		eventLogEnc = nil
		return err
	}
	return nil
}

// MarshalJSON implements json.Marshaler for QueryRecord
// Handles special types like time.Time, nil, and nullable types
func (q QueryRecord) MarshalJSON() ([]byte, error) {
	// Create a helper struct with string parameters
	type QueryRecordJSON struct {
		SQL        string `json:"sql"`
		Parameters []any  `json:"parameters"`
	}

	result := QueryRecordJSON{
		SQL:        q.SQL,
		Parameters: make([]any, len(q.Parameters)),
	}

	for i, param := range q.Parameters {
		result.Parameters[i] = marshalParam(param)
	}

	return json.Marshal(result)
}

// marshalParam converts a parameter to a JSON-serializable value.
// Pointer types (including all *T) are handled via reflection: nil → null,
// non-nil → dereference and recurse.
func marshalParam(param any) any {
	if param == nil {
		return nil
	}

	// Handle pointer types via reflection: nil → null, else dereference and recurse.
	// This covers all *T cases (including *time.Time) without explicit type switches.
	rv := reflect.ValueOf(param)
	if rv.Kind() == reflect.Pointer {
		if rv.IsNil() {
			return nil
		}
		return marshalParam(rv.Elem().Interface())
	}

	// Value types
	switch v := param.(type) {
	case time.Time:
		return v.Format(time.RFC3339Nano)
	case string:
		return v
	case int, int8, int16, int32, int64,
		uint, uint8, uint16, uint32, uint64,
		float32, float64, bool:
		return v
	case []byte:
		return v
	default:
		return fmt.Sprintf("%v", v)
	}
}