package tools
import (
"context"
"fmt"
"skraak_mcp/db"
"skraak_mcp/utils"
"strings"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
type DatasetInput struct {
ID *string `json:"id,omitempty" jsonschema:"Dataset ID (12 characters). Omit to create a new dataset, provide to update an existing one."`
Name *string `json:"name,omitempty" jsonschema:"Dataset name (max 255 characters). Required for create."`
Description *string `json:"description,omitempty" jsonschema:"Optional dataset description (max 255 characters)"`
Type *string `json:"type,omitempty" jsonschema:"Dataset type: 'organise'/'test'/'train' (defaults to 'organise' on create)"`
}
type DatasetOutput struct {
Dataset db.Dataset `json:"dataset" jsonschema:"The created or updated dataset"`
Message string `json:"message" jsonschema:"Success message"`
}
func CreateOrUpdateDataset(
ctx context.Context,
req *mcp.CallToolRequest,
input DatasetInput,
) (*mcp.CallToolResult, DatasetOutput, error) {
if input.ID != nil && strings.TrimSpace(*input.ID) != "" {
return updateDataset(ctx, input)
}
return createDataset(ctx, input)
}
func createDataset(ctx context.Context, input DatasetInput) (*mcp.CallToolResult, DatasetOutput, error) {
var output DatasetOutput
if input.Name == nil || strings.TrimSpace(*input.Name) == "" {
return nil, output, fmt.Errorf("name is required when creating a dataset")
}
if len(*input.Name) > 255 {
return nil, output, fmt.Errorf("name must be 255 characters or less (got %d)", len(*input.Name))
}
if input.Description != nil && len(*input.Description) > 255 {
return nil, output, fmt.Errorf("description must be 255 characters or less (got %d)", len(*input.Description))
}
datasetType := db.DatasetTypeOrganise if input.Type != nil {
typeStr := strings.ToLower(strings.TrimSpace(*input.Type))
switch typeStr {
case "organise":
datasetType = db.DatasetTypeOrganise
case "test":
datasetType = db.DatasetTypeTest
case "train":
datasetType = db.DatasetTypeTrain
default:
return nil, output, fmt.Errorf("invalid type '%s': must be 'organise', 'test', or 'train'", *input.Type)
}
}
database, err := db.OpenWriteableDB(dbPath)
if err != nil {
return nil, output, fmt.Errorf("database connection failed: %w", err)
}
defer database.Close()
tx, err := database.BeginTx(ctx, nil)
if err != nil {
return nil, output, fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if err != nil {
tx.Rollback()
}
}()
id, err := utils.GenerateShortID()
if err != nil {
return nil, output, fmt.Errorf("failed to generate ID: %w", err)
}
_, err = tx.ExecContext(ctx,
"INSERT INTO dataset (id, name, description, type, created_at, last_modified, active) VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, TRUE)",
id, *input.Name, input.Description, string(datasetType),
)
if err != nil {
return nil, output, fmt.Errorf("failed to create dataset: %w", err)
}
var dataset db.Dataset
err = tx.QueryRowContext(ctx,
"SELECT id, name, description, created_at, last_modified, active, type FROM dataset WHERE id = ?",
id,
).Scan(&dataset.ID, &dataset.Name, &dataset.Description, &dataset.CreatedAt, &dataset.LastModified, &dataset.Active, &dataset.Type)
if err != nil {
return nil, output, fmt.Errorf("failed to fetch created dataset: %w", err)
}
if err = tx.Commit(); err != nil {
return nil, output, fmt.Errorf("failed to commit transaction: %w", err)
}
output.Dataset = dataset
output.Message = fmt.Sprintf("Successfully created dataset '%s' with ID %s (type: %s)",
dataset.Name, dataset.ID, dataset.Type)
return &mcp.CallToolResult{}, output, nil
}
func updateDataset(ctx context.Context, input DatasetInput) (*mcp.CallToolResult, DatasetOutput, error) {
var output DatasetOutput
datasetID := *input.ID
if input.Name != nil && len(*input.Name) > 255 {
return nil, output, fmt.Errorf("name must be 255 characters or less (got %d)", len(*input.Name))
}
if input.Description != nil && len(*input.Description) > 255 {
return nil, output, fmt.Errorf("description must be 255 characters or less (got %d)", len(*input.Description))
}
if input.Type != nil {
typeValue := strings.ToLower(*input.Type)
if typeValue != "organise" && typeValue != "test" && typeValue != "train" {
return nil, output, fmt.Errorf("invalid dataset type: %s (must be 'organise', 'test', or 'train')", *input.Type)
}
}
database, err := db.OpenWriteableDB(dbPath)
if err != nil {
return nil, output, fmt.Errorf("failed to open database: %w", err)
}
defer database.Close()
var exists bool
err = database.QueryRow("SELECT EXISTS(SELECT 1 FROM dataset WHERE id = ?)", datasetID).Scan(&exists)
if err != nil {
return nil, output, fmt.Errorf("failed to query dataset: %w", err)
}
if !exists {
return nil, output, fmt.Errorf("dataset not found: %s", datasetID)
}
updates := []string{}
args := []any{}
if input.Name != nil {
updates = append(updates, "name = ?")
args = append(args, *input.Name)
}
if input.Description != nil {
updates = append(updates, "description = ?")
args = append(args, *input.Description)
}
if input.Type != nil {
updates = append(updates, "type = ?")
args = append(args, strings.ToLower(*input.Type))
}
if len(updates) == 0 {
return nil, output, fmt.Errorf("no fields provided to update")
}
updates = append(updates, "last_modified = now()")
args = append(args, datasetID)
query := fmt.Sprintf("UPDATE dataset SET %s WHERE id = ?", strings.Join(updates, ", "))
_, err = database.Exec(query, args...)
if err != nil {
return nil, output, fmt.Errorf("failed to update dataset: %w", err)
}
var dataset db.Dataset
err = database.QueryRow(
"SELECT id, name, description, created_at, last_modified, active, type FROM dataset WHERE id = ?",
datasetID,
).Scan(&dataset.ID, &dataset.Name, &dataset.Description, &dataset.CreatedAt, &dataset.LastModified, &dataset.Active, &dataset.Type)
if err != nil {
return nil, output, fmt.Errorf("failed to fetch updated dataset: %w", err)
}
output.Dataset = dataset
output.Message = fmt.Sprintf("Successfully updated dataset '%s' (ID: %s)", dataset.Name, dataset.ID)
return &mcp.CallToolResult{}, output, nil
}