package tools

import (
	"context"
	"database/sql"
	"fmt"
	"os"
	"path/filepath"
	"strings"
	"time"

	"skraak/db"
	"skraak/utils"
)

// ImportSegmentsInput defines the input parameters for the import_segments tool
type ImportSegmentsInput struct {
	Folder          string `json:"folder"`
	Mapping         string `json:"mapping"`
	DatasetID       string `json:"dataset_id"`
	LocationID      string `json:"location_id"`
	ClusterID       string `json:"cluster_id"`
	ProgressHandler func(processed, total int, message string)
}

// ImportSegmentsOutput defines the output structure for the import_segments tool
type ImportSegmentsOutput struct {
	Summary  ImportSegmentsSummary `json:"summary"`
	Segments []SegmentImport       `json:"segments"`
	Errors   []ImportSegmentError  `json:"errors,omitempty"`
}

// ImportSegmentsSummary provides summary statistics for the import operation
type ImportSegmentsSummary struct {
	DataFilesFound     int   `json:"data_files_found"`
	DataFilesProcessed int   `json:"data_files_processed"`
	TotalSegments      int   `json:"total_segments"`
	ImportedSegments   int   `json:"imported_segments"`
	ImportedLabels     int   `json:"imported_labels"`
	ImportedSubtypes   int   `json:"imported_subtypes"`
	ProcessingTimeMs   int64 `json:"processing_time_ms"`
}

// SegmentImport represents an imported segment in the output
type SegmentImport struct {
	SegmentID string        `json:"segment_id"`
	FileName  string        `json:"file_name"`
	StartTime float64       `json:"start_time"`
	EndTime   float64       `json:"end_time"`
	FreqLow   float64       `json:"freq_low"`
	FreqHigh  float64       `json:"freq_high"`
	Labels    []LabelImport `json:"labels"`
}

// LabelImport represents an imported label in the output
type LabelImport struct {
	LabelID   string `json:"label_id"`
	Species   string `json:"species"`
	CallType  string `json:"calltype,omitempty"`
	Filter    string `json:"filter"`
	Certainty int    `json:"certainty"`
	Comment   string `json:"comment,omitempty"`
}

// ImportSegmentError records errors encountered during segment import
type ImportSegmentError struct {
	File    string            `json:"file,omitempty"`
	Stage   utils.ImportStage `json:"stage"`
	Message string            `json:"message"`
}

// scannedDataFile holds parsed data for a .data file
type scannedDataFile struct {
	DataPath string
	WavPath  string
	WavHash  string
	FileID   string
	Duration float64
	Segments []*utils.Segment
}

// segmentValidation holds the results of pre-import validation (phases B+C).
type segmentValidation struct {
	scannedFiles  []scannedDataFile
	filterIDMap   map[string]string
	speciesIDMap  map[string]string
	calltypeIDMap map[string]map[string]string
	fileIDMap     map[string]scannedDataFile
}

// validateAndPrepareSegments performs phases B+C: parse data files, validate DB state, and prepare ID maps.
func validateAndPrepareSegments(
	database *sql.DB,
	input ImportSegmentsInput,
	mapping utils.MappingFile,
	dataFiles []string,
) (*segmentValidation, []ImportSegmentError, error) {
	// Phase B: Parse all .data files and collect unique values
	scannedFiles, parseErrors, uniqueFilters, uniqueSpecies, uniqueCalltypes := scanAllDataFiles(dataFiles, input.Folder)
	if len(scannedFiles) == 0 {
		return nil, parseErrors, nil
	}

	// Validate dataset/location/cluster hierarchy
	if err := validateSegmentHierarchy(database, input.DatasetID, input.LocationID, input.ClusterID); err != nil {
		return nil, parseErrors, err
	}

	// Validate all filters exist
	filterIDMap, err := validateFiltersExist(database, uniqueFilters)
	if err != nil {
		return nil, parseErrors, fmt.Errorf("filter validation failed: %w", err)
	}

	// Validate mapping covers all species/calltypes and they exist in DB
	validationResult, err := utils.ValidateMappingAgainstDB(database, mapping, uniqueSpecies, uniqueCalltypes)
	if err != nil {
		return nil, parseErrors, fmt.Errorf("mapping validation failed: %w", err)
	}
	if validationResult.HasErrors() {
		return nil, parseErrors, fmt.Errorf("mapping validation failed: %s", validationResult.Error())
	}

	// Load species and calltype ID maps
	speciesIDMap, calltypeIDMap, err := loadSpeciesCalltypeIDs(database, mapping, uniqueSpecies, uniqueCalltypes)
	if err != nil {
		return nil, parseErrors, fmt.Errorf("failed to load species/calltype IDs: %w", err)
	}

	// Validate files: hash exists, linked to dataset, no existing labels
	fileIDMap, hashErrors := validateAndMapFiles(database, scannedFiles, input.ClusterID, input.DatasetID)
	allErrors := append(parseErrors, hashErrors...)

	return &segmentValidation{
		scannedFiles:  scannedFiles,
		filterIDMap:   filterIDMap,
		speciesIDMap:  speciesIDMap,
		calltypeIDMap: calltypeIDMap,
		fileIDMap:     fileIDMap,
	}, allErrors, nil
}

// ImportSegments imports segments from AviaNZ .data files into the database
func ImportSegments(ctx context.Context, input ImportSegmentsInput) (ImportSegmentsOutput, error) {
	startTime := time.Now()
	var output ImportSegmentsOutput
	output.Segments = make([]SegmentImport, 0)
	output.Errors = make([]ImportSegmentError, 0)

	// Phase A: Input Validation
	if err := validateSegmentImportInput(input); err != nil {
		return output, err
	}

	// Load mapping file
	mapping, err := utils.LoadMappingFile(input.Mapping)
	if err != nil {
		return output, fmt.Errorf("failed to load mapping file: %w", err)
	}

	// Find .data files
	dataFiles, err := utils.FindDataFiles(input.Folder)
	if err != nil {
		return output, fmt.Errorf("failed to find .data files: %w", err)
	}
	output.Summary.DataFilesFound = len(dataFiles)

	if len(dataFiles) == 0 {
		return output, fmt.Errorf("no .data files found in folder: %s", input.Folder)
	}

	// Phase B+C: Parse data files and validate against DB
	database, err := db.OpenWriteableDB(dbPath)
	if err != nil {
		return output, fmt.Errorf("failed to open database: %w", err)
	}
	defer database.Close()

	val, valErrors, err := validateAndPrepareSegments(database, input, mapping, dataFiles)
	output.Errors = append(output.Errors, valErrors...)
	if err != nil {
		return output, err
	}
	if val == nil || len(val.fileIDMap) == 0 {
		output.Summary.ProcessingTimeMs = time.Since(startTime).Milliseconds()
		return output, nil
	}

	// Phase D: Transactional Import
	importedSegments, importedLabels, importedSubtypes, fileUpdates, importErrors := importSegmentsIntoDB(
		ctx, database, val.fileIDMap, val.scannedFiles, mapping, val.filterIDMap, val.speciesIDMap, val.calltypeIDMap, input.DatasetID, input.ProgressHandler,
	)
	output.Errors = append(output.Errors, importErrors...)
	output.Segments = append(output.Segments, importedSegments...)

	// Phase E: Write IDs back to .data files
	if len(fileUpdates) > 0 {
		writeErrors := writeIDsToDataFiles(fileUpdates)
		output.Errors = append(output.Errors, writeErrors...)
	}

	output.Summary.DataFilesProcessed = len(val.fileIDMap)
	output.Summary.TotalSegments = countTotalSegments(val.fileIDMap)
	output.Summary.ImportedSegments = len(importedSegments)
	output.Summary.ImportedLabels = importedLabels
	output.Summary.ImportedSubtypes = importedSubtypes
	output.Summary.ProcessingTimeMs = time.Since(startTime).Milliseconds()

	return output, nil
}

// validateSegmentImportInput validates input parameters
func validateSegmentImportInput(input ImportSegmentsInput) error {
	// Validate folder exists
	if info, err := os.Stat(input.Folder); err != nil {
		return fmt.Errorf("folder does not exist: %s", input.Folder)
	} else if !info.IsDir() {
		return fmt.Errorf("path is not a folder: %s", input.Folder)
	}

	// Validate mapping file exists
	if _, err := os.Stat(input.Mapping); err != nil {
		return fmt.Errorf("mapping file does not exist: %s", input.Mapping)
	}

	// Validate IDs
	if err := utils.ValidateShortID(input.DatasetID, "dataset_id"); err != nil {
		return err
	}
	if err := utils.ValidateShortID(input.LocationID, "location_id"); err != nil {
		return err
	}
	if err := utils.ValidateShortID(input.ClusterID, "cluster_id"); err != nil {
		return err
	}

	return nil
}

// validateSegmentHierarchy validates dataset/location/cluster relationships
func validateSegmentHierarchy(dbConn *sql.DB, datasetID, locationID, clusterID string) error {
	// Validate dataset exists and is structured
	if err := db.ValidateDatasetTypeForImport(dbConn, datasetID); err != nil {
		return err
	}

	// Validate location belongs to dataset
	if err := db.ValidateLocationBelongsToDataset(dbConn, locationID, datasetID); err != nil {
		return err
	}

	// Validate cluster belongs to location
	if err := db.ClusterBelongsToLocation(dbConn, clusterID, locationID); err != nil {
		return err
	}

	return nil
}

// scanAllDataFiles parses all .data files and collects unique values
func scanAllDataFiles(dataFiles []string, folder string) (
	[]scannedDataFile,
	[]ImportSegmentError,
	map[string]bool,
	map[string]bool,
	map[string]map[string]bool,
) {
	var scanned []scannedDataFile
	var errors []ImportSegmentError
	uniqueFilters := make(map[string]bool)
	uniqueSpecies := make(map[string]bool)
	uniqueCalltypes := make(map[string]map[string]bool) // species -> calltype -> true

	for _, dataPath := range dataFiles {
		// Find corresponding WAV file
		wavPath := strings.TrimSuffix(dataPath, ".data")
		if _, err := os.Stat(wavPath); err != nil {
			errors = append(errors, ImportSegmentError{
				File:    filepath.Base(dataPath),
				Stage:   utils.StageValidation,
				Message: fmt.Sprintf("corresponding WAV file not found: %s", filepath.Base(wavPath)),
			})
			continue
		}

		// Parse .data file
		df, err := utils.ParseDataFile(dataPath)
		if err != nil {
			errors = append(errors, ImportSegmentError{
				File:    filepath.Base(dataPath),
				Stage:   utils.StageValidation,
				Message: fmt.Sprintf("failed to parse .data file: %v", err),
			})
			continue
		}

		// Collect unique filters, species, calltypes
		for _, seg := range df.Segments {
			for _, label := range seg.Labels {
				uniqueFilters[label.Filter] = true
				uniqueSpecies[label.Species] = true
				if label.CallType != "" {
					if uniqueCalltypes[label.Species] == nil {
						uniqueCalltypes[label.Species] = make(map[string]bool)
					}
					uniqueCalltypes[label.Species][label.CallType] = true
				}
			}
		}

		scanned = append(scanned, scannedDataFile{
			DataPath: dataPath,
			WavPath:  wavPath,
			Duration: df.Meta.Duration,
			Segments: df.Segments,
		})
	}

	return scanned, errors, uniqueFilters, uniqueSpecies, uniqueCalltypes
}

// validateFiltersExist checks all filters exist in DB and returns ID map
func validateFiltersExist(dbConn *sql.DB, filterNames map[string]bool) (map[string]string, error) {
	filterIDMap := make(map[string]string)

	if len(filterNames) == 0 {
		return filterIDMap, nil
	}

	names := make([]string, 0, len(filterNames))
	for name := range filterNames {
		names = append(names, name)
	}

	query := `SELECT id, name FROM filter WHERE name IN (` + db.Placeholders(len(names)) + `) AND active = true`
	args := make([]any, len(names))
	for i, name := range names {
		args[i] = name
	}

	rows, err := dbConn.Query(query, args...)
	if err != nil {
		return nil, fmt.Errorf("failed to query filters: %w", err)
	}
	defer rows.Close()

	for rows.Next() {
		var id, name string
		if err := rows.Scan(&id, &name); err == nil {
			filterIDMap[name] = id
		}
	}

	// Check for missing filters
	var missing []string
	for name := range filterNames {
		if _, exists := filterIDMap[name]; !exists {
			missing = append(missing, name)
		}
	}

	if len(missing) > 0 {
		return nil, fmt.Errorf("filters not found in database: [%s]", strings.Join(missing, ", "))
	}

	return filterIDMap, nil
}

// loadSpeciesCalltypeIDs loads species and calltype ID maps
func loadSpeciesCalltypeIDs(
	dbConn *sql.DB,
	mapping utils.MappingFile,
	uniqueSpecies map[string]bool,
	uniqueCalltypes map[string]map[string]bool,
) (map[string]string, map[string]map[string]string, error) {
	speciesIDMap := make(map[string]string)
	calltypeIDMap := make(map[string]map[string]string) // (dbSpecies, dbCalltype) -> calltype_id

	// Collect all DB species labels from mapping
	dbSpeciesSet := make(map[string]bool)
	for dataSpecies := range uniqueSpecies {
		if dbSpecies, ok := mapping.GetDBSpecies(dataSpecies); ok {
			dbSpeciesSet[dbSpecies] = true
		}
	}

	// Load species IDs
	if len(dbSpeciesSet) > 0 {
		dbSpeciesList := make([]string, 0, len(dbSpeciesSet))
		for s := range dbSpeciesSet {
			dbSpeciesList = append(dbSpeciesList, s)
		}

		query := `SELECT id, label FROM species WHERE label IN (` + db.Placeholders(len(dbSpeciesList)) + `) AND active = true`
		args := make([]any, len(dbSpeciesList))
		for i, s := range dbSpeciesList {
			args[i] = s
		}

		rows, err := dbConn.Query(query, args...)
		if err != nil {
			return nil, nil, fmt.Errorf("failed to query species: %w", err)
		}
		defer rows.Close()

		for rows.Next() {
			var id, label string
			if err := rows.Scan(&id, &label); err == nil {
				speciesIDMap[label] = id
			}
		}
	}

	// Load calltype IDs
	for dataSpecies, ctSet := range uniqueCalltypes {
		dbSpecies, ok := mapping.GetDBSpecies(dataSpecies)
		if !ok {
			continue
		}

		if calltypeIDMap[dbSpecies] == nil {
			calltypeIDMap[dbSpecies] = make(map[string]string)
		}

		for dataCalltype := range ctSet {
			dbCalltype := mapping.GetDBCalltype(dataSpecies, dataCalltype)

			// Query calltype ID
			var calltypeID string
			err := dbConn.QueryRow(`
				SELECT ct.id
				FROM call_type ct
				JOIN species s ON ct.species_id = s.id
				WHERE s.label = ? AND ct.label = ? AND ct.active = true
			`, dbSpecies, dbCalltype).Scan(&calltypeID)

			if err == nil {
				calltypeIDMap[dbSpecies][dbCalltype] = calltypeID
			}
		}
	}

	return speciesIDMap, calltypeIDMap, nil
}

// validateAndMapFiles validates files exist by hash, are linked to dataset, and have no existing labels
func validateAndMapFiles(
	dbConn *sql.DB,
	scannedFiles []scannedDataFile,
	clusterID string,
	datasetID string,
) (map[string]scannedDataFile, []ImportSegmentError) {
	fileIDMap := make(map[string]scannedDataFile)
	var errors []ImportSegmentError

	for _, sf := range scannedFiles {
		// Compute hash
		hash, err := utils.ComputeXXH64(sf.WavPath)
		if err != nil {
			errors = append(errors, ImportSegmentError{
				File:    filepath.Base(sf.WavPath),
				Stage:   utils.StageHash,
				Message: fmt.Sprintf("failed to compute hash: %v", err),
			})
			continue
		}
		sf.WavHash = hash

		// Find file by hash in cluster
		var fileID string
		var duration float64
		err = dbConn.QueryRow(`
			SELECT id, duration FROM file WHERE xxh64_hash = ? AND cluster_id = ? AND active = true
		`, hash, clusterID).Scan(&fileID, &duration)

		if err == sql.ErrNoRows {
			errors = append(errors, ImportSegmentError{
				File:    filepath.Base(sf.WavPath),
				Stage:   utils.StageValidation,
				Message: fmt.Sprintf("file hash not found in database for cluster (hash: %s)", hash),
			})
			continue
		}
		if err != nil {
			errors = append(errors, ImportSegmentError{
				File:    filepath.Base(sf.WavPath),
				Stage:   utils.StageValidation,
				Message: fmt.Sprintf("failed to query file: %v", err),
			})
			continue
		}

		sf.FileID = fileID
		sf.Duration = duration

		// Verify file is linked to dataset via file_dataset junction table (composite FK)
		var fileLinkedToDataset bool
		err = dbConn.QueryRow(`
			SELECT EXISTS(SELECT 1 FROM file_dataset WHERE file_id = ? AND dataset_id = ?)
		`, fileID, datasetID).Scan(&fileLinkedToDataset)
		if err != nil {
			errors = append(errors, ImportSegmentError{
				File:    filepath.Base(sf.WavPath),
				Stage:   utils.StageValidation,
				Message: fmt.Sprintf("failed to verify file-dataset link: %v", err),
			})
			continue
		}
		if !fileLinkedToDataset {
			errors = append(errors, ImportSegmentError{
				File:    filepath.Base(sf.WavPath),
				Stage:   utils.StageValidation,
				Message: fmt.Sprintf("file exists in cluster but is not linked to dataset %s", datasetID),
			})
			continue
		}

		// Check no existing labels for this file
		var labelCount int
		err = dbConn.QueryRow(`
			SELECT COUNT(*) FROM label l
			JOIN segment s ON l.segment_id = s.id
			WHERE s.file_id = ? AND l.active = true
		`, fileID).Scan(&labelCount)

		if err != nil {
			errors = append(errors, ImportSegmentError{
				File:    filepath.Base(sf.WavPath),
				Stage:   utils.StageValidation,
				Message: fmt.Sprintf("failed to check existing labels: %v", err),
			})
			continue
		}

		if labelCount > 0 {
			errors = append(errors, ImportSegmentError{
				File:    filepath.Base(sf.WavPath),
				Stage:   utils.StageValidation,
				Message: fmt.Sprintf("file already has %d label(s) - fresh imports only", labelCount),
			})
			continue
		}

		fileIDMap[fileID] = sf
	}

	return fileIDMap, errors
}

// dataFileUpdate holds data to write back to .data file after import
type dataFileUpdate struct {
	DataPath string
	WavHash  string
	LabelIDs map[int]map[int]string // segmentIndex -> labelIndex -> labelID
}

// importLabelResult holds the result of importing a single label.
type importLabelResult struct {
	labelImport      LabelImport
	labelID          string
	subtypesImported int
	err              ImportSegmentError
	hasError         bool
}

// importSingleLabel inserts a single label and its metadata/subtype into the DB.
func importSingleLabel(
	ctx context.Context,
	tx *db.LoggedTx,
	label *utils.Label,
	segmentID string,
	segIdx, labelIdx int,
	sf scannedDataFile,
	mapping utils.MappingFile,
	filterIDMap map[string]string,
	speciesIDMap map[string]string,
	calltypeIDMap map[string]map[string]string,
) importLabelResult {
	dbSpecies, ok := mapping.GetDBSpecies(label.Species)
	if !ok {
		return importLabelResult{err: ImportSegmentError{
			File: filepath.Base(sf.DataPath), Stage: utils.StageImport,
			Message: fmt.Sprintf("species not found in mapping: %s", label.Species),
		}, hasError: true}
	}

	speciesID, ok := speciesIDMap[dbSpecies]
	if !ok {
		return importLabelResult{err: ImportSegmentError{
			File: filepath.Base(sf.DataPath), Stage: utils.StageImport,
			Message: fmt.Sprintf("species ID not found: %s", dbSpecies),
		}, hasError: true}
	}

	filterID, ok := filterIDMap[label.Filter]
	if !ok {
		return importLabelResult{err: ImportSegmentError{
			File: filepath.Base(sf.DataPath), Stage: utils.StageImport,
			Message: fmt.Sprintf("filter ID not found: %s", label.Filter),
		}, hasError: true}
	}

	labelID, err := utils.GenerateLongID()
	if err != nil {
		return importLabelResult{err: ImportSegmentError{
			File: filepath.Base(sf.DataPath), Stage: utils.StageImport,
			Message: fmt.Sprintf("failed to generate label ID: %v", err),
		}, hasError: true}
	}

	_, err = tx.ExecContext(ctx, `
		INSERT INTO label (id, segment_id, species_id, filter_id, certainty, created_at, last_modified, active)
		VALUES (?, ?, ?, ?, ?, now(), now(), true)
	`, labelID, segmentID, speciesID, filterID, label.Certainty)
	if err != nil {
		return importLabelResult{err: ImportSegmentError{
			File: filepath.Base(sf.DataPath), Stage: utils.StageImport,
			Message: fmt.Sprintf("failed to insert label: %v", err),
		}, hasError: true}
	}

	// Insert label_metadata if comment exists
	if label.Comment != "" {
		escapedComment := strings.ReplaceAll(label.Comment, `"`, `\"`)
		metadataJSON := fmt.Sprintf(`{"comment": "%s"}`, escapedComment)
		if _, err := tx.ExecContext(ctx, `
			INSERT INTO label_metadata (label_id, json, created_at, last_modified, active)
			VALUES (?, ?, now(), now(), true)
		`, labelID, metadataJSON); err != nil {
			return importLabelResult{err: ImportSegmentError{
				File: filepath.Base(sf.DataPath), Stage: utils.StageImport,
				Message: fmt.Sprintf("failed to insert label_metadata: %v", err),
			}, hasError: true}
		}
	}

	labelImport := LabelImport{
		LabelID:   labelID,
		Species:   dbSpecies,
		Filter:    label.Filter,
		Certainty: label.Certainty,
	}
	if label.Comment != "" {
		labelImport.Comment = label.Comment
	}

	// Insert label_subtype if calltype exists
	if label.CallType != "" {
		if err := importCalltype(ctx, tx, labelID, label, dbSpecies, filterID, mapping, calltypeIDMap, sf); err != nil {
			return importLabelResult{err: *err, hasError: true}
		}
		labelImport.CallType = mapping.GetDBCalltype(label.Species, label.CallType)
		return importLabelResult{labelImport: labelImport, labelID: labelID, subtypesImported: 1}
	}

	return importLabelResult{labelImport: labelImport, labelID: labelID}
}

// importCalltype inserts a label_subtype row for a calltype label.
func importCalltype(
	ctx context.Context,
	tx *db.LoggedTx,
	labelID string,
	label *utils.Label,
	dbSpecies string,
	filterID string,
	mapping utils.MappingFile,
	calltypeIDMap map[string]map[string]string,
	sf scannedDataFile,
) *ImportSegmentError {
	dbCalltype := mapping.GetDBCalltype(label.Species, label.CallType)

	calltypeID := ""
	if calltypeIDMap[dbSpecies] != nil {
		calltypeID = calltypeIDMap[dbSpecies][dbCalltype]
	}
	if calltypeID == "" {
		return &ImportSegmentError{
			File: filepath.Base(sf.DataPath), Stage: utils.StageImport,
			Message: fmt.Sprintf("calltype ID not found: %s/%s", dbSpecies, dbCalltype),
		}
	}

	subtypeID, err := utils.GenerateLongID()
	if err != nil {
		return &ImportSegmentError{
			File: filepath.Base(sf.DataPath), Stage: utils.StageImport,
			Message: fmt.Sprintf("failed to generate label_subtype ID: %v", err),
		}
	}

	_, err = tx.ExecContext(ctx, `
		INSERT INTO label_subtype (id, label_id, calltype_id, filter_id, certainty, created_at, last_modified, active)
		VALUES (?, ?, ?, ?, ?, now(), now(), true)
	`, subtypeID, labelID, calltypeID, filterID, label.Certainty)
	if err != nil {
		return &ImportSegmentError{
			File: filepath.Base(sf.DataPath), Stage: utils.StageImport,
			Message: fmt.Sprintf("failed to insert label_subtype: %v", err),
		}
	}
	return nil
}

// importSegmentsIntoDB performs the transactional import
func importSegmentsIntoDB(
	ctx context.Context,
	database *sql.DB,
	fileIDMap map[string]scannedDataFile,
	scannedFiles []scannedDataFile,
	mapping utils.MappingFile,
	filterIDMap map[string]string,
	speciesIDMap map[string]string,
	calltypeIDMap map[string]map[string]string,
	datasetID string,
	progressHandler func(processed, total int, message string),
) ([]SegmentImport, int, int, []dataFileUpdate, []ImportSegmentError) {
	var importedSegments []SegmentImport
	var errors []ImportSegmentError
	importedLabels := 0
	importedSubtypes := 0
	var fileUpdates []dataFileUpdate

	tx, err := db.BeginLoggedTx(ctx, database, "import_segments")
	if err != nil {
		errors = append(errors, ImportSegmentError{
			Stage:   utils.StageImport,
			Message: fmt.Sprintf("failed to begin transaction: %v", err),
		})
		return nil, 0, 0, nil, errors
	}
	defer tx.Rollback()

	totalFiles := len(fileIDMap)
	processedFiles := 0

	for _, sf := range fileIDMap {
		if sf.FileID == "" {
			continue
		}

		processedFiles++
		if progressHandler != nil {
			progressHandler(processedFiles, totalFiles, filepath.Base(sf.DataPath))
		}

		fileUpdate := dataFileUpdate{
			DataPath: sf.DataPath,
			WavHash:  sf.WavHash,
			LabelIDs: make(map[int]map[int]string),
		}

		for segIdx, seg := range sf.Segments {
			segImp, labelIDs, subtypes, segErrs := importSegment(ctx, tx, seg, segIdx, sf, datasetID, mapping, filterIDMap, speciesIDMap, calltypeIDMap)
			errors = append(errors, segErrs...)
			importedSubtypes += subtypes

			if len(segImp.Labels) == 0 {
				// Delete orphaned segment (no labels succeeded)
				if _, err := tx.ExecContext(ctx, `DELETE FROM segment WHERE id = ?`, segImp.SegmentID); err != nil {
					errors = append(errors, ImportSegmentError{
						File: filepath.Base(sf.DataPath), Stage: utils.StageImport,
						Message: fmt.Sprintf("failed to delete orphaned segment: %v", err),
					})
				}
			} else {
				importedSegments = append(importedSegments, segImp)
				importedLabels += len(labelIDs)
				fileUpdate.LabelIDs[segIdx] = labelIDs
			}
		}

		fileUpdates = append(fileUpdates, fileUpdate)
	}

	if err := tx.Commit(); err != nil {
		errors = append(errors, ImportSegmentError{
			Stage:   utils.StageImport,
			Message: fmt.Sprintf("failed to commit transaction: %v", err),
		})
		return nil, 0, 0, nil, errors
	}

	return importedSegments, importedLabels, importedSubtypes, fileUpdates, errors
}

// importSegment inserts a single segment and its labels into the DB.
func importSegment(
	ctx context.Context,
	tx *db.LoggedTx,
	seg *utils.Segment,
	segIdx int,
	sf scannedDataFile,
	datasetID string,
	mapping utils.MappingFile,
	filterIDMap map[string]string,
	speciesIDMap map[string]string,
	calltypeIDMap map[string]map[string]string,
) (SegmentImport, map[int]string, int, []ImportSegmentError) {
	var errors []ImportSegmentError

	if seg.StartTime >= seg.EndTime {
		errors = append(errors, ImportSegmentError{
			File: filepath.Base(sf.DataPath), Stage: utils.StageImport,
			Message: fmt.Sprintf("invalid segment bounds: start=%.2f >= end=%.2f", seg.StartTime, seg.EndTime),
		})
		return SegmentImport{}, nil, 0, errors
	}

	if seg.EndTime > sf.Duration {
		errors = append(errors, ImportSegmentError{
			File: filepath.Base(sf.DataPath), Stage: utils.StageImport,
			Message: fmt.Sprintf("segment end time (%.2f) exceeds file duration (%.2f)", seg.EndTime, sf.Duration),
		})
		return SegmentImport{}, nil, 0, errors
	}

	segmentID, err := utils.GenerateLongID()
	if err != nil {
		errors = append(errors, ImportSegmentError{
			File: filepath.Base(sf.DataPath), Stage: utils.StageImport,
			Message: fmt.Sprintf("failed to generate segment ID: %v", err),
		})
		return SegmentImport{}, nil, 0, errors
	}

	_, err = tx.ExecContext(ctx, `
		INSERT INTO segment (id, file_id, dataset_id, start_time, end_time, freq_low, freq_high, created_at, last_modified, active)
		VALUES (?, ?, ?, ?, ?, ?, ?, now(), now(), true)
	`, segmentID, sf.FileID, datasetID, seg.StartTime, seg.EndTime, seg.FreqLow, seg.FreqHigh)
	if err != nil {
		errors = append(errors, ImportSegmentError{
			File: filepath.Base(sf.DataPath), Stage: utils.StageImport,
			Message: fmt.Sprintf("failed to insert segment: %v", err),
		})
		return SegmentImport{}, nil, 0, errors
	}

	segImport := SegmentImport{
		SegmentID: segmentID,
		FileName:  filepath.Base(sf.WavPath),
		StartTime: seg.StartTime,
		EndTime:   seg.EndTime,
		FreqLow:   seg.FreqLow,
		FreqHigh:  seg.FreqHigh,
		Labels:    make([]LabelImport, 0),
	}
	labelIDs := make(map[int]string)
	var subtypesImported int

	for labelIdx, label := range seg.Labels {
		result := importSingleLabel(ctx, tx, label, segmentID, segIdx, labelIdx, sf, mapping, filterIDMap, speciesIDMap, calltypeIDMap)
		if result.hasError {
			errors = append(errors, result.err)
			continue
		}
		labelIDs[labelIdx] = result.labelID
		segImport.Labels = append(segImport.Labels, result.labelImport)
		subtypesImported += result.subtypesImported
	}

	return segImport, labelIDs, subtypesImported, errors
}

// countTotalSegments counts total segments from validated files
func countTotalSegments(fileIDMap map[string]scannedDataFile) int {
	count := 0
	for _, sf := range fileIDMap {
		count += len(sf.Segments)
	}
	return count
}

// writeIDsToDataFiles writes skraak_hash and skraak_label_ids back to .data files
func writeIDsToDataFiles(fileUpdates []dataFileUpdate) []ImportSegmentError {
	var errors []ImportSegmentError

	for _, fu := range fileUpdates {
		// Parse the .data file
		df, err := utils.ParseDataFile(fu.DataPath)
		if err != nil {
			errors = append(errors, ImportSegmentError{
				File:    filepath.Base(fu.DataPath),
				Stage:   utils.StageImport,
				Message: fmt.Sprintf("failed to re-parse .data file for writing: %v", err),
			})
			continue
		}

		// Write skraak_hash to metadata
		if df.Meta.Extra == nil {
			df.Meta.Extra = make(map[string]any)
		}
		df.Meta.Extra["skraak_hash"] = fu.WavHash

		// Write skraak_label_id to each label
		for segIdx, labelIDs := range fu.LabelIDs {
			if segIdx >= len(df.Segments) {
				continue
			}
			seg := df.Segments[segIdx]
			for labelIdx, labelID := range labelIDs {
				if labelIdx >= len(seg.Labels) {
					continue
				}
				label := seg.Labels[labelIdx]
				if label.Extra == nil {
					label.Extra = make(map[string]any)
				}
				label.Extra["skraak_label_id"] = labelID
			}
		}

		// Write the updated .data file
		if err := df.Write(fu.DataPath); err != nil {
			errors = append(errors, ImportSegmentError{
				File:    filepath.Base(fu.DataPath),
				Stage:   utils.StageImport,
				Message: fmt.Sprintf("failed to write updated .data file: %v", err),
			})
			continue
		}
	}

	return errors
}