package db

import (
	"database/sql"
	"embed"
	"fmt"
	"slices"
	"strings"
)

//go:embed schema.sql
var schemaFS embed.FS

// ReadSchemaSQL reads the schema.sql file
// Uses embedded file for distributed binaries
func ReadSchemaSQL() (string, error) {
	data, err := schemaFS.ReadFile("schema.sql")
	if err != nil {
		return "", fmt.Errorf("failed to read schema.sql: %w", err)
	}
	return string(data), nil
}

// DDLStatement represents a parsed DDL statement with metadata
type DDLStatement struct {
	SQL       string
	Type      string // "CREATE_TYPE", "CREATE_TABLE", "CREATE_INDEX", "CREATE_TABLE_AS"
	TableName string // for CREATE TABLE and CREATE INDEX
}

// ExtractDDLStatements splits schema SQL into executable DDL statements
// Returns statements in order: types, tables, indexes
// Handles CREATE TABLE ... AS SELECT specially (marked but included)
func ExtractDDLStatements(schemaSQL string) []DDLStatement {
	var statements []DDLStatement

	// Split by semicolon, but handle the CREATE TABLE AS SELECT case
	lines := strings.Split(schemaSQL, "\n")
	var currentStmt strings.Builder

	for _, line := range lines {
		trimmed := strings.TrimSpace(line)

		// Skip empty lines and comments
		if trimmed == "" || strings.HasPrefix(trimmed, "--") {
			continue
		}

		currentStmt.WriteString(line)
		currentStmt.WriteString("\n")

		// Statement ends at semicolon
		if strings.HasSuffix(trimmed, ";") {
			sql := strings.TrimSpace(currentStmt.String())
			if sql != "" {
				stmt := parseDDLStatement(sql)
				statements = append(statements, stmt)
			}
			currentStmt.Reset()
		}
	}

	// Handle any remaining statement without trailing semicolon
	if currentStmt.Len() > 0 {
		sql := strings.TrimSpace(currentStmt.String())
		if sql != "" && strings.HasSuffix(sql, ";") {
			stmt := parseDDLStatement(sql)
			statements = append(statements, stmt)
		}
	}

	return statements
}

// parseDDLStatement determines the type and table name of a DDL statement
func parseDDLStatement(sql string) DDLStatement {
	upper := strings.ToUpper(sql)

	switch {
	case strings.HasPrefix(upper, "CREATE TYPE"):
		return DDLStatement{SQL: sql, Type: "CREATE_TYPE", TableName: ""}

	case strings.HasPrefix(upper, "CREATE TABLE"):
		tableName := extractTableName(sql)
		return DDLStatement{SQL: sql, Type: "CREATE_TABLE", TableName: tableName}

	case strings.HasPrefix(upper, "CREATE INDEX") || strings.HasPrefix(upper, "CREATE UNIQUE INDEX"):
		indexName := extractIndexName(sql)
		return DDLStatement{SQL: sql, Type: "CREATE_INDEX", TableName: indexName}

	default:
		return DDLStatement{SQL: sql, Type: "UNKNOWN", TableName: ""}
	}
}

// extractTableName extracts table name from CREATE TABLE statement
func extractTableName(sql string) string {
	// CREATE TABLE name (
	// or CREATE TABLE name(
	upper := strings.ToUpper(sql)

	// Find "CREATE TABLE"
	idx := strings.Index(upper, "CREATE TABLE")
	if idx == -1 {
		return ""
	}

	// Move past "CREATE TABLE"
	rest := sql[idx+12:]
	rest = strings.TrimSpace(rest)

	// Find opening parenthesis or end
	endIdx := strings.Index(rest, "(")
	if endIdx == -1 {
		endIdx = len(rest)
	}

	name := strings.TrimSpace(rest[:endIdx])
	return name
}

// extractIndexName extracts index name from CREATE INDEX statement
func extractIndexName(sql string) string {
	upper := strings.ToUpper(sql)

	// Handle "CREATE UNIQUE INDEX" or "CREATE INDEX"
	var rest string
	if strings.HasPrefix(upper, "CREATE UNIQUE INDEX") {
		rest = sql[19:]
	} else if strings.HasPrefix(upper, "CREATE INDEX") {
		rest = sql[12:]
	} else {
		return ""
	}

	rest = strings.TrimSpace(rest)

	// Find " ON "
	onIdx := strings.Index(strings.ToUpper(rest), " ON ")
	if onIdx == -1 {
		return ""
	}

	name := strings.TrimSpace(rest[:onIdx])
	return name
}

// FKRelation represents a foreign key relationship between tables
type FKRelation struct {
	Table        string // table that has the FK
	Column       string // FK column
	ForeignTable string // referenced table
}

// GetFKOrder computes the order tables should be copied based on FK dependencies
// Tables with no FKs come first, then dependent tables in topological order
func GetFKOrder(db *sql.DB) ([]string, error) {
	dependsOnMe, tables, err := buildFKDependencyGraph(db)
	if err != nil {
		return nil, err
	}

	if err := collectAllTables(db, tables); err != nil {
		return nil, err
	}

	return topologicalSort(tables, dependsOnMe), nil
}

// buildFKDependencyGraph queries FK constraints and builds a reverse dependency graph.
// Returns dependsOnMe (referenced table -> list of tables that depend on it) and
// the set of tables seen.
func buildFKDependencyGraph(db *sql.DB) (map[string][]string, map[string]bool, error) {
	dependsOnMe := make(map[string][]string)
	tables := make(map[string]bool)

	rows, err := db.Query(`
		SELECT table_name, referenced_table
		FROM duckdb_constraints()
		WHERE constraint_type = 'FOREIGN KEY'
		AND referenced_table IS NOT NULL
	`)
	if err != nil {
		return nil, nil, fmt.Errorf("failed to query FK relationships: %w", err)
	}
	defer rows.Close()

	for rows.Next() {
		var table, foreignTable string
		if err := rows.Scan(&table, &foreignTable); err != nil {
			return nil, nil, fmt.Errorf("failed to scan FK row: %w", err)
		}
		tables[table] = true
		tables[foreignTable] = true
		dependsOnMe[foreignTable] = append(dependsOnMe[foreignTable], table)
	}

	if err := rows.Err(); err != nil {
		return nil, nil, fmt.Errorf("error iterating FK rows: %w", err)
	}
	return dependsOnMe, tables, nil
}

// collectAllTables adds all base tables from the database schema to the tables set.
func collectAllTables(db *sql.DB, tables map[string]bool) error {
	rows, err := db.Query(`
		SELECT table_name 
		FROM information_schema.tables 
		WHERE table_schema = 'main' 
		AND table_type = 'BASE TABLE'
	`)
	if err != nil {
		return fmt.Errorf("failed to query tables: %w", err)
	}
	defer rows.Close()

	for rows.Next() {
		var name string
		if err := rows.Scan(&name); err != nil {
			return fmt.Errorf("failed to scan table name: %w", err)
		}
		tables[name] = true
	}
	return rows.Err()
}

// topologicalSort orders tables so dependencies come first (Kahn's algorithm).
// Tables in cycles are appended at the end.
func topologicalSort(tables map[string]bool, dependsOnMe map[string][]string) []string {
	fkCount := make(map[string]int, len(tables))
	for table := range tables {
		fkCount[table] = 0
	}
	for _, dependents := range dependsOnMe {
		for _, dependent := range dependents {
			fkCount[dependent]++
		}
	}

	var queue []string
	for table := range tables {
		if fkCount[table] == 0 {
			queue = append(queue, table)
		}
	}

	var result []string
	for len(queue) > 0 {
		current := queue[0]
		queue = queue[1:]
		result = append(result, current)

		for _, dependent := range dependsOnMe[current] {
			fkCount[dependent]--
			if fkCount[dependent] == 0 {
				queue = append(queue, dependent)
			}
		}
	}

	// Handle cycles: append remaining tables
	if len(result) != len(tables) {
		for table := range tables {
			if !slices.Contains(result, table) {
				result = append(result, table)
			}
		}
	}
	return result
}