package db
import (
"database/sql"
"fmt"
"strings"
"testing"
_ "github.com/duckdb/duckdb-go/v2"
)
func GetTableRowCount(db *sql.DB, table string) (int64, error) {
var count int64
err := db.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", table)).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to count rows in %s: %w", table, err)
}
return count, nil
}
func TestReadSchemaSQL(t *testing.T) {
schema, err := ReadSchemaSQL()
if err != nil {
t.Fatalf("ReadSchemaSQL() error = %v", err)
}
if !strings.Contains(schema, "CREATE TABLE dataset") {
t.Error("schema missing CREATE TABLE dataset")
}
if !strings.Contains(schema, "CREATE TYPE dataset_type") {
t.Error("schema missing CREATE TYPE dataset_type")
}
if !strings.Contains(schema, "CREATE INDEX") {
t.Error("schema missing CREATE INDEX")
}
}
func TestExtractDDLStatements(t *testing.T) {
schema, err := ReadSchemaSQL()
if err != nil {
t.Fatalf("ReadSchemaSQL() error = %v", err)
}
statements := ExtractDDLStatements(schema)
if len(statements) == 0 {
t.Fatal("ExtractDDLStatements returned no statements")
}
typeCounts := make(map[string]int)
tableNames := make(map[string]bool)
for _, stmt := range statements {
typeCounts[stmt.Type]++
if stmt.TableName != "" {
tableNames[stmt.TableName] = true
}
t.Logf("Statement type=%s table=%s sql=%s", stmt.Type, stmt.TableName, stmt.SQL[:min(50, len(stmt.SQL))])
}
if typeCounts["CREATE_TYPE"] < 2 {
t.Errorf("expected at least 2 CREATE_TYPE statements, got %d", typeCounts["CREATE_TYPE"])
}
if typeCounts["CREATE_TABLE"] < 10 {
t.Errorf("expected at least 10 CREATE_TABLE statements, got %d", typeCounts["CREATE_TABLE"])
}
if typeCounts["CREATE_INDEX"] < 5 {
t.Errorf("expected at least 5 CREATE_INDEX statements, got %d", typeCounts["CREATE_INDEX"])
}
expectedTables := []string{"dataset", "location", "cluster", "file", "segment", "label"}
for _, expected := range expectedTables {
if !tableNames[expected] {
t.Errorf("missing table %s in extracted statements", expected)
}
}
}
func TestExtractDDLStatement_Types(t *testing.T) {
tests := []struct {
name string
sql string
wantType string
wantTable string
}{
{
name: "CREATE TYPE",
sql: "CREATE TYPE dataset_type AS ENUM ('structured', 'unstructured');",
wantType: "CREATE_TYPE",
wantTable: "",
},
{
name: "CREATE TABLE simple",
sql: "CREATE TABLE dataset (id VARCHAR(12) PRIMARY KEY);",
wantType: "CREATE_TABLE",
wantTable: "dataset",
},
{
name: "CREATE TABLE with newlines",
sql: "CREATE TABLE location\n(\n id VARCHAR(12) PRIMARY KEY\n);",
wantType: "CREATE_TABLE",
wantTable: "location",
},
{
name: "CREATE INDEX",
sql: "CREATE INDEX idx_file_location ON file(location_id);",
wantType: "CREATE_INDEX",
wantTable: "idx_file_location",
},
{
name: "CREATE UNIQUE INDEX",
sql: "CREATE UNIQUE INDEX idx_species_label ON species(label);",
wantType: "CREATE_INDEX",
wantTable: "idx_species_label",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
stmt := parseDDLStatement(tt.sql)
if stmt.Type != tt.wantType {
t.Errorf("parseDDLStatement().Type = %v, want %v", stmt.Type, tt.wantType)
}
if stmt.TableName != tt.wantTable {
t.Errorf("parseDDLStatement().TableName = %v, want %v", stmt.TableName, tt.wantTable)
}
})
}
}
func TestExtractTableName(t *testing.T) {
tests := []struct {
name string
sql string
want string
}{
{
name: "simple table",
sql: "CREATE TABLE dataset (id VARCHAR(12) PRIMARY KEY",
want: "dataset",
},
{
name: "table with space before paren",
sql: "CREATE TABLE location (id VARCHAR(12)",
want: "location",
},
{
name: "table with newline",
sql: "CREATE TABLE cluster\n(\n id VARCHAR(12)",
want: "cluster",
},
{
name: "table with no space",
sql: "CREATE TABLE file(id VARCHAR(21)",
want: "file",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractTableName(tt.sql)
if got != tt.want {
t.Errorf("extractTableName() = %v, want %v", got, tt.want)
}
})
}
}
func TestExtractIndexName(t *testing.T) {
tests := []struct {
name string
sql string
want string
}{
{
name: "CREATE INDEX",
sql: "CREATE INDEX idx_file_location ON file(location_id)",
want: "idx_file_location",
},
{
name: "CREATE UNIQUE INDEX",
sql: "CREATE UNIQUE INDEX idx_species_label ON species(label)",
want: "idx_species_label",
},
{
name: "index with spaces",
sql: "CREATE INDEX idx_test ON table_name (column)",
want: "idx_test",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractIndexName(tt.sql)
if got != tt.want {
t.Errorf("extractIndexName() = %v, want %v", got, tt.want)
}
})
}
}
func TestExtractDDLStatements_SkipsComments(t *testing.T) {
schema := `-- This is a comment
CREATE TABLE test (id INT);
-- Another comment
CREATE INDEX idx_test ON test(id);
`
statements := ExtractDDLStatements(schema)
if len(statements) != 2 {
t.Errorf("expected 2 statements, got %d", len(statements))
}
for _, stmt := range statements {
if strings.Contains(stmt.SQL, "--") {
t.Errorf("statement should not contain comments: %s", stmt.SQL)
}
}
}
func TestGetFKOrder(t *testing.T) {
db, err := sql.Open("duckdb", ":memory:")
if err != nil {
t.Fatalf("failed to open database: %v", err)
}
defer db.Close()
schema := `
CREATE TABLE parent (id VARCHAR(12) PRIMARY KEY);
CREATE TABLE child (id VARCHAR(12) PRIMARY KEY, parent_id VARCHAR(12), FOREIGN KEY (parent_id) REFERENCES parent(id));
CREATE TABLE grandchild (id VARCHAR(12) PRIMARY KEY, child_id VARCHAR(12), FOREIGN KEY (child_id) REFERENCES child(id));
CREATE TABLE independent (id VARCHAR(12) PRIMARY KEY);
`
_, err = db.Exec(schema)
if err != nil {
t.Fatalf("failed to create schema: %v", err)
}
order, err := GetFKOrder(db)
if err != nil {
t.Fatalf("GetFKOrder() error = %v", err)
}
orderMap := make(map[string]int)
for i, table := range order {
orderMap[table] = i
}
if orderMap["parent"] >= orderMap["child"] {
t.Error("parent should come before child")
}
if orderMap["child"] >= orderMap["grandchild"] {
t.Error("child should come before grandchild")
}
if _, ok := orderMap["independent"]; !ok {
t.Error("independent table missing from order")
}
}
func TestGetTableRowCount(t *testing.T) {
db, err := sql.Open("duckdb", ":memory:")
if err != nil {
t.Fatalf("failed to open database: %v", err)
}
defer db.Close()
_, err = db.Exec("CREATE TABLE test (id INT)")
if err != nil {
t.Fatalf("failed to create table: %v", err)
}
_, err = db.Exec("INSERT INTO test VALUES (1), (2), (3)")
if err != nil {
t.Fatalf("failed to insert: %v", err)
}
count, err := GetTableRowCount(db, "test")
if err != nil {
t.Fatalf("GetTableRowCount() error = %v", err)
}
if count != 3 {
t.Errorf("GetTableRowCount() = %d, want 3", count)
}
}