package tools

import (
	"context"
	"fmt"
	"skraak_mcp/db"
	"skraak_mcp/utils"
	"strings"

	"github.com/modelcontextprotocol/go-sdk/mcp"
)

// ClusterInput defines the input parameters for the create_or_update_cluster tool
type ClusterInput struct {
	ID                       *string `json:"id,omitempty" jsonschema:"Cluster ID (12 characters). Omit to create a new cluster, provide to update an existing one."`
	DatasetID                *string `json:"dataset_id,omitempty" jsonschema:"ID of the parent dataset (12-character nanoid). Required for create."`
	LocationID               *string `json:"location_id,omitempty" jsonschema:"ID of the parent location (12-character nanoid). Required for create."`
	Name                     *string `json:"name,omitempty" jsonschema:"Cluster name (max 140 characters). Required for create."`
	SampleRate               *int    `json:"sample_rate,omitempty" jsonschema:"Sample rate in Hz (must be positive). Required for create."`
	Path                     *string `json:"path,omitempty" jsonschema:"Normalized folder path (max 255 characters)"`
	CyclicRecordingPatternID *string `json:"cyclic_recording_pattern_id,omitempty" jsonschema:"Optional ID of cyclic recording pattern (12-character nanoid). Set to empty string to clear."`
	Description              *string `json:"description,omitempty" jsonschema:"Optional cluster description (max 255 characters)"`
}

// ClusterOutput defines the output structure
type ClusterOutput struct {
	Cluster db.Cluster `json:"cluster" jsonschema:"The created or updated cluster"`
	Message string     `json:"message" jsonschema:"Success message"`
}

// CreateOrUpdateCluster implements the create_or_update_cluster tool handler
func CreateOrUpdateCluster(
	ctx context.Context,
	req *mcp.CallToolRequest,
	input ClusterInput,
) (*mcp.CallToolResult, ClusterOutput, error) {
	if input.ID != nil && strings.TrimSpace(*input.ID) != "" {
		return updateCluster(ctx, input)
	}
	return createCluster(ctx, input)
}

// validateClusterFields validates fields common to both create and update
func validateClusterFields(input ClusterInput) error {
	if input.Name != nil && len(*input.Name) > 140 {
		return fmt.Errorf("name must be 140 characters or less (got %d)", len(*input.Name))
	}
	if input.Description != nil && len(*input.Description) > 255 {
		return fmt.Errorf("description must be 255 characters or less (got %d)", len(*input.Description))
	}
	if input.SampleRate != nil && *input.SampleRate <= 0 {
		return fmt.Errorf("sample_rate must be positive (got %d)", *input.SampleRate)
	}
	return nil
}

func createCluster(ctx context.Context, input ClusterInput) (*mcp.CallToolResult, ClusterOutput, error) {
	var output ClusterOutput

	// Validate required fields for create
	if input.DatasetID == nil || strings.TrimSpace(*input.DatasetID) == "" {
		return nil, output, fmt.Errorf("dataset_id is required when creating a cluster")
	}
	if input.LocationID == nil || strings.TrimSpace(*input.LocationID) == "" {
		return nil, output, fmt.Errorf("location_id is required when creating a cluster")
	}
	if input.Name == nil || strings.TrimSpace(*input.Name) == "" {
		return nil, output, fmt.Errorf("name is required when creating a cluster")
	}
	if input.SampleRate == nil {
		return nil, output, fmt.Errorf("sample_rate is required when creating a cluster")
	}

	if err := validateClusterFields(input); err != nil {
		return nil, output, err
	}

	// Open writable database connection
	database, err := db.OpenWriteableDB(dbPath)
	if err != nil {
		return nil, output, fmt.Errorf("database connection failed: %w", err)
	}
	defer database.Close()

	// Begin transaction
	tx, err := database.BeginTx(ctx, nil)
	if err != nil {
		return nil, output, fmt.Errorf("failed to begin transaction: %w", err)
	}
	defer func() {
		if err != nil {
			tx.Rollback()
		}
	}()

	// Verify dataset exists and is active
	var datasetExists bool
	var datasetActive bool
	var datasetName string
	err = tx.QueryRowContext(ctx,
		"SELECT EXISTS(SELECT 1 FROM dataset WHERE id = ?), active, name FROM dataset WHERE id = ?",
		*input.DatasetID, *input.DatasetID,
	).Scan(&datasetExists, &datasetActive, &datasetName)
	if err != nil {
		return nil, output, fmt.Errorf("failed to verify dataset: %w", err)
	}
	if !datasetExists {
		return nil, output, fmt.Errorf("dataset with ID '%s' does not exist", *input.DatasetID)
	}
	if !datasetActive {
		return nil, output, fmt.Errorf("dataset '%s' (ID: %s) is not active", datasetName, *input.DatasetID)
	}

	// Verify location exists, is active, and belongs to the specified dataset
	var locationExists bool
	var locationActive bool
	var locationName string
	var locationDatasetID string
	err = tx.QueryRowContext(ctx,
		"SELECT EXISTS(SELECT 1 FROM location WHERE id = ?), active, name, dataset_id FROM location WHERE id = ?",
		*input.LocationID, *input.LocationID,
	).Scan(&locationExists, &locationActive, &locationName, &locationDatasetID)
	if err != nil {
		return nil, output, fmt.Errorf("failed to verify location: %w", err)
	}
	if !locationExists {
		return nil, output, fmt.Errorf("location with ID '%s' does not exist", *input.LocationID)
	}
	if !locationActive {
		return nil, output, fmt.Errorf("location '%s' (ID: %s) is not active", locationName, *input.LocationID)
	}
	if locationDatasetID != *input.DatasetID {
		return nil, output, fmt.Errorf("location '%s' (ID: %s) does not belong to dataset '%s' (ID: %s) - it belongs to dataset ID '%s'",
			locationName, *input.LocationID, datasetName, *input.DatasetID, locationDatasetID)
	}

	// Verify cyclic recording pattern if provided
	if input.CyclicRecordingPatternID != nil && strings.TrimSpace(*input.CyclicRecordingPatternID) != "" {
		var patternExists bool
		var patternActive bool
		err = tx.QueryRowContext(ctx,
			"SELECT EXISTS(SELECT 1 FROM cyclic_recording_pattern WHERE id = ?), active FROM cyclic_recording_pattern WHERE id = ?",
			*input.CyclicRecordingPatternID, *input.CyclicRecordingPatternID,
		).Scan(&patternExists, &patternActive)
		if err != nil {
			return nil, output, fmt.Errorf("failed to verify cyclic recording pattern: %w", err)
		}
		if !patternExists {
			return nil, output, fmt.Errorf("cyclic recording pattern with ID '%s' does not exist", *input.CyclicRecordingPatternID)
		}
		if !patternActive {
			return nil, output, fmt.Errorf("cyclic recording pattern with ID '%s' is not active", *input.CyclicRecordingPatternID)
		}
	}

	// Generate ID
	id, err := utils.GenerateShortID()
	if err != nil {
		return nil, output, fmt.Errorf("failed to generate ID: %w", err)
	}

	// Insert cluster
	_, err = tx.ExecContext(ctx,
		"INSERT INTO cluster (id, dataset_id, location_id, name, sample_rate, cyclic_recording_pattern_id, description, created_at, last_modified, active) VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, TRUE)",
		id, *input.DatasetID, *input.LocationID, *input.Name, *input.SampleRate, input.CyclicRecordingPatternID, input.Description,
	)
	if err != nil {
		return nil, output, fmt.Errorf("failed to create cluster: %w", err)
	}

	// Fetch the created cluster
	var cluster db.Cluster
	err = tx.QueryRowContext(ctx,
		"SELECT id, dataset_id, location_id, name, description, created_at, last_modified, active, cyclic_recording_pattern_id, sample_rate FROM cluster WHERE id = ?",
		id,
	).Scan(&cluster.ID, &cluster.DatasetID, &cluster.LocationID, &cluster.Name, &cluster.Description,
		&cluster.CreatedAt, &cluster.LastModified, &cluster.Active, &cluster.CyclicRecordingPatternID, &cluster.SampleRate)
	if err != nil {
		return nil, output, fmt.Errorf("failed to fetch created cluster: %w", err)
	}

	if err = tx.Commit(); err != nil {
		return nil, output, fmt.Errorf("failed to commit transaction: %w", err)
	}

	output.Cluster = cluster
	output.Message = fmt.Sprintf("Successfully created cluster '%s' with ID %s in location '%s' at dataset '%s' (sample rate: %d Hz)",
		cluster.Name, cluster.ID, locationName, datasetName, cluster.SampleRate)

	return &mcp.CallToolResult{}, output, nil
}

func updateCluster(ctx context.Context, input ClusterInput) (*mcp.CallToolResult, ClusterOutput, error) {
	var output ClusterOutput
	clusterID := *input.ID

	if err := validateClusterFields(input); err != nil {
		return nil, output, err
	}

	// Open writable database
	database, err := db.OpenWriteableDB(dbPath)
	if err != nil {
		return nil, output, fmt.Errorf("failed to open database: %w", err)
	}
	defer database.Close()

	// Verify cluster exists
	var exists bool
	err = database.QueryRow("SELECT EXISTS(SELECT 1 FROM cluster WHERE id = ?)", clusterID).Scan(&exists)
	if err != nil {
		return nil, output, fmt.Errorf("failed to query cluster: %w", err)
	}
	if !exists {
		return nil, output, fmt.Errorf("cluster not found: %s", clusterID)
	}

	// Validate cyclic_recording_pattern_id if provided
	if input.CyclicRecordingPatternID != nil {
		trimmedPatternID := strings.TrimSpace(*input.CyclicRecordingPatternID)
		if trimmedPatternID != "" {
			var patternExists bool
			err = database.QueryRow(
				"SELECT EXISTS(SELECT 1 FROM cyclic_recording_pattern WHERE id = ? AND active = true)",
				trimmedPatternID,
			).Scan(&patternExists)
			if err != nil {
				return nil, output, fmt.Errorf("failed to verify cyclic recording pattern: %w", err)
			}
			if !patternExists {
				return nil, output, fmt.Errorf("cyclic recording pattern not found or not active: %s", trimmedPatternID)
			}
		}
	}

	// Build dynamic UPDATE query
	updates := []string{}
	args := []any{}

	if input.Name != nil {
		updates = append(updates, "name = ?")
		args = append(args, *input.Name)
	}
	if input.Path != nil {
		updates = append(updates, "path = ?")
		args = append(args, *input.Path)
	}
	if input.SampleRate != nil {
		updates = append(updates, "sample_rate = ?")
		args = append(args, *input.SampleRate)
	}
	if input.Description != nil {
		updates = append(updates, "description = ?")
		args = append(args, *input.Description)
	}
	if input.CyclicRecordingPatternID != nil {
		trimmedPatternID := strings.TrimSpace(*input.CyclicRecordingPatternID)
		if trimmedPatternID == "" {
			updates = append(updates, "cyclic_recording_pattern_id = NULL")
		} else {
			updates = append(updates, "cyclic_recording_pattern_id = ?")
			args = append(args, trimmedPatternID)
		}
	}

	if len(updates) == 0 {
		return nil, output, fmt.Errorf("no fields provided to update")
	}

	// Always update last_modified
	updates = append(updates, "last_modified = now()")
	args = append(args, clusterID)

	query := fmt.Sprintf("UPDATE cluster SET %s WHERE id = ?", strings.Join(updates, ", "))
	_, err = database.Exec(query, args...)
	if err != nil {
		return nil, output, fmt.Errorf("failed to update cluster: %w", err)
	}

	// Fetch the updated cluster
	var cluster db.Cluster
	err = database.QueryRow(
		"SELECT id, dataset_id, location_id, name, description, created_at, last_modified, active, cyclic_recording_pattern_id, sample_rate FROM cluster WHERE id = ?",
		clusterID,
	).Scan(&cluster.ID, &cluster.DatasetID, &cluster.LocationID, &cluster.Name, &cluster.Description,
		&cluster.CreatedAt, &cluster.LastModified, &cluster.Active, &cluster.CyclicRecordingPatternID, &cluster.SampleRate)
	if err != nil {
		return nil, output, fmt.Errorf("failed to fetch updated cluster: %w", err)
	}

	output.Cluster = cluster
	output.Message = fmt.Sprintf("Successfully updated cluster '%s' (ID: %s)", cluster.Name, cluster.ID)

	return &mcp.CallToolResult{}, output, nil
}