package utils

import (
	"math"
	"testing"
)

// Reference values verified against opensoundscape.utils.generate_clip_times_df
// at https://github.com/kitzeslab/opensoundscape/blob/master/opensoundscape/utils.py

func TestGenerateClipTimes_FullModeBasic(t *testing.T) {
	// full_duration=10, clip_duration=4, overlap=0.5, final="full"
	// increment = 3.5
	// raw starts: 0, 3.5, 7   (next would be 10.5 ≥ 10)
	// raw ends:   4, 7.5, 11
	// "full": last clip start shifts back by (11-10)=1 → start=6, end=10
	// → [(0,4), (3.5,7.5), (6,10)]
	got, err := GenerateClipTimes(10, 4, 0.5, FinalClipFull, 10)
	if err != nil {
		t.Fatal(err)
	}
	want := []ClipWindow{{0, 4}, {3.5, 7.5}, {6, 10}}
	assertClips(t, got, want)
}

func TestGenerateClipTimes_NoneMode(t *testing.T) {
	// final="none": drop any clip whose end > full_duration.
	// full=10, dur=4, overlap=0: starts 0,4,8; ends 4,8,12 → keep (0,4),(4,8)
	got, err := GenerateClipTimes(10, 4, 0, FinalClipNone, 10)
	if err != nil {
		t.Fatal(err)
	}
	assertClips(t, got, []ClipWindow{{0, 4}, {4, 8}})
}

func TestGenerateClipTimes_RemainderMode(t *testing.T) {
	// full=10, dur=4, overlap=0: starts 0,4,8; ends 4,8,12
	// remainder: trim 12 → 10. → (0,4),(4,8),(8,10)
	got, err := GenerateClipTimes(10, 4, 0, FinalClipRemainder, 10)
	if err != nil {
		t.Fatal(err)
	}
	assertClips(t, got, []ClipWindow{{0, 4}, {4, 8}, {8, 10}})
}

func TestGenerateClipTimes_ExtendMode(t *testing.T) {
	got, err := GenerateClipTimes(10, 4, 0, FinalClipExtend, 10)
	if err != nil {
		t.Fatal(err)
	}
	assertClips(t, got, []ClipWindow{{0, 4}, {4, 8}, {8, 12}})
}

func TestGenerateClipTimes_AudioShorterThanClip(t *testing.T) {
	// full=2, dur=4, overlap=0, final="full":
	// raw start=0, end=4; end > full=2 → start shifts to 0-(4-2)=-2 → clamped to 0;
	// end=2 → single clip (0,2)
	got, err := GenerateClipTimes(2, 4, 0, FinalClipFull, 10)
	if err != nil {
		t.Fatal(err)
	}
	assertClips(t, got, []ClipWindow{{0, 2}})
}

func TestGenerateClipTimes_DedupAfterFullShift(t *testing.T) {
	// full=8, dur=4, overlap=0:
	// raw starts 0,4; ends 4,8 — no shift needed; output (0,4),(4,8).
	// (Tests the no-duplicate path.)
	got, err := GenerateClipTimes(8, 4, 0, FinalClipFull, 10)
	if err != nil {
		t.Fatal(err)
	}
	assertClips(t, got, []ClipWindow{{0, 4}, {4, 8}})
}

func TestGenerateClipTimes_InvalidArgs(t *testing.T) {
	_, err := GenerateClipTimes(10, 0, 0, FinalClipFull, 10)
	if err == nil {
		t.Error("expected error for clip_duration=0")
	}
	_, err = GenerateClipTimes(10, 4, 4, FinalClipFull, 10)
	if err == nil {
		t.Error("expected error for clip_overlap >= clip_duration")
	}
	_, err = GenerateClipTimes(0, 4, 0, FinalClipFull, 10)
	if err == nil {
		t.Error("expected error for full_duration=0")
	}
}

func TestParseFinalClipMode(t *testing.T) {
	tests := []struct {
		input string
		want  FinalClipMode
		err   bool
	}{
		{"none", FinalClipNone, false},
		{"", FinalClipNone, false},
		{"remainder", FinalClipRemainder, false},
		{"full", FinalClipFull, false},
		{"extend", FinalClipExtend, false},
		{"invalid", 0, true},
		{"FULL", 0, true}, // case-sensitive
	}
	for _, tt := range tests {
		t.Run(tt.input, func(t *testing.T) {
			got, err := ParseFinalClipMode(tt.input)
			if tt.err {
				if err == nil {
					t.Error("expected error")
				}
			} else {
				if err != nil {
					t.Errorf("unexpected error: %v", err)
				}
				if got != tt.want {
					t.Errorf("got %d, want %d", got, tt.want)
				}
			}
		})
	}
}

func assertClips(t *testing.T, got, want []ClipWindow) {
	t.Helper()
	if len(got) != len(want) {
		t.Fatalf("len(got)=%d, len(want)=%d\ngot=%v\nwant=%v", len(got), len(want), got, want)
	}
	for i := range got {
		if math.Abs(got[i].Start-want[i].Start) > 1e-9 || math.Abs(got[i].End-want[i].End) > 1e-9 {
			t.Errorf("clip %d: got (%v,%v), want (%v,%v)", i, got[i].Start, got[i].End, want[i].Start, want[i].End)
		}
	}
}