package tools

import (
	"encoding/binary"
	"math"
	"os"
	"testing"

	"skraak/utils"
)

const benchWAV = "../audio/20211028_211500.WAV"


func BenchmarkReadWAV(b *testing.B) {
	b.ReportAllocs()
	for i := 0; i < b.N; i++ {
		_, _, err := utils.ReadWAVSamples(benchWAV)
		if err != nil {
			b.Fatal(err)
		}
	}
}

func BenchmarkConvertToFloat64_16bit(b *testing.B) {
	// Simulate 16-bit mono WAV data (same size as test file: 14.32M samples)
	numSamples := 14320000
	data := make([]byte, numSamples*2)
	for i := range numSamples {
		binary.LittleEndian.PutUint16(data[i*2:], uint16(i%65536))
	}
	b.ResetTimer()
	b.ReportAllocs()
	for i := 0; i < b.N; i++ {
		_ = convertToFloat64Bench(data, 16, 1)
	}
}

// Duplicate of convertToFloat64 for benchmarking (unexported in utils)
func convertToFloat64Bench(data []byte, bitsPerSample, channels int) []float64 {
	bytesPerSample := bitsPerSample / 8
	blockAlign := bytesPerSample * channels
	numSamples := len(data) / blockAlign
	samples := make([]float64, numSamples)
	for i := range numSamples {
		offset := i * blockAlign
		sample := int16(binary.LittleEndian.Uint16(data[offset : offset+2]))
		samples[i] = float64(sample) / 32768.0
	}
	return samples
}

func BenchmarkWriteWAV(b *testing.B) {
	samples, sr, _ := utils.ReadWAVSamples(benchWAV)
	segSamples := utils.ExtractSegmentSamples(samples, sr, 872, 895)
	b.Logf("segment samples=%d", len(segSamples))
	b.ResetTimer()
	b.ReportAllocs()
	for i := 0; i < b.N; i++ {
		f, _ := os.CreateTemp("", "bench_*.wav")
		utils.WriteWAVFile(f.Name(), segSamples, sr)
		f.Close()
		os.Remove(f.Name())
	}
}


func BenchmarkResampleRate_48k(b *testing.B) {
	samples, _, _ := utils.ReadWAVSamples(benchWAV)
	b.Logf("resampling %d samples 48000->16000", len(samples))
	b.ResetTimer()
	b.ReportAllocs()
	for i := 0; i < b.N; i++ {
		utils.ResampleRate(samples, 48000, 16000)
	}
}

func BenchmarkResampleRate_250k(b *testing.B) {
	samples, _, _ := utils.ReadWAVSamples(benchWAV)
	b.Logf("resampling %d samples 250000->16000", len(samples))
	b.ResetTimer()
	b.ReportAllocs()
	for i := 0; i < b.N; i++ {
		utils.ResampleRate(samples, 250000, 16000)
	}
}


func BenchmarkExtractSegment(b *testing.B) {
	samples, sr, _ := utils.ReadWAVSamples(benchWAV)
	b.Logf("full file: %d samples, sr=%d", len(samples), sr)
	b.ResetTimer()
	b.ReportAllocs()
	for i := 0; i < b.N; i++ {
		seg := utils.ExtractSegmentSamples(samples, sr, 872, 895)
		if len(seg) == 0 {
			b.Fatal("empty segment")
		}
	}
}

func BenchmarkPowerSpectrumFFT_512(b *testing.B) {
	n := 512
	samples, sr, _ := utils.ReadWAVSamples(benchWAV)
	segSamples := utils.ExtractSegmentSamples(samples, sr, 872, 895)
	frameData := make([]float64, n)
	power := make([]float64, n/2+1)
	scratch := make([]complex128, n)
	b.ResetTimer()
	b.ReportAllocs()
	for i := 0; i < b.N; i++ {
		// Simulate the windowing step (Hann) + FFT
		for j := range n {
			frameData[j] = segSamples[j] * 0.5 * (1.0 - math.Cos(2.0*math.Pi*float64(j)/float64(n-1)))
		}
		utils.PowerSpectrumFFT(frameData, power, scratch)
	}
}

func BenchmarkSpectrogram_23s(b *testing.B) {
	samples, sr, _ := utils.ReadWAVSamples(benchWAV)
	segSamples := utils.ExtractSegmentSamples(samples, sr, 872, 895)
	cfg := utils.DefaultSpectrogramConfig(16000)
	b.Logf("segment samples=%d, windowSize=%d, hopSize=%d", len(segSamples), cfg.WindowSize, cfg.HopSize)
	b.ResetTimer()
	b.ReportAllocs()
	for i := 0; i < b.N; i++ {
		spect := utils.GenerateSpectrogram(segSamples, cfg)
		if spect == nil {
			b.Fatal("nil spectrogram")
		}
	}
}

func BenchmarkSpectrogram_60s(b *testing.B) {
	samples, sr, _ := utils.ReadWAVSamples(benchWAV)
	segSamples := utils.ExtractSegmentSamples(samples, sr, 0, 60)
	cfg := utils.DefaultSpectrogramConfig(16000)
	b.Logf("60s segment samples=%d", len(segSamples))
	b.ResetTimer()
	b.ReportAllocs()
	for i := 0; i < b.N; i++ {
		spect := utils.GenerateSpectrogram(segSamples, cfg)
		if spect == nil {
			b.Fatal("nil spectrogram")
		}
	}
}


func BenchmarkCreateGrayscaleImage(b *testing.B) {
	samples, sr, _ := utils.ReadWAVSamples(benchWAV)
	segSamples := utils.ExtractSegmentSamples(samples, sr, 872, 895)
	cfg := utils.DefaultSpectrogramConfig(16000)
	spect := utils.GenerateSpectrogram(segSamples, cfg)
	b.ResetTimer()
	b.ReportAllocs()
	for i := 0; i < b.N; i++ {
		img := utils.CreateGrayscaleImage(spect)
		if img == nil {
			b.Fatal("nil image")
		}
	}
}

func BenchmarkCreateRGBImage(b *testing.B) {
	samples, sr, _ := utils.ReadWAVSamples(benchWAV)
	segSamples := utils.ExtractSegmentSamples(samples, sr, 872, 895)
	cfg := utils.DefaultSpectrogramConfig(16000)
	spect := utils.GenerateSpectrogram(segSamples, cfg)
	b.ResetTimer()
	b.ReportAllocs()
	for i := 0; i < b.N; i++ {
		colorData := utils.ApplyL4Colormap(spect)
		img := utils.CreateRGBImage(colorData)
		if img == nil {
			b.Fatal("nil image")
		}
	}
}

func BenchmarkApplyL4Colormap(b *testing.B) {
	samples, sr, _ := utils.ReadWAVSamples(benchWAV)
	segSamples := utils.ExtractSegmentSamples(samples, sr, 872, 895)
	cfg := utils.DefaultSpectrogramConfig(16000)
	spect := utils.GenerateSpectrogram(segSamples, cfg)
	b.ResetTimer()
	b.ReportAllocs()
	for i := 0; i < b.N; i++ {
		colorData := utils.ApplyL4Colormap(spect)
		if colorData == nil {
			b.Fatal("nil colormap")
		}
	}
}

func BenchmarkResizeGray224(b *testing.B) {
	samples, sr, _ := utils.ReadWAVSamples(benchWAV)
	segSamples := utils.ExtractSegmentSamples(samples, sr, 872, 895)
	cfg := utils.DefaultSpectrogramConfig(16000)
	spect := utils.GenerateSpectrogram(segSamples, cfg)
	img := utils.CreateGrayscaleImage(spect)
	b.ResetTimer()
	b.ReportAllocs()
	for i := 0; i < b.N; i++ {
		resized := utils.ResizeImage(img, 224, 224)
		if resized == nil {
			b.Fatal("nil resize")
		}
	}
}

func BenchmarkResizeGray448(b *testing.B) {
	samples, sr, _ := utils.ReadWAVSamples(benchWAV)
	segSamples := utils.ExtractSegmentSamples(samples, sr, 872, 895)
	cfg := utils.DefaultSpectrogramConfig(16000)
	spect := utils.GenerateSpectrogram(segSamples, cfg)
	img := utils.CreateGrayscaleImage(spect)
	b.ResetTimer()
	b.ReportAllocs()
	for i := 0; i < b.N; i++ {
		resized := utils.ResizeImage(img, 448, 448)
		if resized == nil {
			b.Fatal("nil resize")
		}
	}
}


func BenchmarkWritePNG_224(b *testing.B) {
	samples, sr, _ := utils.ReadWAVSamples(benchWAV)
	segSamples := utils.ExtractSegmentSamples(samples, sr, 872, 895)
	cfg := utils.DefaultSpectrogramConfig(16000)
	spect := utils.GenerateSpectrogram(segSamples, cfg)
	img := utils.CreateGrayscaleImage(spect)
	resized := utils.ResizeImage(img, 224, 224)
	b.ResetTimer()
	b.ReportAllocs()
	for i := 0; i < b.N; i++ {
		f, _ := os.CreateTemp("", "bench_*.png")
		utils.WritePNG(resized, f)
		f.Close()
		os.Remove(f.Name())
	}
}


func BenchmarkFullPipelineGray224(b *testing.B) {
	samples, sr, _ := utils.ReadWAVSamples(benchWAV)
	b.ResetTimer()
	b.ReportAllocs()
	for i := 0; i < b.N; i++ {
		segSamples := utils.ExtractSegmentSamples(samples, sr, 872, 895)
		outputSR := sr
		if sr > 16000 {
			segSamples = utils.ResampleRate(segSamples, sr, 16000)
			outputSR = 16000
		}
		cfg := utils.DefaultSpectrogramConfig(outputSR)
		spect := utils.GenerateSpectrogram(segSamples, cfg)
		img := utils.CreateGrayscaleImage(spect)
		resized := utils.ResizeImage(img, 224, 224)
		f, _ := os.CreateTemp("", "bench_*.png")
		utils.WritePNG(resized, f)
		f.Close()
		os.Remove(f.Name())
		utils.WriteWAVFile(f.Name(), segSamples, outputSR)
		os.Remove(f.Name())
		_ = resized
	}
}

func BenchmarkFullPipelineColor448(b *testing.B) {
	samples, sr, _ := utils.ReadWAVSamples(benchWAV)
	b.ResetTimer()
	b.ReportAllocs()
	for i := 0; i < b.N; i++ {
		segSamples := utils.ExtractSegmentSamples(samples, sr, 872, 895)
		outputSR := sr
		if sr > 16000 {
			segSamples = utils.ResampleRate(segSamples, sr, 16000)
			outputSR = 16000
		}
		cfg := utils.DefaultSpectrogramConfig(outputSR)
		spect := utils.GenerateSpectrogram(segSamples, cfg)
		colorData := utils.ApplyL4Colormap(spect)
		img := utils.CreateRGBImage(colorData)
		resized := utils.ResizeImage(img, 448, 448)
		f, _ := os.CreateTemp("", "bench_*.png")
		utils.WritePNG(resized, f)
		f.Close()
		os.Remove(f.Name())
		utils.WriteWAVFile(f.Name(), segSamples, outputSR)
		os.Remove(f.Name())
		_ = resized
	}
}


func TestPipelineDimensions(t *testing.T) {
	samples, sr, _ := utils.ReadWAVSamples(benchWAV)
	segSamples := utils.ExtractSegmentSamples(samples, sr, 872, 895)

	t.Logf("Input: %d samples, sr=%d, segment=%d samples (%.1fs)",
		len(samples), sr, len(segSamples), float64(len(segSamples))/float64(sr))

	cfg := utils.DefaultSpectrogramConfig(16000)
	numFrames := (len(segSamples)-cfg.WindowSize)/cfg.HopSize + 1
	numBins := cfg.WindowSize/2 + 1
	t.Logf("Spectrogram: %d freq bins x %d time frames = %d values",
		numBins, numFrames, numBins*numFrames)

	spect := utils.GenerateSpectrogram(segSamples, cfg)
	t.Logf("Output: %d x %d (freq x time)", len(spect), len(spect[0]))

	img := utils.CreateGrayscaleImage(spect)
	t.Logf("Grayscale image: %dx%d pixels, %d bytes",
		img.Bounds().Dx(), img.Bounds().Dy(), img.Bounds().Dx()*img.Bounds().Dy())

	resized := utils.ResizeImage(img, 224, 224)
	t.Logf("Resized 224: %dx%d", resized.Bounds().Dx(), resized.Bounds().Dy())

	resized448 := utils.ResizeImage(img, 448, 448)
	t.Logf("Resized 448: %dx%d", resized448.Bounds().Dx(), resized448.Bounds().Dy())
}