package db
import (
"database/sql"
"embed"
"fmt"
"slices"
"strings"
)
var schemaFS embed.FS
func ReadSchemaSQL() (string, error) {
data, err := schemaFS.ReadFile("schema.sql")
if err != nil {
return "", fmt.Errorf("failed to read schema.sql: %w", err)
}
return string(data), nil
}
type DDLStatement struct {
SQL string
Type string TableName string }
func ExtractDDLStatements(schemaSQL string) []DDLStatement {
var statements []DDLStatement
lines := strings.Split(schemaSQL, "\n")
var currentStmt strings.Builder
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == "" || strings.HasPrefix(trimmed, "--") {
continue
}
currentStmt.WriteString(line)
currentStmt.WriteString("\n")
if strings.HasSuffix(trimmed, ";") {
sql := strings.TrimSpace(currentStmt.String())
if sql != "" {
stmt := parseDDLStatement(sql)
statements = append(statements, stmt)
}
currentStmt.Reset()
}
}
if currentStmt.Len() > 0 {
sql := strings.TrimSpace(currentStmt.String())
if sql != "" && strings.HasSuffix(sql, ";") {
stmt := parseDDLStatement(sql)
statements = append(statements, stmt)
}
}
return statements
}
func parseDDLStatement(sql string) DDLStatement {
upper := strings.ToUpper(sql)
switch {
case strings.HasPrefix(upper, "CREATE TYPE"):
return DDLStatement{SQL: sql, Type: "CREATE_TYPE", TableName: ""}
case strings.HasPrefix(upper, "CREATE TABLE"):
tableName := extractTableName(sql)
return DDLStatement{SQL: sql, Type: "CREATE_TABLE", TableName: tableName}
case strings.HasPrefix(upper, "CREATE INDEX") || strings.HasPrefix(upper, "CREATE UNIQUE INDEX"):
indexName := extractIndexName(sql)
return DDLStatement{SQL: sql, Type: "CREATE_INDEX", TableName: indexName}
default:
return DDLStatement{SQL: sql, Type: "UNKNOWN", TableName: ""}
}
}
func extractTableName(sql string) string {
upper := strings.ToUpper(sql)
idx := strings.Index(upper, "CREATE TABLE")
if idx == -1 {
return ""
}
rest := sql[idx+12:]
rest = strings.TrimSpace(rest)
endIdx := strings.Index(rest, "(")
if endIdx == -1 {
endIdx = len(rest)
}
name := strings.TrimSpace(rest[:endIdx])
return name
}
func extractIndexName(sql string) string {
upper := strings.ToUpper(sql)
var rest string
if strings.HasPrefix(upper, "CREATE UNIQUE INDEX") {
rest = sql[19:]
} else if strings.HasPrefix(upper, "CREATE INDEX") {
rest = sql[12:]
} else {
return ""
}
rest = strings.TrimSpace(rest)
onIdx := strings.Index(strings.ToUpper(rest), " ON ")
if onIdx == -1 {
return ""
}
name := strings.TrimSpace(rest[:onIdx])
return name
}
type FKRelation struct {
Table string Column string ForeignTable string }
func GetFKOrder(db *sql.DB) ([]string, error) {
dependsOnMe, tables, err := buildFKDependencyGraph(db)
if err != nil {
return nil, err
}
if err := collectAllTables(db, tables); err != nil {
return nil, err
}
return topologicalSort(tables, dependsOnMe), nil
}
func buildFKDependencyGraph(db *sql.DB) (map[string][]string, map[string]bool, error) {
dependsOnMe := make(map[string][]string)
tables := make(map[string]bool)
rows, err := db.Query(`
SELECT table_name, referenced_table
FROM duckdb_constraints()
WHERE constraint_type = 'FOREIGN KEY'
AND referenced_table IS NOT NULL
`)
if err != nil {
return nil, nil, fmt.Errorf("failed to query FK relationships: %w", err)
}
defer rows.Close()
for rows.Next() {
var table, foreignTable string
if err := rows.Scan(&table, &foreignTable); err != nil {
return nil, nil, fmt.Errorf("failed to scan FK row: %w", err)
}
tables[table] = true
tables[foreignTable] = true
dependsOnMe[foreignTable] = append(dependsOnMe[foreignTable], table)
}
if err := rows.Err(); err != nil {
return nil, nil, fmt.Errorf("error iterating FK rows: %w", err)
}
return dependsOnMe, tables, nil
}
func collectAllTables(db *sql.DB, tables map[string]bool) error {
rows, err := db.Query(`
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'main'
AND table_type = 'BASE TABLE'
`)
if err != nil {
return fmt.Errorf("failed to query tables: %w", err)
}
defer rows.Close()
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return fmt.Errorf("failed to scan table name: %w", err)
}
tables[name] = true
}
return rows.Err()
}
func topologicalSort(tables map[string]bool, dependsOnMe map[string][]string) []string {
fkCount := make(map[string]int, len(tables))
for table := range tables {
fkCount[table] = 0
}
for _, dependents := range dependsOnMe {
for _, dependent := range dependents {
fkCount[dependent]++
}
}
var queue []string
for table := range tables {
if fkCount[table] == 0 {
queue = append(queue, table)
}
}
var result []string
for len(queue) > 0 {
current := queue[0]
queue = queue[1:]
result = append(result, current)
for _, dependent := range dependsOnMe[current] {
fkCount[dependent]--
if fkCount[dependent] == 0 {
queue = append(queue, dependent)
}
}
}
if len(result) != len(tables) {
for table := range tables {
if !slices.Contains(result, table) {
result = append(result, table)
}
}
}
return result
}