package tools
import (
"context"
"database/sql"
"fmt"
"os"
"time"
"github.com/modelcontextprotocol/go-sdk/mcp"
"skraak_mcp/db"
"skraak_mcp/utils"
)
type ImportAudioFilesInput struct {
FolderPath string `json:"folder_path" jsonschema:"required,Absolute path to folder containing WAV files"`
DatasetID string `json:"dataset_id" jsonschema:"required,Dataset ID (12 characters)"`
LocationID string `json:"location_id" jsonschema:"required,Location ID (12 characters)"`
ClusterID string `json:"cluster_id" jsonschema:"required,Cluster ID (12 characters)"`
Recursive *bool `json:"recursive,omitempty" jsonschema:"Scan subfolders recursively (default: true)"`
}
type ImportAudioFilesOutput struct {
Summary ImportSummary `json:"summary" jsonschema:"Import summary with counts and statistics"`
FileIDs []string `json:"file_ids" jsonschema:"List of successfully imported file IDs"`
Errors []utils.FileImportError `json:"errors,omitempty" jsonschema:"Errors encountered during import (if any)"`
}
type ImportSummary struct {
TotalFiles int `json:"total_files"`
ImportedFiles int `json:"imported_files"`
SkippedFiles int `json:"skipped_files"` FailedFiles int `json:"failed_files"`
AudioMothFiles int `json:"audiomoth_files"`
TotalDuration float64 `json:"total_duration_seconds"`
ProcessingTime string `json:"processing_time"`
}
func ImportAudioFiles(
ctx context.Context,
req *mcp.CallToolRequest,
input ImportAudioFilesInput,
) (*mcp.CallToolResult, ImportAudioFilesOutput, error) {
startTime := time.Now()
var output ImportAudioFilesOutput
recursive := true
if input.Recursive != nil {
recursive = *input.Recursive
}
if err := validateImportInput(input, dbPath); err != nil {
return nil, output, fmt.Errorf("validation failed: %w", err)
}
database, err := db.OpenWriteableDB(dbPath)
if err != nil {
return nil, output, fmt.Errorf("failed to open database: %w", err)
}
defer database.Close()
err = utils.EnsureClusterPath(database, input.ClusterID, input.FolderPath)
if err != nil {
return nil, output, fmt.Errorf("failed to set cluster path: %w", err)
}
clusterOutput, err := utils.ImportCluster(database, utils.ClusterImportInput{
FolderPath: input.FolderPath,
DatasetID: input.DatasetID,
LocationID: input.LocationID,
ClusterID: input.ClusterID,
Recursive: recursive,
})
if err != nil {
return nil, output, fmt.Errorf("cluster import failed: %w", err)
}
output = ImportAudioFilesOutput{
Summary: ImportSummary{
TotalFiles: clusterOutput.TotalFiles,
ImportedFiles: clusterOutput.ImportedFiles,
SkippedFiles: clusterOutput.SkippedFiles,
FailedFiles: clusterOutput.FailedFiles,
AudioMothFiles: clusterOutput.AudioMothFiles,
TotalDuration: clusterOutput.TotalDuration,
ProcessingTime: time.Since(startTime).String(),
},
FileIDs: []string{}, Errors: clusterOutput.Errors,
}
return &mcp.CallToolResult{}, output, nil
}
func validateImportInput(input ImportAudioFilesInput, dbPath string) error {
info, err := os.Stat(input.FolderPath)
if err != nil {
return fmt.Errorf("folder not accessible: %w", err)
}
if !info.IsDir() {
return fmt.Errorf("path is not a directory: %s", input.FolderPath)
}
database, err := db.OpenReadOnlyDB(dbPath)
if err != nil {
return fmt.Errorf("failed to open database: %w", err)
}
defer database.Close()
var datasetExists bool
err = database.QueryRow("SELECT EXISTS(SELECT 1 FROM dataset WHERE id = ? AND active = true)", input.DatasetID).Scan(&datasetExists)
if err != nil {
return fmt.Errorf("failed to query dataset: %w", err)
}
if !datasetExists {
return fmt.Errorf("dataset not found or inactive: %s", input.DatasetID)
}
var locationDatasetID string
err = database.QueryRow("SELECT dataset_id FROM location WHERE id = ? AND active = true", input.LocationID).Scan(&locationDatasetID)
if err == sql.ErrNoRows {
return fmt.Errorf("location not found or inactive: %s", input.LocationID)
}
if err != nil {
return fmt.Errorf("failed to query location: %w", err)
}
if locationDatasetID != input.DatasetID {
return fmt.Errorf("location %s does not belong to dataset %s", input.LocationID, input.DatasetID)
}
var clusterLocationID string
err = database.QueryRow("SELECT location_id FROM cluster WHERE id = ? AND active = true", input.ClusterID).Scan(&clusterLocationID)
if err == sql.ErrNoRows {
return fmt.Errorf("cluster not found or inactive: %s", input.ClusterID)
}
if err != nil {
return fmt.Errorf("failed to query cluster: %w", err)
}
if clusterLocationID != input.LocationID {
return fmt.Errorf("cluster %s does not belong to location %s", input.ClusterID, input.LocationID)
}
return nil
}