package tools
import (
"context"
"fmt"
"skraak_mcp/db"
"skraak_mcp/utils"
"strings"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
type ClusterInput struct {
ID *string `json:"id,omitempty" jsonschema:"Cluster ID (12 characters). Omit to create a new cluster, provide to update an existing one."`
DatasetID *string `json:"dataset_id,omitempty" jsonschema:"ID of the parent dataset (12-character nanoid). Required for create."`
LocationID *string `json:"location_id,omitempty" jsonschema:"ID of the parent location (12-character nanoid). Required for create."`
Name *string `json:"name,omitempty" jsonschema:"Cluster name (max 140 characters). Required for create."`
SampleRate *int `json:"sample_rate,omitempty" jsonschema:"Sample rate in Hz (must be positive). Required for create."`
Path *string `json:"path,omitempty" jsonschema:"Normalized folder path (max 255 characters)"`
CyclicRecordingPatternID *string `json:"cyclic_recording_pattern_id,omitempty" jsonschema:"Optional ID of cyclic recording pattern (12-character nanoid). Set to empty string to clear."`
Description *string `json:"description,omitempty" jsonschema:"Optional cluster description (max 255 characters)"`
}
type ClusterOutput struct {
Cluster db.Cluster `json:"cluster" jsonschema:"The created or updated cluster"`
Message string `json:"message" jsonschema:"Success message"`
}
func CreateOrUpdateCluster(
ctx context.Context,
req *mcp.CallToolRequest,
input ClusterInput,
) (*mcp.CallToolResult, ClusterOutput, error) {
if input.ID != nil && strings.TrimSpace(*input.ID) != "" {
return updateCluster(ctx, input)
}
return createCluster(ctx, input)
}
func validateClusterFields(input ClusterInput) error {
if input.Name != nil && len(*input.Name) > 140 {
return fmt.Errorf("name must be 140 characters or less (got %d)", len(*input.Name))
}
if input.Description != nil && len(*input.Description) > 255 {
return fmt.Errorf("description must be 255 characters or less (got %d)", len(*input.Description))
}
if input.SampleRate != nil && *input.SampleRate <= 0 {
return fmt.Errorf("sample_rate must be positive (got %d)", *input.SampleRate)
}
return nil
}
func createCluster(ctx context.Context, input ClusterInput) (*mcp.CallToolResult, ClusterOutput, error) {
var output ClusterOutput
if input.DatasetID == nil || strings.TrimSpace(*input.DatasetID) == "" {
return nil, output, fmt.Errorf("dataset_id is required when creating a cluster")
}
if input.LocationID == nil || strings.TrimSpace(*input.LocationID) == "" {
return nil, output, fmt.Errorf("location_id is required when creating a cluster")
}
if input.Name == nil || strings.TrimSpace(*input.Name) == "" {
return nil, output, fmt.Errorf("name is required when creating a cluster")
}
if input.SampleRate == nil {
return nil, output, fmt.Errorf("sample_rate is required when creating a cluster")
}
if err := validateClusterFields(input); err != nil {
return nil, output, err
}
database, err := db.OpenWriteableDB(dbPath)
if err != nil {
return nil, output, fmt.Errorf("database connection failed: %w", err)
}
defer database.Close()
tx, err := database.BeginTx(ctx, nil)
if err != nil {
return nil, output, fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if err != nil {
tx.Rollback()
}
}()
var datasetExists bool
var datasetActive bool
var datasetName string
err = tx.QueryRowContext(ctx,
"SELECT EXISTS(SELECT 1 FROM dataset WHERE id = ?), active, name FROM dataset WHERE id = ?",
*input.DatasetID, *input.DatasetID,
).Scan(&datasetExists, &datasetActive, &datasetName)
if err != nil {
return nil, output, fmt.Errorf("failed to verify dataset: %w", err)
}
if !datasetExists {
return nil, output, fmt.Errorf("dataset with ID '%s' does not exist", *input.DatasetID)
}
if !datasetActive {
return nil, output, fmt.Errorf("dataset '%s' (ID: %s) is not active", datasetName, *input.DatasetID)
}
var locationExists bool
var locationActive bool
var locationName string
var locationDatasetID string
err = tx.QueryRowContext(ctx,
"SELECT EXISTS(SELECT 1 FROM location WHERE id = ?), active, name, dataset_id FROM location WHERE id = ?",
*input.LocationID, *input.LocationID,
).Scan(&locationExists, &locationActive, &locationName, &locationDatasetID)
if err != nil {
return nil, output, fmt.Errorf("failed to verify location: %w", err)
}
if !locationExists {
return nil, output, fmt.Errorf("location with ID '%s' does not exist", *input.LocationID)
}
if !locationActive {
return nil, output, fmt.Errorf("location '%s' (ID: %s) is not active", locationName, *input.LocationID)
}
if locationDatasetID != *input.DatasetID {
return nil, output, fmt.Errorf("location '%s' (ID: %s) does not belong to dataset '%s' (ID: %s) - it belongs to dataset ID '%s'",
locationName, *input.LocationID, datasetName, *input.DatasetID, locationDatasetID)
}
if input.CyclicRecordingPatternID != nil && strings.TrimSpace(*input.CyclicRecordingPatternID) != "" {
var patternExists bool
var patternActive bool
err = tx.QueryRowContext(ctx,
"SELECT EXISTS(SELECT 1 FROM cyclic_recording_pattern WHERE id = ?), active FROM cyclic_recording_pattern WHERE id = ?",
*input.CyclicRecordingPatternID, *input.CyclicRecordingPatternID,
).Scan(&patternExists, &patternActive)
if err != nil {
return nil, output, fmt.Errorf("failed to verify cyclic recording pattern: %w", err)
}
if !patternExists {
return nil, output, fmt.Errorf("cyclic recording pattern with ID '%s' does not exist", *input.CyclicRecordingPatternID)
}
if !patternActive {
return nil, output, fmt.Errorf("cyclic recording pattern with ID '%s' is not active", *input.CyclicRecordingPatternID)
}
}
id, err := utils.GenerateShortID()
if err != nil {
return nil, output, fmt.Errorf("failed to generate ID: %w", err)
}
_, err = tx.ExecContext(ctx,
"INSERT INTO cluster (id, dataset_id, location_id, name, sample_rate, cyclic_recording_pattern_id, description, created_at, last_modified, active) VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, TRUE)",
id, *input.DatasetID, *input.LocationID, *input.Name, *input.SampleRate, input.CyclicRecordingPatternID, input.Description,
)
if err != nil {
return nil, output, fmt.Errorf("failed to create cluster: %w", err)
}
var cluster db.Cluster
err = tx.QueryRowContext(ctx,
"SELECT id, dataset_id, location_id, name, description, created_at, last_modified, active, cyclic_recording_pattern_id, sample_rate FROM cluster WHERE id = ?",
id,
).Scan(&cluster.ID, &cluster.DatasetID, &cluster.LocationID, &cluster.Name, &cluster.Description,
&cluster.CreatedAt, &cluster.LastModified, &cluster.Active, &cluster.CyclicRecordingPatternID, &cluster.SampleRate)
if err != nil {
return nil, output, fmt.Errorf("failed to fetch created cluster: %w", err)
}
if err = tx.Commit(); err != nil {
return nil, output, fmt.Errorf("failed to commit transaction: %w", err)
}
output.Cluster = cluster
output.Message = fmt.Sprintf("Successfully created cluster '%s' with ID %s in location '%s' at dataset '%s' (sample rate: %d Hz)",
cluster.Name, cluster.ID, locationName, datasetName, cluster.SampleRate)
return &mcp.CallToolResult{}, output, nil
}
func updateCluster(ctx context.Context, input ClusterInput) (*mcp.CallToolResult, ClusterOutput, error) {
var output ClusterOutput
clusterID := *input.ID
if err := validateClusterFields(input); err != nil {
return nil, output, err
}
database, err := db.OpenWriteableDB(dbPath)
if err != nil {
return nil, output, fmt.Errorf("failed to open database: %w", err)
}
defer database.Close()
var exists bool
err = database.QueryRow("SELECT EXISTS(SELECT 1 FROM cluster WHERE id = ?)", clusterID).Scan(&exists)
if err != nil {
return nil, output, fmt.Errorf("failed to query cluster: %w", err)
}
if !exists {
return nil, output, fmt.Errorf("cluster not found: %s", clusterID)
}
if input.CyclicRecordingPatternID != nil {
trimmedPatternID := strings.TrimSpace(*input.CyclicRecordingPatternID)
if trimmedPatternID != "" {
var patternExists bool
err = database.QueryRow(
"SELECT EXISTS(SELECT 1 FROM cyclic_recording_pattern WHERE id = ? AND active = true)",
trimmedPatternID,
).Scan(&patternExists)
if err != nil {
return nil, output, fmt.Errorf("failed to verify cyclic recording pattern: %w", err)
}
if !patternExists {
return nil, output, fmt.Errorf("cyclic recording pattern not found or not active: %s", trimmedPatternID)
}
}
}
updates := []string{}
args := []any{}
if input.Name != nil {
updates = append(updates, "name = ?")
args = append(args, *input.Name)
}
if input.Path != nil {
updates = append(updates, "path = ?")
args = append(args, *input.Path)
}
if input.SampleRate != nil {
updates = append(updates, "sample_rate = ?")
args = append(args, *input.SampleRate)
}
if input.Description != nil {
updates = append(updates, "description = ?")
args = append(args, *input.Description)
}
if input.CyclicRecordingPatternID != nil {
trimmedPatternID := strings.TrimSpace(*input.CyclicRecordingPatternID)
if trimmedPatternID == "" {
updates = append(updates, "cyclic_recording_pattern_id = NULL")
} else {
updates = append(updates, "cyclic_recording_pattern_id = ?")
args = append(args, trimmedPatternID)
}
}
if len(updates) == 0 {
return nil, output, fmt.Errorf("no fields provided to update")
}
updates = append(updates, "last_modified = now()")
args = append(args, clusterID)
query := fmt.Sprintf("UPDATE cluster SET %s WHERE id = ?", strings.Join(updates, ", "))
_, err = database.Exec(query, args...)
if err != nil {
return nil, output, fmt.Errorf("failed to update cluster: %w", err)
}
var cluster db.Cluster
err = database.QueryRow(
"SELECT id, dataset_id, location_id, name, description, created_at, last_modified, active, cyclic_recording_pattern_id, sample_rate FROM cluster WHERE id = ?",
clusterID,
).Scan(&cluster.ID, &cluster.DatasetID, &cluster.LocationID, &cluster.Name, &cluster.Description,
&cluster.CreatedAt, &cluster.LastModified, &cluster.Active, &cluster.CyclicRecordingPatternID, &cluster.SampleRate)
if err != nil {
return nil, output, fmt.Errorf("failed to fetch updated cluster: %w", err)
}
output.Cluster = cluster
output.Message = fmt.Sprintf("Successfully updated cluster '%s' (ID: %s)", cluster.Name, cluster.ID)
return &mcp.CallToolResult{}, output, nil
}