package tools
import (
"encoding/csv"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"skraak/utils"
)
const (
CLUSTER_GAP_MULTIPLIER = 2 MIN_DETECTIONS_PER_CLUSTER = 0 DEFAULT_CERTAINTY = 70 DOT_DATA_WORKERS = 8 )
type ClusteredCall struct {
File string `json:"file"`
StartTime float64 `json:"start_time"`
EndTime float64 `json:"end_time"`
EbirdCode string `json:"ebird_code"`
Segments int `json:"segments"`
}
type CallsFromPredsInput struct {
CSVPath string `json:"csv_path"`
Filter string `json:"filter"`
WriteDotData bool `json:"write_dot_data"`
GapMultiplier int `json:"gap_multiplier"`
MinDetections int `json:"min_detections"`
ProgressHandler ProgressHandler `json:"-"` }
type ProgressHandler func(processed, total int, message string)
type CallsFromPredsOutput struct {
Calls []ClusteredCall `json:"calls"`
TotalCalls int `json:"total_calls"`
ClipDuration float64 `json:"clip_duration"`
GapThreshold float64 `json:"gap_threshold"`
SpeciesCount map[string]int `json:"species_count"`
DataFilesWritten int `json:"data_files_written"`
DataFilesSkipped int `json:"data_files_skipped"`
Filter string `json:"filter"`
Error *string `json:"error,omitempty"`
}
type AviaNZMeta struct {
Operator string `json:"Operator"`
Reviewer *string `json:"Reviewer,omitempty"`
Duration float64 `json:"Duration"`
}
type AviaNZLabel struct {
Species string `json:"species"`
Certainty int `json:"certainty"`
Filter string `json:"filter"`
}
type AviaNZSegment [5]any
type predFileSpeciesKey struct {
File string
EbirdCode string
}
func CallsFromPreds(input CallsFromPredsInput) (CallsFromPredsOutput, error) {
var output CallsFromPredsOutput
filter := input.Filter
if filter == "" {
filter = ParseFilterFromFilename(input.CSVPath)
}
if filter == "" {
errMsg := "Filter must be specified via --filter flag or parsable from CSV filename"
output.Error = &errMsg
return output, fmt.Errorf("%s", errMsg)
}
output.Filter = filter
_, detections, clipDuration, err := readPredCSV(input.CSVPath)
if err != nil {
errMsg := err.Error()
output.Error = &errMsg
return output, err
}
output.ClipDuration = clipDuration
gapMultiplier := CLUSTER_GAP_MULTIPLIER
if input.GapMultiplier > 0 {
gapMultiplier = input.GapMultiplier
}
minDetections := MIN_DETECTIONS_PER_CLUSTER
if input.MinDetections >= 0 {
minDetections = input.MinDetections
}
gapThreshold := float64(gapMultiplier) * clipDuration
output.GapThreshold = gapThreshold
allCalls, speciesCount := clusterDetections(detections, clipDuration, gapThreshold, minDetections)
output.Calls = allCalls
output.TotalCalls = len(allCalls)
output.SpeciesCount = speciesCount
if input.WriteDotData {
dataFilesWritten, dataFilesSkipped, err := writeDotFiles(input.CSVPath, filter, allCalls, input.ProgressHandler)
if err != nil {
errMsg := fmt.Sprintf("Error writing .data files: %v", err)
output.Error = &errMsg
return output, fmt.Errorf("%s", errMsg)
}
output.DataFilesWritten = dataFilesWritten
output.DataFilesSkipped = dataFilesSkipped
}
return output, nil
}
func readPredCSV(csvPath string) (predCSVColumns, map[predFileSpeciesKey][]float64, float64, error) {
file, err := os.Open(csvPath)
if err != nil {
return predCSVColumns{}, nil, 0, fmt.Errorf("failed to open CSV file: %w", err)
}
defer func() { _ = file.Close() }()
reader := csv.NewReader(file)
reader.ReuseRecord = true
header, err := reader.Read()
if err != nil {
return predCSVColumns{}, nil, 0, fmt.Errorf("failed to read CSV header: %w", err)
}
cols, err := findPredCSVColumns(header)
if err != nil {
return predCSVColumns{}, nil, 0, err
}
detections, clipDuration, err := readPredCSVRows(reader, cols)
if err != nil {
return predCSVColumns{}, nil, 0, err
}
return cols, detections, clipDuration, nil
}
type predCSVColumns struct {
fileIdx int
startTimeIdx int
endTimeIdx int
ebirdCodes []string
ebirdIdx []int
}
func findPredCSVColumns(header []string) (predCSVColumns, error) {
cols := predCSVColumns{
fileIdx: -1,
startTimeIdx: -1,
endTimeIdx: -1,
}
ignoredColumns := map[string]bool{"NotKiwi": true, "0.0": true}
for i, col := range header {
switch col {
case "file":
cols.fileIdx = i
case "start_time":
cols.startTimeIdx = i
case "end_time":
cols.endTimeIdx = i
default:
if ignoredColumns[col] {
continue
}
cols.ebirdCodes = append(cols.ebirdCodes, col)
cols.ebirdIdx = append(cols.ebirdIdx, i)
}
}
if cols.fileIdx == -1 || cols.startTimeIdx == -1 || cols.endTimeIdx == -1 {
return cols, fmt.Errorf("CSV must have 'file', 'start_time', and 'end_time' columns")
}
if len(cols.ebirdCodes) == 0 {
return cols, fmt.Errorf("CSV must have at least one ebird code column")
}
return cols, nil
}
func readPredCSVRows(reader *csv.Reader, cols predCSVColumns) (map[predFileSpeciesKey][]float64, float64, error) {
detections := make(map[predFileSpeciesKey][]float64)
clipDuration := 0.0
record, err := reader.Read()
if err == io.EOF {
return detections, 0, nil
}
if err != nil {
return nil, 0, fmt.Errorf("failed to read first CSV row: %w", err)
}
startTime, _ := strconv.ParseFloat(record[cols.startTimeIdx], 64)
endTime, _ := strconv.ParseFloat(record[cols.endTimeIdx], 64)
clipDuration = endTime - startTime
addDetectionsFromRow(record, cols, startTime, detections)
for {
record, err := reader.Read()
if err == io.EOF {
break
}
if err != nil {
return nil, 0, fmt.Errorf("failed to read CSV row: %w", err)
}
startTime, _ = strconv.ParseFloat(record[cols.startTimeIdx], 64)
addDetectionsFromRow(record, cols, startTime, detections)
}
return detections, clipDuration, nil
}
func addDetectionsFromRow(record []string, cols predCSVColumns, startTime float64, detections map[predFileSpeciesKey][]float64) {
fileName := record[cols.fileIdx]
for i, idx := range cols.ebirdIdx {
if record[idx] == "1" {
key := predFileSpeciesKey{File: fileName, EbirdCode: cols.ebirdCodes[i]}
detections[key] = append(detections[key], startTime)
}
}
}
func clusterDetections(detections map[predFileSpeciesKey][]float64, clipDuration, gapThreshold float64, minDetections int) ([]ClusteredCall, map[string]int) {
var allCalls []ClusteredCall
speciesCount := make(map[string]int)
for key, startTimes := range detections {
sort.Float64s(startTimes)
clusters := clusterStartTimes(startTimes, gapThreshold)
for _, cluster := range clusters {
if len(cluster) <= minDetections {
continue
}
call := ClusteredCall{
File: key.File,
StartTime: cluster[0],
EndTime: cluster[len(cluster)-1] + clipDuration,
EbirdCode: key.EbirdCode,
Segments: len(cluster),
}
allCalls = append(allCalls, call)
speciesCount[key.EbirdCode]++
}
}
sort.Slice(allCalls, func(i, j int) bool {
if allCalls[i].File != allCalls[j].File {
return allCalls[i].File < allCalls[j].File
}
return allCalls[i].StartTime < allCalls[j].StartTime
})
return allCalls, speciesCount
}
type DirCache struct {
dir string
wavMap map[string]string dirMap map[string]string }
func NewDirCache(dir string) *DirCache {
entries, err := os.ReadDir(dir)
if err != nil {
return &DirCache{dir: dir, wavMap: make(map[string]string), dirMap: make(map[string]string)}
}
wavMap := make(map[string]string, len(entries))
dirMap := make(map[string]string, len(entries))
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
ext := filepath.Ext(name)
base := strings.TrimSuffix(name, ext)
dirMap[strings.ToLower(base)] = name
if strings.EqualFold(ext, ".wav") {
wavMap[strings.ToLower(base)] = name
}
}
return &DirCache{dir: dir, wavMap: wavMap, dirMap: dirMap}
}
func (dc *DirCache) FindWAV(baseName string) string {
if name, ok := dc.wavMap[strings.ToLower(baseName)]; ok {
return filepath.Join(dc.dir, name)
}
return ""
}
func (dc *DirCache) FindFile(baseName string) string {
if name, ok := dc.dirMap[strings.ToLower(baseName)]; ok {
return filepath.Join(dc.dir, name)
}
return ""
}
func findWAVFile(dir, baseName string) string {
entries, err := os.ReadDir(dir)
if err != nil {
return ""
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
ext := filepath.Ext(name)
nameNoExt := strings.TrimSuffix(name, ext)
if nameNoExt == baseName && strings.EqualFold(ext, ".wav") {
return filepath.Join(dir, name)
}
}
return ""
}
func writeDotFiles(csvPath, filter string, calls []ClusteredCall, progress ProgressHandler) (int, int, error) {
csvDir := filepath.Dir(csvPath)
callsByFile := make(map[string][]ClusteredCall)
for _, call := range calls {
filename := filepath.Base(call.File)
callsByFile[filename] = append(callsByFile[filename], call)
}
if progress != nil {
progress(0, len(callsByFile), "Processing WAV files")
}
if len(callsByFile) < 10 {
return writeDotFilesSequential(csvDir, filter, callsByFile, progress)
}
return writeDotFilesParallel(csvDir, filter, callsByFile, progress)
}
type dotDataJob struct {
filename string
fileCalls []ClusteredCall
}
type dotDataResult struct {
filename string
written bool
err error
}
func writeDotFilesSequential(csvDir, filter string, callsByFile map[string][]ClusteredCall, progress ProgressHandler) (int, int, error) {
dataFilesWritten := 0
dataFilesSkipped := 0
total := len(callsByFile)
processed := 0
for filename, fileCalls := range callsByFile {
baseName := strings.TrimSuffix(filename, filepath.Ext(filename))
wavPath := findWAVFile(csvDir, baseName)
if wavPath == "" {
dataFilesSkipped++
processed++
if progress != nil {
progress(processed, total, "")
}
continue
}
dataPath := wavPath + ".data"
sampleRate, duration, err := utils.ParseWAVHeaderMinimal(wavPath)
if err != nil {
dataFilesSkipped++
processed++
if progress != nil {
progress(processed, total, "")
}
continue
}
meta, segments := buildAviaNZMetaAndSegments(fileCalls, filter, duration, sampleRate)
if err := writeDotDataFileSafe(dataPath, segments, filter, meta); err != nil {
return dataFilesWritten, dataFilesSkipped, fmt.Errorf("failed to write %s: %w", dataPath, err)
}
dataFilesWritten++
processed++
if progress != nil {
progress(processed, total, "")
}
}
return dataFilesWritten, dataFilesSkipped, nil
}
func writeDotFilesParallel(csvDir, filter string, callsByFile map[string][]ClusteredCall, progress ProgressHandler) (int, int, error) {
total := len(callsByFile)
var processed atomic.Int32
jobs := make(chan dotDataJob, len(callsByFile))
results := make(chan dotDataResult, len(callsByFile))
var wg sync.WaitGroup
for range DOT_DATA_WORKERS {
wg.Add(1)
go dotDataWorker(csvDir, filter, jobs, results, &wg)
}
for filename, fileCalls := range callsByFile {
jobs <- dotDataJob{filename: filename, fileCalls: fileCalls}
}
close(jobs)
go func() {
wg.Wait()
close(results)
}()
dataFilesWritten := 0
dataFilesSkipped := 0
var firstErr error
for result := range results {
if result.err != nil && firstErr == nil {
firstErr = result.err
}
if result.written {
dataFilesWritten++
} else {
dataFilesSkipped++
}
if progress != nil {
current := int(processed.Add(1))
progress(current, total, "")
}
}
return dataFilesWritten, dataFilesSkipped, firstErr
}
func dotDataWorker(csvDir, filter string, jobs <-chan dotDataJob, results chan<- dotDataResult, wg *sync.WaitGroup) {
defer wg.Done()
for job := range jobs {
baseName := strings.TrimSuffix(job.filename, filepath.Ext(job.filename))
wavPath := findWAVFile(csvDir, baseName)
if wavPath == "" {
results <- dotDataResult{filename: job.filename, written: false, err: nil}
continue
}
dataPath := wavPath + ".data"
sampleRate, duration, err := utils.ParseWAVHeaderMinimal(wavPath)
if err != nil {
results <- dotDataResult{filename: job.filename, written: false, err: nil}
continue
}
meta, segments := buildAviaNZMetaAndSegments(job.fileCalls, filter, duration, sampleRate)
if err := writeDotDataFileSafe(dataPath, segments, filter, meta); err != nil {
results <- dotDataResult{filename: job.filename, written: false, err: fmt.Errorf("failed to write %s: %w", dataPath, err)}
continue
}
results <- dotDataResult{filename: job.filename, written: true, err: nil}
}
}
func buildAviaNZMetaAndSegments(calls []ClusteredCall, filter string, duration float64, sampleRate int) (AviaNZMeta, []AviaNZSegment) {
reviewer := "None"
meta := AviaNZMeta{
Operator: "Auto",
Reviewer: &reviewer,
Duration: duration,
}
var segments []AviaNZSegment
for _, call := range calls {
labels := []AviaNZLabel{
{
Species: call.EbirdCode,
Certainty: DEFAULT_CERTAINTY,
Filter: filter,
},
}
segment := AviaNZSegment{
call.StartTime,
call.EndTime,
0, sampleRate, labels,
}
segments = append(segments, segment)
}
return meta, segments
}
func writeAviaNZDataFile(path string, data []any) error {
file, err := os.Create(path)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer func() { _ = file.Close() }()
encoder := json.NewEncoder(file)
encoder.SetIndent("", "")
if err := encoder.Encode(data); err != nil {
return fmt.Errorf("failed to encode JSON: %w", err)
}
return nil
}
func writeDotDataFileSafe(path string, newSegments []AviaNZSegment, filter string, meta AviaNZMeta) error {
if _, err := os.Stat(path); err == nil {
existing, err := utils.ParseDataFile(path)
if err != nil {
return fmt.Errorf("cannot parse existing %s: %w (refusing to clobber)", path, err)
}
for _, seg := range existing.Segments {
if seg.HasFilterLabel(filter) {
return fmt.Errorf("%s already contains filter '%s' (refusing to clobber)", path, filter)
}
}
for _, newSeg := range newSegments {
seg := convertAviaNZSegment(newSeg, filter)
existing.Segments = append(existing.Segments, seg)
}
sort.Slice(existing.Segments, func(i, j int) bool {
return existing.Segments[i].StartTime < existing.Segments[j].StartTime
})
return existing.Write(path)
}
data := buildDataFileFromSegments(meta, newSegments)
return writeAviaNZDataFile(path, data)
}
func convertAviaNZSegment(seg AviaNZSegment, filter string) *utils.Segment {
labels := seg[4].([]AviaNZLabel)
utilsLabels := make([]*utils.Label, len(labels))
for i, l := range labels {
utilsLabels[i] = &utils.Label{
Species: l.Species,
Certainty: l.Certainty,
Filter: filter,
}
}
var freqLow, freqHigh float64
switch v := seg[2].(type) {
case int:
freqLow = float64(v)
case float64:
freqLow = v
}
switch v := seg[3].(type) {
case int:
freqHigh = float64(v)
case float64:
freqHigh = v
}
return &utils.Segment{
StartTime: seg[0].(float64),
EndTime: seg[1].(float64),
FreqLow: freqLow,
FreqHigh: freqHigh,
Labels: utilsLabels,
}
}
func buildDataFileFromSegments(meta AviaNZMeta, segments []AviaNZSegment) []any {
result := make([]any, 0, 1+len(segments))
result = append(result, meta)
for _, seg := range segments {
result = append(result, seg)
}
return result
}
func ParseFilterFromFilename(csvPath string) string {
filename := filepath.Base(csvPath)
name := strings.TrimSuffix(filename, ".csv")
parts := strings.Split(name, "_")
if len(parts) == 3 {
return parts[1]
}
return ""
}
func clusterStartTimes(startTimes []float64, gapThreshold float64) [][]float64 {
if len(startTimes) == 0 {
return nil
}
var clusters [][]float64
currentCluster := []float64{startTimes[0]}
for i := 1; i < len(startTimes); i++ {
gap := startTimes[i] - startTimes[i-1]
if gap <= gapThreshold {
currentCluster = append(currentCluster, startTimes[i])
} else {
clusters = append(clusters, currentCluster)
currentCluster = []float64{startTimes[i]}
}
}
clusters = append(clusters, currentCluster)
return clusters
}