package tools
import (
"os"
"path/filepath"
"testing"
"skraak/utils"
)
func TestCallsFromPreds_EmptyFilterError(t *testing.T) {
// Create a temp CSV file
tmpDir := t.TempDir()
csvPath := filepath.Join(tmpDir, "preds.csv")
csvContent := "file,start_time,end_time,kiwi\n./test.wav,0.0,3.0,1\n"
if err := os.WriteFile(csvPath, []byte(csvContent), 0644); err != nil {
t.Fatal(err)
}
// Create a dummy WAV file (minimal valid WAV)
wavPath := filepath.Join(tmpDir, "test.wav")
createMinimalWAV(t, wavPath, 44100, 10.0)
// Test with empty filter (should error)
input := CallsFromPredsInput{
CSVPath: csvPath,
Filter: "",
WriteDotData: true,
ProgressHandler: nil,
}
output, err := CallsFromPreds(input)
// Should return error
if err == nil {
t.Error("expected error for empty filter, got nil")
}
if output.Error == nil || *output.Error == "" {
t.Error("expected error message in output, got empty")
}
}
func TestCallsFromPreds_NewDataFile(t *testing.T) {
// Create a temp CSV file
tmpDir := t.TempDir()
csvPath := filepath.Join(tmpDir, "predsST_test-filter_2025-01-01.csv")
csvContent := "file,start_time,end_time,kiwi\n./test.wav,0.0,3.0,1\n"
if err := os.WriteFile(csvPath, []byte(csvContent), 0644); err != nil {
t.Fatal(err)
}
// Create a dummy WAV file
wavPath := filepath.Join(tmpDir, "test.wav")
createMinimalWAV(t, wavPath, 44100, 10.0)
// Test with filter parsed from filename
input := CallsFromPredsInput{
CSVPath: csvPath,
Filter: "", // Will parse from filename
WriteDotData: true,
ProgressHandler: nil,
}
output, err := CallsFromPreds(input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if output.DataFilesWritten != 1 {
t.Errorf("expected 1 data file written, got %d", output.DataFilesWritten)
}
if output.Filter != "test-filter" {
t.Errorf("expected filter 'test-filter', got '%s'", output.Filter)
}
// Verify .data file was created
dataPath := wavPath + ".data"
if _, err := os.Stat(dataPath); os.IsNotExist(err) {
t.Error("expected .data file to be created")
}
// Verify content
df, err := utils.ParseDataFile(dataPath)
if err != nil {
t.Fatalf("failed to parse .data file: %v", err)
}
if len(df.Segments) != 1 {
t.Errorf("expected 1 segment, got %d", len(df.Segments))
}
if len(df.Segments[0].Labels) != 1 {
t.Errorf("expected 1 label, got %d", len(df.Segments[0].Labels))
}
if df.Segments[0].Labels[0].Filter != "test-filter" {
t.Errorf("expected filter 'test-filter', got '%s'", df.Segments[0].Labels[0].Filter)
}
}
func TestCallsFromPreds_ExistingDataFileSameFilter(t *testing.T) {
// Create a temp CSV file
tmpDir := t.TempDir()
csvPath := filepath.Join(tmpDir, "predsST_existing-filter_2025-01-01.csv")
csvContent := "file,start_time,end_time,kiwi\n./test.wav,0.0,3.0,1\n"
if err := os.WriteFile(csvPath, []byte(csvContent), 0644); err != nil {
t.Fatal(err)
}
// Create a dummy WAV file
wavPath := filepath.Join(tmpDir, "test.wav")
createMinimalWAV(t, wavPath, 44100, 10.0)
// Create existing .data file with same filter
dataPath := wavPath + ".data"
existingData := `[
{"Operator": "Manual", "Reviewer": "David", "Duration": 10.0},
[5.0, 8.0, 0, 44100, [{"species": "morepork", "certainty": 90, "filter": "existing-filter"}]]
]`
if err := os.WriteFile(dataPath, []byte(existingData), 0644); err != nil {
t.Fatal(err)
}
// Test with same filter (should error)
input := CallsFromPredsInput{
CSVPath: csvPath,
Filter: "", // Will parse from filename -> "existing-filter"
WriteDotData: true,
ProgressHandler: nil,
}
output, err := CallsFromPreds(input)
// Should return error
if err == nil {
t.Error("expected error for same filter, got nil")
}
if output.Error == nil {
t.Error("expected error message in output")
}
// Verify original .data file is unchanged
df, err := utils.ParseDataFile(dataPath)
if err != nil {
t.Fatalf("failed to parse .data file: %v", err)
}
if len(df.Segments) != 1 {
t.Errorf("expected original 1 segment, got %d", len(df.Segments))
}
if df.Segments[0].Labels[0].Species != "morepork" {
t.Errorf("expected original species 'morepork', got '%s'", df.Segments[0].Labels[0].Species)
}
}
func TestCallsFromPreds_ExistingDataFileDifferentFilter(t *testing.T) {
// Create a temp CSV file
tmpDir := t.TempDir()
csvPath := filepath.Join(tmpDir, "predsST_new-filter_2025-01-01.csv")
csvContent := "file,start_time,end_time,kiwi\n./test.wav,0.0,3.0,1\n"
if err := os.WriteFile(csvPath, []byte(csvContent), 0644); err != nil {
t.Fatal(err)
}
// Create a dummy WAV file
wavPath := filepath.Join(tmpDir, "test.wav")
createMinimalWAV(t, wavPath, 44100, 10.0)
// Create existing .data file with different filter
dataPath := wavPath + ".data"
existingData := `[
{"Operator": "Manual", "Reviewer": "David", "Duration": 10.0},
[5.0, 8.0, 0, 44100, [{"species": "morepork", "certainty": 90, "filter": "old-filter"}]]
]`
if err := os.WriteFile(dataPath, []byte(existingData), 0644); err != nil {
t.Fatal(err)
}
// Test with different filter (should merge)
input := CallsFromPredsInput{
CSVPath: csvPath,
Filter: "", // Will parse from filename -> "new-filter"
WriteDotData: true,
ProgressHandler: nil,
}
output, err := CallsFromPreds(input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if output.DataFilesWritten != 1 {
t.Errorf("expected 1 data file written, got %d", output.DataFilesWritten)
}
// Verify .data file has merged content
df, err := utils.ParseDataFile(dataPath)
if err != nil {
t.Fatalf("failed to parse .data file: %v", err)
}
if len(df.Segments) != 2 {
t.Errorf("expected 2 segments after merge, got %d", len(df.Segments))
}
// Check segments are sorted by start time
if df.Segments[0].StartTime > df.Segments[1].StartTime {
t.Error("expected segments to be sorted by start time")
}
// Check both filters are present
filters := make(map[string]bool)
for _, seg := range df.Segments {
for _, label := range seg.Labels {
filters[label.Filter] = true
}
}
if !filters["old-filter"] {
t.Error("expected 'old-filter' to be present")
}
if !filters["new-filter"] {
t.Error("expected 'new-filter' to be present")
}
}
func TestCallsFromPreds_ExistingDataFileParseError(t *testing.T) {
// Create a temp CSV file
tmpDir := t.TempDir()
csvPath := filepath.Join(tmpDir, "predsST_test-filter_2025-01-01.csv")
csvContent := "file,start_time,end_time,kiwi\n./test.wav,0.0,3.0,1\n"
if err := os.WriteFile(csvPath, []byte(csvContent), 0644); err != nil {
t.Fatal(err)
}
// Create a dummy WAV file
wavPath := filepath.Join(tmpDir, "test.wav")
createMinimalWAV(t, wavPath, 44100, 10.0)
// Create corrupted .data file
dataPath := wavPath + ".data"
corruptedData := `this is not valid json`
if err := os.WriteFile(dataPath, []byte(corruptedData), 0644); err != nil {
t.Fatal(err)
}
// Test (should error due to parse failure)
input := CallsFromPredsInput{
CSVPath: csvPath,
Filter: "",
WriteDotData: true,
ProgressHandler: nil,
}
output, err := CallsFromPreds(input)
// Should return error
if err == nil {
t.Error("expected error for corrupted .data file, got nil")
}
if output.Error == nil {
t.Error("expected error message in output")
}
// Verify original file is unchanged
content, err := os.ReadFile(dataPath)
if err != nil {
t.Fatal(err)
}
if string(content) != corruptedData {
t.Error("expected corrupted file to remain unchanged")
}
}
func TestCallsFromPreds_ExplicitFilter(t *testing.T) {
// Create a temp CSV file with non-standard name
tmpDir := t.TempDir()
csvPath := filepath.Join(tmpDir, "predictions.csv")
csvContent := "file,start_time,end_time,kiwi\n./test.wav,0.0,3.0,1\n"
if err := os.WriteFile(csvPath, []byte(csvContent), 0644); err != nil {
t.Fatal(err)
}
// Create a dummy WAV file
wavPath := filepath.Join(tmpDir, "test.wav")
createMinimalWAV(t, wavPath, 44100, 10.0)
// Test with explicit filter
input := CallsFromPredsInput{
CSVPath: csvPath,
Filter: "my-custom-filter",
WriteDotData: true,
ProgressHandler: nil,
}
output, err := CallsFromPreds(input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if output.Filter != "my-custom-filter" {
t.Errorf("expected filter 'my-custom-filter', got '%s'", output.Filter)
}
// Verify .data file uses explicit filter
dataPath := wavPath + ".data"
df, err := utils.ParseDataFile(dataPath)
if err != nil {
t.Fatalf("failed to parse .data file: %v", err)
}
if df.Segments[0].Labels[0].Filter != "my-custom-filter" {
t.Errorf("expected filter 'my-custom-filter' in .data file, got '%s'", df.Segments[0].Labels[0].Filter)
}
}
func TestCallsFromPreds_NonParsableFilenameNoFilter(t *testing.T) {
// Create a temp CSV file with non-standard name that can't be parsed
tmpDir := t.TempDir()
csvPath := filepath.Join(tmpDir, "random_name.csv")
csvContent := "file,start_time,end_time,kiwi\n./test.wav,0.0,3.0,1\n"
if err := os.WriteFile(csvPath, []byte(csvContent), 0644); err != nil {
t.Fatal(err)
}
// Create a dummy WAV file
wavPath := filepath.Join(tmpDir, "test.wav")
createMinimalWAV(t, wavPath, 44100, 10.0)
// Test with no filter and non-parsable filename (should error)
input := CallsFromPredsInput{
CSVPath: csvPath,
Filter: "",
WriteDotData: true,
ProgressHandler: nil,
}
output, err := CallsFromPreds(input)
// Should return error
if err == nil {
t.Error("expected error for unparsable filename with no filter, got nil")
}
if output.Error == nil {
t.Error("expected error message in output")
}
}
// createMinimalWAV creates a minimal valid WAV file for testing
func createMinimalWAV(t *testing.T, path string, sampleRate int, duration float64) {
t.Helper()
numSamples := int(float64(sampleRate) * duration)
dataSize := numSamples * 2 // 16-bit mono
// WAV header (44 bytes)
header := make([]byte, 44)
// RIFF header
copy(header[0:4], "RIFF")
totalSize := uint32(36 + dataSize)
header[4] = byte(totalSize)
header[5] = byte(totalSize >> 8)
header[6] = byte(totalSize >> 16)
header[7] = byte(totalSize >> 24)
copy(header[8:12], "WAVE")
// fmt chunk
copy(header[12:16], "fmt ")
chunkSize := uint32(16)
header[16] = byte(chunkSize)
header[17] = byte(chunkSize >> 8)
header[18] = byte(chunkSize >> 16)
header[19] = byte(chunkSize >> 24)
audioFormat := uint16(1) // PCM
header[20] = byte(audioFormat)
header[21] = byte(audioFormat >> 8)
numChannels := uint16(1)
header[22] = byte(numChannels)
header[23] = byte(numChannels >> 8)
header[24] = byte(sampleRate)
header[25] = byte(sampleRate >> 8)
header[26] = byte(sampleRate >> 16)
header[27] = byte(sampleRate >> 24)
byteRate := uint32(sampleRate * 2)
header[28] = byte(byteRate)
header[29] = byte(byteRate >> 8)
header[30] = byte(byteRate >> 16)
header[31] = byte(byteRate >> 24)
blockAlign := uint16(2)
header[32] = byte(blockAlign)
header[33] = byte(blockAlign >> 8)
bitsPerSample := uint16(16)
header[34] = byte(bitsPerSample)
header[35] = byte(bitsPerSample >> 8)
// data chunk
copy(header[36:40], "data")
header[40] = byte(dataSize)
header[41] = byte(dataSize >> 8)
header[42] = byte(dataSize >> 16)
header[43] = byte(dataSize >> 24)
// Create file with header and silence
file, err := os.Create(path)
if err != nil {
t.Fatal(err)
}
defer file.Close()
if _, err := file.Write(header); err != nil {
t.Fatal(err)
}
// Write silence (zeros)
silence := make([]byte, dataSize)
if _, err := file.Write(silence); err != nil {
t.Fatal(err)
}
}