package utils
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"os"
"sync"
"time"
"github.com/cespare/xxhash/v2"
)
var (
headerBufferPool = sync.Pool{
New: func() any {
buf := make([]byte, 200*1024)
return &buf
},
}
minimalHeaderBufferPool = sync.Pool{
New: func() any {
buf := make([]byte, 4*1024)
return &buf
},
}
)
func getHeaderBuffer() *[]byte {
return headerBufferPool.Get().(*[]byte)
}
func putHeaderBuffer(buf *[]byte) {
headerBufferPool.Put(buf)
}
func getMinimalHeaderBuffer() *[]byte {
return minimalHeaderBufferPool.Get().(*[]byte)
}
func putMinimalHeaderBuffer(buf *[]byte) {
minimalHeaderBufferPool.Put(buf)
}
type WAVMetadata struct {
Duration float64 SampleRate int Comment string Artist string Channels int BitsPerSample int FileModTime time.Time FileSize int64 }
func readAndParseHeader(filepath string, getBuf func() *[]byte, putBuf func(*[]byte)) (*WAVMetadata, error) {
file, err := os.Open(filepath)
if err != nil {
return nil, fmt.Errorf("failed to open file: %w", err)
}
defer func() { _ = file.Close() }()
fileInfo, err := file.Stat()
if err != nil {
return nil, fmt.Errorf("failed to get file info: %w", err)
}
bufPtr := getBuf()
defer putBuf(bufPtr)
buf := (*bufPtr)[:cap(*bufPtr)]
n, err := file.Read(buf)
if err != nil && err != io.EOF {
return nil, fmt.Errorf("failed to read header: %w", err)
}
buf = buf[:n]
metadata, err := parseWAVFromBytes(buf)
if err != nil {
return nil, err
}
metadata.FileModTime = fileInfo.ModTime()
metadata.FileSize = fileInfo.Size()
return metadata, nil
}
func ParseWAVHeader(filepath string) (*WAVMetadata, error) {
return readAndParseHeader(filepath, getHeaderBuffer, putHeaderBuffer)
}
func ParseWAVHeaderMinimal(filepath string) (sampleRate int, duration float64, err error) {
metadata, err := readAndParseHeader(filepath, getMinimalHeaderBuffer, putMinimalHeaderBuffer)
if err != nil {
return 0, 0, err
}
return metadata.SampleRate, metadata.Duration, nil
}
func ParseWAVHeaderWithHash(filepath string) (*WAVMetadata, string, error) {
file, err := os.Open(filepath)
if err != nil {
return nil, "", fmt.Errorf("failed to open file: %w", err)
}
defer func() { _ = file.Close() }()
fileInfo, err := file.Stat()
if err != nil {
return nil, "", fmt.Errorf("failed to get file info: %w", err)
}
headerBufPtr := getHeaderBuffer()
defer putHeaderBuffer(headerBufPtr)
headerBuf := (*headerBufPtr)[:cap(*headerBufPtr)]
n, err := file.Read(headerBuf)
if err != nil && err != io.EOF {
return nil, "", fmt.Errorf("failed to read header: %w", err)
}
headerBuf = headerBuf[:n]
metadata, err := parseWAVFromBytes(headerBuf)
if err != nil {
return nil, "", err
}
metadata.FileModTime = fileInfo.ModTime()
metadata.FileSize = fileInfo.Size()
if _, err := file.Seek(0, 0); err != nil {
return nil, "", fmt.Errorf("failed to seek: %w", err)
}
hashBufPtr := getHashBuffer()
defer putHashBuffer(hashBufPtr)
hashBuf := *hashBufPtr
h := xxhash.New()
if _, err := io.CopyBuffer(h, file, hashBuf); err != nil {
return nil, "", fmt.Errorf("failed to read file for hash: %w", err)
}
hash := fmt.Sprintf("%016x", h.Sum64())
return metadata, hash, nil
}
func parseWAVFromBytes(data []byte) (*WAVMetadata, error) {
if len(data) < 44 {
return nil, fmt.Errorf("file too small to be valid WAV")
}
if string(data[0:4]) != "RIFF" {
return nil, fmt.Errorf("not a valid WAV file (missing RIFF header)")
}
if string(data[8:12]) != "WAVE" {
return nil, fmt.Errorf("not a valid WAV file (missing WAVE format)")
}
metadata := &WAVMetadata{}
offset := 12
for offset < len(data)-8 {
chunkID := string(data[offset : offset+4])
chunkSize := int(binary.LittleEndian.Uint32(data[offset+4 : offset+8]))
offset += 8
switch chunkID {
case "fmt ":
parseFmtChunkData(data[offset:], chunkSize, metadata)
case "data":
calcDataChunkDuration(chunkSize, metadata)
case "LIST":
parseLISTChunkData(data[offset:], chunkSize, metadata)
}
offset += chunkSize
if chunkSize%2 != 0 {
offset++
}
}
if metadata.SampleRate == 0 {
return nil, fmt.Errorf("invalid WAV file: missing or corrupt fmt chunk")
}
if metadata.Duration == 0 {
return nil, fmt.Errorf("invalid WAV file: missing or corrupt data chunk")
}
return metadata, nil
}
func parseFmtChunkData(data []byte, chunkSize int, m *WAVMetadata) {
if chunkSize >= 16 && len(data) >= 16 {
m.Channels = int(binary.LittleEndian.Uint16(data[2:4]))
m.SampleRate = int(binary.LittleEndian.Uint32(data[4:8]))
m.BitsPerSample = int(binary.LittleEndian.Uint16(data[14:16]))
}
}
func calcDataChunkDuration(chunkSize int, m *WAVMetadata) {
if m.SampleRate > 0 && m.Channels > 0 && m.BitsPerSample > 0 {
bytesPerSample := m.BitsPerSample / 8
bytesPerSecond := m.SampleRate * m.Channels * bytesPerSample
if bytesPerSecond > 0 {
m.Duration = float64(chunkSize) / float64(bytesPerSecond)
}
}
}
func parseLISTChunkData(data []byte, chunkSize int, m *WAVMetadata) {
if chunkSize >= 4 && len(data) >= chunkSize {
if string(data[:4]) == "INFO" {
parseINFOChunk(data[4:chunkSize], m)
}
}
}
func parseINFOChunk(data []byte, metadata *WAVMetadata) {
offset := 0
for offset < len(data)-8 {
if offset+8 > len(data) {
break
}
subchunkID := string(data[offset : offset+4])
subchunkSize := int(binary.LittleEndian.Uint32(data[offset+4 : offset+8]))
offset += 8
if offset+subchunkSize > len(data) {
break
}
value := extractNullTerminatedString(data[offset : offset+subchunkSize])
switch subchunkID {
case "ICMT": metadata.Comment = value
case "IART": metadata.Artist = value
}
offset += subchunkSize
if subchunkSize%2 != 0 {
offset++ }
}
}
func extractNullTerminatedString(data []byte) string {
before, _, ok := bytes.Cut(data, []byte{0})
if ok {
return string(before)
}
return string(data)
}
type wavChunkInfo struct {
sampleRate int
channels int
bitsPerSample int
dataOffset int64
dataSize int64
}
func parseWAVChunks(file *os.File) (wavChunkInfo, error) {
var info wavChunkInfo
for {
chunkHeader := make([]byte, 8)
if _, err := io.ReadFull(file, chunkHeader); err != nil {
if err == io.EOF {
break
}
return info, fmt.Errorf("failed to read chunk header: %w", err)
}
chunkID := string(chunkHeader[0:4])
chunkSize := int64(binary.LittleEndian.Uint32(chunkHeader[4:8]))
switch chunkID {
case "fmt ":
fmtData := make([]byte, chunkSize)
if _, err := io.ReadFull(file, fmtData); err != nil {
return info, fmt.Errorf("failed to read fmt chunk: %w", err)
}
if len(fmtData) >= 16 {
info.channels = int(binary.LittleEndian.Uint16(fmtData[2:4]))
info.sampleRate = int(binary.LittleEndian.Uint32(fmtData[4:8]))
info.bitsPerSample = int(binary.LittleEndian.Uint16(fmtData[14:16]))
}
case "data":
info.dataOffset, _ = file.Seek(0, io.SeekCurrent)
info.dataSize = chunkSize
return info, nil
default:
if _, err := file.Seek(chunkSize, io.SeekCurrent); err != nil {
return info, fmt.Errorf("failed to skip chunk: %w", err)
}
}
if chunkSize%2 != 0 {
if _, err := file.Seek(1, io.SeekCurrent); err != nil {
return info, fmt.Errorf("failed to skip padding: %w", err)
}
}
}
return info, fmt.Errorf("no data chunk found in WAV file")
}
func calcWAVReadRange(startSec, endSec float64, info wavChunkInfo) (startOffset, readSize int64) {
bytesPerSample := info.bitsPerSample / 8
blockAlign := bytesPerSample * info.channels
if startSec > 0 {
startSample := int64(startSec * float64(info.sampleRate))
startOffset = min(startSample*int64(blockAlign), info.dataSize)
}
if endSec > 0 {
endSample := int64(endSec * float64(info.sampleRate))
endOffset := min(endSample*int64(blockAlign), info.dataSize)
if endOffset > startOffset {
readSize = endOffset - startOffset
}
} else {
readSize = info.dataSize - startOffset
}
return
}
func parseWAVInfo(filepath string) (f *os.File, info wavChunkInfo, err error) {
f, err = os.Open(filepath)
if err != nil {
return nil, wavChunkInfo{}, fmt.Errorf("failed to open file: %w", err)
}
defer func() {
if err != nil {
_ = f.Close()
}
}()
headerBuf := make([]byte, 44)
if _, err = io.ReadFull(f, headerBuf); err != nil {
return nil, wavChunkInfo{}, fmt.Errorf("failed to read header: %w", err)
}
if string(headerBuf[0:4]) != "RIFF" || string(headerBuf[8:12]) != "WAVE" {
return nil, wavChunkInfo{}, fmt.Errorf("not a valid WAV file")
}
if _, err = f.Seek(12, 0); err != nil {
return nil, wavChunkInfo{}, fmt.Errorf("failed to seek: %w", err)
}
info, err = parseWAVChunks(f)
if err != nil {
return nil, wavChunkInfo{}, err
}
if info.sampleRate == 0 || info.channels == 0 || info.bitsPerSample == 0 {
return nil, wavChunkInfo{}, fmt.Errorf("missing or invalid fmt chunk")
}
return f, info, nil
}
func readAudioSegment(file *os.File, info wavChunkInfo, startOffset, readSize int64) ([]byte, error) {
if readSize == 0 {
return nil, nil
}
if _, err := file.Seek(info.dataOffset+startOffset, io.SeekStart); err != nil {
return nil, fmt.Errorf("failed to seek to data segment: %w", err)
}
audioData := make([]byte, readSize)
if _, err := io.ReadFull(file, audioData); err != nil {
if err != io.EOF && err != io.ErrUnexpectedEOF {
return nil, fmt.Errorf("failed to read audio data: %w", err)
}
}
return audioData, nil
}
func ReadWAVSegmentSamples(filepath string, startSec, endSec float64) ([]float64, int, error) {
file, info, err := parseWAVInfo(filepath)
if err != nil {
return nil, 0, err
}
defer func() { _ = file.Close() }()
startOffset, readSize := calcWAVReadRange(startSec, endSec, info)
audioData, err := readAudioSegment(file, info, startOffset, readSize)
if err != nil {
return nil, 0, err
}
if readSize == 0 {
return []float64{}, info.sampleRate, nil
}
samples := convertToFloat64(audioData, info.bitsPerSample, info.channels)
return samples, info.sampleRate, nil
}
func ReadWAVSamples(filepath string) ([]float64, int, error) {
return ReadWAVSegmentSamples(filepath, 0, 0)
}
func convertToFloat64(data []byte, bitsPerSample, channels int) []float64 {
bytesPerSample := bitsPerSample / 8
blockAlign := bytesPerSample * channels
numSamples := len(data) / blockAlign
samples := make([]float64, numSamples)
switch bitsPerSample {
case 16:
for i := range numSamples {
offset := i * blockAlign
sample := int16(binary.LittleEndian.Uint16(data[offset : offset+2]))
samples[i] = float64(sample) / 32768.0
}
case 24:
for i := range numSamples {
offset := i * blockAlign
b := data[offset : offset+3]
sample := int32(b[0]) | int32(b[1])<<8 | int32(b[2])<<16
if sample >= 0x800000 {
sample -= 0x1000000
}
samples[i] = float64(sample) / 8388608.0
}
case 32:
for i := range numSamples {
offset := i * blockAlign
sample := int32(binary.LittleEndian.Uint32(data[offset : offset+4]))
samples[i] = float64(sample) / 2147483648.0
}
default:
for i := range numSamples {
offset := i * blockAlign
sample := int16(binary.LittleEndian.Uint16(data[offset : offset+2]))
samples[i] = float64(sample) / 32768.0
}
}
return samples
}