package tools
import (
"context"
"database/sql"
"encoding/csv"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"time"
gonanoid "github.com/matoous/go-nanoid/v2"
"github.com/modelcontextprotocol/go-sdk/mcp"
"skraak_mcp/db"
"skraak_mcp/utils"
)
// BulkFileImportInput defines the input parameters for the bulk_file_import tool
type BulkFileImportInput struct {
DatasetID string `json:"dataset_id" jsonschema:"required,Dataset ID (12 characters)"`
CSVPath string `json:"csv_path" jsonschema:"required,Absolute path to CSV file"`
LogFilePath string `json:"log_file_path" jsonschema:"required,Absolute path for progress log"`
}
// BulkFileImportOutput defines the output structure for the bulk_file_import tool
type BulkFileImportOutput struct {
TotalLocations int `json:"total_locations"`
ClustersCreated int `json:"clusters_created"`
ClustersExisting int `json:"clusters_existing"`
TotalFilesScanned int `json:"total_files_scanned"`
FilesImported int `json:"files_imported"`
FilesDuplicate int `json:"files_duplicate"`
FilesError int `json:"files_error"`
ProcessingTime string `json:"processing_time"`
Errors []string `json:"errors,omitempty"`
}
// bulkLocationData holds CSV row data for a location
type bulkLocationData struct {
LocationName string
LocationID string
DirectoryPath string
DateRange string
SampleRate int
FileCount int
}
// bulkImportStats tracks import statistics for a single cluster
type bulkImportStats struct {
TotalFiles int
ImportedFiles int
DuplicateFiles int
ErrorFiles int
}
// progressLogger handles writing to both log file and internal buffer
type progressLogger struct {
file *os.File
buffer *strings.Builder
}
// Log writes a formatted message with timestamp to both log file and buffer
func (l *progressLogger) Log(format string, args ...interface{}) {
timestamp := time.Now().Format("2006-01-02 15:04:05")
message := fmt.Sprintf(format, args...)
line := fmt.Sprintf("[%s] %s\n", timestamp, message)
// Write to file
l.file.WriteString(line)
l.file.Sync() // Ensure immediate write for tail monitoring
// Also keep in memory for potential error reporting
l.buffer.WriteString(line)
}
// BulkFileImport implements the bulk_file_import MCP tool
func BulkFileImport(
ctx context.Context,
req *mcp.CallToolRequest,
input BulkFileImportInput,
) (*mcp.CallToolResult, BulkFileImportOutput, error) {
startTime := time.Now()
var output BulkFileImportOutput
// Open log file
logFile, err := os.OpenFile(input.LogFilePath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
return nil, output, fmt.Errorf("failed to open log file: %w", err)
}
defer logFile.Close()
logger := &progressLogger{
file: logFile,
buffer: &strings.Builder{},
}
logger.Log("Starting bulk file import for dataset %s", input.DatasetID)
// Phase 0: Validate input
logger.Log("Validating input parameters...")
if err := bulkValidateInput(input); err != nil {
logger.Log("ERROR: Validation failed: %v", err)
output.Errors = []string{fmt.Sprintf("validation failed: %v", err)}
output.ProcessingTime = time.Since(startTime).String()
return nil, output, fmt.Errorf("validation failed: %w", err)
}
logger.Log("Validation complete")
// Phase 1: Read CSV
logger.Log("Reading CSV file: %s", input.CSVPath)
locations, err := bulkReadCSV(input.CSVPath)
if err != nil {
logger.Log("ERROR: Failed to read CSV: %v", err)
output.Errors = []string{fmt.Sprintf("failed to read CSV: %v", err)}
output.ProcessingTime = time.Since(startTime).String()
return nil, output, fmt.Errorf("failed to read CSV: %w", err)
}
logger.Log("Loaded %d locations from CSV", len(locations))
output.TotalLocations = len(locations)
// Phase 2: Create/Validate Clusters
logger.Log("=== Phase 1: Creating/Validating Clusters ===")
clusterIDMap := make(map[string]string) // locationID -> clusterID
database, err := db.OpenWriteableDB(dbPath)
if err != nil {
logger.Log("ERROR: Failed to open database: %v", err)
output.Errors = []string{fmt.Sprintf("failed to open database: %v", err)}
output.ProcessingTime = time.Since(startTime).String()
return nil, output, fmt.Errorf("failed to open database: %w", err)
}
defer database.Close()
for i, loc := range locations {
logger.Log("[%d/%d] Processing location: %s", i+1, len(locations), loc.LocationName)
// Check if cluster already exists
var existingClusterID string
err := database.QueryRow(`
SELECT id FROM cluster
WHERE location_id = ? AND name = ? AND active = true
`, loc.LocationID, loc.DateRange).Scan(&existingClusterID)
var clusterID string
if err == sql.ErrNoRows {
// Create cluster
clusterID, err = bulkCreateCluster(database, input.DatasetID, loc.LocationID, loc.DateRange, loc.SampleRate)
if err != nil {
errMsg := fmt.Sprintf("Failed to create cluster for location %s: %v", loc.LocationName, err)
logger.Log("ERROR: %s", errMsg)
output.Errors = append(output.Errors, errMsg)
output.ProcessingTime = time.Since(startTime).String()
return nil, output, fmt.Errorf("failed to create cluster: %w", err)
}
logger.Log(" Created cluster: %s", clusterID)
output.ClustersCreated++
} else if err != nil {
errMsg := fmt.Sprintf("Failed to check cluster for location %s: %v", loc.LocationName, err)
logger.Log("ERROR: %s", errMsg)
output.Errors = append(output.Errors, errMsg)
output.ProcessingTime = time.Since(startTime).String()
return nil, output, fmt.Errorf("failed to check cluster: %w", err)
} else {
clusterID = existingClusterID
logger.Log(" Using existing cluster: %s", clusterID)
output.ClustersExisting++
}
clusterIDMap[loc.LocationID] = clusterID
}
logger.Log("=== Phase 2: Importing Files ===")
totalImported := 0
totalDuplicates := 0
totalErrors := 0
totalScanned := 0
for i, loc := range locations {
clusterID, ok := clusterIDMap[loc.LocationID]
if !ok {
continue // Should not happen, but safety check
}
logger.Log("[%d/%d] Importing files for: %s", i+1, len(locations), loc.LocationName)
logger.Log(" Directory: %s", loc.DirectoryPath)
// Check if directory exists
if _, err := os.Stat(loc.DirectoryPath); os.IsNotExist(err) {
logger.Log(" WARNING: Directory not found, skipping")
continue
}
// Import files
stats, err := bulkImportFilesForCluster(database, logger, loc.DirectoryPath, input.DatasetID, loc.LocationID, clusterID)
if err != nil {
errMsg := fmt.Sprintf("Failed to import files for location %s: %v", loc.LocationName, err)
logger.Log("ERROR: %s", errMsg)
output.Errors = append(output.Errors, errMsg)
output.TotalFilesScanned = totalScanned
output.FilesImported = totalImported
output.FilesDuplicate = totalDuplicates
output.FilesError = totalErrors
output.ProcessingTime = time.Since(startTime).String()
return nil, output, fmt.Errorf("failed to import files: %w", err)
}
logger.Log(" Scanned: %d files", stats.TotalFiles)
logger.Log(" Imported: %d, Duplicates: %d", stats.ImportedFiles, stats.DuplicateFiles)
if stats.ErrorFiles > 0 {
logger.Log(" Errors: %d files", stats.ErrorFiles)
}
totalScanned += stats.TotalFiles
totalImported += stats.ImportedFiles
totalDuplicates += stats.DuplicateFiles
totalErrors += stats.ErrorFiles
}
logger.Log("=== Import Complete ===")
logger.Log("Total files scanned: %d", totalScanned)
logger.Log("Files imported: %d", totalImported)
logger.Log("Duplicates skipped: %d", totalDuplicates)
logger.Log("Errors: %d", totalErrors)
logger.Log("Processing time: %s", time.Since(startTime).Round(time.Second))
output.TotalFilesScanned = totalScanned
output.FilesImported = totalImported
output.FilesDuplicate = totalDuplicates
output.FilesError = totalErrors
output.ProcessingTime = time.Since(startTime).String()
return &mcp.CallToolResult{}, output, nil
}
// bulkValidateInput validates input parameters
func bulkValidateInput(input BulkFileImportInput) error {
// Verify CSV file exists
if _, err := os.Stat(input.CSVPath); err != nil {
return fmt.Errorf("CSV file not accessible: %w", err)
}
// Verify log file path is writable
logDir := filepath.Dir(input.LogFilePath)
if _, err := os.Stat(logDir); err != nil {
return fmt.Errorf("log file directory not accessible: %w", err)
}
// Open database for validation queries
database, err := db.OpenReadOnlyDB(dbPath)
if err != nil {
return fmt.Errorf("failed to open database: %w", err)
}
defer database.Close()
// Verify dataset exists
var datasetExists bool
err = database.QueryRow("SELECT EXISTS(SELECT 1 FROM dataset WHERE id = ? AND active = true)", input.DatasetID).Scan(&datasetExists)
if err != nil {
return fmt.Errorf("failed to query dataset: %w", err)
}
if !datasetExists {
return fmt.Errorf("dataset not found or inactive: %s", input.DatasetID)
}
return nil
}
// bulkReadCSV reads and parses the CSV file
func bulkReadCSV(path string) ([]bulkLocationData, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()
reader := csv.NewReader(file)
records, err := reader.ReadAll()
if err != nil {
return nil, err
}
if len(records) == 0 {
return nil, fmt.Errorf("CSV file is empty")
}
var locations []bulkLocationData
for i, record := range records {
if i == 0 {
continue // Skip header
}
if len(record) < 6 {
return nil, fmt.Errorf("CSV row %d has insufficient columns (expected 6, got %d)", i+1, len(record))
}
sampleRate, err := strconv.Atoi(record[4])
if err != nil {
return nil, fmt.Errorf("invalid sample_rate in row %d: %v", i+1, err)
}
fileCount, err := strconv.Atoi(record[5])
if err != nil {
return nil, fmt.Errorf("invalid file_count in row %d: %v", i+1, err)
}
locations = append(locations, bulkLocationData{
LocationName: record[0],
LocationID: record[1],
DirectoryPath: record[2],
DateRange: record[3],
SampleRate: sampleRate,
FileCount: fileCount,
})
}
return locations, nil
}
// bulkCreateCluster creates a new cluster in the database
func bulkCreateCluster(database *sql.DB, datasetID, locationID, name string, sampleRate int) (string, error) {
// Generate a 12-character nanoid
clusterID, err := gonanoid.New(12)
if err != nil {
return "", fmt.Errorf("failed to generate cluster ID: %v", err)
}
now := time.Now().UTC()
// Get location name for the path
var locationName string
err = database.QueryRow("SELECT name FROM location WHERE id = ?", locationID).Scan(&locationName)
if err != nil {
return "", fmt.Errorf("failed to get location name: %v", err)
}
// Normalize path: replace spaces and special characters
path := strings.ReplaceAll(locationName, " ", "_")
path = strings.ReplaceAll(path, "/", "_")
_, err = database.Exec(`
INSERT INTO cluster (id, dataset_id, location_id, name, path, sample_rate, active, created_at, last_modified)
VALUES (?, ?, ?, ?, ?, ?, true, ?, ?)
`, clusterID, datasetID, locationID, name, path, sampleRate, now, now)
if err != nil {
return "", fmt.Errorf("failed to insert cluster: %v", err)
}
return clusterID, nil
}
// bulkImportFilesForCluster imports all WAV files for a single cluster
func bulkImportFilesForCluster(database *sql.DB, logger *progressLogger, folderPath, datasetID, locationID, clusterID string) (*bulkImportStats, error) {
stats := &bulkImportStats{}
// Get location data
var latitude, longitude float64
var timezoneID string
err := database.QueryRow(`
SELECT latitude, longitude, timezone_id
FROM location
WHERE id = ?
`, locationID).Scan(&latitude, &longitude, &timezoneID)
if err != nil {
return nil, fmt.Errorf("failed to get location data: %v", err)
}
// Find all WAV files recursively
var wavFiles []string
err = filepath.Walk(folderPath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".wav") {
wavFiles = append(wavFiles, path)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to scan directory: %v", err)
}
stats.TotalFiles = len(wavFiles)
if stats.TotalFiles == 0 {
return stats, nil
}
// Process each file
for i, filePath := range wavFiles {
// Log progress periodically
if (i+1)%100 == 0 {
logger.Log(" Processing file %d/%d...", i+1, len(wavFiles))
}
err := bulkImportSingleFile(database, filePath, datasetID, locationID, clusterID, latitude, longitude, timezoneID)
if err != nil {
if strings.Contains(err.Error(), "duplicate") {
stats.DuplicateFiles++
} else {
stats.ErrorFiles++
// Log first few errors
if stats.ErrorFiles <= 5 {
logger.Log(" ERROR: %s: %v", filepath.Base(filePath), err)
}
}
} else {
stats.ImportedFiles++
}
}
return stats, nil
}
// bulkImportSingleFile imports a single WAV file into the database
func bulkImportSingleFile(database *sql.DB, filePath, datasetID, locationID, clusterID string, latitude, longitude float64, timezoneID string) error {
// Calculate hash
hash, err := utils.ComputeXXH64(filePath)
if err != nil {
return fmt.Errorf("hash calculation failed: %v", err)
}
// Check for duplicate
var existingID string
err = database.QueryRow("SELECT id FROM file WHERE xxh64_hash = ? AND active = true", hash).Scan(&existingID)
if err == nil {
return fmt.Errorf("duplicate") // File already exists
} else if err != sql.ErrNoRows {
return fmt.Errorf("duplicate check failed: %v", err)
}
// Extract WAV metadata
wavMeta, err := utils.ParseWAVHeader(filePath)
if err != nil {
return fmt.Errorf("WAV metadata extraction failed: %v", err)
}
// Try to parse AudioMoth comment first, fall back to filename parsing
var timestamp time.Time
if utils.IsAudioMoth(wavMeta.Comment, wavMeta.Artist) {
mothData, err := utils.ParseAudioMothComment(wavMeta.Comment)
if err == nil && mothData != nil {
timestamp = mothData.Timestamp
}
}
// Fall back to filename parsing if no AudioMoth timestamp
if timestamp.IsZero() {
results, err := utils.ParseFilenameTimestamps([]string{filepath.Base(filePath)})
if err != nil || len(results) == 0 {
return fmt.Errorf("timestamp parsing failed: %v", err)
}
localTimes, err := utils.ApplyTimezoneOffset(results, timezoneID)
if err != nil {
return fmt.Errorf("timezone application failed: %v", err)
}
timestamp = localTimes[0]
}
// Calculate astronomical data
astro := utils.CalculateAstronomicalData(timestamp.UTC(), wavMeta.Duration, latitude, longitude)
// Generate file ID
fileID, err := gonanoid.Generate("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", 21)
if err != nil {
return fmt.Errorf("ID generation failed: %v", err)
}
// Insert into database
now := time.Now().UTC()
_, err = database.Exec(`
INSERT INTO file (
id, location_id, cluster_id,
file_name, xxh64_hash, duration, sample_rate,
timestamp_local, maybe_solar_night, maybe_civil_night,
moon_phase,
active, created_at, last_modified
) VALUES (
?, ?, ?,
?, ?, ?, ?,
?, ?, ?,
?,
true, ?, ?
)
`, fileID, locationID, clusterID,
filepath.Base(filePath), hash, wavMeta.Duration, wavMeta.SampleRate,
timestamp, astro.SolarNight, astro.CivilNight,
astro.MoonPhase,
now, now)
if err != nil {
return fmt.Errorf("database insert failed: %v", err)
}
return nil
}