package utils

import (
	"math"
	"sync"
)

// FFT twiddle factors and bit-reversal tables, cached per size.
var (
	fftCacheMu sync.RWMutex
	fftCache   = map[int]*fftPlan{}
)

// fftPlan holds pre-computed data for a given FFT size.
type fftPlan struct {
	n       int
	twiddle []complex128 // twiddle factors: exp(-2*pi*i*k/N) for k=0..N/2-1
	bitrev  []int        // bit-reversal permutation table
}

// getFFFTPlan returns a cached FFT plan for the given size (must be power of 2).
func getFFTPlan(n int) *fftPlan {
	fftCacheMu.RLock()
	if p, ok := fftCache[n]; ok {
		fftCacheMu.RUnlock()
		return p
	}
	fftCacheMu.RUnlock()

	fftCacheMu.Lock()
	defer fftCacheMu.Unlock()
	if p, ok := fftCache[n]; ok {
		return p
	}

	p := &fftPlan{n: n}

	// Compute twiddle factors: exp(-2*pi*i*k/N) for k = 0..N/2-1
	p.twiddle = make([]complex128, n/2)
	for k := range p.twiddle {
		angle := -2.0 * math.Pi * float64(k) / float64(n)
		sin, cos := math.Sincos(angle)
		p.twiddle[k] = complex(cos, sin)
	}

	// Compute bit-reversal permutation
	bits := 0
	for v := n; v > 1; v >>= 1 {
		bits++
	}
	p.bitrev = make([]int, n)
	for i := range p.bitrev {
		p.bitrev[i] = reverseBitsN(i, bits)
	}

	fftCache[n] = p
	return p
}

// reverseBitsN reverses the lowest `bits` bits of v.
func reverseBitsN(v, bits int) int {
	var r int
	for range bits {
		r = (r << 1) | (v & 1)
		v >>= 1
	}
	return r
}

// PowerSpectrumFFT computes the power spectrum of a real-valued signal using radix-2 FFT.
//
// samples: real input of length N (must be power of 2, N >= 2)
// power:   output buffer of length >= N/2+1; receives |X[k]|^2 for k=0..N/2
// scratch: working buffer of length >= N; contents are overwritten
//
// All buffers are caller-provided to enable zero-allocation across repeated calls.
func PowerSpectrumFFT(samples []float64, power []float64, scratch []complex128) {
	n := len(samples)
	plan := getFFTPlan(n)

	// Bit-reversal copy: load real samples into scratch in bit-reversed order
	for i, j := range plan.bitrev {
		scratch[j] = complex(samples[i], 0)
	}

	// Iterative Cooley-Tukey butterfly (decimation-in-time)
	for size := 2; size <= n; size <<= 1 {
		half := size >> 1
		step := n / size // twiddle index step

		for start := 0; start < n; start += size {
			tw := 0
			for j := range half {
				u := scratch[start+j]
				v := scratch[start+j+half] * plan.twiddle[tw]
				scratch[start+j] = u + v
				scratch[start+j+half] = u - v
				tw += step
			}
		}
	}

	// Extract power spectrum: |X[k]|^2 = re^2 + im^2 for k = 0..N/2
	numBins := n/2 + 1
	for k := range numBins {
		re := real(scratch[k])
		im := imag(scratch[k])
		power[k] = re*re + im*im
	}
}