package tools
import (
"context"
"database/sql"
"encoding/csv"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"skraak/db"
"skraak/utils"
)
type BulkFileImportInput struct {
DatasetID string `json:"dataset_id"`
CSVPath string `json:"csv_path"`
LogFilePath string `json:"log_file_path"`
}
type BulkFileImportOutput struct {
TotalLocations int `json:"total_locations"`
ClustersCreated int `json:"clusters_created"`
ClustersExisting int `json:"clusters_existing"`
TotalFilesScanned int `json:"total_files_scanned"`
FilesImported int `json:"files_imported"`
FilesDuplicate int `json:"files_duplicate"`
FilesError int `json:"files_error"`
ProcessingTime string `json:"processing_time"`
Errors []string `json:"errors,omitempty"`
}
type bulkLocationData struct {
LocationName string
LocationID string
DirectoryPath string
DateRange string
SampleRate int
FileCount int
}
type bulkImportStats struct {
TotalFiles int
ImportedFiles int
DuplicateFiles int
ErrorFiles int
}
type progressLogger struct {
file *os.File
buffer *strings.Builder
}
func (l *progressLogger) Log(format string, args ...any) {
timestamp := time.Now().Format("2006-01-02 15:04:05")
message := fmt.Sprintf(format, args...)
line := fmt.Sprintf("[%s] %s\n", timestamp, message)
if _, err := l.file.WriteString(line); err != nil {
fmt.Fprintf(os.Stderr, "Warning: log write failed: %v\n", err)
}
if err := l.file.Sync(); err != nil {
fmt.Fprintf(os.Stderr, "Warning: log sync failed: %v\n", err)
}
l.buffer.WriteString(line)
}
func (o *BulkFileImportOutput) failOutput(errs []string, startTime time.Time) {
o.Errors = errs
o.ProcessingTime = time.Since(startTime).String()
}
func BulkFileImport(
ctx context.Context,
input BulkFileImportInput,
) (BulkFileImportOutput, error) {
startTime := time.Now()
var output BulkFileImportOutput
logFile, err := os.OpenFile(input.LogFilePath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
return output, fmt.Errorf("failed to open log file: %w", err)
}
defer func() { _ = logFile.Close() }()
logger := &progressLogger{
file: logFile,
buffer: &strings.Builder{},
}
logger.Log("Starting bulk file import for dataset %s", input.DatasetID)
logger.Log("Validating input parameters...")
if err := bulkValidateInput(input); err != nil {
logger.Log("ERROR: Validation failed: %v", err)
output.failOutput([]string{fmt.Sprintf("validation failed: %v", err)}, startTime)
return output, fmt.Errorf("validation failed: %w", err)
}
logger.Log("Validation complete")
logger.Log("Reading CSV file: %s", input.CSVPath)
locations, err := bulkReadCSV(input.CSVPath)
if err != nil {
logger.Log("ERROR: Failed to read CSV: %v", err)
output.failOutput([]string{fmt.Sprintf("failed to read CSV: %v", err)}, startTime)
return output, fmt.Errorf("failed to read CSV: %w", err)
}
logger.Log("Loaded %d locations from CSV", len(locations))
output.TotalLocations = len(locations)
logger.Log("Validating location_ids belong to dataset...")
if err := bulkValidateLocations(logger, locations, input.DatasetID); err != nil {
output.failOutput([]string{err.Error()}, startTime)
return output, err
}
logger.Log("Location validation complete")
logger.Log("=== Phase 1: Creating/Validating Clusters ===")
database, err := db.OpenWriteableDB(dbPath)
if err != nil {
logger.Log("ERROR: Failed to open database: %v", err)
output.failOutput([]string{fmt.Sprintf("failed to open database: %v", err)}, startTime)
return output, fmt.Errorf("failed to open database: %w", err)
}
defer database.Close()
clusterIDMap, created, existing, err := bulkCreateClusters(ctx, database, logger, locations, input.DatasetID)
if err != nil {
output.failOutput(output.Errors, startTime)
return output, err
}
output.ClustersCreated = created
output.ClustersExisting = existing
logger.Log("=== Phase 2: Importing Files ===")
fileStats, errs := bulkImportAllFiles(database, logger, locations, clusterIDMap, input.DatasetID)
output.TotalFilesScanned = fileStats.TotalFiles
output.FilesImported = fileStats.ImportedFiles
output.FilesDuplicate = fileStats.DuplicateFiles
output.FilesError = fileStats.ErrorFiles
output.Errors = append(output.Errors, errs...)
if len(errs) > 0 {
output.ProcessingTime = time.Since(startTime).String()
return output, fmt.Errorf("failed to import files: %s", errs[0])
}
logger.Log("=== Import Complete ===")
logger.Log("Total files scanned: %d", fileStats.TotalFiles)
logger.Log("Files imported: %d", fileStats.ImportedFiles)
logger.Log("Duplicates skipped: %d", fileStats.DuplicateFiles)
logger.Log("Errors: %d", fileStats.ErrorFiles)
logger.Log("Processing time: %s", time.Since(startTime).Round(time.Second))
output.ProcessingTime = time.Since(startTime).String()
return output, nil
}
func bulkValidateInput(input BulkFileImportInput) error {
if err := utils.ValidateShortID(input.DatasetID, "dataset_id"); err != nil {
return err
}
if _, err := os.Stat(input.CSVPath); err != nil {
return fmt.Errorf("CSV file not accessible: %w", err)
}
logDir := filepath.Dir(input.LogFilePath)
if _, err := os.Stat(logDir); err != nil {
return fmt.Errorf("log file directory not accessible: %w", err)
}
database, err := db.OpenReadOnlyDB(dbPath)
if err != nil {
return fmt.Errorf("failed to open database: %w", err)
}
defer database.Close()
if err := db.ValidateDatasetTypeForImport(database, input.DatasetID); err != nil {
return err
}
return nil
}
func bulkValidateLocationsBelongToDataset(dbConn *sql.DB, locations []bulkLocationData, datasetID string) []string {
var errors []string
uniqueLocations := make(map[string]bool)
for _, loc := range locations {
uniqueLocations[loc.LocationID] = true
}
for locationID := range uniqueLocations {
if err := db.ValidateLocationBelongsToDataset(dbConn, locationID, datasetID); err != nil {
errors = append(errors, err.Error())
}
}
return errors
}
func bulkValidateLocations(logger *progressLogger, locations []bulkLocationData, datasetID string) error {
readDB, err := db.OpenReadOnlyDB(dbPath)
if err != nil {
logger.Log("ERROR: Failed to open database: %v", err)
return fmt.Errorf("failed to open database: %w", err)
}
locationErrors := bulkValidateLocationsBelongToDataset(readDB, locations, datasetID)
readDB.Close()
if len(locationErrors) > 0 {
for _, locErr := range locationErrors {
logger.Log("ERROR: %s", locErr)
}
return fmt.Errorf("location validation failed: %d location(s) do not belong to dataset %s", len(locationErrors), datasetID)
}
return nil
}
func bulkCreateClusters(ctx context.Context, database *sql.DB, logger *progressLogger, locations []bulkLocationData, datasetID string) (map[string]string, int, int, error) {
clusterIDMap := make(map[string]string)
created := 0
existing := 0
for i, loc := range locations {
logger.Log("[%d/%d] Processing location: %s", i+1, len(locations), loc.LocationName)
var existingClusterID string
err := database.QueryRow(`
SELECT id FROM cluster
WHERE location_id = ? AND name = ? AND active = true
`, loc.LocationID, loc.DateRange).Scan(&existingClusterID)
var clusterID string
if err == sql.ErrNoRows {
clusterID, err = bulkCreateCluster(ctx, database, datasetID, loc.LocationID, loc.DateRange, loc.SampleRate)
if err != nil {
logger.Log("ERROR: Failed to create cluster for location %s: %v", loc.LocationName, err)
return nil, 0, 0, fmt.Errorf("failed to create cluster: %w", err)
}
logger.Log(" Created cluster: %s", clusterID)
created++
} else if err != nil {
logger.Log("ERROR: Failed to check cluster for location %s: %v", loc.LocationName, err)
return nil, 0, 0, fmt.Errorf("failed to check cluster: %w", err)
} else {
clusterID = existingClusterID
logger.Log(" Using existing cluster: %s", clusterID)
existing++
}
compositeKey := loc.LocationID + "|" + loc.DateRange
clusterIDMap[compositeKey] = clusterID
}
return clusterIDMap, created, existing, nil
}
func bulkImportAllFiles(database *sql.DB, logger *progressLogger, locations []bulkLocationData, clusterIDMap map[string]string, datasetID string) (bulkImportStats, []string) {
var total bulkImportStats
var errs []string
for i, loc := range locations {
compositeKey := loc.LocationID + "|" + loc.DateRange
clusterID, ok := clusterIDMap[compositeKey]
if !ok {
continue
}
logger.Log("[%d/%d] Importing files for: %s", i+1, len(locations), loc.LocationName)
logger.Log(" Directory: %s", loc.DirectoryPath)
if _, err := os.Stat(loc.DirectoryPath); os.IsNotExist(err) {
logger.Log(" WARNING: Directory not found, skipping")
continue
}
stats, err := bulkImportFilesForCluster(database, logger, loc.DirectoryPath, datasetID, loc.LocationID, clusterID)
if err != nil {
errMsg := fmt.Sprintf("Failed to import files for location %s: %v", loc.LocationName, err)
logger.Log("ERROR: %s", errMsg)
return total, []string{errMsg}
}
logger.Log(" Scanned: %d files", stats.TotalFiles)
logger.Log(" Imported: %d, Duplicates: %d", stats.ImportedFiles, stats.DuplicateFiles)
if stats.ErrorFiles > 0 {
logger.Log(" Errors: %d files", stats.ErrorFiles)
}
total.TotalFiles += stats.TotalFiles
total.ImportedFiles += stats.ImportedFiles
total.DuplicateFiles += stats.DuplicateFiles
total.ErrorFiles += stats.ErrorFiles
}
return total, errs
}
func bulkReadCSV(path string) ([]bulkLocationData, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer func() { _ = file.Close() }()
reader := csv.NewReader(file)
records, err := reader.ReadAll()
if err != nil {
return nil, err
}
if len(records) == 0 {
return nil, fmt.Errorf("CSV file is empty")
}
var locations []bulkLocationData
for i, record := range records {
if i == 0 {
continue }
if len(record) < 6 {
return nil, fmt.Errorf("CSV row %d has insufficient columns (expected 6, got %d)", i+1, len(record))
}
locationName := strings.TrimSpace(record[0])
if locationName == "" {
return nil, fmt.Errorf("empty location_name in row %d", i+1)
}
directoryPath := strings.TrimSpace(record[2])
if directoryPath == "" {
return nil, fmt.Errorf("empty directory_path in row %d", i+1)
}
dateRange := strings.TrimSpace(record[3])
if dateRange == "" {
return nil, fmt.Errorf("empty date_range in row %d", i+1)
}
locationID := record[1]
if err := utils.ValidateShortID(locationID, "location_id"); err != nil {
return nil, fmt.Errorf("invalid location_id in row %d: %v", i+1, err)
}
sampleRate, err := strconv.Atoi(record[4])
if err != nil {
return nil, fmt.Errorf("invalid sample_rate in row %d: %v", i+1, err)
}
if err := utils.ValidateSampleRate(sampleRate); err != nil {
return nil, fmt.Errorf("invalid sample_rate in row %d: %v", i+1, err)
}
fileCount, err := strconv.Atoi(record[5])
if err != nil {
return nil, fmt.Errorf("invalid file_count in row %d: %v", i+1, err)
}
locations = append(locations, bulkLocationData{
LocationName: locationName,
LocationID: locationID,
DirectoryPath: directoryPath,
DateRange: dateRange,
SampleRate: sampleRate,
FileCount: fileCount,
})
}
return locations, nil
}
func bulkCreateCluster(ctx context.Context, database *sql.DB, datasetID, locationID, name string, sampleRate int) (string, error) {
clusterID, err := utils.GenerateShortID()
if err != nil {
return "", fmt.Errorf("failed to generate cluster ID: %v", err)
}
now := time.Now().UTC()
var locationName string
err = database.QueryRow("SELECT name FROM location WHERE id = ?", locationID).Scan(&locationName)
if err != nil {
return "", fmt.Errorf("failed to get location name: %v", err)
}
path := strings.ReplaceAll(locationName, " ", "_")
path = strings.ReplaceAll(path, "/", "_")
tx, err := db.BeginLoggedTx(ctx, database, "bulk_file_import")
if err != nil {
return "", fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback()
_, err = tx.ExecContext(ctx, `
INSERT INTO cluster (id, dataset_id, location_id, name, path, sample_rate, active, created_at, last_modified)
VALUES (?, ?, ?, ?, ?, ?, true, ?, ?)
`, clusterID, datasetID, locationID, name, path, sampleRate, now, now)
if err != nil {
return "", fmt.Errorf("failed to insert cluster: %w", err)
}
if err = tx.Commit(); err != nil {
return "", fmt.Errorf("failed to commit cluster creation: %w", err)
}
return clusterID, nil
}
func bulkImportFilesForCluster(database *sql.DB, logger *progressLogger, folderPath, datasetID, locationID, clusterID string) (*bulkImportStats, error) {
stats := &bulkImportStats{}
if _, err := os.Stat(folderPath); os.IsNotExist(err) {
logger.Log(" WARNING: Directory not found, skipping")
return stats, nil
}
logger.Log(" Importing cluster %s", clusterID)
ctx := context.Background()
tx, err := db.BeginLoggedTx(ctx, database, "import_audio_files")
if err != nil {
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
clusterOutput, err := utils.ImportCluster(database, tx.UnderlyingTx(), utils.ClusterImportInput{
FolderPath: folderPath,
DatasetID: datasetID,
LocationID: locationID,
ClusterID: clusterID,
Recursive: true,
})
if err != nil {
tx.Rollback()
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("transaction commit failed: %w", err)
}
stats.TotalFiles = clusterOutput.TotalFiles
stats.ImportedFiles = clusterOutput.ImportedFiles
stats.DuplicateFiles = clusterOutput.SkippedFiles
stats.ErrorFiles = clusterOutput.FailedFiles
for i, fileErr := range clusterOutput.Errors {
if i < 5 { logger.Log(" ERROR: %s: %s", fileErr.FileName, fileErr.Error)
}
}
logger.Log(" Complete: %d imported, %d duplicates, %d errors", stats.ImportedFiles, stats.DuplicateFiles, stats.ErrorFiles)
return stats, nil
}