package tools

import (
	"context"
	"database/sql"
	"fmt"
	"skraak/db"
	"skraak/utils"
	"strings"
)

// ClusterInput defines the input parameters for the create_or_update_cluster tool
type ClusterInput struct {
	ID                       *string `json:"id,omitempty"`
	DatasetID                *string `json:"dataset_id,omitempty"`
	LocationID               *string `json:"location_id,omitempty"`
	Name                     *string `json:"name,omitempty"`
	SampleRate               *int    `json:"sample_rate,omitempty"`
	Path                     *string `json:"path,omitempty"`
	CyclicRecordingPatternID *string `json:"cyclic_recording_pattern_id,omitempty"`
	Description              *string `json:"description,omitempty"`
}

// ClusterOutput defines the output structure
type ClusterOutput struct {
	Cluster db.Cluster `json:"cluster"`
	Message string     `json:"message"`
}

// CreateOrUpdateCluster creates a new cluster or updates an existing one within a location
func CreateOrUpdateCluster(
	ctx context.Context,
	input ClusterInput,
) (ClusterOutput, error) {
	if input.ID != nil && strings.TrimSpace(*input.ID) != "" {
		return updateCluster(ctx, input)
	}
	return createCluster(ctx, input)
}

// validateClusterFields validates fields common to both create and update
func validateClusterFields(input ClusterInput) error {
	if err := utils.ValidateOptionalStringLength(input.Name, "name", utils.MaxNameLen); err != nil {
		return err
	}
	if err := utils.ValidateOptionalStringLength(input.Description, "description", utils.MaxDescriptionLen); err != nil {
		return err
	}
	if err := utils.ValidateOptionalStringLength(input.Path, "path", utils.MaxPathLen); err != nil {
		return err
	}
	if input.SampleRate != nil {
		if err := utils.ValidatePositive(*input.SampleRate, "sample_rate"); err != nil {
			return err
		}
		// Also check reasonable bounds
		if err := utils.ValidateSampleRate(*input.SampleRate); err != nil {
			return err
		}
	}
	return nil
}

func createCluster(ctx context.Context, input ClusterInput) (ClusterOutput, error) {
	var output ClusterOutput

	// Validate required fields for create
	if err := validateCreateClusterFields(input); err != nil {
		return output, err
	}

	// Validate ID formats and common fields
	if err := validateCreateClusterIDs(input); err != nil {
		return output, err
	}

	if err := validateClusterFields(input); err != nil {
		return output, err
	}

	if err := utils.ValidateOptionalShortID(input.CyclicRecordingPatternID, "cyclic_recording_pattern_id"); err != nil {
		return output, err
	}

	err := db.WithWriteTx(ctx, dbPath, "create_or_update_cluster", func(database *sql.DB, tx *db.LoggedTx) error {
		// Verify parent references exist and are active
		datasetName, locationName, verr := verifyClusterParentRefs(ctx, tx, input)
		if verr != nil {
			return verr
		}

		// Check for existing cluster with same name in location (UNIQUE constraint)
		existing, findErr := findExistingClusterInLocation(ctx, tx, *input.LocationID, *input.Name)
		if findErr == nil {
			output.Cluster = existing
			output.Message = fmt.Sprintf("Cluster '%s' already exists in location '%s' (ID: %s) - returning existing cluster", existing.Name, locationName, existing.ID)
			return nil // commit transaction
		}

		result, insErr := insertNewCluster(ctx, tx, input, datasetName, locationName)
		if insErr != nil {
			return insErr
		}
		output = result
		return nil // commit transaction
	})
	return output, err
}

// verifyClusterParentRefs validates that the dataset, location, and optional pattern exist and are active.
func verifyClusterParentRefs(ctx context.Context, tx *db.LoggedTx, input ClusterInput) (string, string, error) {
	datasetName, err := db.DatasetExistsAndActive(tx, *input.DatasetID)
	if err != nil {
		return "", "", err
	}

	locationName, err := db.LocationBelongsToDataset(tx, *input.LocationID, *input.DatasetID)
	if err != nil {
		return "", "", err
	}

	if input.CyclicRecordingPatternID != nil && strings.TrimSpace(*input.CyclicRecordingPatternID) != "" {
		if err := verifyPatternExists(ctx, tx, *input.CyclicRecordingPatternID); err != nil {
			return "", "", err
		}
	}

	return datasetName, locationName, nil
}

// insertNewCluster inserts a new cluster row and returns it within a transaction.
// Caller is responsible for committing the transaction.
func insertNewCluster(ctx context.Context, tx *db.LoggedTx, input ClusterInput, datasetName, locationName string) (ClusterOutput, error) {
	id, err := utils.GenerateShortID()
	if err != nil {
		return ClusterOutput{}, fmt.Errorf("failed to generate ID: %w", err)
	}

	_, err = tx.ExecContext(ctx,
		"INSERT INTO cluster (id, dataset_id, location_id, name, sample_rate, cyclic_recording_pattern_id, description, created_at, last_modified, active) VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, TRUE)",
		id, *input.DatasetID, *input.LocationID, *input.Name, *input.SampleRate, input.CyclicRecordingPatternID, input.Description,
	)
	if err != nil {
		return ClusterOutput{}, fmt.Errorf("failed to create cluster: %w", err)
	}

	cluster, err := fetchClusterByID(ctx, tx, id)
	if err != nil {
		return ClusterOutput{}, fmt.Errorf("failed to fetch created cluster: %w", err)
	}

	return ClusterOutput{
		Cluster: cluster,
		Message: fmt.Sprintf("Successfully created cluster '%s' with ID %s in location '%s' at dataset '%s' (sample rate: %d Hz)",
			cluster.Name, cluster.ID, locationName, datasetName, cluster.SampleRate),
	}, nil
}

// validateCreateClusterFields validates required fields for creating a cluster
func validateCreateClusterFields(input ClusterInput) error {
	if input.DatasetID == nil || strings.TrimSpace(*input.DatasetID) == "" {
		return fmt.Errorf("dataset_id is required when creating a cluster")
	}
	if input.LocationID == nil || strings.TrimSpace(*input.LocationID) == "" {
		return fmt.Errorf("location_id is required when creating a cluster")
	}
	if input.Name == nil || strings.TrimSpace(*input.Name) == "" {
		return fmt.Errorf("name is required when creating a cluster")
	}
	if input.SampleRate == nil {
		return fmt.Errorf("sample_rate is required when creating a cluster")
	}
	return nil
}

// validateCreateClusterIDs validates ID format fields
func validateCreateClusterIDs(input ClusterInput) error {
	if err := utils.ValidateShortID(*input.DatasetID, "dataset_id"); err != nil {
		return err
	}
	return utils.ValidateShortID(*input.LocationID, "location_id")
}

// verifyPatternExists verifies a cyclic recording pattern exists and is active
func verifyPatternExists(ctx context.Context, tx *db.LoggedTx, patternID string) error {
	var exists, active bool
	err := tx.QueryRowContext(ctx,
		"SELECT EXISTS(SELECT 1 FROM cyclic_recording_pattern WHERE id = ?), COALESCE((SELECT active FROM cyclic_recording_pattern WHERE id = ?), false)",
		patternID, patternID,
	).Scan(&exists, &active)
	if err != nil {
		return fmt.Errorf("failed to verify cyclic recording pattern: %w", err)
	}
	if !exists {
		return fmt.Errorf("cyclic recording pattern with ID '%s' does not exist", patternID)
	}
	if !active {
		return fmt.Errorf("cyclic recording pattern with ID '%s' is not active", patternID)
	}
	return nil
}

// findExistingClusterInLocation checks for an existing cluster with the same name in a location
func findExistingClusterInLocation(ctx context.Context, tx *db.LoggedTx, locationID, name string) (db.Cluster, error) {
	var existingID string
	err := tx.QueryRowContext(ctx,
		"SELECT id FROM cluster WHERE location_id = ? AND name = ? AND active = true",
		locationID, name,
	).Scan(&existingID)
	if err != nil {
		return db.Cluster{}, err
	}
	return fetchClusterByID(ctx, tx, existingID)
}

// fetchClusterByID fetches a cluster row by ID
func fetchClusterByID(ctx context.Context, tx *db.LoggedTx, id string) (db.Cluster, error) {
	var c db.Cluster
	err := tx.QueryRowContext(ctx,
		"SELECT id, dataset_id, location_id, name, description, created_at, last_modified, active, cyclic_recording_pattern_id, sample_rate FROM cluster WHERE id = ?",
		id,
	).Scan(&c.ID, &c.DatasetID, &c.LocationID, &c.Name, &c.Description,
		&c.CreatedAt, &c.LastModified, &c.Active, &c.CyclicRecordingPatternID, &c.SampleRate)
	return c, err
}

// validateClusterActive checks that a cluster exists and is active.
func validateClusterActive(database *sql.DB, clusterID string) error {
	var exists, active bool
	err := database.QueryRow(
		"SELECT EXISTS(SELECT 1 FROM cluster WHERE id = ?), COALESCE((SELECT active FROM cluster WHERE id = ?), false)",
		clusterID, clusterID,
	).Scan(&exists, &active)
	if err != nil {
		return fmt.Errorf("failed to query cluster: %w", err)
	}
	if !exists {
		return fmt.Errorf("cluster not found: %s", clusterID)
	}
	if !active {
		return fmt.Errorf("cluster '%s' is not active (cannot update inactive clusters)", clusterID)
	}
	return nil
}

// validateClusterCyclicPattern validates the cyclic recording pattern if provided.
func validateClusterCyclicPattern(database *sql.DB, input ClusterInput) error {
	if input.CyclicRecordingPatternID == nil {
		return nil
	}
	trimmed := strings.TrimSpace(*input.CyclicRecordingPatternID)
	if trimmed == "" {
		return nil
	}
	return validateCyclicPattern(database, trimmed)
}

// validateCyclicPattern checks that a cyclic recording pattern exists and is active.
func validateCyclicPattern(database *sql.DB, patternID string) error {
	var exists, active bool
	err := database.QueryRow(
		"SELECT EXISTS(SELECT 1 FROM cyclic_recording_pattern WHERE id = ?), COALESCE((SELECT active FROM cyclic_recording_pattern WHERE id = ?), false)",
		patternID, patternID,
	).Scan(&exists, &active)
	if err != nil {
		return fmt.Errorf("failed to verify cyclic recording pattern: %w", err)
	}
	if !exists {
		return fmt.Errorf("cyclic recording pattern not found: %s", patternID)
	}
	if !active {
		return fmt.Errorf("cyclic recording pattern '%s' is not active", patternID)
	}
	return nil
}

// buildClusterUpdateQuery builds the dynamic UPDATE query and args for cluster fields.
func buildClusterUpdateQuery(input ClusterInput, clusterID string) (string, []any, error) {
	updates := []string{}
	args := []any{}

	if input.Name != nil {
		updates = append(updates, "name = ?")
		args = append(args, *input.Name)
	}
	if input.Path != nil {
		updates = append(updates, "path = ?")
		args = append(args, *input.Path)
	}
	if input.SampleRate != nil {
		updates = append(updates, "sample_rate = ?")
		args = append(args, *input.SampleRate)
	}
	if input.Description != nil {
		updates = append(updates, "description = ?")
		args = append(args, *input.Description)
	}
	if input.CyclicRecordingPatternID != nil {
		trimmedPatternID := strings.TrimSpace(*input.CyclicRecordingPatternID)
		if trimmedPatternID == "" {
			updates = append(updates, "cyclic_recording_pattern_id = NULL")
		} else {
			updates = append(updates, "cyclic_recording_pattern_id = ?")
			args = append(args, trimmedPatternID)
		}
	}

	if len(updates) == 0 {
		return "", nil, fmt.Errorf("no fields provided to update")
	}

	updates = append(updates, "last_modified = now()")
	args = append(args, clusterID)

	query := fmt.Sprintf("UPDATE cluster SET %s WHERE id = ?", strings.Join(updates, ", "))
	return query, args, nil
}

// validateClusterUpdateInput validates cluster ID, fields, and cyclic pattern for update.
func validateClusterUpdateInput(input ClusterInput) (string, error) {
	clusterID := *input.ID

	if err := utils.ValidateShortID(clusterID, "cluster_id"); err != nil {
		return "", err
	}
	if err := validateClusterFields(input); err != nil {
		return "", err
	}
	if input.CyclicRecordingPatternID != nil {
		trimmed := strings.TrimSpace(*input.CyclicRecordingPatternID)
		if trimmed != "" {
			if err := utils.ValidateShortID(trimmed, "cyclic_recording_pattern_id"); err != nil {
				return "", err
			}
		}
	}
	return clusterID, nil
}

func updateCluster(ctx context.Context, input ClusterInput) (ClusterOutput, error) {
	var output ClusterOutput

	clusterID, err := validateClusterUpdateInput(input)
	if err != nil {
		return output, err
	}

	err = db.WithWriteTx(ctx, dbPath, "create_or_update_cluster", func(database *sql.DB, tx *db.LoggedTx) error {
		if err := validateClusterActive(database, clusterID); err != nil {
			return err
		}

		if err := validateClusterCyclicPattern(database, input); err != nil {
			return err
		}

		query, args, qerr := buildClusterUpdateQuery(input, clusterID)
		if qerr != nil {
			return qerr
		}

		if _, err := tx.Exec(query, args...); err != nil {
			return fmt.Errorf("failed to update cluster: %w", err)
		}

		cluster, ferr := fetchClusterByID(ctx, tx, clusterID)
		if ferr != nil {
			return fmt.Errorf("failed to fetch updated cluster: %w", ferr)
		}

		output.Cluster = cluster
		output.Message = fmt.Sprintf("Successfully updated cluster '%s' (ID: %s)", cluster.Name, cluster.ID)
		return nil
	})
	return output, err
}