package utils

import (
	"math"
	"testing"
)

func TestResampleRate(t *testing.T) {
	t.Run("should return same samples for same rate", func(t *testing.T) {
		samples := []float64{0.1, 0.2, 0.3, 0.4, 0.5}
		result := ResampleRate(samples, 16000, 16000)
		if len(result) != len(samples) {
			t.Errorf("length mismatch: got %d, want %d", len(result), len(samples))
		}
		for i := range samples {
			if result[i] != samples[i] {
				t.Errorf("sample %d mismatch: got %f, want %f", i, result[i], samples[i])
			}
		}
	})

	t.Run("should downsample from 250000 to 16000", func(t *testing.T) {
		// 250000 / 16000 = 15.625 ratio
		samples := make([]float64, 2500) // 0.01 seconds at 250kHz
		for i := range samples {
			samples[i] = float64(i) / float64(len(samples))
		}
		result := ResampleRate(samples, 250000, 16000)
		expectedLen := 160 // 0.01 seconds at 16kHz
		if len(result) != expectedLen {
			t.Errorf("length mismatch: got %d, want %d", len(result), expectedLen)
		}
	})

	t.Run("should downsample from 44100 to 16000", func(t *testing.T) {
		// 44100 / 16000 = 2.75625 ratio
		samples := make([]float64, 441) // 0.01 seconds at 44.1kHz
		for i := range samples {
			samples[i] = float64(i) / float64(len(samples))
		}
		result := ResampleRate(samples, 44100, 16000)
		expectedLen := 160 // 0.01 seconds at 16kHz
		if len(result) != expectedLen {
			t.Errorf("length mismatch: got %d, want %d", len(result), expectedLen)
		}
	})

	t.Run("should preserve signal shape", func(t *testing.T) {
		// Create a simple ramp signal
		samples := []float64{0.0, 0.25, 0.5, 0.75, 1.0}
		result := ResampleRate(samples, 50000, 16000)
		// Should still be a roughly increasing signal
		for i := 1; i < len(result); i++ {
			if result[i] < result[i-1]-0.1 {
				t.Errorf("signal not preserved: result[%d]=%f < result[%d]=%f", i, result[i], i-1, result[i-1])
			}
		}
	})

	t.Run("should handle empty samples", func(t *testing.T) {
		result := ResampleRate([]float64{}, 44100, 16000)
		if len(result) != 0 {
			t.Errorf("expected empty result, got %d samples", len(result))
		}
	})
}

func TestResample(t *testing.T) {
	t.Run("should return same samples for speed 1.0", func(t *testing.T) {
		samples := []float64{0.1, 0.2, 0.3, 0.4, 0.5}
		result := Resample(samples, 1.0)
		if len(result) != len(samples) {
			t.Errorf("length mismatch: got %d, want %d", len(result), len(samples))
		}
		for i := range samples {
			if result[i] != samples[i] {
				t.Errorf("sample %d mismatch: got %f, want %f", i, result[i], samples[i])
			}
		}
	})

	t.Run("should double samples for half speed", func(t *testing.T) {
		samples := []float64{0.0, 1.0, 0.0, -1.0, 0.0}
		result := Resample(samples, 0.5)
		// Half speed = 2x more samples
		expectedLen := len(samples) * 2
		if len(result) != expectedLen {
			t.Errorf("length mismatch: got %d, want %d", len(result), expectedLen)
		}
	})

	t.Run("should halve samples for double speed", func(t *testing.T) {
		samples := []float64{0.0, 0.5, 1.0, 0.5, 0.0, -0.5, -1.0, -0.5, 0.0}
		result := Resample(samples, 2.0)
		// Double speed = half the samples
		expectedLen := len(samples) / 2
		if len(result) != expectedLen {
			t.Errorf("length mismatch: got %d, want %d", len(result), expectedLen)
		}
	})

	t.Run("should use linear interpolation", func(t *testing.T) {
		// With samples [0, 1], half-speed should interpolate to [0, 0.5, 1]
		samples := []float64{0.0, 1.0}
		result := Resample(samples, 0.5)
		// Expected: 4 samples (2 / 0.5 = 4)
		if len(result) != 4 {
			t.Errorf("length mismatch: got %d, want 4", len(result))
		}
		// Check interpolation: index 1 should be ~0.5 (midpoint)
		expected := 0.5
		if math.Abs(result[1]-expected) > 0.01 {
			t.Errorf("interpolated value mismatch: got %f, want ~%f", result[1], expected)
		}
	})

	t.Run("should handle empty samples", func(t *testing.T) {
		result := Resample([]float64{}, 0.5)
		if len(result) != 0 {
			t.Errorf("expected empty result, got %d samples", len(result))
		}
	})

	t.Run("should handle single sample", func(t *testing.T) {
		samples := []float64{0.5}
		result := Resample(samples, 0.5)
		// 1 / 0.5 = 2 samples
		if len(result) != 2 {
			t.Errorf("length mismatch: got %d, want 2", len(result))
		}
	})
}

func TestResampleQuality(t *testing.T) {
	t.Run("should preserve zero crossings", func(t *testing.T) {
		// Sine wave: should have zero crossings at multiples of pi
		sampleRate := 1000
		samples := make([]float64, sampleRate)
		for i := range samples {
			samples[i] = math.Sin(2 * math.Pi * float64(i) / float64(sampleRate))
		}

		// Resample to half speed
		result := Resample(samples, 0.5)

		// First sample should still be ~0 (sine at 0)
		if math.Abs(result[0]) > 0.01 {
			t.Errorf("first sample not near zero: got %f", result[0])
		}

		// Peak should still be ~1.0 (sine max)
		peakFound := false
		for _, s := range result {
			if math.Abs(s-1.0) < 0.1 {
				peakFound = true
				break
			}
		}
		if !peakFound {
			t.Error("peak not preserved in resampled signal")
		}
	})
}