package tools
import (
"context"
"database/sql"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"skraak/db"
)
type ExportDatasetInput struct {
DatasetID string `json:"dataset_id"`
Output string `json:"output"`
DryRun bool `json:"dry_run"`
Force bool `json:"force"`
}
type ExportDatasetOutput struct {
DatasetID string `json:"dataset_id"`
DatasetName string `json:"dataset_name"`
OutputPath string `json:"output_path"`
RowCounts map[string]int64 `json:"row_counts"`
FileSizeMB float64 `json:"file_size_mb,omitempty"`
DryRun bool `json:"dry_run"`
Message string `json:"message"`
}
type TableRelationship struct {
Table string Relation string FilterCol string ViaTable string }
var datasetTables = []TableRelationship{
{Table: "dataset", Relation: "owned", FilterCol: "id"},
{Table: "location", Relation: "owned", FilterCol: "dataset_id"},
{Table: "cluster", Relation: "owned", FilterCol: "dataset_id"},
{Table: "segment", Relation: "owned", FilterCol: "dataset_id"},
{Table: "file_dataset", Relation: "owned", FilterCol: "dataset_id"},
{Table: "file", Relation: "owned-via", FilterCol: "cluster_id", ViaTable: "cluster"},
{Table: "moth_metadata", Relation: "owned-via", FilterCol: "file_id", ViaTable: "file"},
{Table: "file_metadata", Relation: "owned-via", FilterCol: "file_id", ViaTable: "file"},
{Table: "label_metadata", Relation: "owned-via", FilterCol: "label_id", ViaTable: "label"},
{Table: "label", Relation: "owned-via", FilterCol: "segment_id", ViaTable: "segment"},
{Table: "label_subtype", Relation: "owned-via", FilterCol: "label_id", ViaTable: "label"},
{Table: "ebird_taxonomy", Relation: "copy"},
{Table: "species", Relation: "copy"},
{Table: "call_type", Relation: "copy"},
{Table: "cyclic_recording_pattern", Relation: "copy"},
{Table: "filter", Relation: "copy"},
}
func ExportDataset(
ctx context.Context,
input ExportDatasetInput,
) (ExportDatasetOutput, error) {
var output ExportDatasetOutput
output.DatasetID = input.DatasetID
output.OutputPath = input.Output
output.DryRun = input.DryRun
output.RowCounts = make(map[string]int64)
sourceDB, err := db.OpenReadOnlyDB(dbPath)
if err != nil {
return output, fmt.Errorf("failed to open source database: %w", err)
}
datasetName, err := verifyExportDataset(ctx, sourceDB, input)
if err != nil {
sourceDB.Close()
return output, err
}
output.DatasetName = datasetName
if err := checkOutputFile(input); err != nil {
sourceDB.Close()
return output, err
}
orderedTables, err := getOrderedTableManifest(sourceDB)
if err != nil {
sourceDB.Close()
return output, err
}
if err := countAllTableRows(ctx, sourceDB, orderedTables, input.DatasetID, &output); err != nil {
sourceDB.Close()
return output, err
}
if input.DryRun {
sourceDB.Close()
output.Message = fmt.Sprintf("Would export dataset '%s' (%s)", datasetName, input.DatasetID)
return output, nil
}
sourceDB.Close()
if err := createOutputDir(input.Output); err != nil {
return output, err
}
outputDB, err := createOutputDatabase(input.Output)
if err != nil {
return output, fmt.Errorf("failed to create output database: %w", err)
}
if err := copyDataToOutput(ctx, outputDB, orderedTables, input.DatasetID); err != nil {
return output, err
}
if _, err := outputDB.ExecContext(ctx, "DETACH source"); err != nil {
return output, fmt.Errorf("failed to detach source database: %w", err)
}
outputDB.Close()
if info, err := os.Stat(input.Output); err == nil {
output.FileSizeMB = float64(info.Size()) / 1024 / 1024
}
eventLogPath := input.Output + ".events.jsonl"
eventFile, err := os.Create(eventLogPath)
if err != nil {
return output, fmt.Errorf("failed to create event log file: %w", err)
}
if err := eventFile.Close(); err != nil {
return output, fmt.Errorf("failed to close event log file: %w", err)
}
output.Message = fmt.Sprintf("Successfully exported dataset '%s' (%s) to %s",
datasetName, input.DatasetID, input.Output)
return output, nil
}
func checkOutputFile(input ExportDatasetInput) error {
if input.DryRun {
return nil
}
if _, err := os.Stat(input.Output); err == nil && !input.Force {
return fmt.Errorf("output file exists: %s (use --force to overwrite)", input.Output)
}
return nil
}
func verifyExportDataset(ctx context.Context, sourceDB *sql.DB, input ExportDatasetInput) (string, error) {
var datasetName, datasetType string
err := sourceDB.QueryRowContext(ctx,
"SELECT name, type FROM dataset WHERE id = ? AND active = true",
input.DatasetID,
).Scan(&datasetName, &datasetType)
if err != nil {
return "", fmt.Errorf("dataset not found: %s", input.DatasetID)
}
if datasetType != "structured" {
return "", fmt.Errorf("cannot export dataset of type '%s': only structured datasets are supported", datasetType)
}
return datasetName, nil
}
func getOrderedTableManifest(sourceDB *sql.DB) ([]TableRelationship, error) {
fkOrder, err := db.GetFKOrder(sourceDB)
if err != nil {
return nil, fmt.Errorf("failed to compute table order: %w", err)
}
return orderByFKDependency(datasetTables, fkOrder), nil
}
func countAllTableRows(ctx context.Context, sourceDB *sql.DB, tables []TableRelationship, datasetID string, output *ExportDatasetOutput) error {
for _, tr := range tables {
count, err := countTableRows(ctx, sourceDB, tr, datasetID)
if err != nil {
return fmt.Errorf("failed to count rows in %s: %w", tr.Table, err)
}
if count > 0 {
output.RowCounts[tr.Table] = count
}
}
return nil
}
func createOutputDir(outputPath string) error {
outputDir := filepath.Dir(outputPath)
if outputDir != "" && outputDir != "." {
if err := os.MkdirAll(outputDir, 0755); err != nil {
return fmt.Errorf("failed to create output directory: %w", err)
}
}
return nil
}
func copyDataToOutput(ctx context.Context, outputDB *sql.DB, tables []TableRelationship, datasetID string) error {
_, err := outputDB.ExecContext(ctx, fmt.Sprintf("ATTACH '%s' AS source", dbPath))
if err != nil {
return fmt.Errorf("failed to attach source database: %w", err)
}
for _, tr := range tables {
if tr.Relation == "copy" {
err = copyTableAsIs(ctx, outputDB, tr.Table)
} else {
err = copyTableData(ctx, outputDB, tr, datasetID)
}
if err != nil {
return fmt.Errorf("failed to copy %s: %w", tr.Table, err)
}
}
return nil
}
func createOutputDatabase(outputPath string) (*sql.DB, error) {
os.Remove(outputPath)
connStr := outputPath + "?access_mode=read_write"
database, err := sql.Open("duckdb", connStr)
if err != nil {
return nil, fmt.Errorf("failed to create output database: %w", err)
}
schemaSQL, err := db.ReadSchemaSQL()
if err != nil {
database.Close()
return nil, fmt.Errorf("failed to read schema: %w", err)
}
statements := db.ExtractDDLStatements(schemaSQL)
for _, stmt := range statements {
if stmt.Type == "CREATE_TABLE_AS" {
continue
}
if _, err := database.Exec(stmt.SQL); err != nil {
if !strings.Contains(err.Error(), "already exists") {
database.Close()
return nil, fmt.Errorf("failed to execute DDL for %s: %w", stmt.TableName, err)
}
}
}
return database, nil
}
func copyTableAsIs(ctx context.Context, outputDB *sql.DB, table string) error {
query := fmt.Sprintf("INSERT INTO %s SELECT * FROM source.%s", table, table)
_, err := outputDB.ExecContext(ctx, query)
return err
}
func copyTableData(ctx context.Context, outputDB *sql.DB, tr TableRelationship, datasetID string) error {
var query string
switch tr.Relation {
case "owned":
if tr.Table == "dataset" {
query = fmt.Sprintf("INSERT INTO %s SELECT * FROM source.%s WHERE id = ?", tr.Table, tr.Table)
} else {
query = fmt.Sprintf("INSERT INTO %s SELECT * FROM source.%s WHERE dataset_id = ?", tr.Table, tr.Table)
}
case "owned-via":
query = buildOwnedViaQuery(tr, datasetID)
default:
return fmt.Errorf("unknown relation type: %s", tr.Relation)
}
_, err := outputDB.ExecContext(ctx, query, datasetID)
return err
}
func buildOwnedViaQuery(tr TableRelationship, datasetID string) string {
switch tr.ViaTable {
case "cluster":
return fmt.Sprintf(`INSERT INTO %s SELECT * FROM source.%s
WHERE %s IN (SELECT id FROM source.cluster WHERE dataset_id = ?)`,
tr.Table, tr.Table, tr.FilterCol)
case "file":
return fmt.Sprintf(`INSERT INTO %s SELECT * FROM source.%s
WHERE %s IN (SELECT id FROM source.file WHERE cluster_id IN
(SELECT id FROM source.cluster WHERE dataset_id = ?))`,
tr.Table, tr.Table, tr.FilterCol)
case "segment":
return fmt.Sprintf(`INSERT INTO %s SELECT * FROM source.%s
WHERE %s IN (SELECT id FROM source.segment WHERE dataset_id = ?)`,
tr.Table, tr.Table, tr.FilterCol)
case "label":
return fmt.Sprintf(`INSERT INTO %s SELECT * FROM source.%s
WHERE %s IN (SELECT id FROM source.label WHERE segment_id IN
(SELECT id FROM source.segment WHERE dataset_id = ?))`,
tr.Table, tr.Table, tr.FilterCol)
default:
return fmt.Sprintf(`INSERT INTO %s SELECT * FROM source.%s WHERE %s IN
(SELECT id FROM source.%s WHERE dataset_id = ?)`,
tr.Table, tr.Table, tr.FilterCol, tr.ViaTable)
}
}
func countTableRows(ctx context.Context, db *sql.DB, tr TableRelationship, datasetID string) (int64, error) {
var query string
switch tr.Relation {
case "copy":
query = "SELECT COUNT(*) FROM " + tr.Table
case "owned":
if tr.Table == "dataset" {
query = "SELECT COUNT(*) FROM " + tr.Table + " WHERE id = ?"
} else {
query = "SELECT COUNT(*) FROM " + tr.Table + " WHERE dataset_id = ?"
}
case "owned-via":
query = buildCountOwnedViaQuery(tr)
default:
return 0, nil
}
var count int64
err := db.QueryRowContext(ctx, query, datasetID).Scan(&count)
return count, err
}
func buildCountOwnedViaQuery(tr TableRelationship) string {
switch tr.ViaTable {
case "cluster":
return fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE %s IN
(SELECT id FROM cluster WHERE dataset_id = ?)`, tr.Table, tr.FilterCol)
case "file":
return fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE %s IN
(SELECT id FROM file WHERE cluster_id IN
(SELECT id FROM cluster WHERE dataset_id = ?))`, tr.Table, tr.FilterCol)
case "segment":
return fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE %s IN
(SELECT id FROM segment WHERE dataset_id = ?)`, tr.Table, tr.FilterCol)
case "label":
return fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE %s IN
(SELECT id FROM label WHERE segment_id IN
(SELECT id FROM segment WHERE dataset_id = ?))`, tr.Table, tr.FilterCol)
default:
return fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE %s IN
(SELECT id FROM %s WHERE dataset_id = ?)`, tr.Table, tr.FilterCol, tr.ViaTable)
}
}
func orderByFKDependency(tables []TableRelationship, fkOrder []string) []TableRelationship {
orderMap := make(map[string]int)
for i, table := range fkOrder {
orderMap[table] = i
}
sorted := make([]TableRelationship, len(tables))
copy(sorted, tables)
sort.Slice(sorted, func(i, j int) bool {
ti, tj := sorted[i], sorted[j]
oi := orderMap[ti.Table]
oj := orderMap[tj.Table]
return oi < oj
})
return sorted
}