package tools
import (
"context"
"database/sql"
"fmt"
"skraak_mcp/db"
"skraak_mcp/utils"
"strings"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
type PatternInput struct {
ID *string `json:"id,omitempty" jsonschema:"Pattern ID (12 characters). Omit to create a new pattern, provide to update an existing one."`
RecordSeconds *int `json:"record_seconds,omitempty" jsonschema:"Number of seconds to record (must be positive). Required for create."`
SleepSeconds *int `json:"sleep_seconds,omitempty" jsonschema:"Number of seconds to sleep between recordings (must be positive for create, >= 0 for update)."`
}
type PatternOutput struct {
Pattern db.CyclicRecordingPattern `json:"pattern" jsonschema:"The created or updated recording pattern"`
Message string `json:"message" jsonschema:"Success message"`
}
func CreateOrUpdatePattern(
ctx context.Context,
req *mcp.CallToolRequest,
input PatternInput,
) (*mcp.CallToolResult, PatternOutput, error) {
if input.ID != nil && strings.TrimSpace(*input.ID) != "" {
return updatePattern(ctx, input)
}
return createPattern(ctx, input)
}
func createPattern(ctx context.Context, input PatternInput) (*mcp.CallToolResult, PatternOutput, error) {
var output PatternOutput
if input.RecordSeconds == nil {
return nil, output, fmt.Errorf("record_seconds is required when creating a pattern")
}
if input.SleepSeconds == nil {
return nil, output, fmt.Errorf("sleep_seconds is required when creating a pattern")
}
if *input.RecordSeconds <= 0 {
return nil, output, fmt.Errorf("record_seconds must be positive (got %d)", *input.RecordSeconds)
}
if *input.SleepSeconds <= 0 {
return nil, output, fmt.Errorf("sleep_seconds must be positive (got %d)", *input.SleepSeconds)
}
database, err := db.OpenWriteableDB(dbPath)
if err != nil {
return nil, output, fmt.Errorf("database connection failed: %w", err)
}
defer database.Close()
tx, err := database.BeginTx(ctx, nil)
if err != nil {
return nil, output, fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if err != nil {
tx.Rollback()
}
}()
var existingID string
err = tx.QueryRowContext(ctx,
"SELECT id FROM cyclic_recording_pattern WHERE record_s = ? AND sleep_s = ? AND active = true",
*input.RecordSeconds, *input.SleepSeconds,
).Scan(&existingID)
if err == nil {
var pattern db.CyclicRecordingPattern
err = tx.QueryRowContext(ctx,
"SELECT id, record_s, sleep_s, created_at, last_modified, active FROM cyclic_recording_pattern WHERE id = ?",
existingID,
).Scan(&pattern.ID, &pattern.RecordS, &pattern.SleepS, &pattern.CreatedAt, &pattern.LastModified, &pattern.Active)
if err != nil {
return nil, output, fmt.Errorf("failed to fetch existing pattern: %w", err)
}
if err = tx.Commit(); err != nil {
return nil, output, fmt.Errorf("failed to commit transaction: %w", err)
}
output.Pattern = pattern
output.Message = fmt.Sprintf("Pattern already exists with ID %s (record %ds, sleep %ds) - returning existing pattern",
pattern.ID, pattern.RecordS, pattern.SleepS)
return &mcp.CallToolResult{}, output, nil
} else if err != sql.ErrNoRows {
return nil, output, fmt.Errorf("failed to check for existing pattern: %w", err)
}
id, err := utils.GenerateShortID()
if err != nil {
return nil, output, fmt.Errorf("failed to generate ID: %w", err)
}
_, err = tx.ExecContext(ctx,
"INSERT INTO cyclic_recording_pattern (id, record_s, sleep_s, created_at, last_modified, active) VALUES (?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, TRUE)",
id, *input.RecordSeconds, *input.SleepSeconds,
)
if err != nil {
return nil, output, fmt.Errorf("failed to create pattern: %w", err)
}
var pattern db.CyclicRecordingPattern
err = tx.QueryRowContext(ctx,
"SELECT id, record_s, sleep_s, created_at, last_modified, active FROM cyclic_recording_pattern WHERE id = ?",
id,
).Scan(&pattern.ID, &pattern.RecordS, &pattern.SleepS, &pattern.CreatedAt, &pattern.LastModified, &pattern.Active)
if err != nil {
return nil, output, fmt.Errorf("failed to fetch created pattern: %w", err)
}
if err = tx.Commit(); err != nil {
return nil, output, fmt.Errorf("failed to commit transaction: %w", err)
}
output.Pattern = pattern
output.Message = fmt.Sprintf("Successfully created cyclic recording pattern with ID %s (record %ds, sleep %ds)",
pattern.ID, pattern.RecordS, pattern.SleepS)
return &mcp.CallToolResult{}, output, nil
}
func updatePattern(ctx context.Context, input PatternInput) (*mcp.CallToolResult, PatternOutput, error) {
var output PatternOutput
patternID := *input.ID
if input.RecordSeconds != nil && *input.RecordSeconds <= 0 {
return nil, output, fmt.Errorf("record_seconds must be greater than 0: %d", *input.RecordSeconds)
}
if input.SleepSeconds != nil && *input.SleepSeconds < 0 {
return nil, output, fmt.Errorf("sleep_seconds must be greater than or equal to 0: %d", *input.SleepSeconds)
}
database, err := db.OpenWriteableDB(dbPath)
if err != nil {
return nil, output, fmt.Errorf("failed to open database: %w", err)
}
defer database.Close()
var exists bool
err = database.QueryRow("SELECT EXISTS(SELECT 1 FROM cyclic_recording_pattern WHERE id = ?)", patternID).Scan(&exists)
if err != nil {
return nil, output, fmt.Errorf("failed to query pattern: %w", err)
}
if !exists {
return nil, output, fmt.Errorf("pattern not found: %s", patternID)
}
updates := []string{}
args := []any{}
if input.RecordSeconds != nil {
updates = append(updates, "record_s = ?")
args = append(args, *input.RecordSeconds)
}
if input.SleepSeconds != nil {
updates = append(updates, "sleep_s = ?")
args = append(args, *input.SleepSeconds)
}
if len(updates) == 0 {
return nil, output, fmt.Errorf("no fields provided to update")
}
updates = append(updates, "last_modified = now()")
args = append(args, patternID)
query := fmt.Sprintf("UPDATE cyclic_recording_pattern SET %s WHERE id = ?", strings.Join(updates, ", "))
_, err = database.Exec(query, args...)
if err != nil {
return nil, output, fmt.Errorf("failed to update pattern: %w", err)
}
var pattern db.CyclicRecordingPattern
err = database.QueryRow(
"SELECT id, record_s, sleep_s, created_at, last_modified, active FROM cyclic_recording_pattern WHERE id = ?",
patternID,
).Scan(&pattern.ID, &pattern.RecordS, &pattern.SleepS, &pattern.CreatedAt, &pattern.LastModified, &pattern.Active)
if err != nil {
return nil, output, fmt.Errorf("failed to fetch updated pattern: %w", err)
}
output.Pattern = pattern
output.Message = fmt.Sprintf("Successfully updated pattern (ID: %s, record %ds, sleep %ds)",
pattern.ID, pattern.RecordS, pattern.SleepS)
return &mcp.CallToolResult{}, output, nil
}