package utils

import (
	"math"
	"math/rand"
	"testing"

	"github.com/madelynnblue/go-dsp/fft"
)

// referencepower computes the power spectrum using go-dsp as ground truth.
func referencePower(samples []float64) []float64 {
	result := fft.FFTReal(samples)
	n := len(samples)
	numBins := n/2 + 1
	power := make([]float64, numBins)
	for k := range numBins {
		re := real(result[k])
		im := imag(result[k])
		power[k] = re*re + im*im
	}
	return power
}

func TestPowerSpectrumFFT_Sinusoid(t *testing.T) {
	// 512-point FFT of a pure 1kHz sine at 16kHz sample rate
	// Expected: peak at bin k = 1000 * 512 / 16000 = 32
	n := 512
	sampleRate := 16000.0
	freq := 1000.0

	samples := make([]float64, n)
	for i := range samples {
		samples[i] = math.Sin(2.0 * math.Pi * freq * float64(i) / sampleRate)
	}

	power := make([]float64, n/2+1)
	scratch := make([]complex128, n)
	PowerSpectrumFFT(samples, power, scratch)

	// Find peak bin
	maxBin := 0
	maxVal := 0.0
	for k, v := range power {
		if v > maxVal {
			maxVal = v
			maxBin = k
		}
	}

	expectedBin := int(freq * float64(n) / sampleRate)
	if maxBin != expectedBin {
		t.Errorf("peak at bin %d, expected %d", maxBin, expectedBin)
	}

	// Compare against reference
	ref := referencePower(samples)
	for k := range power {
		if math.Abs(power[k]-ref[k]) > 1e-6*math.Abs(ref[k])+1e-10 {
			t.Errorf("bin %d: got %g, ref %g", k, power[k], ref[k])
		}
	}
}

func TestPowerSpectrumFFT_Random(t *testing.T) {
	n := 512
	rng := rand.New(rand.NewSource(42))

	samples := make([]float64, n)
	for i := range samples {
		samples[i] = rng.Float64()*2 - 1
	}

	power := make([]float64, n/2+1)
	scratch := make([]complex128, n)
	PowerSpectrumFFT(samples, power, scratch)

	ref := referencePower(samples)
	for k := range power {
		relErr := math.Abs(power[k]-ref[k]) / (math.Abs(ref[k]) + 1e-15)
		if relErr > 1e-8 {
			t.Errorf("bin %d: got %g, ref %g (relErr=%g)", k, power[k], ref[k], relErr)
		}
	}
}

func TestPowerSpectrumFFT_DC(t *testing.T) {
	n := 512
	samples := make([]float64, n)
	for i := range samples {
		samples[i] = 1.0
	}

	power := make([]float64, n/2+1)
	scratch := make([]complex128, n)
	PowerSpectrumFFT(samples, power, scratch)

	ref := referencePower(samples)
	for k := range power {
		if math.Abs(power[k]-ref[k]) > 1e-6 {
			t.Errorf("bin %d: got %g, ref %g", k, power[k], ref[k])
		}
	}

	// DC bin should have all the energy
	if power[0] < power[1]*1000 {
		t.Errorf("DC bin should dominate: power[0]=%g, power[1]=%g", power[0], power[1])
	}
}

func TestPowerSpectrumFFT_Silence(t *testing.T) {
	n := 512
	samples := make([]float64, n)

	power := make([]float64, n/2+1)
	scratch := make([]complex128, n)
	PowerSpectrumFFT(samples, power, scratch)

	for k, v := range power {
		if v != 0 {
			t.Errorf("bin %d: expected 0, got %g", k, v)
		}
	}
}

func TestPowerSpectrumFFT_Impulse(t *testing.T) {
	n := 512
	samples := make([]float64, n)
	samples[0] = 1.0

	power := make([]float64, n/2+1)
	scratch := make([]complex128, n)
	PowerSpectrumFFT(samples, power, scratch)

	ref := referencePower(samples)
	for k := range power {
		if math.Abs(power[k]-ref[k]) > 1e-10 {
			t.Errorf("bin %d: got %g, ref %g", k, power[k], ref[k])
		}
	}

	// Impulse: flat power spectrum, all bins should be equal (= 1.0)
	for k, v := range power {
		if math.Abs(v-1.0) > 1e-10 {
			t.Errorf("bin %d: expected ~1.0, got %g", k, v)
		}
	}
}

func TestPowerSpectrumFFT_DifferentSizes(t *testing.T) {
	rng := rand.New(rand.NewSource(99))

	for _, n := range []int{2, 4, 8, 16, 64, 256, 1024} {
		samples := make([]float64, n)
		for i := range samples {
			samples[i] = rng.Float64()*2 - 1
		}

		power := make([]float64, n/2+1)
		scratch := make([]complex128, n)
		PowerSpectrumFFT(samples, power, scratch)

		ref := referencePower(samples)
		for k := range power {
			relErr := math.Abs(power[k]-ref[k]) / (math.Abs(ref[k]) + 1e-15)
			if relErr > 1e-8 {
				t.Errorf("n=%d bin %d: got %g, ref %g (relErr=%g)", n, k, power[k], ref[k], relErr)
			}
		}
	}
}

func BenchmarkPowerSpectrumFFT_512(b *testing.B) {
	n := 512
	rng := rand.New(rand.NewSource(42))
	samples := make([]float64, n)
	for i := range samples {
		samples[i] = rng.Float64()*2 - 1
	}
	power := make([]float64, n/2+1)
	scratch := make([]complex128, n)

	b.ResetTimer()
	for range b.N {
		PowerSpectrumFFT(samples, power, scratch)
	}
}

func BenchmarkGodsFFTReal_512(b *testing.B) {
	n := 512
	rng := rand.New(rand.NewSource(42))
	samples := make([]float64, n)
	for i := range samples {
		samples[i] = rng.Float64()*2 - 1
	}

	b.ResetTimer()
	for range b.N {
		fft.FFTReal(samples)
	}
}