package tools

import (
	"context"
	"database/sql"
	"encoding/base64"
	"fmt"
	"regexp"
	"strings"
	"time"

	"skraak/db"
)

// Package-level variable to store database path
var dbPath string

// SetDBPath sets the database path for the tools package
// Called from main.go during initialization
func SetDBPath(path string) {
	dbPath = path
}

// ExecuteSQLInput defines the input parameters for the execute_sql tool
type ExecuteSQLInput struct {
	Query      string `json:"query"`
	Parameters []any  `json:"parameters,omitempty"`
	Limit      *int   `json:"limit,omitempty"`
}

// ColumnInfo contains metadata about a result column
type ColumnInfo struct {
	Name         string `json:"name"`
	DatabaseType string `json:"database_type"`
}

// ExecuteSQLOutput defines the output structure for the execute_sql tool
type ExecuteSQLOutput struct {
	Rows     []map[string]any `json:"rows"`
	RowCount int              `json:"row_count"`
	Columns  []ColumnInfo     `json:"columns"`
	Limited  bool             `json:"limited"`
	Query    string           `json:"query_executed"`
}

// Validation patterns
var (
	// Must start with SELECT or WITH (case-insensitive, allows leading whitespace)
	selectPattern = regexp.MustCompile(`(?i)^\s*(SELECT|WITH)\s+`)

	// Check for forbidden keywords that might indicate write operations
	forbiddenPattern = regexp.MustCompile(`(?i)\b(INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|TRUNCATE|GRANT|REVOKE)\b`)

	// Check for existing LIMIT clause (case-insensitive)
	limitPattern = regexp.MustCompile(`(?i)\bLIMIT\s+\d+`)
)

const (
	defaultLimit = 1000
	maxLimit     = 10000
)

// ExecuteSQL executes arbitrary SQL SELECT queries with safety validation
// ExecuteSQL executes arbitrary SQL SELECT queries with safety validation and row limiting
func ExecuteSQL(
	ctx context.Context,
	input ExecuteSQLInput,
) (ExecuteSQLOutput, error) {
	if err := validateSQLQuery(input.Query, input.Limit); err != nil {
		return ExecuteSQLOutput{}, err
	}

	limit := resolveLimit(input.Limit)
	query, autoAddedLimit := applyLimit(input.Query, limit)

	var output ExecuteSQLOutput
	err := db.WithReadDB(dbPath, func(database *sql.DB) error {
		rows, rerr := executeSQLQuery(ctx, database, query, input.Parameters)
		if rerr != nil {
			return rerr
		}
		defer rows.Close()

		columnInfo, columns, cerr := buildColumnInfo(rows)
		if cerr != nil {
			return cerr
		}

		results, serr := scanResultRows(rows, columns)
		if serr != nil {
			return serr
		}

		// Handle empty results (return empty array, not error)
		if results == nil {
			results = []map[string]any{}
		}

		// Detect truncation: if we auto-added limit+1 and got more than limit rows
		limited := false
		if autoAddedLimit && len(results) > limit {
			limited = true
			results = results[:limit]
		}

		queryReported := buildQueryReported(input.Query, autoAddedLimit, limit)

		output = ExecuteSQLOutput{
			Rows:     results,
			RowCount: len(results),
			Columns:  columnInfo,
			Limited:  limited,
			Query:    queryReported,
		}
		return nil
	})
	return output, err
}

// validateSQLQuery checks the query is a safe SELECT/WITH statement.
func validateSQLQuery(query string, limit *int) error {
	if strings.TrimSpace(query) == "" {
		return fmt.Errorf("query cannot be empty")
	}
	if !selectPattern.MatchString(query) {
		return fmt.Errorf("only SELECT and WITH queries are allowed")
	}
	if forbiddenPattern.MatchString(query) {
		return fmt.Errorf("query contains forbidden keywords (INSERT/UPDATE/DELETE/DROP/CREATE/ALTER)")
	}
	if limit != nil {
		if *limit < 1 || *limit > maxLimit {
			return fmt.Errorf("limit must be between 1 and %d", maxLimit)
		}
	}
	return nil
}

// resolveLimit returns the effective row limit from input or default.
func resolveLimit(limit *int) int {
	if limit != nil {
		return *limit
	}
	return defaultLimit
}

// applyLimit appends a LIMIT clause if not already present.
// Returns the modified query and whether a limit was auto-added.
func applyLimit(query string, limit int) (string, bool) {
	if !limitPattern.MatchString(query) {
		return fmt.Sprintf("%s LIMIT %d", strings.TrimSpace(query), limit+1), true
	}
	return query, false
}

// executeSQLQuery runs the query and returns the result rows.
func executeSQLQuery(ctx context.Context, database *sql.DB, query string, params []any) (*sql.Rows, error) {
	if len(params) > 0 {
		return database.QueryContext(ctx, query, params...)
	}
	return database.QueryContext(ctx, query)
}

// buildColumnInfo extracts column metadata from the result set.
func buildColumnInfo(rows *sql.Rows) ([]ColumnInfo, []string, error) {
	columns, err := rows.Columns()
	if err != nil {
		return nil, nil, fmt.Errorf("failed to get columns: %w", err)
	}
	columnTypes, err := rows.ColumnTypes()
	if err != nil {
		return nil, nil, fmt.Errorf("failed to get column types: %w", err)
	}
	columnInfo := make([]ColumnInfo, len(columns))
	for i, col := range columns {
		columnInfo[i] = ColumnInfo{
			Name:         col,
			DatabaseType: columnTypes[i].DatabaseTypeName(),
		}
	}
	return columnInfo, columns, nil
}

// scanResultRows scans all rows from the result set into maps.
func scanResultRows(rows *sql.Rows, columns []string) ([]map[string]any, error) {
	var results []map[string]any
	for rows.Next() {
		values := make([]any, len(columns))
		valuePtrs := make([]any, len(columns))
		for i := range values {
			valuePtrs[i] = &values[i]
		}
		if err := rows.Scan(valuePtrs...); err != nil {
			return nil, fmt.Errorf("row scan failed: %w", err)
		}
		rowMap := make(map[string]any)
		for i, col := range columns {
			rowMap[col] = convertValue(values[i])
		}
		results = append(results, rowMap)
	}
	if err := rows.Err(); err != nil {
		return nil, fmt.Errorf("row iteration failed: %w", err)
	}
	return results, nil
}

// buildQueryReported constructs the query string to report in output.
func buildQueryReported(originalQuery string, autoAddedLimit bool, limit int) string {
	if autoAddedLimit {
		return fmt.Sprintf("%s LIMIT %d", strings.TrimSpace(originalQuery), limit)
	}
	return originalQuery
}

// convertValue converts database values to JSON-friendly types
func convertValue(val any) any {
	if val == nil {
		return nil
	}

	switch v := val.(type) {
	case time.Time:
		// Format timestamps as RFC3339 strings (consistent with existing code)
		return v.Format(time.RFC3339)
	case []byte:
		// Convert binary data to base64
		return base64.StdEncoding.EncodeToString(v)
	case int64, float64, string, bool:
		// Pass through primitive types
		return v
	default:
		// For unknown types, convert to string
		return fmt.Sprintf("%v", v)
	}
}