package tools

import (
	"context"
	"fmt"
	"skraak_mcp/db"
	"skraak_mcp/utils"
	"strings"

	"github.com/modelcontextprotocol/go-sdk/mcp"
)

// DatasetInput defines the input parameters for the create_or_update_dataset tool
type DatasetInput struct {
	ID          *string `json:"id,omitempty" jsonschema:"Dataset ID (12 characters). Omit to create a new dataset, provide to update an existing one."`
	Name        *string `json:"name,omitempty" jsonschema:"Dataset name (max 255 characters). Required for create."`
	Description *string `json:"description,omitempty" jsonschema:"Optional dataset description (max 255 characters)"`
	Type        *string `json:"type,omitempty" jsonschema:"Dataset type: 'organise'/'test'/'train' (defaults to 'organise' on create)"`
}

// DatasetOutput defines the output structure
type DatasetOutput struct {
	Dataset db.Dataset `json:"dataset" jsonschema:"The created or updated dataset"`
	Message string     `json:"message" jsonschema:"Success message"`
}

// CreateOrUpdateDataset implements the create_or_update_dataset tool handler
func CreateOrUpdateDataset(
	ctx context.Context,
	req *mcp.CallToolRequest,
	input DatasetInput,
) (*mcp.CallToolResult, DatasetOutput, error) {
	if input.ID != nil && strings.TrimSpace(*input.ID) != "" {
		return updateDataset(ctx, input)
	}
	return createDataset(ctx, input)
}

func createDataset(ctx context.Context, input DatasetInput) (*mcp.CallToolResult, DatasetOutput, error) {
	var output DatasetOutput

	// Validate name (required for create)
	if input.Name == nil || strings.TrimSpace(*input.Name) == "" {
		return nil, output, fmt.Errorf("name is required when creating a dataset")
	}
	if len(*input.Name) > 255 {
		return nil, output, fmt.Errorf("name must be 255 characters or less (got %d)", len(*input.Name))
	}

	// Validate description length if provided
	if input.Description != nil && len(*input.Description) > 255 {
		return nil, output, fmt.Errorf("description must be 255 characters or less (got %d)", len(*input.Description))
	}

	// Validate and set type
	datasetType := db.DatasetTypeOrganise // Default
	if input.Type != nil {
		typeStr := strings.ToLower(strings.TrimSpace(*input.Type))
		switch typeStr {
		case "organise":
			datasetType = db.DatasetTypeOrganise
		case "test":
			datasetType = db.DatasetTypeTest
		case "train":
			datasetType = db.DatasetTypeTrain
		default:
			return nil, output, fmt.Errorf("invalid type '%s': must be 'organise', 'test', or 'train'", *input.Type)
		}
	}

	// Open writable database connection
	database, err := db.OpenWriteableDB(dbPath)
	if err != nil {
		return nil, output, fmt.Errorf("database connection failed: %w", err)
	}
	defer database.Close()

	// Begin transaction
	tx, err := database.BeginTx(ctx, nil)
	if err != nil {
		return nil, output, fmt.Errorf("failed to begin transaction: %w", err)
	}
	defer func() {
		if err != nil {
			tx.Rollback()
		}
	}()

	// Generate ID
	id, err := utils.GenerateShortID()
	if err != nil {
		return nil, output, fmt.Errorf("failed to generate ID: %w", err)
	}

	// Insert dataset
	_, err = tx.ExecContext(ctx,
		"INSERT INTO dataset (id, name, description, type, created_at, last_modified, active) VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, TRUE)",
		id, *input.Name, input.Description, string(datasetType),
	)
	if err != nil {
		return nil, output, fmt.Errorf("failed to create dataset: %w", err)
	}

	// Fetch the created dataset
	var dataset db.Dataset
	err = tx.QueryRowContext(ctx,
		"SELECT id, name, description, created_at, last_modified, active, type FROM dataset WHERE id = ?",
		id,
	).Scan(&dataset.ID, &dataset.Name, &dataset.Description, &dataset.CreatedAt, &dataset.LastModified, &dataset.Active, &dataset.Type)
	if err != nil {
		return nil, output, fmt.Errorf("failed to fetch created dataset: %w", err)
	}

	if err = tx.Commit(); err != nil {
		return nil, output, fmt.Errorf("failed to commit transaction: %w", err)
	}

	output.Dataset = dataset
	output.Message = fmt.Sprintf("Successfully created dataset '%s' with ID %s (type: %s)",
		dataset.Name, dataset.ID, dataset.Type)

	return &mcp.CallToolResult{}, output, nil
}

func updateDataset(ctx context.Context, input DatasetInput) (*mcp.CallToolResult, DatasetOutput, error) {
	var output DatasetOutput
	datasetID := *input.ID

	// Validate fields if provided
	if input.Name != nil && len(*input.Name) > 255 {
		return nil, output, fmt.Errorf("name must be 255 characters or less (got %d)", len(*input.Name))
	}
	if input.Description != nil && len(*input.Description) > 255 {
		return nil, output, fmt.Errorf("description must be 255 characters or less (got %d)", len(*input.Description))
	}
	if input.Type != nil {
		typeValue := strings.ToLower(*input.Type)
		if typeValue != "organise" && typeValue != "test" && typeValue != "train" {
			return nil, output, fmt.Errorf("invalid dataset type: %s (must be 'organise', 'test', or 'train')", *input.Type)
		}
	}

	// Open writable database
	database, err := db.OpenWriteableDB(dbPath)
	if err != nil {
		return nil, output, fmt.Errorf("failed to open database: %w", err)
	}
	defer database.Close()

	// Verify dataset exists
	var exists bool
	err = database.QueryRow("SELECT EXISTS(SELECT 1 FROM dataset WHERE id = ?)", datasetID).Scan(&exists)
	if err != nil {
		return nil, output, fmt.Errorf("failed to query dataset: %w", err)
	}
	if !exists {
		return nil, output, fmt.Errorf("dataset not found: %s", datasetID)
	}

	// Build dynamic UPDATE query
	updates := []string{}
	args := []any{}

	if input.Name != nil {
		updates = append(updates, "name = ?")
		args = append(args, *input.Name)
	}
	if input.Description != nil {
		updates = append(updates, "description = ?")
		args = append(args, *input.Description)
	}
	if input.Type != nil {
		updates = append(updates, "type = ?")
		args = append(args, strings.ToLower(*input.Type))
	}

	if len(updates) == 0 {
		return nil, output, fmt.Errorf("no fields provided to update")
	}

	// Always update last_modified
	updates = append(updates, "last_modified = now()")
	args = append(args, datasetID)

	query := fmt.Sprintf("UPDATE dataset SET %s WHERE id = ?", strings.Join(updates, ", "))
	_, err = database.Exec(query, args...)
	if err != nil {
		return nil, output, fmt.Errorf("failed to update dataset: %w", err)
	}

	// Fetch the updated dataset
	var dataset db.Dataset
	err = database.QueryRow(
		"SELECT id, name, description, created_at, last_modified, active, type FROM dataset WHERE id = ?",
		datasetID,
	).Scan(&dataset.ID, &dataset.Name, &dataset.Description, &dataset.CreatedAt, &dataset.LastModified, &dataset.Active, &dataset.Type)
	if err != nil {
		return nil, output, fmt.Errorf("failed to fetch updated dataset: %w", err)
	}

	output.Dataset = dataset
	output.Message = fmt.Sprintf("Successfully updated dataset '%s' (ID: %s)", dataset.Name, dataset.ID)

	return &mcp.CallToolResult{}, output, nil
}