package db
import (
"context"
"database/sql"
"fmt"
)
type Querier interface {
QueryRow(query string, args ...any) *sql.Row
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
func GetDatasetType(q Querier, datasetID string) (string, bool, error) {
var datasetType string
err := q.QueryRow("SELECT type FROM dataset WHERE id = ?", datasetID).Scan(&datasetType)
if err == sql.ErrNoRows {
return "", false, nil
}
if err != nil {
return "", false, err
}
return datasetType, true, nil
}
func ValidateDatasetTypeForImport(q Querier, datasetID string) error {
datasetType, exists, err := GetDatasetType(q, datasetID)
if err != nil {
return fmt.Errorf("failed to query dataset type: %w", err)
}
if !exists {
return fmt.Errorf("dataset not found: %s", datasetID)
}
if datasetType != "structured" {
return fmt.Errorf("dataset '%s' is type '%s' - file imports only support 'structured' datasets", datasetID, datasetType)
}
return nil
}
func ValidateDatasetTypeUnstructured(q Querier, datasetID string) error {
datasetType, exists, err := GetDatasetType(q, datasetID)
if err != nil {
return fmt.Errorf("failed to query dataset type: %w", err)
}
if !exists {
return fmt.Errorf("dataset not found: %s", datasetID)
}
if datasetType != "unstructured" {
return fmt.Errorf("dataset '%s' is type '%s' - this command only supports 'unstructured' datasets", datasetID, datasetType)
}
return nil
}
func ValidateLocationBelongsToDataset(q Querier, locationID, datasetID string) error {
var locationDatasetID string
err := q.QueryRow("SELECT dataset_id FROM location WHERE id = ? AND active = true", locationID).Scan(&locationDatasetID)
if err == sql.ErrNoRows {
return fmt.Errorf("location not found or inactive: %s", locationID)
}
if err != nil {
return fmt.Errorf("failed to query location: %w", err)
}
if locationDatasetID != datasetID {
return fmt.Errorf("location %s does not belong to dataset %s", locationID, datasetID)
}
return nil
}
func DatasetExistsAndActive(q Querier, datasetID string) (name string, err error) {
var exists, active bool
err = q.QueryRow(
"SELECT EXISTS(SELECT 1 FROM dataset WHERE id = ?), COALESCE((SELECT active FROM dataset WHERE id = ?), false), COALESCE((SELECT name FROM dataset WHERE id = ?), '')",
datasetID, datasetID, datasetID,
).Scan(&exists, &active, &name)
if err != nil {
return "", fmt.Errorf("failed to verify dataset: %w", err)
}
if !exists {
return "", fmt.Errorf("dataset with ID '%s' does not exist", datasetID)
}
if !active {
return "", fmt.Errorf("dataset '%s' (ID: %s) is not active", name, datasetID)
}
return name, nil
}
func LocationBelongsToDataset(q Querier, locationID, datasetID string) (name string, err error) {
var exists, active bool
var locDatasetID string
err = q.QueryRow(
"SELECT EXISTS(SELECT 1 FROM location WHERE id = ?), COALESCE((SELECT active FROM location WHERE id = ?), false), COALESCE((SELECT name FROM location WHERE id = ?), ''), COALESCE((SELECT dataset_id FROM location WHERE id = ?), '')",
locationID, locationID, locationID, locationID,
).Scan(&exists, &active, &name, &locDatasetID)
if err != nil {
return "", fmt.Errorf("failed to verify location: %w", err)
}
if !exists {
return "", fmt.Errorf("location with ID '%s' does not exist", locationID)
}
if !active {
return "", fmt.Errorf("location '%s' (ID: %s) is not active", name, locationID)
}
if locDatasetID != datasetID {
return "", fmt.Errorf("location '%s' (ID: %s) does not belong to dataset ID '%s'",
name, locationID, locDatasetID)
}
return name, nil
}
func ClusterBelongsToLocation(q Querier, clusterID, locationID string) error {
var exists, active bool
var clusterLocationID string
err := q.QueryRow(
"SELECT EXISTS(SELECT 1 FROM cluster WHERE id = ?), COALESCE((SELECT active FROM cluster WHERE id = ?), false), COALESCE((SELECT location_id FROM cluster WHERE id = ?), '')",
clusterID, clusterID, clusterID,
).Scan(&exists, &active, &clusterLocationID)
if err != nil {
return fmt.Errorf("failed to verify cluster: %w", err)
}
if !exists {
return fmt.Errorf("cluster with ID '%s' does not exist", clusterID)
}
if !active {
return fmt.Errorf("cluster '%s' is not active", clusterID)
}
if clusterLocationID != locationID {
return fmt.Errorf("cluster '%s' does not belong to location '%s'", clusterID, locationID)
}
return nil
}