package tools

import (
	"context"
	"database/sql"
	"fmt"
	"skraak_mcp/db"
	"skraak_mcp/utils"
	"strings"

	"github.com/modelcontextprotocol/go-sdk/mcp"
)

// PatternInput defines the input parameters for the create_or_update_pattern tool
type PatternInput struct {
	ID            *string `json:"id,omitempty" jsonschema:"Pattern ID (12 characters). Omit to create a new pattern, provide to update an existing one."`
	RecordSeconds *int    `json:"record_seconds,omitempty" jsonschema:"Number of seconds to record (must be positive). Required for create."`
	SleepSeconds  *int    `json:"sleep_seconds,omitempty" jsonschema:"Number of seconds to sleep between recordings (must be positive for create, >= 0 for update)."`
}

// PatternOutput defines the output structure
type PatternOutput struct {
	Pattern db.CyclicRecordingPattern `json:"pattern" jsonschema:"The created or updated recording pattern"`
	Message string                    `json:"message" jsonschema:"Success message"`
}

// CreateOrUpdatePattern implements the create_or_update_pattern tool handler
func CreateOrUpdatePattern(
	ctx context.Context,
	req *mcp.CallToolRequest,
	input PatternInput,
) (*mcp.CallToolResult, PatternOutput, error) {
	if input.ID != nil && strings.TrimSpace(*input.ID) != "" {
		return updatePattern(ctx, input)
	}
	return createPattern(ctx, input)
}

func createPattern(ctx context.Context, input PatternInput) (*mcp.CallToolResult, PatternOutput, error) {
	var output PatternOutput

	// Validate required fields for create
	if input.RecordSeconds == nil {
		return nil, output, fmt.Errorf("record_seconds is required when creating a pattern")
	}
	if input.SleepSeconds == nil {
		return nil, output, fmt.Errorf("sleep_seconds is required when creating a pattern")
	}
	if *input.RecordSeconds <= 0 {
		return nil, output, fmt.Errorf("record_seconds must be positive (got %d)", *input.RecordSeconds)
	}
	if *input.SleepSeconds <= 0 {
		return nil, output, fmt.Errorf("sleep_seconds must be positive (got %d)", *input.SleepSeconds)
	}

	// Open writable database connection
	database, err := db.OpenWriteableDB(dbPath)
	if err != nil {
		return nil, output, fmt.Errorf("database connection failed: %w", err)
	}
	defer database.Close()

	// Begin transaction
	tx, err := database.BeginTx(ctx, nil)
	if err != nil {
		return nil, output, fmt.Errorf("failed to begin transaction: %w", err)
	}
	defer func() {
		if err != nil {
			tx.Rollback()
		}
	}()

	// Check if pattern with same record_s/sleep_s already exists
	var existingID string
	err = tx.QueryRowContext(ctx,
		"SELECT id FROM cyclic_recording_pattern WHERE record_s = ? AND sleep_s = ? AND active = true",
		*input.RecordSeconds, *input.SleepSeconds,
	).Scan(&existingID)

	if err == nil {
		// Pattern already exists, return it instead of creating duplicate
		var pattern db.CyclicRecordingPattern
		err = tx.QueryRowContext(ctx,
			"SELECT id, record_s, sleep_s, created_at, last_modified, active FROM cyclic_recording_pattern WHERE id = ?",
			existingID,
		).Scan(&pattern.ID, &pattern.RecordS, &pattern.SleepS, &pattern.CreatedAt, &pattern.LastModified, &pattern.Active)
		if err != nil {
			return nil, output, fmt.Errorf("failed to fetch existing pattern: %w", err)
		}

		if err = tx.Commit(); err != nil {
			return nil, output, fmt.Errorf("failed to commit transaction: %w", err)
		}

		output.Pattern = pattern
		output.Message = fmt.Sprintf("Pattern already exists with ID %s (record %ds, sleep %ds) - returning existing pattern",
			pattern.ID, pattern.RecordS, pattern.SleepS)

		return &mcp.CallToolResult{}, output, nil
	} else if err != sql.ErrNoRows {
		return nil, output, fmt.Errorf("failed to check for existing pattern: %w", err)
	}

	// Generate ID
	id, err := utils.GenerateShortID()
	if err != nil {
		return nil, output, fmt.Errorf("failed to generate ID: %w", err)
	}

	// Insert pattern
	_, err = tx.ExecContext(ctx,
		"INSERT INTO cyclic_recording_pattern (id, record_s, sleep_s, created_at, last_modified, active) VALUES (?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, TRUE)",
		id, *input.RecordSeconds, *input.SleepSeconds,
	)
	if err != nil {
		return nil, output, fmt.Errorf("failed to create pattern: %w", err)
	}

	// Fetch the created pattern
	var pattern db.CyclicRecordingPattern
	err = tx.QueryRowContext(ctx,
		"SELECT id, record_s, sleep_s, created_at, last_modified, active FROM cyclic_recording_pattern WHERE id = ?",
		id,
	).Scan(&pattern.ID, &pattern.RecordS, &pattern.SleepS, &pattern.CreatedAt, &pattern.LastModified, &pattern.Active)
	if err != nil {
		return nil, output, fmt.Errorf("failed to fetch created pattern: %w", err)
	}

	if err = tx.Commit(); err != nil {
		return nil, output, fmt.Errorf("failed to commit transaction: %w", err)
	}

	output.Pattern = pattern
	output.Message = fmt.Sprintf("Successfully created cyclic recording pattern with ID %s (record %ds, sleep %ds)",
		pattern.ID, pattern.RecordS, pattern.SleepS)

	return &mcp.CallToolResult{}, output, nil
}

func updatePattern(ctx context.Context, input PatternInput) (*mcp.CallToolResult, PatternOutput, error) {
	var output PatternOutput
	patternID := *input.ID

	// Validate fields if provided
	if input.RecordSeconds != nil && *input.RecordSeconds <= 0 {
		return nil, output, fmt.Errorf("record_seconds must be greater than 0: %d", *input.RecordSeconds)
	}
	if input.SleepSeconds != nil && *input.SleepSeconds < 0 {
		return nil, output, fmt.Errorf("sleep_seconds must be greater than or equal to 0: %d", *input.SleepSeconds)
	}

	// Open writable database
	database, err := db.OpenWriteableDB(dbPath)
	if err != nil {
		return nil, output, fmt.Errorf("failed to open database: %w", err)
	}
	defer database.Close()

	// Verify pattern exists
	var exists bool
	err = database.QueryRow("SELECT EXISTS(SELECT 1 FROM cyclic_recording_pattern WHERE id = ?)", patternID).Scan(&exists)
	if err != nil {
		return nil, output, fmt.Errorf("failed to query pattern: %w", err)
	}
	if !exists {
		return nil, output, fmt.Errorf("pattern not found: %s", patternID)
	}

	// Build dynamic UPDATE query
	updates := []string{}
	args := []any{}

	if input.RecordSeconds != nil {
		updates = append(updates, "record_s = ?")
		args = append(args, *input.RecordSeconds)
	}
	if input.SleepSeconds != nil {
		updates = append(updates, "sleep_s = ?")
		args = append(args, *input.SleepSeconds)
	}

	if len(updates) == 0 {
		return nil, output, fmt.Errorf("no fields provided to update")
	}

	// Always update last_modified
	updates = append(updates, "last_modified = now()")
	args = append(args, patternID)

	query := fmt.Sprintf("UPDATE cyclic_recording_pattern SET %s WHERE id = ?", strings.Join(updates, ", "))
	_, err = database.Exec(query, args...)
	if err != nil {
		return nil, output, fmt.Errorf("failed to update pattern: %w", err)
	}

	// Fetch the updated pattern
	var pattern db.CyclicRecordingPattern
	err = database.QueryRow(
		"SELECT id, record_s, sleep_s, created_at, last_modified, active FROM cyclic_recording_pattern WHERE id = ?",
		patternID,
	).Scan(&pattern.ID, &pattern.RecordS, &pattern.SleepS, &pattern.CreatedAt, &pattern.LastModified, &pattern.Active)
	if err != nil {
		return nil, output, fmt.Errorf("failed to fetch updated pattern: %w", err)
	}

	output.Pattern = pattern
	output.Message = fmt.Sprintf("Successfully updated pattern (ID: %s, record %ds, sleep %ds)",
		pattern.ID, pattern.RecordS, pattern.SleepS)

	return &mcp.CallToolResult{}, output, nil
}