package utils

import (
	"image"
	"math"
	"strings"
	"sync"

	"github.com/madelynnblue/go-dsp/window"
)

// cached Hann windows by size, computed once
var (
	hannCache   = map[int][]float64{}
	hannCacheMu sync.RWMutex
)

// getCachedHannWindow returns a cached Hann window of the given size.
func getCachedHannWindow(size int) []float64 {
	hannCacheMu.RLock()
	if w, ok := hannCache[size]; ok {
		hannCacheMu.RUnlock()
		return w
	}
	hannCacheMu.RUnlock()

	hannCacheMu.Lock()
	defer hannCacheMu.Unlock()
	// Double-check after acquiring write lock
	if w, ok := hannCache[size]; ok {
		return w
	}
	w := window.Hann(size)
	hannCache[size] = w
	return w
}

// DefaultMaxSampleRate is the maximum sample rate for spectrograms.
// Higher sample rates are downsampled to this rate for better visualization.
const DefaultMaxSampleRate = 16000

// SpectrogramConfig holds STFT parameters
type SpectrogramConfig struct {
	WindowSize int // FFT window size (e.g., 400)
	HopSize    int // Hop between windows (e.g., 200 for 50% overlap)
	SampleRate int // Sample rate in Hz
}

// DefaultSpectrogramConfig returns default config matching Julia implementation
func DefaultSpectrogramConfig(sampleRate int) SpectrogramConfig {
	return SpectrogramConfig{
		WindowSize: 512,
		HopSize:    256, // 50% overlap (window/2)
		SampleRate: sampleRate,
	}
}

// GenerateSpectrogram generates a spectrogram from audio samples.
// Returns a 2D array of uint8 (0-255) where:
// - First dimension is frequency bins (rows)
// - Second dimension is time frames (columns)
func GenerateSpectrogram(samples []float64, cfg SpectrogramConfig) [][]uint8 {
	if len(samples) < cfg.WindowSize {
		return nil
	}

	// Get cached Hann window
	hannWindow := getCachedHannWindow(cfg.WindowSize)

	// Calculate number of frames
	numFrames := (len(samples)-cfg.WindowSize)/cfg.HopSize + 1
	if numFrames <= 0 {
		return nil
	}

	// Number of frequency bins (half of FFT due to symmetry)
	numFreqBins := cfg.WindowSize/2 + 1

	// Allocate power spectrum as flat backing slice (single allocation)
	powerFlat := make([]float64, numFreqBins*numFrames)

	// Pre-allocate scratch buffers (reused across all frames — zero allocs in loop)
	frameData := make([]float64, cfg.WindowSize)
	scratch := make([]complex128, cfg.WindowSize)
	framePower := make([]float64, numFreqBins)

	// Perform STFT
	for frame := range numFrames {
		start := frame * cfg.HopSize

		// Extract and window the frame
		for i := 0; i < cfg.WindowSize; i++ {
			frameData[i] = samples[start+i] * hannWindow[i]
		}

		// Compute power spectrum via inline FFT (zero allocations)
		PowerSpectrumFFT(frameData, framePower, scratch)

		// Copy power into flat matrix (freq bins x time frames layout)
		for bin := range numFreqBins {
			powerFlat[bin*numFrames+frame] = framePower[bin]
		}
	}

	// Fused normalization: replace zeros, convert to dB, find min/max, normalize to uint8
	// All in 2 passes instead of 6
	return normalizeFlat(powerFlat, numFreqBins, numFrames)
}

// normalizeFlat converts power values to dB, normalizes to 0-255, in 2 passes.
// Operates on a flat slice laid out as [row0_col0, row0_col1, ..., row1_col0, ...].
// Returns [][]uint8 with rows flipped vertically (low frequencies at bottom).
func normalizeFlat(power []float64, rows, cols int) [][]uint8 {
	if rows == 0 || cols == 0 {
		return nil
	}

	// Pass 1: find minNonZero, then convert power to dB in-place, tracking min/max dB
	minNonZero := math.MaxFloat64
	for _, val := range power {
		if val > 0 && val < minNonZero {
			minNonZero = val
		}
	}
	if minNonZero == math.MaxFloat64 {
		minNonZero = 1e-20 // fallback floor
	}

	minDB := math.MaxFloat64
	maxDB := -math.MaxFloat64
	for i, val := range power {
		if val <= 0 {
			val = minNonZero
		}
		db := 10.0 * math.Log10(val)
		power[i] = db
		if db < minDB {
			minDB = db
		}
		if db > maxDB {
			maxDB = db
		}
	}

	// Pass 2: normalize dB to uint8 and write into result (with vertical flip)
	rangeDB := maxDB - minDB
	if rangeDB == 0 {
		rangeDB = 1
	}
	scale := 255.0 / rangeDB

	// Allocate result with flat backing slice (single allocation)
	resultFlat := make([]uint8, rows*cols)
	result := make([][]uint8, rows)
	for i := range result {
		// Flip: row i in result gets data from row (rows-1-i) in power
		srcRow := rows - 1 - i
		result[i] = resultFlat[i*cols : (i+1)*cols]
		srcOff := srcRow * cols
		for j := range cols {
			result[i][j] = uint8((power[srcOff+j] - minDB) * scale)
		}
	}

	return result
}

// ExtractSegmentSamples extracts samples from a time range
func ExtractSegmentSamples(samples []float64, sampleRate int, startSec, endSec float64) []float64 {
	startIdx := int(startSec * float64(sampleRate))
	endIdx := int(endSec * float64(sampleRate))

	if startIdx < 0 {
		startIdx = 0
	}
	if endIdx > len(samples) {
		endIdx = len(samples)
	}
	if startIdx >= endIdx {
		return nil
	}

	return samples[startIdx:endIdx]
}

// GenerateSegmentSpectrogram generates a spectrogram image for a time segment.
// Handles WAV loading, downsampling, and image creation.
// color=true applies L4 colormap, color=false creates grayscale.
// imgSize specifies the output image dimensions (clamped to [224, 896]).
func GenerateSegmentSpectrogram(dataFilePath string, startTime, endTime float64, color bool, imgSize int) (image.Image, error) {
	// Derive WAV file path (strip .data suffix)
	wavPath := strings.TrimSuffix(dataFilePath, ".data")

	// Read only the requested segment's samples from the WAV file
	segSamples, sampleRate, err := ReadWAVSegmentSamples(wavPath, startTime, endTime)
	if err != nil {
		return nil, err
	}

	if len(segSamples) == 0 {
		return nil, nil
	}

	// For spectrograms, downsample if sample rate exceeds 16kHz
	spectSampleRate := sampleRate
	if sampleRate > DefaultMaxSampleRate {
		segSamples = ResampleRate(segSamples, sampleRate, DefaultMaxSampleRate)
		spectSampleRate = DefaultMaxSampleRate
	}

	// Generate spectrogram
	config := DefaultSpectrogramConfig(spectSampleRate)
	spectrogram := GenerateSpectrogram(segSamples, config)
	if spectrogram == nil {
		return nil, nil
	}

	// Create image (grayscale or color)
	var img image.Image
	if color {
		colorData := ApplyL4Colormap(spectrogram)
		img = CreateRGBImage(colorData)
	} else {
		img = CreateGrayscaleImage(spectrogram)
	}
	if img == nil {
		return nil, nil
	}

	// Resize
	imgSize = ClampImageSize(imgSize)
	return ResizeImage(img, imgSize, imgSize), nil
}