package tools
import (
"context"
"fmt"
"skraak_mcp/db"
"skraak_mcp/utils"
"strings"
"time"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
type LocationInput struct {
ID *string `json:"id,omitempty" jsonschema:"Location ID (12 characters). Omit to create a new location, provide to update an existing one."`
DatasetID *string `json:"dataset_id,omitempty" jsonschema:"ID of the parent dataset (12-character nanoid). Required for create."`
Name *string `json:"name,omitempty" jsonschema:"Location name (max 140 characters). Required for create."`
Latitude *float64 `json:"latitude,omitempty" jsonschema:"Latitude in decimal degrees (-90 to 90). Required for create."`
Longitude *float64 `json:"longitude,omitempty" jsonschema:"Longitude in decimal degrees (-180 to 180). Required for create."`
TimezoneID *string `json:"timezone_id,omitempty" jsonschema:"IANA timezone ID (e.g. 'Pacific/Auckland'). Required for create."`
Description *string `json:"description,omitempty" jsonschema:"Optional location description (max 255 characters)"`
}
type LocationOutput struct {
Location db.Location `json:"location" jsonschema:"The created or updated location"`
Message string `json:"message" jsonschema:"Success message"`
}
func CreateOrUpdateLocation(
ctx context.Context,
req *mcp.CallToolRequest,
input LocationInput,
) (*mcp.CallToolResult, LocationOutput, error) {
if input.ID != nil && strings.TrimSpace(*input.ID) != "" {
return updateLocation(ctx, input)
}
return createLocation(ctx, input)
}
func validateLocationFields(input LocationInput) error {
if input.Name != nil && len(*input.Name) > 140 {
return fmt.Errorf("name must be 140 characters or less (got %d)", len(*input.Name))
}
if input.Description != nil && len(*input.Description) > 255 {
return fmt.Errorf("description must be 255 characters or less (got %d)", len(*input.Description))
}
if input.Latitude != nil && (*input.Latitude < -90 || *input.Latitude > 90) {
return fmt.Errorf("latitude must be between -90 and 90 (got %f)", *input.Latitude)
}
if input.Longitude != nil && (*input.Longitude < -180 || *input.Longitude > 180) {
return fmt.Errorf("longitude must be between -180 and 180 (got %f)", *input.Longitude)
}
if input.TimezoneID != nil {
if len(*input.TimezoneID) > 40 {
return fmt.Errorf("timezone_id must be 40 characters or less (got %d)", len(*input.TimezoneID))
}
if _, err := time.LoadLocation(*input.TimezoneID); err != nil {
return fmt.Errorf("invalid timezone_id '%s': %w", *input.TimezoneID, err)
}
}
return nil
}
func createLocation(ctx context.Context, input LocationInput) (*mcp.CallToolResult, LocationOutput, error) {
var output LocationOutput
if input.DatasetID == nil || strings.TrimSpace(*input.DatasetID) == "" {
return nil, output, fmt.Errorf("dataset_id is required when creating a location")
}
if input.Name == nil || strings.TrimSpace(*input.Name) == "" {
return nil, output, fmt.Errorf("name is required when creating a location")
}
if input.Latitude == nil {
return nil, output, fmt.Errorf("latitude is required when creating a location")
}
if input.Longitude == nil {
return nil, output, fmt.Errorf("longitude is required when creating a location")
}
if input.TimezoneID == nil || strings.TrimSpace(*input.TimezoneID) == "" {
return nil, output, fmt.Errorf("timezone_id is required when creating a location")
}
if err := validateLocationFields(input); err != nil {
return nil, output, err
}
database, err := db.OpenWriteableDB(dbPath)
if err != nil {
return nil, output, fmt.Errorf("database connection failed: %w", err)
}
defer database.Close()
tx, err := database.BeginTx(ctx, nil)
if err != nil {
return nil, output, fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if err != nil {
tx.Rollback()
}
}()
var datasetExists bool
var datasetActive bool
var datasetName string
err = tx.QueryRowContext(ctx,
"SELECT EXISTS(SELECT 1 FROM dataset WHERE id = ?), active, name FROM dataset WHERE id = ?",
*input.DatasetID, *input.DatasetID,
).Scan(&datasetExists, &datasetActive, &datasetName)
if err != nil {
return nil, output, fmt.Errorf("failed to verify dataset: %w", err)
}
if !datasetExists {
return nil, output, fmt.Errorf("dataset with ID '%s' does not exist", *input.DatasetID)
}
if !datasetActive {
return nil, output, fmt.Errorf("dataset '%s' (ID: %s) is not active", datasetName, *input.DatasetID)
}
id, err := utils.GenerateShortID()
if err != nil {
return nil, output, fmt.Errorf("failed to generate ID: %w", err)
}
_, err = tx.ExecContext(ctx,
"INSERT INTO location (id, dataset_id, name, latitude, longitude, timezone_id, description, created_at, last_modified, active) VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, TRUE)",
id, *input.DatasetID, *input.Name, *input.Latitude, *input.Longitude, *input.TimezoneID, input.Description,
)
if err != nil {
return nil, output, fmt.Errorf("failed to create location: %w", err)
}
var location db.Location
err = tx.QueryRowContext(ctx,
"SELECT id, dataset_id, name, latitude, longitude, description, created_at, last_modified, active, timezone_id FROM location WHERE id = ?",
id,
).Scan(&location.ID, &location.DatasetID, &location.Name, &location.Latitude, &location.Longitude,
&location.Description, &location.CreatedAt, &location.LastModified, &location.Active, &location.TimezoneID)
if err != nil {
return nil, output, fmt.Errorf("failed to fetch created location: %w", err)
}
if err = tx.Commit(); err != nil {
return nil, output, fmt.Errorf("failed to commit transaction: %w", err)
}
output.Location = location
output.Message = fmt.Sprintf("Successfully created location '%s' with ID %s in dataset '%s' (%.6f, %.6f, %s)",
location.Name, location.ID, datasetName, location.Latitude, location.Longitude, location.TimezoneID)
return &mcp.CallToolResult{}, output, nil
}
func updateLocation(ctx context.Context, input LocationInput) (*mcp.CallToolResult, LocationOutput, error) {
var output LocationOutput
locationID := *input.ID
if err := validateLocationFields(input); err != nil {
return nil, output, err
}
database, err := db.OpenWriteableDB(dbPath)
if err != nil {
return nil, output, fmt.Errorf("failed to open database: %w", err)
}
defer database.Close()
var exists bool
err = database.QueryRow("SELECT EXISTS(SELECT 1 FROM location WHERE id = ?)", locationID).Scan(&exists)
if err != nil {
return nil, output, fmt.Errorf("failed to query location: %w", err)
}
if !exists {
return nil, output, fmt.Errorf("location not found: %s", locationID)
}
if input.DatasetID != nil {
var datasetExists bool
err = database.QueryRow("SELECT EXISTS(SELECT 1 FROM dataset WHERE id = ?)", *input.DatasetID).Scan(&datasetExists)
if err != nil {
return nil, output, fmt.Errorf("failed to query dataset: %w", err)
}
if !datasetExists {
return nil, output, fmt.Errorf("dataset not found: %s", *input.DatasetID)
}
}
updates := []string{}
args := []any{}
if input.DatasetID != nil {
updates = append(updates, "dataset_id = ?")
args = append(args, *input.DatasetID)
}
if input.Name != nil {
updates = append(updates, "name = ?")
args = append(args, *input.Name)
}
if input.Latitude != nil {
updates = append(updates, "latitude = ?")
args = append(args, *input.Latitude)
}
if input.Longitude != nil {
updates = append(updates, "longitude = ?")
args = append(args, *input.Longitude)
}
if input.Description != nil {
updates = append(updates, "description = ?")
args = append(args, *input.Description)
}
if input.TimezoneID != nil {
updates = append(updates, "timezone_id = ?")
args = append(args, *input.TimezoneID)
}
if len(updates) == 0 {
return nil, output, fmt.Errorf("no fields provided to update")
}
updates = append(updates, "last_modified = now()")
args = append(args, locationID)
query := fmt.Sprintf("UPDATE location SET %s WHERE id = ?", strings.Join(updates, ", "))
_, err = database.Exec(query, args...)
if err != nil {
return nil, output, fmt.Errorf("failed to update location: %w", err)
}
var location db.Location
err = database.QueryRow(
"SELECT id, dataset_id, name, latitude, longitude, description, created_at, last_modified, active, timezone_id FROM location WHERE id = ?",
locationID,
).Scan(&location.ID, &location.DatasetID, &location.Name, &location.Latitude, &location.Longitude,
&location.Description, &location.CreatedAt, &location.LastModified, &location.Active, &location.TimezoneID)
if err != nil {
return nil, output, fmt.Errorf("failed to fetch updated location: %w", err)
}
output.Location = location
output.Message = fmt.Sprintf("Successfully updated location '%s' (ID: %s)", location.Name, location.ID)
return &mcp.CallToolResult{}, output, nil
}