package tools
import (
"context"
"database/sql"
"fmt"
"skraak/db"
"skraak/utils"
"strings"
)
type ClusterInput struct {
ID *string `json:"id,omitempty"`
DatasetID *string `json:"dataset_id,omitempty"`
LocationID *string `json:"location_id,omitempty"`
Name *string `json:"name,omitempty"`
SampleRate *int `json:"sample_rate,omitempty"`
Path *string `json:"path,omitempty"`
CyclicRecordingPatternID *string `json:"cyclic_recording_pattern_id,omitempty"`
Description *string `json:"description,omitempty"`
}
type ClusterOutput struct {
Cluster db.Cluster `json:"cluster"`
Message string `json:"message"`
}
func CreateOrUpdateCluster(
ctx context.Context,
input ClusterInput,
) (ClusterOutput, error) {
if input.ID != nil && strings.TrimSpace(*input.ID) != "" {
return updateCluster(ctx, input)
}
return createCluster(ctx, input)
}
func validateClusterFields(input ClusterInput) error {
if err := utils.ValidateOptionalStringLength(input.Name, "name", utils.MaxNameLen); err != nil {
return err
}
if err := utils.ValidateOptionalStringLength(input.Description, "description", utils.MaxDescriptionLen); err != nil {
return err
}
if err := utils.ValidateOptionalStringLength(input.Path, "path", utils.MaxPathLen); err != nil {
return err
}
if input.SampleRate != nil {
if err := utils.ValidatePositive(*input.SampleRate, "sample_rate"); err != nil {
return err
}
if err := utils.ValidateSampleRate(*input.SampleRate); err != nil {
return err
}
}
return nil
}
func createCluster(ctx context.Context, input ClusterInput) (ClusterOutput, error) {
var output ClusterOutput
if err := validateCreateClusterFields(input); err != nil {
return output, err
}
if err := validateCreateClusterIDs(input); err != nil {
return output, err
}
if err := validateClusterFields(input); err != nil {
return output, err
}
if err := utils.ValidateOptionalShortID(input.CyclicRecordingPatternID, "cyclic_recording_pattern_id"); err != nil {
return output, err
}
err := db.WithWriteTx(ctx, dbPath, "create_or_update_cluster", func(database *sql.DB, tx *db.LoggedTx) error {
datasetName, locationName, verr := verifyClusterParentRefs(ctx, tx, input)
if verr != nil {
return verr
}
existing, findErr := findExistingClusterInLocation(ctx, tx, *input.LocationID, *input.Name)
if findErr == nil {
output.Cluster = existing
output.Message = fmt.Sprintf("Cluster '%s' already exists in location '%s' (ID: %s) - returning existing cluster", existing.Name, locationName, existing.ID)
return nil }
result, insErr := insertNewCluster(ctx, tx, input, datasetName, locationName)
if insErr != nil {
return insErr
}
output = result
return nil })
return output, err
}
func verifyClusterParentRefs(ctx context.Context, tx *db.LoggedTx, input ClusterInput) (string, string, error) {
datasetName, err := db.DatasetExistsAndActive(tx, *input.DatasetID)
if err != nil {
return "", "", err
}
locationName, err := db.LocationBelongsToDataset(tx, *input.LocationID, *input.DatasetID)
if err != nil {
return "", "", err
}
if input.CyclicRecordingPatternID != nil && strings.TrimSpace(*input.CyclicRecordingPatternID) != "" {
if err := verifyPatternExists(ctx, tx, *input.CyclicRecordingPatternID); err != nil {
return "", "", err
}
}
return datasetName, locationName, nil
}
func insertNewCluster(ctx context.Context, tx *db.LoggedTx, input ClusterInput, datasetName, locationName string) (ClusterOutput, error) {
id, err := utils.GenerateShortID()
if err != nil {
return ClusterOutput{}, fmt.Errorf("failed to generate ID: %w", err)
}
_, err = tx.ExecContext(ctx,
"INSERT INTO cluster (id, dataset_id, location_id, name, sample_rate, cyclic_recording_pattern_id, description, created_at, last_modified, active) VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, TRUE)",
id, *input.DatasetID, *input.LocationID, *input.Name, *input.SampleRate, input.CyclicRecordingPatternID, input.Description,
)
if err != nil {
return ClusterOutput{}, fmt.Errorf("failed to create cluster: %w", err)
}
cluster, err := fetchClusterByID(ctx, tx, id)
if err != nil {
return ClusterOutput{}, fmt.Errorf("failed to fetch created cluster: %w", err)
}
return ClusterOutput{
Cluster: cluster,
Message: fmt.Sprintf("Successfully created cluster '%s' with ID %s in location '%s' at dataset '%s' (sample rate: %d Hz)",
cluster.Name, cluster.ID, locationName, datasetName, cluster.SampleRate),
}, nil
}
func validateCreateClusterFields(input ClusterInput) error {
if input.DatasetID == nil || strings.TrimSpace(*input.DatasetID) == "" {
return fmt.Errorf("dataset_id is required when creating a cluster")
}
if input.LocationID == nil || strings.TrimSpace(*input.LocationID) == "" {
return fmt.Errorf("location_id is required when creating a cluster")
}
if input.Name == nil || strings.TrimSpace(*input.Name) == "" {
return fmt.Errorf("name is required when creating a cluster")
}
if input.SampleRate == nil {
return fmt.Errorf("sample_rate is required when creating a cluster")
}
return nil
}
func validateCreateClusterIDs(input ClusterInput) error {
if err := utils.ValidateShortID(*input.DatasetID, "dataset_id"); err != nil {
return err
}
return utils.ValidateShortID(*input.LocationID, "location_id")
}
func verifyPatternExists(ctx context.Context, tx *db.LoggedTx, patternID string) error {
var exists, active bool
err := tx.QueryRowContext(ctx,
"SELECT EXISTS(SELECT 1 FROM cyclic_recording_pattern WHERE id = ?), COALESCE((SELECT active FROM cyclic_recording_pattern WHERE id = ?), false)",
patternID, patternID,
).Scan(&exists, &active)
if err != nil {
return fmt.Errorf("failed to verify cyclic recording pattern: %w", err)
}
if !exists {
return fmt.Errorf("cyclic recording pattern with ID '%s' does not exist", patternID)
}
if !active {
return fmt.Errorf("cyclic recording pattern with ID '%s' is not active", patternID)
}
return nil
}
func findExistingClusterInLocation(ctx context.Context, tx *db.LoggedTx, locationID, name string) (db.Cluster, error) {
var existingID string
err := tx.QueryRowContext(ctx,
"SELECT id FROM cluster WHERE location_id = ? AND name = ? AND active = true",
locationID, name,
).Scan(&existingID)
if err != nil {
return db.Cluster{}, err
}
return fetchClusterByID(ctx, tx, existingID)
}
func fetchClusterByID(ctx context.Context, tx *db.LoggedTx, id string) (db.Cluster, error) {
var c db.Cluster
err := tx.QueryRowContext(ctx,
"SELECT id, dataset_id, location_id, name, description, created_at, last_modified, active, cyclic_recording_pattern_id, sample_rate FROM cluster WHERE id = ?",
id,
).Scan(&c.ID, &c.DatasetID, &c.LocationID, &c.Name, &c.Description,
&c.CreatedAt, &c.LastModified, &c.Active, &c.CyclicRecordingPatternID, &c.SampleRate)
return c, err
}
func validateClusterActive(database *sql.DB, clusterID string) error {
var exists, active bool
err := database.QueryRow(
"SELECT EXISTS(SELECT 1 FROM cluster WHERE id = ?), COALESCE((SELECT active FROM cluster WHERE id = ?), false)",
clusterID, clusterID,
).Scan(&exists, &active)
if err != nil {
return fmt.Errorf("failed to query cluster: %w", err)
}
if !exists {
return fmt.Errorf("cluster not found: %s", clusterID)
}
if !active {
return fmt.Errorf("cluster '%s' is not active (cannot update inactive clusters)", clusterID)
}
return nil
}
func validateClusterCyclicPattern(database *sql.DB, input ClusterInput) error {
if input.CyclicRecordingPatternID == nil {
return nil
}
trimmed := strings.TrimSpace(*input.CyclicRecordingPatternID)
if trimmed == "" {
return nil
}
return validateCyclicPattern(database, trimmed)
}
func validateCyclicPattern(database *sql.DB, patternID string) error {
var exists, active bool
err := database.QueryRow(
"SELECT EXISTS(SELECT 1 FROM cyclic_recording_pattern WHERE id = ?), COALESCE((SELECT active FROM cyclic_recording_pattern WHERE id = ?), false)",
patternID, patternID,
).Scan(&exists, &active)
if err != nil {
return fmt.Errorf("failed to verify cyclic recording pattern: %w", err)
}
if !exists {
return fmt.Errorf("cyclic recording pattern not found: %s", patternID)
}
if !active {
return fmt.Errorf("cyclic recording pattern '%s' is not active", patternID)
}
return nil
}
func buildClusterUpdateQuery(input ClusterInput, clusterID string) (string, []any, error) {
updates := []string{}
args := []any{}
if input.Name != nil {
updates = append(updates, "name = ?")
args = append(args, *input.Name)
}
if input.Path != nil {
updates = append(updates, "path = ?")
args = append(args, *input.Path)
}
if input.SampleRate != nil {
updates = append(updates, "sample_rate = ?")
args = append(args, *input.SampleRate)
}
if input.Description != nil {
updates = append(updates, "description = ?")
args = append(args, *input.Description)
}
if input.CyclicRecordingPatternID != nil {
trimmedPatternID := strings.TrimSpace(*input.CyclicRecordingPatternID)
if trimmedPatternID == "" {
updates = append(updates, "cyclic_recording_pattern_id = NULL")
} else {
updates = append(updates, "cyclic_recording_pattern_id = ?")
args = append(args, trimmedPatternID)
}
}
if len(updates) == 0 {
return "", nil, fmt.Errorf("no fields provided to update")
}
updates = append(updates, "last_modified = now()")
args = append(args, clusterID)
query := fmt.Sprintf("UPDATE cluster SET %s WHERE id = ?", strings.Join(updates, ", "))
return query, args, nil
}
func validateClusterUpdateInput(input ClusterInput) (string, error) {
clusterID := *input.ID
if err := utils.ValidateShortID(clusterID, "cluster_id"); err != nil {
return "", err
}
if err := validateClusterFields(input); err != nil {
return "", err
}
if input.CyclicRecordingPatternID != nil {
trimmed := strings.TrimSpace(*input.CyclicRecordingPatternID)
if trimmed != "" {
if err := utils.ValidateShortID(trimmed, "cyclic_recording_pattern_id"); err != nil {
return "", err
}
}
}
return clusterID, nil
}
func updateCluster(ctx context.Context, input ClusterInput) (ClusterOutput, error) {
var output ClusterOutput
clusterID, err := validateClusterUpdateInput(input)
if err != nil {
return output, err
}
err = db.WithWriteTx(ctx, dbPath, "create_or_update_cluster", func(database *sql.DB, tx *db.LoggedTx) error {
if err := validateClusterActive(database, clusterID); err != nil {
return err
}
if err := validateClusterCyclicPattern(database, input); err != nil {
return err
}
query, args, qerr := buildClusterUpdateQuery(input, clusterID)
if qerr != nil {
return qerr
}
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("failed to update cluster: %w", err)
}
cluster, ferr := fetchClusterByID(ctx, tx, clusterID)
if ferr != nil {
return fmt.Errorf("failed to fetch updated cluster: %w", ferr)
}
output.Cluster = cluster
output.Message = fmt.Sprintf("Successfully updated cluster '%s' (ID: %s)", cluster.Name, cluster.ID)
return nil
})
return output, err
}