package tools

import (
	"encoding/csv"
	"encoding/json"
	"fmt"
	"io"
	"os"
	"path/filepath"
	"sort"
	"strconv"
	"strings"
	"sync"
	"sync/atomic"

	"skraak/utils"
)

// Constants for clustering algorithm
const (
	CLUSTER_GAP_MULTIPLIER     = 2  // 3 Gap threshold = CLUSTER_GAP_MULTIPLIER * clip_duration. 3 for kiwi
	MIN_DETECTIONS_PER_CLUSTER = 0  // 1 = filter out single detections (used for kiwi, they have long calls 30s), 0 = let single detections pass through
	DEFAULT_CERTAINTY          = 70 // .data certainty:70
	DOT_DATA_WORKERS           = 8  // Number of parallel workers for .data file writing
)

// ClusteredCall represents a clustered bird call detection
type ClusteredCall struct {
	File      string  `json:"file"`
	StartTime float64 `json:"start_time"`
	EndTime   float64 `json:"end_time"`
	EbirdCode string  `json:"ebird_code"`
	Segments  int     `json:"segments"`
}

// CallsFromPredsInput defines the input for the calls-from-preds tool
type CallsFromPredsInput struct {
	CSVPath         string          `json:"csv_path"`
	Filter          string          `json:"filter"`
	WriteDotData    bool            `json:"write_dot_data"`
	GapMultiplier   int             `json:"gap_multiplier"`
	MinDetections   int             `json:"min_detections"`
	ProgressHandler ProgressHandler `json:"-"` // Optional progress callback (not serialized)
}

// ProgressHandler is a callback function for reporting progress during long operations
// processed: number of items processed so far
// total: total number of items to process
// message: optional status message
type ProgressHandler func(processed, total int, message string)

// CallsFromPredsOutput defines the output for the calls-from-preds tool
type CallsFromPredsOutput struct {
	Calls            []ClusteredCall `json:"calls"`
	TotalCalls       int             `json:"total_calls"`
	ClipDuration     float64         `json:"clip_duration"`
	GapThreshold     float64         `json:"gap_threshold"`
	SpeciesCount     map[string]int  `json:"species_count"`
	DataFilesWritten int             `json:"data_files_written"`
	DataFilesSkipped int             `json:"data_files_skipped"`
	Filter           string          `json:"filter"`
	Error            *string         `json:"error,omitempty"`
}

// AviaNZ .data file types

// AviaNZMeta is the metadata element in a .data file
type AviaNZMeta struct {
	Operator string  `json:"Operator"`
	Reviewer *string `json:"Reviewer,omitempty"`
	Duration float64 `json:"Duration"`
}

// AviaNZLabel represents a species label in a segment
type AviaNZLabel struct {
	Species   string `json:"species"`
	Certainty int    `json:"certainty"`
	Filter    string `json:"filter"`
}

// AviaNZSegment represents a detection segment [start, end, freq_low, freq_high, labels]
type AviaNZSegment [5]any

// predFileSpeciesKey groups detections by file and ebird code
type predFileSpeciesKey struct {
	File      string
	EbirdCode string
}

// CallsFromPreds reads a predictions CSV and clusters detections into continuous bird calls
func CallsFromPreds(input CallsFromPredsInput) (CallsFromPredsOutput, error) {
	var output CallsFromPredsOutput

	// Determine filter: use provided filter, or parse from CSV filename
	filter := input.Filter
	if filter == "" {
		filter = ParseFilterFromFilename(input.CSVPath)
	}
	if filter == "" {
		errMsg := "Filter must be specified via --filter flag or parsable from CSV filename"
		output.Error = &errMsg
		return output, fmt.Errorf("%s", errMsg)
	}
	output.Filter = filter

	_, detections, clipDuration, err := readPredCSV(input.CSVPath)
	if err != nil {
		errMsg := err.Error()
		output.Error = &errMsg
		return output, err
	}
	output.ClipDuration = clipDuration

	gapMultiplier := CLUSTER_GAP_MULTIPLIER
	if input.GapMultiplier > 0 {
		gapMultiplier = input.GapMultiplier
	}
	minDetections := MIN_DETECTIONS_PER_CLUSTER
	if input.MinDetections >= 0 {
		minDetections = input.MinDetections
	}
	gapThreshold := float64(gapMultiplier) * clipDuration
	output.GapThreshold = gapThreshold

	allCalls, speciesCount := clusterDetections(detections, clipDuration, gapThreshold, minDetections)

	output.Calls = allCalls
	output.TotalCalls = len(allCalls)
	output.SpeciesCount = speciesCount

	if input.WriteDotData {
		dataFilesWritten, dataFilesSkipped, err := writeDotFiles(input.CSVPath, filter, allCalls, input.ProgressHandler)
		if err != nil {
			errMsg := fmt.Sprintf("Error writing .data files: %v", err)
			output.Error = &errMsg
			return output, fmt.Errorf("%s", errMsg)
		}
		output.DataFilesWritten = dataFilesWritten
		output.DataFilesSkipped = dataFilesSkipped
	}

	return output, nil
}

// readPredCSV opens and reads a predictions CSV, returning column mappings, detections, and clip duration
func readPredCSV(csvPath string) (predCSVColumns, map[predFileSpeciesKey][]float64, float64, error) {
	file, err := os.Open(csvPath)
	if err != nil {
		return predCSVColumns{}, nil, 0, fmt.Errorf("failed to open CSV file: %w", err)
	}
	defer func() { _ = file.Close() }()

	reader := csv.NewReader(file)
	reader.ReuseRecord = true

	header, err := reader.Read()
	if err != nil {
		return predCSVColumns{}, nil, 0, fmt.Errorf("failed to read CSV header: %w", err)
	}

	cols, err := findPredCSVColumns(header)
	if err != nil {
		return predCSVColumns{}, nil, 0, err
	}

	detections, clipDuration, err := readPredCSVRows(reader, cols)
	if err != nil {
		return predCSVColumns{}, nil, 0, err
	}

	return cols, detections, clipDuration, nil
}

// predCSVColumns holds the column indices for a predictions CSV
type predCSVColumns struct {
	fileIdx      int
	startTimeIdx int
	endTimeIdx   int
	ebirdCodes   []string
	ebirdIdx     []int
}

// findPredCSVColumns parses the CSV header to find column indices
func findPredCSVColumns(header []string) (predCSVColumns, error) {
	cols := predCSVColumns{
		fileIdx:      -1,
		startTimeIdx: -1,
		endTimeIdx:   -1,
	}

	ignoredColumns := map[string]bool{"NotKiwi": true, "0.0": true}

	for i, col := range header {
		switch col {
		case "file":
			cols.fileIdx = i
		case "start_time":
			cols.startTimeIdx = i
		case "end_time":
			cols.endTimeIdx = i
		default:
			if ignoredColumns[col] {
				continue
			}
			cols.ebirdCodes = append(cols.ebirdCodes, col)
			cols.ebirdIdx = append(cols.ebirdIdx, i)
		}
	}

	if cols.fileIdx == -1 || cols.startTimeIdx == -1 || cols.endTimeIdx == -1 {
		return cols, fmt.Errorf("CSV must have 'file', 'start_time', and 'end_time' columns")
	}
	if len(cols.ebirdCodes) == 0 {
		return cols, fmt.Errorf("CSV must have at least one ebird code column")
	}
	return cols, nil
}

// readPredCSVRows reads all CSV data rows and returns detections grouped by file+species, plus clip duration
func readPredCSVRows(reader *csv.Reader, cols predCSVColumns) (map[predFileSpeciesKey][]float64, float64, error) {
	detections := make(map[predFileSpeciesKey][]float64)
	clipDuration := 0.0

	record, err := reader.Read()
	if err == io.EOF {
		return detections, 0, nil
	}
	if err != nil {
		return nil, 0, fmt.Errorf("failed to read first CSV row: %w", err)
	}

	startTime, _ := strconv.ParseFloat(record[cols.startTimeIdx], 64)
	endTime, _ := strconv.ParseFloat(record[cols.endTimeIdx], 64)
	clipDuration = endTime - startTime

	addDetectionsFromRow(record, cols, startTime, detections)

	for {
		record, err := reader.Read()
		if err == io.EOF {
			break
		}
		if err != nil {
			return nil, 0, fmt.Errorf("failed to read CSV row: %w", err)
		}

		startTime, _ = strconv.ParseFloat(record[cols.startTimeIdx], 64)
		addDetectionsFromRow(record, cols, startTime, detections)
	}

	return detections, clipDuration, nil
}

// addDetectionsFromRow adds positive detections from a single CSV row
func addDetectionsFromRow(record []string, cols predCSVColumns, startTime float64, detections map[predFileSpeciesKey][]float64) {
	fileName := record[cols.fileIdx]
	for i, idx := range cols.ebirdIdx {
		if record[idx] == "1" {
			key := predFileSpeciesKey{File: fileName, EbirdCode: cols.ebirdCodes[i]}
			detections[key] = append(detections[key], startTime)
		}
	}
}

// clusterDetections groups detections into clusters and produces sorted ClusteredCalls
func clusterDetections(detections map[predFileSpeciesKey][]float64, clipDuration, gapThreshold float64, minDetections int) ([]ClusteredCall, map[string]int) {
	var allCalls []ClusteredCall
	speciesCount := make(map[string]int)

	for key, startTimes := range detections {
		sort.Float64s(startTimes)

		clusters := clusterStartTimes(startTimes, gapThreshold)

		for _, cluster := range clusters {
			if len(cluster) <= minDetections {
				continue
			}

			call := ClusteredCall{
				File:      key.File,
				StartTime: cluster[0],
				EndTime:   cluster[len(cluster)-1] + clipDuration,
				EbirdCode: key.EbirdCode,
				Segments:  len(cluster),
			}
			allCalls = append(allCalls, call)
			speciesCount[key.EbirdCode]++
		}
	}

	sort.Slice(allCalls, func(i, j int) bool {
		if allCalls[i].File != allCalls[j].File {
			return allCalls[i].File < allCalls[j].File
		}
		return allCalls[i].StartTime < allCalls[j].StartTime
	})

	return allCalls, speciesCount
}

// DirCache caches directory entries for fast WAV file lookup.
// Scans the directory once and builds a map from lowercased basename to full filename.
// Safe for concurrent read-only use after construction.
type DirCache struct {
	dir    string
	wavMap map[string]string // lowercase basename -> filename with original case (e.g. "20230610_150000" -> "20230610_150000.WAV")
	dirMap map[string]string // lowercase basename -> filename for any file (used by from-raven for .selections.txt etc.)
}

// NewDirCache creates a DirCache by scanning the directory once.
func NewDirCache(dir string) *DirCache {
	entries, err := os.ReadDir(dir)
	if err != nil {
		return &DirCache{dir: dir, wavMap: make(map[string]string), dirMap: make(map[string]string)}
	}
	wavMap := make(map[string]string, len(entries))
	dirMap := make(map[string]string, len(entries))
	for _, entry := range entries {
		if entry.IsDir() {
			continue
		}
		name := entry.Name()
		ext := filepath.Ext(name)
		base := strings.TrimSuffix(name, ext)
		dirMap[strings.ToLower(base)] = name
		if strings.EqualFold(ext, ".wav") {
			wavMap[strings.ToLower(base)] = name
		}
	}
	return &DirCache{dir: dir, wavMap: wavMap, dirMap: dirMap}
}

// FindWAV looks up a WAV file by basename (case-insensitive).
// Returns the full path with correct case, or empty string if not found.
func (dc *DirCache) FindWAV(baseName string) string {
	if name, ok := dc.wavMap[strings.ToLower(baseName)]; ok {
		return filepath.Join(dc.dir, name)
	}
	return ""
}

// FindFile looks up any file by basename (case-insensitive).
// Returns the full path with correct case, or empty string if not found.
func (dc *DirCache) FindFile(baseName string) string {
	if name, ok := dc.dirMap[strings.ToLower(baseName)]; ok {
		return filepath.Join(dc.dir, name)
	}
	return ""
}

// findWAVFile finds a WAV file in the directory with case-insensitive matching.
// baseName is the filename without extension (e.g., "20230610_150000").
// Returns the full path with correct case, or empty string if not found.
// Deprecated: Use DirCache.FindWAV for batch operations to avoid repeated directory scans.
func findWAVFile(dir, baseName string) string {
	entries, err := os.ReadDir(dir)
	if err != nil {
		return ""
	}
	for _, entry := range entries {
		if entry.IsDir() {
			continue
		}
		name := entry.Name()
		ext := filepath.Ext(name)
		nameNoExt := strings.TrimSuffix(name, ext)
		if nameNoExt == baseName && strings.EqualFold(ext, ".wav") {
			return filepath.Join(dir, name)
		}
	}
	return ""
}

// writeDotFiles writes AviaNZ .data files for each audio file with calls
// Uses parallel workers for improved performance on large batches
func writeDotFiles(csvPath, filter string, calls []ClusteredCall, progress ProgressHandler) (int, int, error) {
	// Base directory is the directory containing the CSV file
	csvDir := filepath.Dir(csvPath)

	// Group calls by file (using extracted filename)
	callsByFile := make(map[string][]ClusteredCall)
	for _, call := range calls {
		filename := filepath.Base(call.File)
		callsByFile[filename] = append(callsByFile[filename], call)
	}

	// Report initial progress
	if progress != nil {
		progress(0, len(callsByFile), "Processing WAV files")
	}

	// If small batch, process sequentially (avoid goroutine overhead)
	if len(callsByFile) < 10 {
		return writeDotFilesSequential(csvDir, filter, callsByFile, progress)
	}

	// Parallel processing for larger batches
	return writeDotFilesParallel(csvDir, filter, callsByFile, progress)
}

// dotDataJob represents a single file to process
type dotDataJob struct {
	filename  string
	fileCalls []ClusteredCall
}

// dotDataResult represents the result of processing a single file
type dotDataResult struct {
	filename string
	written  bool
	err      error
}

// writeDotFilesSequential processes files one at a time (for small batches)
func writeDotFilesSequential(csvDir, filter string, callsByFile map[string][]ClusteredCall, progress ProgressHandler) (int, int, error) {
	dataFilesWritten := 0
	dataFilesSkipped := 0
	total := len(callsByFile)
	processed := 0

	for filename, fileCalls := range callsByFile {
		// Find WAV file with correct case
		baseName := strings.TrimSuffix(filename, filepath.Ext(filename))
		wavPath := findWAVFile(csvDir, baseName)
		if wavPath == "" {
			dataFilesSkipped++
			processed++
			if progress != nil {
				progress(processed, total, "")
			}
			continue
		}

		dataPath := wavPath + ".data"

		sampleRate, duration, err := utils.ParseWAVHeaderMinimal(wavPath)
		if err != nil {
			dataFilesSkipped++
			processed++
			if progress != nil {
				progress(processed, total, "")
			}
			continue
		}

		// Build segments and metadata
		meta, segments := buildAviaNZMetaAndSegments(fileCalls, filter, duration, sampleRate)

		if err := writeDotDataFileSafe(dataPath, segments, filter, meta); err != nil {
			return dataFilesWritten, dataFilesSkipped, fmt.Errorf("failed to write %s: %w", dataPath, err)
		}

		dataFilesWritten++
		processed++
		if progress != nil {
			progress(processed, total, "")
		}
	}

	return dataFilesWritten, dataFilesSkipped, nil
}

// writeDotFilesParallel processes files concurrently using a worker pool
func writeDotFilesParallel(csvDir, filter string, callsByFile map[string][]ClusteredCall, progress ProgressHandler) (int, int, error) {
	total := len(callsByFile)
	var processed atomic.Int32

	// Create job channel
	jobs := make(chan dotDataJob, len(callsByFile))
	results := make(chan dotDataResult, len(callsByFile))

	// Start workers
	var wg sync.WaitGroup
	for range DOT_DATA_WORKERS {
		wg.Add(1)
		go dotDataWorker(csvDir, filter, jobs, results, &wg)
	}

	// Send jobs
	for filename, fileCalls := range callsByFile {
		jobs <- dotDataJob{filename: filename, fileCalls: fileCalls}
	}
	close(jobs)

	// Wait for workers to finish
	go func() {
		wg.Wait()
		close(results)
	}()

	// Collect results with progress reporting
	dataFilesWritten := 0
	dataFilesSkipped := 0
	var firstErr error

	for result := range results {
		if result.err != nil && firstErr == nil {
			firstErr = result.err
		}
		if result.written {
			dataFilesWritten++
		} else {
			dataFilesSkipped++
		}

		// Report progress
		if progress != nil {
			current := int(processed.Add(1))
			progress(current, total, "")
		}
	}

	return dataFilesWritten, dataFilesSkipped, firstErr
}

// dotDataWorker processes files from the jobs channel
func dotDataWorker(csvDir, filter string, jobs <-chan dotDataJob, results chan<- dotDataResult, wg *sync.WaitGroup) {
	defer wg.Done()

	for job := range jobs {
		// Find WAV file with correct case
		baseName := strings.TrimSuffix(job.filename, filepath.Ext(job.filename))
		wavPath := findWAVFile(csvDir, baseName)
		if wavPath == "" {
			results <- dotDataResult{filename: job.filename, written: false, err: nil}
			continue
		}

		dataPath := wavPath + ".data"

		sampleRate, duration, err := utils.ParseWAVHeaderMinimal(wavPath)
		if err != nil {
			results <- dotDataResult{filename: job.filename, written: false, err: nil}
			continue
		}

		// Build segments and metadata
		meta, segments := buildAviaNZMetaAndSegments(job.fileCalls, filter, duration, sampleRate)

		if err := writeDotDataFileSafe(dataPath, segments, filter, meta); err != nil {
			results <- dotDataResult{filename: job.filename, written: false, err: fmt.Errorf("failed to write %s: %w", dataPath, err)}
			continue
		}

		results <- dotDataResult{filename: job.filename, written: true, err: nil}
	}
}

// buildAviaNZMetaAndSegments creates metadata and segments for a .data file
func buildAviaNZMetaAndSegments(calls []ClusteredCall, filter string, duration float64, sampleRate int) (AviaNZMeta, []AviaNZSegment) {
	// Create metadata
	reviewer := "None"
	meta := AviaNZMeta{
		Operator: "Auto",
		Reviewer: &reviewer,
		Duration: duration,
	}

	// Build segments array
	var segments []AviaNZSegment
	for _, call := range calls {
		// Create labels for this segment
		labels := []AviaNZLabel{
			{
				Species:   call.EbirdCode,
				Certainty: DEFAULT_CERTAINTY,
				Filter:    filter,
			},
		}

		// Create segment: [start, end, freq_low, freq_high, labels]
		// freq_low=0, freq_high=sampleRate for full-band segments
		segment := AviaNZSegment{
			call.StartTime,
			call.EndTime,
			0,          // freq_low
			sampleRate, // freq_high (full band)
			labels,
		}
		segments = append(segments, segment)
	}

	return meta, segments
}

// writeAviaNZDataFile writes a new .data file to disk (does not check for existing files)
func writeAviaNZDataFile(path string, data []any) error {
	file, err := os.Create(path)
	if err != nil {
		return fmt.Errorf("failed to create file: %w", err)
	}
	defer func() { _ = file.Close() }()

	encoder := json.NewEncoder(file)
	encoder.SetIndent("", "") // No indentation for compact output

	if err := encoder.Encode(data); err != nil {
		return fmt.Errorf("failed to encode JSON: %w", err)
	}

	return nil
}

// writeDotDataFileSafe safely writes or merges .data files
// - If file doesn't exist: write new file
// - If file exists with same filter: return error (refuse to clobber)
// - If file exists with different filter: merge segments and write
// - If file exists but can't be parsed: return error (refuse to clobber)
func writeDotDataFileSafe(path string, newSegments []AviaNZSegment, filter string, meta AviaNZMeta) error {
	// Check if file exists
	if _, err := os.Stat(path); err == nil {
		// File exists - parse and check
		existing, err := utils.ParseDataFile(path)
		if err != nil {
			return fmt.Errorf("cannot parse existing %s: %w (refusing to clobber)", path, err)
		}

		// Check for duplicate filter
		for _, seg := range existing.Segments {
			if seg.HasFilterLabel(filter) {
				return fmt.Errorf("%s already contains filter '%s' (refusing to clobber)", path, filter)
			}
		}

		// Append new segments (different filter - safe to merge)
		for _, newSeg := range newSegments {
			seg := convertAviaNZSegment(newSeg, filter)
			existing.Segments = append(existing.Segments, seg)
		}

		// Sort by start time
		sort.Slice(existing.Segments, func(i, j int) bool {
			return existing.Segments[i].StartTime < existing.Segments[j].StartTime
		})

		return existing.Write(path)
	}

	// File doesn't exist - write new
	data := buildDataFileFromSegments(meta, newSegments)
	return writeAviaNZDataFile(path, data)
}

// convertAviaNZSegment converts an AviaNZSegment to utils.Segment
func convertAviaNZSegment(seg AviaNZSegment, filter string) *utils.Segment {
	labels := seg[4].([]AviaNZLabel)
	utilsLabels := make([]*utils.Label, len(labels))
	for i, l := range labels {
		utilsLabels[i] = &utils.Label{
			Species:   l.Species,
			Certainty: l.Certainty,
			Filter:    filter,
		}
	}

	// Handle freq values (could be int or float64 depending on how they were created)
	var freqLow, freqHigh float64
	switch v := seg[2].(type) {
	case int:
		freqLow = float64(v)
	case float64:
		freqLow = v
	}
	switch v := seg[3].(type) {
	case int:
		freqHigh = float64(v)
	case float64:
		freqHigh = v
	}

	return &utils.Segment{
		StartTime: seg[0].(float64),
		EndTime:   seg[1].(float64),
		FreqLow:   freqLow,
		FreqHigh:  freqHigh,
		Labels:    utilsLabels,
	}
}

// buildDataFileFromSegments builds the data file structure from meta and segments
func buildDataFileFromSegments(meta AviaNZMeta, segments []AviaNZSegment) []any {
	result := make([]any, 0, 1+len(segments))
	result = append(result, meta)
	for _, seg := range segments {
		result = append(result, seg)
	}
	return result
}

// ParseFilterFromFilename extracts filter name from preds CSV filename
// "predsST_opensoundscape-kiwi-1.2_2025-11-12.csv" -> "opensoundscape-kiwi-1.2"
// Returns empty string if parsing fails
func ParseFilterFromFilename(csvPath string) string {
	filename := filepath.Base(csvPath)
	// Remove .csv extension
	name := strings.TrimSuffix(filename, ".csv")

	// Split on underscore
	parts := strings.Split(name, "_")
	if len(parts) == 3 {
		return parts[1]
	}

	return ""
}

// clusterStartTimes groups consecutive start times into clusters
// where the gap between consecutive times is <= gapThreshold
func clusterStartTimes(startTimes []float64, gapThreshold float64) [][]float64 {
	if len(startTimes) == 0 {
		return nil
	}

	var clusters [][]float64
	currentCluster := []float64{startTimes[0]}

	for i := 1; i < len(startTimes); i++ {
		gap := startTimes[i] - startTimes[i-1]
		if gap <= gapThreshold {
			// Same cluster
			currentCluster = append(currentCluster, startTimes[i])
		} else {
			// New cluster
			clusters = append(clusters, currentCluster)
			currentCluster = []float64{startTimes[i]}
		}
	}
	// Don't forget the last cluster
	clusters = append(clusters, currentCluster)

	return clusters
}