package tools

import (
	"context"
	"database/sql"
	"encoding/csv"
	"fmt"
	"os"
	"path/filepath"
	"strconv"
	"strings"
	"time"

	"github.com/modelcontextprotocol/go-sdk/mcp"

	"skraak_mcp/db"
	"skraak_mcp/utils"
)

// BulkFileImportInput defines the input parameters for the bulk_file_import tool
type BulkFileImportInput struct {
	DatasetID   string `json:"dataset_id" jsonschema:"required,Dataset ID (12 characters)"`
	CSVPath     string `json:"csv_path" jsonschema:"required,Absolute path to CSV file. Format: location_name,location_id,directory_path,date_range,sample_rate,file_count (header required). date_range becomes cluster name. Example: 'Site A',abc123,/path/files,20240101-20240107,48000,100"`
	LogFilePath string `json:"log_file_path" jsonschema:"required,Absolute path for progress log file (monitors import progress)"`
}

// BulkFileImportOutput defines the output structure for the bulk_file_import tool
type BulkFileImportOutput struct {
	TotalLocations    int      `json:"total_locations"`
	ClustersCreated   int      `json:"clusters_created"`
	ClustersExisting  int      `json:"clusters_existing"`
	TotalFilesScanned int      `json:"total_files_scanned"`
	FilesImported     int      `json:"files_imported"`
	FilesDuplicate    int      `json:"files_duplicate"`
	FilesError        int      `json:"files_error"`
	ProcessingTime    string   `json:"processing_time"`
	Errors            []string `json:"errors,omitempty"`
}

// bulkLocationData holds CSV row data for a location
type bulkLocationData struct {
	LocationName  string
	LocationID    string
	DirectoryPath string
	DateRange     string
	SampleRate    int
	FileCount     int
}

// bulkImportStats tracks import statistics for a single cluster
type bulkImportStats struct {
	TotalFiles     int
	ImportedFiles  int
	DuplicateFiles int
	ErrorFiles     int
}

// progressLogger handles writing to both log file and internal buffer
type progressLogger struct {
	file   *os.File
	buffer *strings.Builder
}

// Log writes a formatted message with timestamp to both log file and buffer
func (l *progressLogger) Log(format string, args ...interface{}) {
	timestamp := time.Now().Format("2006-01-02 15:04:05")
	message := fmt.Sprintf(format, args...)
	line := fmt.Sprintf("[%s] %s\n", timestamp, message)

	// Write to file
	l.file.WriteString(line)
	l.file.Sync() // Ensure immediate write for tail monitoring

	// Also keep in memory for potential error reporting
	l.buffer.WriteString(line)
}

// BulkFileImport implements the bulk_file_import MCP tool
func BulkFileImport(
	ctx context.Context,
	req *mcp.CallToolRequest,
	input BulkFileImportInput,
) (*mcp.CallToolResult, BulkFileImportOutput, error) {
	startTime := time.Now()
	var output BulkFileImportOutput

	// Open log file
	logFile, err := os.OpenFile(input.LogFilePath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
	if err != nil {
		return nil, output, fmt.Errorf("failed to open log file: %w", err)
	}
	defer logFile.Close()

	logger := &progressLogger{
		file:   logFile,
		buffer: &strings.Builder{},
	}

	logger.Log("Starting bulk file import for dataset %s", input.DatasetID)

	// Phase 0: Validate input
	logger.Log("Validating input parameters...")
	if err := bulkValidateInput(input); err != nil {
		logger.Log("ERROR: Validation failed: %v", err)
		output.Errors = []string{fmt.Sprintf("validation failed: %v", err)}
		output.ProcessingTime = time.Since(startTime).String()
		return nil, output, fmt.Errorf("validation failed: %w", err)
	}
	logger.Log("Validation complete")

	// Phase 1: Read CSV
	logger.Log("Reading CSV file: %s", input.CSVPath)
	locations, err := bulkReadCSV(input.CSVPath)
	if err != nil {
		logger.Log("ERROR: Failed to read CSV: %v", err)
		output.Errors = []string{fmt.Sprintf("failed to read CSV: %v", err)}
		output.ProcessingTime = time.Since(startTime).String()
		return nil, output, fmt.Errorf("failed to read CSV: %w", err)
	}
	logger.Log("Loaded %d locations from CSV", len(locations))
	output.TotalLocations = len(locations)

	// Phase 2: Create/Validate Clusters
	logger.Log("=== Phase 1: Creating/Validating Clusters ===")
	clusterIDMap := make(map[string]string) // "locationID|dateRange" -> clusterID

	database, err := db.OpenWriteableDB(dbPath)
	if err != nil {
		logger.Log("ERROR: Failed to open database: %v", err)
		output.Errors = []string{fmt.Sprintf("failed to open database: %v", err)}
		output.ProcessingTime = time.Since(startTime).String()
		return nil, output, fmt.Errorf("failed to open database: %w", err)
	}
	defer database.Close()

	for i, loc := range locations {
		logger.Log("[%d/%d] Processing location: %s", i+1, len(locations), loc.LocationName)

		// Check if cluster already exists
		var existingClusterID string
		err := database.QueryRow(`
			SELECT id FROM cluster
			WHERE location_id = ? AND name = ? AND active = true
		`, loc.LocationID, loc.DateRange).Scan(&existingClusterID)

		var clusterID string
		if err == sql.ErrNoRows {
			// Create cluster
			clusterID, err = bulkCreateCluster(database, input.DatasetID, loc.LocationID, loc.DateRange, loc.SampleRate)
			if err != nil {
				errMsg := fmt.Sprintf("Failed to create cluster for location %s: %v", loc.LocationName, err)
				logger.Log("ERROR: %s", errMsg)
				output.Errors = append(output.Errors, errMsg)
				output.ProcessingTime = time.Since(startTime).String()
				return nil, output, fmt.Errorf("failed to create cluster: %w", err)
			}
			logger.Log("  Created cluster: %s", clusterID)
			output.ClustersCreated++
		} else if err != nil {
			errMsg := fmt.Sprintf("Failed to check cluster for location %s: %v", loc.LocationName, err)
			logger.Log("ERROR: %s", errMsg)
			output.Errors = append(output.Errors, errMsg)
			output.ProcessingTime = time.Since(startTime).String()
			return nil, output, fmt.Errorf("failed to check cluster: %w", err)
		} else {
			clusterID = existingClusterID
			logger.Log("  Using existing cluster: %s", clusterID)
			output.ClustersExisting++
		}

		compositeKey := loc.LocationID + "|" + loc.DateRange
		clusterIDMap[compositeKey] = clusterID
	}

	logger.Log("=== Phase 2: Importing Files ===")

	totalImported := 0
	totalDuplicates := 0
	totalErrors := 0
	totalScanned := 0

	for i, loc := range locations {
		compositeKey := loc.LocationID + "|" + loc.DateRange
		clusterID, ok := clusterIDMap[compositeKey]
		if !ok {
			continue // Should not happen, but safety check
		}

		logger.Log("[%d/%d] Importing files for: %s", i+1, len(locations), loc.LocationName)
		logger.Log("  Directory: %s", loc.DirectoryPath)

		// Check if directory exists
		if _, err := os.Stat(loc.DirectoryPath); os.IsNotExist(err) {
			logger.Log("  WARNING: Directory not found, skipping")
			continue
		}

		// Import files
		stats, err := bulkImportFilesForCluster(database, logger, loc.DirectoryPath, input.DatasetID, loc.LocationID, clusterID)
		if err != nil {
			errMsg := fmt.Sprintf("Failed to import files for location %s: %v", loc.LocationName, err)
			logger.Log("ERROR: %s", errMsg)
			output.Errors = append(output.Errors, errMsg)
			output.TotalFilesScanned = totalScanned
			output.FilesImported = totalImported
			output.FilesDuplicate = totalDuplicates
			output.FilesError = totalErrors
			output.ProcessingTime = time.Since(startTime).String()
			return nil, output, fmt.Errorf("failed to import files: %w", err)
		}

		logger.Log("  Scanned: %d files", stats.TotalFiles)
		logger.Log("  Imported: %d, Duplicates: %d", stats.ImportedFiles, stats.DuplicateFiles)
		if stats.ErrorFiles > 0 {
			logger.Log("  Errors: %d files", stats.ErrorFiles)
		}

		totalScanned += stats.TotalFiles
		totalImported += stats.ImportedFiles
		totalDuplicates += stats.DuplicateFiles
		totalErrors += stats.ErrorFiles
	}

	logger.Log("=== Import Complete ===")
	logger.Log("Total files scanned: %d", totalScanned)
	logger.Log("Files imported: %d", totalImported)
	logger.Log("Duplicates skipped: %d", totalDuplicates)
	logger.Log("Errors: %d", totalErrors)
	logger.Log("Processing time: %s", time.Since(startTime).Round(time.Second))

	output.TotalFilesScanned = totalScanned
	output.FilesImported = totalImported
	output.FilesDuplicate = totalDuplicates
	output.FilesError = totalErrors
	output.ProcessingTime = time.Since(startTime).String()

	return &mcp.CallToolResult{}, output, nil
}

// bulkValidateInput validates input parameters
func bulkValidateInput(input BulkFileImportInput) error {
	// Verify CSV file exists
	if _, err := os.Stat(input.CSVPath); err != nil {
		return fmt.Errorf("CSV file not accessible: %w", err)
	}

	// Verify log file path is writable
	logDir := filepath.Dir(input.LogFilePath)
	if _, err := os.Stat(logDir); err != nil {
		return fmt.Errorf("log file directory not accessible: %w", err)
	}

	// Open database for validation queries
	database, err := db.OpenReadOnlyDB(dbPath)
	if err != nil {
		return fmt.Errorf("failed to open database: %w", err)
	}
	defer database.Close()

	// Verify dataset exists
	var datasetExists bool
	err = database.QueryRow("SELECT EXISTS(SELECT 1 FROM dataset WHERE id = ? AND active = true)", input.DatasetID).Scan(&datasetExists)
	if err != nil {
		return fmt.Errorf("failed to query dataset: %w", err)
	}
	if !datasetExists {
		return fmt.Errorf("dataset not found or inactive: %s", input.DatasetID)
	}

	return nil
}

// bulkReadCSV reads and parses the CSV file
func bulkReadCSV(path string) ([]bulkLocationData, error) {
	file, err := os.Open(path)
	if err != nil {
		return nil, err
	}
	defer file.Close()

	reader := csv.NewReader(file)
	records, err := reader.ReadAll()
	if err != nil {
		return nil, err
	}

	if len(records) == 0 {
		return nil, fmt.Errorf("CSV file is empty")
	}

	var locations []bulkLocationData
	for i, record := range records {
		if i == 0 {
			continue // Skip header
		}

		if len(record) < 6 {
			return nil, fmt.Errorf("CSV row %d has insufficient columns (expected 6, got %d)", i+1, len(record))
		}

		sampleRate, err := strconv.Atoi(record[4])
		if err != nil {
			return nil, fmt.Errorf("invalid sample_rate in row %d: %v", i+1, err)
		}

		fileCount, err := strconv.Atoi(record[5])
		if err != nil {
			return nil, fmt.Errorf("invalid file_count in row %d: %v", i+1, err)
		}

		locations = append(locations, bulkLocationData{
			LocationName:  record[0],
			LocationID:    record[1],
			DirectoryPath: record[2],
			DateRange:     record[3],
			SampleRate:    sampleRate,
			FileCount:     fileCount,
		})
	}

	return locations, nil
}

// bulkCreateCluster creates a new cluster in the database
func bulkCreateCluster(database *sql.DB, datasetID, locationID, name string, sampleRate int) (string, error) {
	// Generate a 12-character nanoid
	clusterID, err := utils.GenerateShortID()
	if err != nil {
		return "", fmt.Errorf("failed to generate cluster ID: %v", err)
	}
	now := time.Now().UTC()

	// Get location name for the path
	var locationName string
	err = database.QueryRow("SELECT name FROM location WHERE id = ?", locationID).Scan(&locationName)
	if err != nil {
		return "", fmt.Errorf("failed to get location name: %v", err)
	}

	// Normalize path: replace spaces and special characters
	path := strings.ReplaceAll(locationName, " ", "_")
	path = strings.ReplaceAll(path, "/", "_")

	_, err = database.Exec(`
		INSERT INTO cluster (id, dataset_id, location_id, name, path, sample_rate, active, created_at, last_modified)
		VALUES (?, ?, ?, ?, ?, ?, true, ?, ?)
	`, clusterID, datasetID, locationID, name, path, sampleRate, now, now)

	if err != nil {
		return "", fmt.Errorf("failed to insert cluster: %v", err)
	}

	return clusterID, nil
}

// bulkImportFilesForCluster imports all WAV files for a single cluster
func bulkImportFilesForCluster(database *sql.DB, logger *progressLogger, folderPath, datasetID, locationID, clusterID string) (*bulkImportStats, error) {
	stats := &bulkImportStats{}

	// Check if directory exists
	if _, err := os.Stat(folderPath); os.IsNotExist(err) {
		logger.Log("  WARNING: Directory not found, skipping")
		return stats, nil
	}

	// Import the cluster (SAME LOGIC AS import_files.go)
	logger.Log("  Importing cluster %s", clusterID)
	clusterOutput, err := utils.ImportCluster(database, utils.ClusterImportInput{
		FolderPath: folderPath,
		DatasetID:  datasetID,
		LocationID: locationID,
		ClusterID:  clusterID,
		Recursive:  true,
	})
	if err != nil {
		return nil, err
	}

	// Map to bulk import stats
	stats.TotalFiles = clusterOutput.TotalFiles
	stats.ImportedFiles = clusterOutput.ImportedFiles
	stats.DuplicateFiles = clusterOutput.SkippedFiles
	stats.ErrorFiles = clusterOutput.FailedFiles

	// Log errors
	for i, fileErr := range clusterOutput.Errors {
		if i < 5 { // Log first 5
			logger.Log("    ERROR: %s: %s", fileErr.FileName, fileErr.Error)
		}
	}

	logger.Log("  Complete: %d imported, %d duplicates, %d errors", stats.ImportedFiles, stats.DuplicateFiles, stats.ErrorFiles)

	return stats, nil
}