package tools
import (
"context"
"database/sql"
"encoding/csv"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/modelcontextprotocol/go-sdk/mcp"
"skraak_mcp/db"
"skraak_mcp/utils"
)
type BulkFileImportInput struct {
DatasetID string `json:"dataset_id" jsonschema:"required,Dataset ID (12 characters)"`
CSVPath string `json:"csv_path" jsonschema:"required,Absolute path to CSV file. Format: location_name,location_id,directory_path,date_range,sample_rate,file_count (header required). date_range becomes cluster name. Example: 'Site A',abc123,/path/files,20240101-20240107,48000,100"`
LogFilePath string `json:"log_file_path" jsonschema:"required,Absolute path for progress log file (monitors import progress)"`
}
type BulkFileImportOutput struct {
TotalLocations int `json:"total_locations"`
ClustersCreated int `json:"clusters_created"`
ClustersExisting int `json:"clusters_existing"`
TotalFilesScanned int `json:"total_files_scanned"`
FilesImported int `json:"files_imported"`
FilesDuplicate int `json:"files_duplicate"`
FilesError int `json:"files_error"`
ProcessingTime string `json:"processing_time"`
Errors []string `json:"errors,omitempty"`
}
type bulkLocationData struct {
LocationName string
LocationID string
DirectoryPath string
DateRange string
SampleRate int
FileCount int
}
type bulkImportStats struct {
TotalFiles int
ImportedFiles int
DuplicateFiles int
ErrorFiles int
}
type progressLogger struct {
file *os.File
buffer *strings.Builder
}
func (l *progressLogger) Log(format string, args ...interface{}) {
timestamp := time.Now().Format("2006-01-02 15:04:05")
message := fmt.Sprintf(format, args...)
line := fmt.Sprintf("[%s] %s\n", timestamp, message)
l.file.WriteString(line)
l.file.Sync()
l.buffer.WriteString(line)
}
func BulkFileImport(
ctx context.Context,
req *mcp.CallToolRequest,
input BulkFileImportInput,
) (*mcp.CallToolResult, BulkFileImportOutput, error) {
startTime := time.Now()
var output BulkFileImportOutput
logFile, err := os.OpenFile(input.LogFilePath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
return nil, output, fmt.Errorf("failed to open log file: %w", err)
}
defer logFile.Close()
logger := &progressLogger{
file: logFile,
buffer: &strings.Builder{},
}
logger.Log("Starting bulk file import for dataset %s", input.DatasetID)
logger.Log("Validating input parameters...")
if err := bulkValidateInput(input); err != nil {
logger.Log("ERROR: Validation failed: %v", err)
output.Errors = []string{fmt.Sprintf("validation failed: %v", err)}
output.ProcessingTime = time.Since(startTime).String()
return nil, output, fmt.Errorf("validation failed: %w", err)
}
logger.Log("Validation complete")
logger.Log("Reading CSV file: %s", input.CSVPath)
locations, err := bulkReadCSV(input.CSVPath)
if err != nil {
logger.Log("ERROR: Failed to read CSV: %v", err)
output.Errors = []string{fmt.Sprintf("failed to read CSV: %v", err)}
output.ProcessingTime = time.Since(startTime).String()
return nil, output, fmt.Errorf("failed to read CSV: %w", err)
}
logger.Log("Loaded %d locations from CSV", len(locations))
output.TotalLocations = len(locations)
logger.Log("=== Phase 1: Creating/Validating Clusters ===")
clusterIDMap := make(map[string]string)
database, err := db.OpenWriteableDB(dbPath)
if err != nil {
logger.Log("ERROR: Failed to open database: %v", err)
output.Errors = []string{fmt.Sprintf("failed to open database: %v", err)}
output.ProcessingTime = time.Since(startTime).String()
return nil, output, fmt.Errorf("failed to open database: %w", err)
}
defer database.Close()
for i, loc := range locations {
logger.Log("[%d/%d] Processing location: %s", i+1, len(locations), loc.LocationName)
var existingClusterID string
err := database.QueryRow(`
SELECT id FROM cluster
WHERE location_id = ? AND name = ? AND active = true
`, loc.LocationID, loc.DateRange).Scan(&existingClusterID)
var clusterID string
if err == sql.ErrNoRows {
clusterID, err = bulkCreateCluster(database, input.DatasetID, loc.LocationID, loc.DateRange, loc.SampleRate)
if err != nil {
errMsg := fmt.Sprintf("Failed to create cluster for location %s: %v", loc.LocationName, err)
logger.Log("ERROR: %s", errMsg)
output.Errors = append(output.Errors, errMsg)
output.ProcessingTime = time.Since(startTime).String()
return nil, output, fmt.Errorf("failed to create cluster: %w", err)
}
logger.Log(" Created cluster: %s", clusterID)
output.ClustersCreated++
} else if err != nil {
errMsg := fmt.Sprintf("Failed to check cluster for location %s: %v", loc.LocationName, err)
logger.Log("ERROR: %s", errMsg)
output.Errors = append(output.Errors, errMsg)
output.ProcessingTime = time.Since(startTime).String()
return nil, output, fmt.Errorf("failed to check cluster: %w", err)
} else {
clusterID = existingClusterID
logger.Log(" Using existing cluster: %s", clusterID)
output.ClustersExisting++
}
compositeKey := loc.LocationID + "|" + loc.DateRange
clusterIDMap[compositeKey] = clusterID
}
logger.Log("=== Phase 2: Importing Files ===")
totalImported := 0
totalDuplicates := 0
totalErrors := 0
totalScanned := 0
for i, loc := range locations {
compositeKey := loc.LocationID + "|" + loc.DateRange
clusterID, ok := clusterIDMap[compositeKey]
if !ok {
continue }
logger.Log("[%d/%d] Importing files for: %s", i+1, len(locations), loc.LocationName)
logger.Log(" Directory: %s", loc.DirectoryPath)
if _, err := os.Stat(loc.DirectoryPath); os.IsNotExist(err) {
logger.Log(" WARNING: Directory not found, skipping")
continue
}
stats, err := bulkImportFilesForCluster(database, logger, loc.DirectoryPath, input.DatasetID, loc.LocationID, clusterID)
if err != nil {
errMsg := fmt.Sprintf("Failed to import files for location %s: %v", loc.LocationName, err)
logger.Log("ERROR: %s", errMsg)
output.Errors = append(output.Errors, errMsg)
output.TotalFilesScanned = totalScanned
output.FilesImported = totalImported
output.FilesDuplicate = totalDuplicates
output.FilesError = totalErrors
output.ProcessingTime = time.Since(startTime).String()
return nil, output, fmt.Errorf("failed to import files: %w", err)
}
logger.Log(" Scanned: %d files", stats.TotalFiles)
logger.Log(" Imported: %d, Duplicates: %d", stats.ImportedFiles, stats.DuplicateFiles)
if stats.ErrorFiles > 0 {
logger.Log(" Errors: %d files", stats.ErrorFiles)
}
totalScanned += stats.TotalFiles
totalImported += stats.ImportedFiles
totalDuplicates += stats.DuplicateFiles
totalErrors += stats.ErrorFiles
}
logger.Log("=== Import Complete ===")
logger.Log("Total files scanned: %d", totalScanned)
logger.Log("Files imported: %d", totalImported)
logger.Log("Duplicates skipped: %d", totalDuplicates)
logger.Log("Errors: %d", totalErrors)
logger.Log("Processing time: %s", time.Since(startTime).Round(time.Second))
output.TotalFilesScanned = totalScanned
output.FilesImported = totalImported
output.FilesDuplicate = totalDuplicates
output.FilesError = totalErrors
output.ProcessingTime = time.Since(startTime).String()
return &mcp.CallToolResult{}, output, nil
}
func bulkValidateInput(input BulkFileImportInput) error {
if _, err := os.Stat(input.CSVPath); err != nil {
return fmt.Errorf("CSV file not accessible: %w", err)
}
logDir := filepath.Dir(input.LogFilePath)
if _, err := os.Stat(logDir); err != nil {
return fmt.Errorf("log file directory not accessible: %w", err)
}
database, err := db.OpenReadOnlyDB(dbPath)
if err != nil {
return fmt.Errorf("failed to open database: %w", err)
}
defer database.Close()
var datasetExists bool
err = database.QueryRow("SELECT EXISTS(SELECT 1 FROM dataset WHERE id = ? AND active = true)", input.DatasetID).Scan(&datasetExists)
if err != nil {
return fmt.Errorf("failed to query dataset: %w", err)
}
if !datasetExists {
return fmt.Errorf("dataset not found or inactive: %s", input.DatasetID)
}
return nil
}
func bulkReadCSV(path string) ([]bulkLocationData, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()
reader := csv.NewReader(file)
records, err := reader.ReadAll()
if err != nil {
return nil, err
}
if len(records) == 0 {
return nil, fmt.Errorf("CSV file is empty")
}
var locations []bulkLocationData
for i, record := range records {
if i == 0 {
continue }
if len(record) < 6 {
return nil, fmt.Errorf("CSV row %d has insufficient columns (expected 6, got %d)", i+1, len(record))
}
sampleRate, err := strconv.Atoi(record[4])
if err != nil {
return nil, fmt.Errorf("invalid sample_rate in row %d: %v", i+1, err)
}
fileCount, err := strconv.Atoi(record[5])
if err != nil {
return nil, fmt.Errorf("invalid file_count in row %d: %v", i+1, err)
}
locations = append(locations, bulkLocationData{
LocationName: record[0],
LocationID: record[1],
DirectoryPath: record[2],
DateRange: record[3],
SampleRate: sampleRate,
FileCount: fileCount,
})
}
return locations, nil
}
func bulkCreateCluster(database *sql.DB, datasetID, locationID, name string, sampleRate int) (string, error) {
clusterID, err := utils.GenerateShortID()
if err != nil {
return "", fmt.Errorf("failed to generate cluster ID: %v", err)
}
now := time.Now().UTC()
var locationName string
err = database.QueryRow("SELECT name FROM location WHERE id = ?", locationID).Scan(&locationName)
if err != nil {
return "", fmt.Errorf("failed to get location name: %v", err)
}
path := strings.ReplaceAll(locationName, " ", "_")
path = strings.ReplaceAll(path, "/", "_")
_, err = database.Exec(`
INSERT INTO cluster (id, dataset_id, location_id, name, path, sample_rate, active, created_at, last_modified)
VALUES (?, ?, ?, ?, ?, ?, true, ?, ?)
`, clusterID, datasetID, locationID, name, path, sampleRate, now, now)
if err != nil {
return "", fmt.Errorf("failed to insert cluster: %v", err)
}
return clusterID, nil
}
func bulkImportFilesForCluster(database *sql.DB, logger *progressLogger, folderPath, datasetID, locationID, clusterID string) (*bulkImportStats, error) {
stats := &bulkImportStats{}
if _, err := os.Stat(folderPath); os.IsNotExist(err) {
logger.Log(" WARNING: Directory not found, skipping")
return stats, nil
}
logger.Log(" Importing cluster %s", clusterID)
clusterOutput, err := utils.ImportCluster(database, utils.ClusterImportInput{
FolderPath: folderPath,
DatasetID: datasetID,
LocationID: locationID,
ClusterID: clusterID,
Recursive: true,
})
if err != nil {
return nil, err
}
stats.TotalFiles = clusterOutput.TotalFiles
stats.ImportedFiles = clusterOutput.ImportedFiles
stats.DuplicateFiles = clusterOutput.SkippedFiles
stats.ErrorFiles = clusterOutput.FailedFiles
for i, fileErr := range clusterOutput.Errors {
if i < 5 { logger.Log(" ERROR: %s: %s", fileErr.FileName, fileErr.Error)
}
}
logger.Log(" Complete: %d imported, %d duplicates, %d errors", stats.ImportedFiles, stats.DuplicateFiles, stats.ErrorFiles)
return stats, nil
}