package utils

import (
	"fmt"
	"math"
)

// ClipWindow is a fixed-duration time window for one audio file.
type ClipWindow struct {
	Start float64
	End   float64
}

// FinalClipMode controls how the trailing partial clip is handled.
// Mirrors opensoundscape.utils.generate_clip_times_df:
//   - FinalClipNone:      discard any clip whose end exceeds full_duration
//   - FinalClipRemainder: trim the final clip's end to full_duration (shorter clip)
//   - FinalClipFull:      shift the final clip's start back so its end equals full_duration
//   - FinalClipExtend:    keep the final clip extending beyond full_duration
type FinalClipMode int

const (
	FinalClipNone FinalClipMode = iota
	FinalClipRemainder
	FinalClipFull
	FinalClipExtend
)

// ParseFinalClipMode parses a CLI flag value.
func ParseFinalClipMode(s string) (FinalClipMode, error) {
	switch s {
	case "none", "":
		return FinalClipNone, nil
	case "remainder":
		return FinalClipRemainder, nil
	case "full":
		return FinalClipFull, nil
	case "extend":
		return FinalClipExtend, nil
	default:
		return 0, fmt.Errorf("invalid final-clip mode %q (want one of: none, remainder, full, extend)", s)
	}
}

// roundTo rounds x to `precision` decimal places. Mirrors numpy.round behaviour.
// Pass precision < 0 to skip rounding.
func roundTo(x float64, precision int) float64 {
	if precision < 0 {
		return x
	}
	scale := math.Pow(10, float64(precision))
	return math.Round(x*scale) / scale
}

// GenerateClipTimes ports opensoundscape.utils.generate_clip_times_df.
//
// Args mirror the Python signature: clipDuration > 0, clipOverlap in [0, clipDuration),
// fullDuration > 0. roundingPrecision defaults to 10 in OPSO; pass -1 to skip rounding.
//
// Result is the list of (start, end) windows for one audio file, with duplicates
// removed (which can happen under FinalClipFull when the shifted final clip
// coincides with the previous one).
func GenerateClipTimes(fullDuration, clipDuration, clipOverlap float64, finalClip FinalClipMode, roundingPrecision int) ([]ClipWindow, error) {
	if clipDuration <= 0 {
		return nil, fmt.Errorf("clipDuration must be > 0, got %v", clipDuration)
	}
	if clipOverlap < 0 || clipOverlap >= clipDuration {
		return nil, fmt.Errorf("clipOverlap must be in [0, clipDuration), got %v with clipDuration=%v", clipOverlap, clipDuration)
	}
	if fullDuration <= 0 {
		return nil, fmt.Errorf("fullDuration must be > 0, got %v", fullDuration)
	}

	starts, ends := buildClipStartsEnds(fullDuration, clipDuration, clipOverlap, roundingPrecision)

	switch finalClip {
	case FinalClipNone:
		return dedupClips(clipWindowsNone(starts, ends, fullDuration)), nil
	case FinalClipRemainder:
		return dedupClips(clipWindowsRemainder(starts, ends, fullDuration)), nil
	case FinalClipFull:
		return dedupClips(clipWindowsFull(starts, ends, fullDuration)), nil
	case FinalClipExtend:
		return dedupClips(clipWindowsExtend(starts, ends)), nil
	default:
		return nil, fmt.Errorf("invalid FinalClipMode %d", finalClip)
	}
}

// buildClipStartsEnds generates the start and end arrays for clips.
func buildClipStartsEnds(fullDuration, clipDuration, clipOverlap float64, roundingPrecision int) ([]float64, []float64) {
	increment := clipDuration - clipOverlap
	var starts []float64
	for s := 0.0; s < fullDuration; s += increment {
		starts = append(starts, roundTo(s, roundingPrecision))
	}
	if len(starts) == 0 {
		starts = []float64{0}
	}
	ends := make([]float64, len(starts))
	for i, s := range starts {
		ends[i] = s + clipDuration
	}
	return starts, ends
}

// clipWindowsNone drops any window whose end exceeds fullDuration.
func clipWindowsNone(starts, ends []float64, fullDuration float64) []ClipWindow {
	out := make([]ClipWindow, 0, len(starts))
	for i := range starts {
		if ends[i] <= fullDuration {
			out = append(out, ClipWindow{Start: starts[i], End: ends[i]})
		}
	}
	return out
}

// clipWindowsRemainder trims ends beyond fullDuration down to fullDuration.
func clipWindowsRemainder(starts, ends []float64, fullDuration float64) []ClipWindow {
	out := make([]ClipWindow, 0, len(starts))
	for i := range starts {
		e := ends[i]
		if e > fullDuration {
			e = fullDuration
		}
		out = append(out, ClipWindow{Start: starts[i], End: e})
	}
	return out
}

// clipWindowsFull shifts windows whose end exceeds fullDuration back so end == fullDuration.
func clipWindowsFull(starts, ends []float64, fullDuration float64) []ClipWindow {
	out := make([]ClipWindow, 0, len(starts))
	for i := range starts {
		s, e := starts[i], ends[i]
		if e > fullDuration {
			s -= e - fullDuration
			e = fullDuration
			if s < 0 {
				s = 0
			}
		}
		out = append(out, ClipWindow{Start: s, End: e})
	}
	return out
}

// clipWindowsExtend keeps ends as-is, even past fullDuration.
func clipWindowsExtend(starts, ends []float64) []ClipWindow {
	out := make([]ClipWindow, 0, len(starts))
	for i := range starts {
		out = append(out, ClipWindow{Start: starts[i], End: ends[i]})
	}
	return out
}

// dedupClips removes consecutive duplicates while preserving order.
// Matches pandas.DataFrame.drop_duplicates() at the end of OPSO's
// generate_clip_times_df.
func dedupClips(in []ClipWindow) []ClipWindow {
	if len(in) <= 1 {
		return in
	}
	seen := make(map[ClipWindow]bool, len(in))
	out := make([]ClipWindow, 0, len(in))
	for _, c := range in {
		if !seen[c] {
			seen[c] = true
			out = append(out, c)
		}
	}
	return out
}