package utils

import (
	"bytes"
	"encoding/binary"
	"fmt"
	"io"
	"os"
	"sync"
	"time"

	"github.com/cespare/xxhash/v2"
)

// Buffer pools for reducing GC pressure during batch imports
var (
	// headerBufferPool stores 200KB buffers for WAV header reading (full metadata)
	headerBufferPool = sync.Pool{
		New: func() any {
			buf := make([]byte, 200*1024)
			return &buf
		},
	}

	// minimalHeaderBufferPool stores 4KB buffers for minimal WAV header reading
	// 4KB is sufficient for fmt + data chunk headers in 99% of WAV files
	minimalHeaderBufferPool = sync.Pool{
		New: func() any {
			buf := make([]byte, 4*1024)
			return &buf
		},
	}
)

// getHeaderBuffer gets a 200KB buffer from the pool
func getHeaderBuffer() *[]byte {
	return headerBufferPool.Get().(*[]byte)
}

// putHeaderBuffer returns a 200KB buffer to the pool
func putHeaderBuffer(buf *[]byte) {
	headerBufferPool.Put(buf)
}

// getMinimalHeaderBuffer gets a 4KB buffer from the pool
func getMinimalHeaderBuffer() *[]byte {
	return minimalHeaderBufferPool.Get().(*[]byte)
}

// putMinimalHeaderBuffer returns a 4KB buffer to the pool
func putMinimalHeaderBuffer(buf *[]byte) {
	minimalHeaderBufferPool.Put(buf)
}

// WAVMetadata contains metadata extracted from WAV file headers
type WAVMetadata struct {
	Duration      float64   // Duration in seconds
	SampleRate    int       // Sample rate in Hz
	Comment       string    // Comment from INFO chunk (may contain AudioMoth data)
	Artist        string    // Artist from INFO chunk
	Channels      int       // Number of audio channels
	BitsPerSample int       // Bits per sample
	FileModTime   time.Time // File modification time (fallback timestamp)
	FileSize      int64     // File size in bytes
}

// readAndParseHeader opens a WAV file, reads its header using the provided buffer pool,
// parses metadata, and sets file modification time and size.
func readAndParseHeader(filepath string, getBuf func() *[]byte, putBuf func(*[]byte)) (*WAVMetadata, error) {
	file, err := os.Open(filepath)
	if err != nil {
		return nil, fmt.Errorf("failed to open file: %w", err)
	}
	defer func() { _ = file.Close() }()

	fileInfo, err := file.Stat()
	if err != nil {
		return nil, fmt.Errorf("failed to get file info: %w", err)
	}

	bufPtr := getBuf()
	defer putBuf(bufPtr)
	buf := (*bufPtr)[:cap(*bufPtr)]

	n, err := file.Read(buf)
	if err != nil && err != io.EOF {
		return nil, fmt.Errorf("failed to read header: %w", err)
	}
	buf = buf[:n]

	metadata, err := parseWAVFromBytes(buf)
	if err != nil {
		return nil, err
	}

	metadata.FileModTime = fileInfo.ModTime()
	metadata.FileSize = fileInfo.Size()
	return metadata, nil
}

// ParseWAVHeader efficiently reads only the WAV file header to extract metadata.
// It reads the first 200KB of the file, which should be sufficient for all header chunks.
func ParseWAVHeader(filepath string) (*WAVMetadata, error) {
	return readAndParseHeader(filepath, getHeaderBuffer, putHeaderBuffer)
}

// ParseWAVHeaderMinimal reads only the first 4KB of a WAV file to extract essential metadata.
// This is optimized for batch processing where INFO chunks (comment/artist) are not needed.
// It's ~50x faster than ParseWAVHeader for large files due to reduced I/O.
// Returns (sampleRate, duration, error) - the minimal data needed for .data file generation.
func ParseWAVHeaderMinimal(filepath string) (sampleRate int, duration float64, err error) {
	metadata, err := readAndParseHeader(filepath, getMinimalHeaderBuffer, putMinimalHeaderBuffer)
	if err != nil {
		return 0, 0, err
	}
	return metadata.SampleRate, metadata.Duration, nil
}

// ParseWAVHeaderWithHash reads the WAV file once to extract both metadata and hash.
// This is more efficient than calling ParseWAVHeader and ComputeXXH64 separately,
// as it only opens the file once and reads it in a single pass.
// Returns (metadata, hash, error).
func ParseWAVHeaderWithHash(filepath string) (*WAVMetadata, string, error) {
	// Use readAndParseHeader for the header portion, but we need the file handle
	// for hashing, so we can't fully delegate.
	file, err := os.Open(filepath)
	if err != nil {
		return nil, "", fmt.Errorf("failed to open file: %w", err)
	}
	defer func() { _ = file.Close() }()

	fileInfo, err := file.Stat()
	if err != nil {
		return nil, "", fmt.Errorf("failed to get file info: %w", err)
	}

	headerBufPtr := getHeaderBuffer()
	defer putHeaderBuffer(headerBufPtr)
	headerBuf := (*headerBufPtr)[:cap(*headerBufPtr)]

	n, err := file.Read(headerBuf)
	if err != nil && err != io.EOF {
		return nil, "", fmt.Errorf("failed to read header: %w", err)
	}
	headerBuf = headerBuf[:n]

	metadata, err := parseWAVFromBytes(headerBuf)
	if err != nil {
		return nil, "", err
	}
	metadata.FileModTime = fileInfo.ModTime()
	metadata.FileSize = fileInfo.Size()

	// Hash: seek back to start and stream entire file
	if _, err := file.Seek(0, 0); err != nil {
		return nil, "", fmt.Errorf("failed to seek: %w", err)
	}

	hashBufPtr := getHashBuffer()
	defer putHashBuffer(hashBufPtr)
	hashBuf := *hashBufPtr

	h := xxhash.New()
	if _, err := io.CopyBuffer(h, file, hashBuf); err != nil {
		return nil, "", fmt.Errorf("failed to read file for hash: %w", err)
	}

	hash := fmt.Sprintf("%016x", h.Sum64())
	return metadata, hash, nil
}

// parseWAVFromBytes parses WAV metadata from a byte buffer
func parseWAVFromBytes(data []byte) (*WAVMetadata, error) {
	if len(data) < 44 {
		return nil, fmt.Errorf("file too small to be valid WAV")
	}
	if string(data[0:4]) != "RIFF" {
		return nil, fmt.Errorf("not a valid WAV file (missing RIFF header)")
	}
	if string(data[8:12]) != "WAVE" {
		return nil, fmt.Errorf("not a valid WAV file (missing WAVE format)")
	}

	metadata := &WAVMetadata{}
	offset := 12
	for offset < len(data)-8 {
		chunkID := string(data[offset : offset+4])
		chunkSize := int(binary.LittleEndian.Uint32(data[offset+4 : offset+8]))
		offset += 8

		switch chunkID {
		case "fmt ":
			parseFmtChunkData(data[offset:], chunkSize, metadata)
		case "data":
			calcDataChunkDuration(chunkSize, metadata)
		case "LIST":
			parseLISTChunkData(data[offset:], chunkSize, metadata)
		}

		offset += chunkSize
		if chunkSize%2 != 0 {
			offset++
		}
	}

	if metadata.SampleRate == 0 {
		return nil, fmt.Errorf("invalid WAV file: missing or corrupt fmt chunk")
	}
	if metadata.Duration == 0 {
		return nil, fmt.Errorf("invalid WAV file: missing or corrupt data chunk")
	}
	return metadata, nil
}

// parseFmtChunkData extracts format info from a fmt chunk.
func parseFmtChunkData(data []byte, chunkSize int, m *WAVMetadata) {
	if chunkSize >= 16 && len(data) >= 16 {
		m.Channels = int(binary.LittleEndian.Uint16(data[2:4]))
		m.SampleRate = int(binary.LittleEndian.Uint32(data[4:8]))
		m.BitsPerSample = int(binary.LittleEndian.Uint16(data[14:16]))
	}
}

// calcDataChunkDuration computes duration from the data chunk size.
func calcDataChunkDuration(chunkSize int, m *WAVMetadata) {
	if m.SampleRate > 0 && m.Channels > 0 && m.BitsPerSample > 0 {
		bytesPerSample := m.BitsPerSample / 8
		bytesPerSecond := m.SampleRate * m.Channels * bytesPerSample
		if bytesPerSecond > 0 {
			m.Duration = float64(chunkSize) / float64(bytesPerSecond)
		}
	}
}

// parseLISTChunkData parses a LIST chunk for INFO metadata.
func parseLISTChunkData(data []byte, chunkSize int, m *WAVMetadata) {
	if chunkSize >= 4 && len(data) >= chunkSize {
		if string(data[:4]) == "INFO" {
			parseINFOChunk(data[4:chunkSize], m)
		}
	}
}

// parseINFOChunk parses INFO list chunk for comment and artist metadata
func parseINFOChunk(data []byte, metadata *WAVMetadata) {
	offset := 0
	for offset < len(data)-8 {
		// Read subchunk ID and size
		if offset+8 > len(data) {
			break
		}

		subchunkID := string(data[offset : offset+4])
		subchunkSize := int(binary.LittleEndian.Uint32(data[offset+4 : offset+8]))
		offset += 8

		if offset+subchunkSize > len(data) {
			break
		}

		// Extract null-terminated string
		value := extractNullTerminatedString(data[offset : offset+subchunkSize])

		switch subchunkID {
		case "ICMT": // Comment
			metadata.Comment = value
		case "IART": // Artist
			metadata.Artist = value
		}

		// Move to next subchunk (word-aligned)
		offset += subchunkSize
		if subchunkSize%2 != 0 {
			offset++ // Skip padding byte
		}
	}
}

// extractNullTerminatedString extracts a null-terminated string from bytes
func extractNullTerminatedString(data []byte) string {
	before, _, ok := bytes.Cut(data, []byte{0})
	if ok {
		return string(before)
	}
	return string(data)
}

// wavChunkInfo holds parsed WAV format and data chunk locations.
type wavChunkInfo struct {
	sampleRate    int
	channels      int
	bitsPerSample int
	dataOffset    int64
	dataSize      int64
}

// parseWAVChunks reads WAV chunks from the current file position, returning
// format info and data chunk location. Returns error if no data chunk is found.
func parseWAVChunks(file *os.File) (wavChunkInfo, error) {
	var info wavChunkInfo
	for {
		chunkHeader := make([]byte, 8)
		if _, err := io.ReadFull(file, chunkHeader); err != nil {
			if err == io.EOF {
				break
			}
			return info, fmt.Errorf("failed to read chunk header: %w", err)
		}

		chunkID := string(chunkHeader[0:4])
		chunkSize := int64(binary.LittleEndian.Uint32(chunkHeader[4:8]))

		switch chunkID {
		case "fmt ":
			fmtData := make([]byte, chunkSize)
			if _, err := io.ReadFull(file, fmtData); err != nil {
				return info, fmt.Errorf("failed to read fmt chunk: %w", err)
			}
			if len(fmtData) >= 16 {
				info.channels = int(binary.LittleEndian.Uint16(fmtData[2:4]))
				info.sampleRate = int(binary.LittleEndian.Uint32(fmtData[4:8]))
				info.bitsPerSample = int(binary.LittleEndian.Uint16(fmtData[14:16]))
			}
		case "data":
			info.dataOffset, _ = file.Seek(0, io.SeekCurrent)
			info.dataSize = chunkSize
			return info, nil
		default:
			if _, err := file.Seek(chunkSize, io.SeekCurrent); err != nil {
				return info, fmt.Errorf("failed to skip chunk: %w", err)
			}
		}

		// Word align
		if chunkSize%2 != 0 {
			if _, err := file.Seek(1, io.SeekCurrent); err != nil {
				return info, fmt.Errorf("failed to skip padding: %w", err)
			}
		}
	}
	return info, fmt.Errorf("no data chunk found in WAV file")
}

// calcWAVReadRange computes the byte offset and size to read from the data chunk.
func calcWAVReadRange(startSec, endSec float64, info wavChunkInfo) (startOffset, readSize int64) {
	bytesPerSample := info.bitsPerSample / 8
	blockAlign := bytesPerSample * info.channels

	if startSec > 0 {
		startSample := int64(startSec * float64(info.sampleRate))
		startOffset = min(startSample*int64(blockAlign), info.dataSize)
	}

	if endSec > 0 {
		endSample := int64(endSec * float64(info.sampleRate))
		endOffset := min(endSample*int64(blockAlign), info.dataSize)
		if endOffset > startOffset {
			readSize = endOffset - startOffset
		}
	} else {
		readSize = info.dataSize - startOffset
	}
	return
}

// parseWAVInfo opens a WAV file, validates its header, and parses chunks.
// Returns the parsed chunk info and the open file (caller must close).
func parseWAVInfo(filepath string) (f *os.File, info wavChunkInfo, err error) {
	f, err = os.Open(filepath)
	if err != nil {
		return nil, wavChunkInfo{}, fmt.Errorf("failed to open file: %w", err)
	}
	defer func() {
		if err != nil {
			_ = f.Close()
		}
	}()

	headerBuf := make([]byte, 44)
	if _, err = io.ReadFull(f, headerBuf); err != nil {
		return nil, wavChunkInfo{}, fmt.Errorf("failed to read header: %w", err)
	}
	if string(headerBuf[0:4]) != "RIFF" || string(headerBuf[8:12]) != "WAVE" {
		return nil, wavChunkInfo{}, fmt.Errorf("not a valid WAV file")
	}

	if _, err = f.Seek(12, 0); err != nil {
		return nil, wavChunkInfo{}, fmt.Errorf("failed to seek: %w", err)
	}

	info, err = parseWAVChunks(f)
	if err != nil {
		return nil, wavChunkInfo{}, err
	}
	if info.sampleRate == 0 || info.channels == 0 || info.bitsPerSample == 0 {
		return nil, wavChunkInfo{}, fmt.Errorf("missing or invalid fmt chunk")
	}

	return f, info, nil
}

// readAudioSegment reads audio bytes from an already-parsed WAV file.
func readAudioSegment(file *os.File, info wavChunkInfo, startOffset, readSize int64) ([]byte, error) {
	if readSize == 0 {
		return nil, nil
	}

	if _, err := file.Seek(info.dataOffset+startOffset, io.SeekStart); err != nil {
		return nil, fmt.Errorf("failed to seek to data segment: %w", err)
	}

	audioData := make([]byte, readSize)
	if _, err := io.ReadFull(file, audioData); err != nil {
		if err != io.EOF && err != io.ErrUnexpectedEOF {
			return nil, fmt.Errorf("failed to read audio data: %w", err)
		}
	}
	return audioData, nil
}

// ReadWAVSegmentSamples reads a specific time range of audio samples from a WAV file.
// If startSec < 0, it starts from 0.
// If endSec <= 0 or endSec > duration, it reads to the end.
func ReadWAVSegmentSamples(filepath string, startSec, endSec float64) ([]float64, int, error) {
	file, info, err := parseWAVInfo(filepath)
	if err != nil {
		return nil, 0, err
	}
	defer func() { _ = file.Close() }()

	startOffset, readSize := calcWAVReadRange(startSec, endSec, info)

	audioData, err := readAudioSegment(file, info, startOffset, readSize)
	if err != nil {
		return nil, 0, err
	}
	if readSize == 0 {
		return []float64{}, info.sampleRate, nil
	}

	samples := convertToFloat64(audioData, info.bitsPerSample, info.channels)
	return samples, info.sampleRate, nil
}

// ReadWAVSamples reads audio samples from a WAV file and returns them as float64.
// Mono files: returns single channel.
// Stereo files: returns left channel only.
// Samples are normalized to the range -1.0 to 1.0.
func ReadWAVSamples(filepath string) ([]float64, int, error) {
	return ReadWAVSegmentSamples(filepath, 0, 0)
}

// convertToFloat64 converts raw audio bytes to float64 samples
// Returns mono (left channel only for stereo)
func convertToFloat64(data []byte, bitsPerSample, channels int) []float64 {
	bytesPerSample := bitsPerSample / 8
	blockAlign := bytesPerSample * channels
	numSamples := len(data) / blockAlign

	samples := make([]float64, numSamples)

	switch bitsPerSample {
	case 16:
		for i := range numSamples {
			// Read first (left) channel only for stereo
			offset := i * blockAlign
			sample := int16(binary.LittleEndian.Uint16(data[offset : offset+2]))
			samples[i] = float64(sample) / 32768.0
		}

	case 24:
		for i := range numSamples {
			offset := i * blockAlign
			// 24-bit signed, little-endian
			b := data[offset : offset+3]
			sample := int32(b[0]) | int32(b[1])<<8 | int32(b[2])<<16
			// Sign extend
			if sample >= 0x800000 {
				sample -= 0x1000000
			}
			samples[i] = float64(sample) / 8388608.0
		}

	case 32:
		for i := range numSamples {
			offset := i * blockAlign
			sample := int32(binary.LittleEndian.Uint32(data[offset : offset+4]))
			samples[i] = float64(sample) / 2147483648.0
		}

	default:
		// Fallback: treat as 16-bit
		for i := range numSamples {
			offset := i * blockAlign
			sample := int16(binary.LittleEndian.Uint16(data[offset : offset+2]))
			samples[i] = float64(sample) / 32768.0
		}
	}

	return samples
}