package tools

import (
	"fmt"
	"os"
	"path/filepath"
	"sort"
	"sync/atomic"
)

// parallelResult is the common interface for birda/raven worker results.
type parallelResult interface {
	filePath() string
	getCalls() []ClusteredCall
	wasWritten() bool
	wasSkipped() bool
	getError() error
}

// aggregateStats holds the collected results from a parallel fan-out/fan-in.
type aggregateStats struct {
	calls            []ClusteredCall
	speciesCount     map[string]int
	dataFilesWritten int
	dataFilesSkipped int
	filesProcessed   int
	filesDeleted     int
	firstErr         error
}

// aggregateResults collects results from a channel of parallelResult values,
// handling error tracking, species counting, optional file deletion, and
// progress reporting. Returns the aggregated stats.
func aggregateResults(
	results <-chan parallelResult,
	total int,
	processed *atomic.Int32,
	deleteFiles bool,
	progressHandler func(int, int, string),
) aggregateStats {
	var stats aggregateStats
	stats.speciesCount = make(map[string]int)

	for result := range results {
		if err := result.getError(); err != nil && stats.firstErr == nil {
			stats.firstErr = err
		}

		if result.wasWritten() {
			stats.dataFilesWritten++
		}
		if result.wasSkipped() {
			stats.dataFilesSkipped++
		}

		for _, call := range result.getCalls() {
			stats.calls = append(stats.calls, call)
			stats.speciesCount[call.EbirdCode]++
		}

		stats.filesProcessed++

		stats.maybeDeleteFile(deleteFiles, result)

		if progressHandler != nil {
			current := int(processed.Add(1))
			progressHandler(current, total, filepath.Base(result.filePath()))
		}
	}

	return stats
}

// maybeDeleteFile deletes the source file if requested and it was successfully processed.
func (s *aggregateStats) maybeDeleteFile(deleteFiles bool, result parallelResult) {
	if !deleteFiles || !result.wasWritten() {
		return
	}
	if err := os.Remove(result.filePath()); err != nil {
		if s.firstErr == nil {
			s.firstErr = fmt.Errorf("failed to delete %s: %w", result.filePath(), err)
		}
	} else {
		s.filesDeleted++
	}
}

// sortCallsByFileAndTime sorts calls by filename, then start time.
func sortCallsByFileAndTime(calls []ClusteredCall) {
	sort.Slice(calls, func(i, j int) bool {
		if calls[i].File != calls[j].File {
			return calls[i].File < calls[j].File
		}
		return calls[i].StartTime < calls[j].StartTime
	})
}