package tools
import (
"context"
"database/sql"
"fmt"
"skraak/db"
"skraak/utils"
"strings"
)
type DatasetInput struct {
ID *string `json:"id,omitempty"`
Name *string `json:"name,omitempty"`
Description *string `json:"description,omitempty"`
Type *string `json:"type,omitempty"`
}
type DatasetOutput struct {
Dataset db.Dataset `json:"dataset"`
Message string `json:"message"`
}
func CreateOrUpdateDataset(
ctx context.Context,
input DatasetInput,
) (DatasetOutput, error) {
if input.ID != nil && strings.TrimSpace(*input.ID) != "" {
return updateDataset(ctx, input)
}
return createDataset(ctx, input)
}
func createDataset(ctx context.Context, input DatasetInput) (DatasetOutput, error) {
var output DatasetOutput
if input.Name == nil || strings.TrimSpace(*input.Name) == "" {
return output, fmt.Errorf("name is required when creating a dataset")
}
if err := utils.ValidateStringLength(*input.Name, "name", utils.MaxDatasetNameLen); err != nil {
return output, err
}
if err := utils.ValidateOptionalStringLength(input.Description, "description", utils.MaxDescriptionLen); err != nil {
return output, err
}
datasetType, err := parseDatasetType(input.Type)
if err != nil {
return output, err
}
err = db.WithWriteTx(ctx, dbPath, "create_or_update_dataset", func(database *sql.DB, tx *db.LoggedTx) error {
var existingID string
qerr := tx.QueryRowContext(ctx,
"SELECT id FROM dataset WHERE name = ? AND active = true",
*input.Name,
).Scan(&existingID)
if qerr == nil {
result, herr := handleExistingDataset(ctx, tx, existingID)
if herr != nil {
return herr
}
output = result
return nil
}
result, insErr := insertNewDataset(ctx, tx, *input.Name, input.Description, datasetType)
if insErr != nil {
return insErr
}
output = result
return nil
})
return output, err
}
func parseDatasetType(t *string) (db.DatasetType, error) {
datasetType := db.DatasetTypeStructured if t != nil {
typeStr := strings.ToLower(strings.TrimSpace(*t))
switch typeStr {
case "structured":
datasetType = db.DatasetTypeStructured
case "unstructured":
datasetType = db.DatasetTypeUnstructured
case "test":
datasetType = db.DatasetTypeTest
case "train":
datasetType = db.DatasetTypeTrain
default:
return "", fmt.Errorf("invalid type '%s': must be 'structured', 'unstructured', 'test', or 'train'", *t)
}
}
return datasetType, nil
}
func handleExistingDataset(ctx context.Context, tx *db.LoggedTx, existingID string) (DatasetOutput, error) {
var dataset db.Dataset
err := tx.QueryRowContext(ctx,
"SELECT id, name, description, created_at, last_modified, active, type FROM dataset WHERE id = ?",
existingID,
).Scan(&dataset.ID, &dataset.Name, &dataset.Description, &dataset.CreatedAt, &dataset.LastModified, &dataset.Active, &dataset.Type)
if err != nil {
return DatasetOutput{}, fmt.Errorf("failed to fetch existing dataset: %w", err)
}
return DatasetOutput{
Dataset: dataset,
Message: fmt.Sprintf("Dataset with name '%s' already exists (ID: %s) - returning existing dataset", dataset.Name, dataset.ID),
}, nil
}
func insertNewDataset(ctx context.Context, tx *db.LoggedTx, name string, description *string, datasetType db.DatasetType) (DatasetOutput, error) {
id, err := utils.GenerateShortID()
if err != nil {
return DatasetOutput{}, fmt.Errorf("failed to generate ID: %w", err)
}
_, err = tx.ExecContext(ctx,
"INSERT INTO dataset (id, name, description, type, created_at, last_modified, active) VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, TRUE)",
id, name, description, string(datasetType),
)
if err != nil {
return DatasetOutput{}, fmt.Errorf("failed to create dataset: %w", err)
}
var dataset db.Dataset
err = tx.QueryRowContext(ctx,
"SELECT id, name, description, created_at, last_modified, active, type FROM dataset WHERE id = ?",
id,
).Scan(&dataset.ID, &dataset.Name, &dataset.Description, &dataset.CreatedAt, &dataset.LastModified, &dataset.Active, &dataset.Type)
if err != nil {
return DatasetOutput{}, fmt.Errorf("failed to fetch created dataset: %w", err)
}
return DatasetOutput{
Dataset: dataset,
Message: fmt.Sprintf("Successfully created dataset '%s' with ID %s (type: %s)",
dataset.Name, dataset.ID, dataset.Type),
}, nil
}
func validateUpdateInput(input DatasetInput) error {
if err := utils.ValidateShortID(*input.ID, "dataset_id"); err != nil {
return err
}
if err := utils.ValidateOptionalStringLength(input.Name, "name", utils.MaxDatasetNameLen); err != nil {
return err
}
if err := utils.ValidateOptionalStringLength(input.Description, "description", utils.MaxDescriptionLen); err != nil {
return err
}
if err := validateDatasetType(input.Type); err != nil {
return err
}
return nil
}
func validateDatasetType(t *string) error {
if t == nil {
return nil
}
typeValue := strings.ToLower(*t)
switch typeValue {
case "structured", "unstructured", "test", "train":
return nil
default:
return fmt.Errorf("invalid dataset type: %s (must be 'structured', 'unstructured', 'test', or 'train')", *t)
}
}
func buildUpdateQuery(input DatasetInput, datasetID string) (string, []any, error) {
updates := []string{}
args := []any{}
if input.Name != nil {
updates = append(updates, "name = ?")
args = append(args, *input.Name)
}
if input.Description != nil {
updates = append(updates, "description = ?")
args = append(args, *input.Description)
}
if input.Type != nil {
updates = append(updates, "type = ?")
args = append(args, strings.ToLower(*input.Type))
}
if len(updates) == 0 {
return "", nil, fmt.Errorf("no fields provided to update")
}
updates = append(updates, "last_modified = now()")
args = append(args, datasetID)
query := fmt.Sprintf("UPDATE dataset SET %s WHERE id = ?", strings.Join(updates, ", "))
return query, args, nil
}
func updateDataset(ctx context.Context, input DatasetInput) (DatasetOutput, error) {
var output DatasetOutput
datasetID := *input.ID
if err := validateUpdateInput(input); err != nil {
return output, err
}
err := db.WithWriteTx(ctx, dbPath, "create_or_update_dataset", func(database *sql.DB, tx *db.LoggedTx) error {
if err := verifyDatasetActive(database, datasetID); err != nil {
return err
}
query, args, qerr := buildUpdateQuery(input, datasetID)
if qerr != nil {
return qerr
}
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("failed to update dataset: %w", err)
}
var dataset db.Dataset
err := tx.QueryRow(
"SELECT id, name, description, created_at, last_modified, active, type FROM dataset WHERE id = ?",
datasetID,
).Scan(&dataset.ID, &dataset.Name, &dataset.Description, &dataset.CreatedAt, &dataset.LastModified, &dataset.Active, &dataset.Type)
if err != nil {
return fmt.Errorf("failed to fetch updated dataset: %w", err)
}
output.Dataset = dataset
output.Message = fmt.Sprintf("Successfully updated dataset '%s' (ID: %s)", dataset.Name, dataset.ID)
return nil
})
return output, err
}
func verifyDatasetActive(database *sql.DB, datasetID string) error {
var exists, active bool
err := database.QueryRow("SELECT EXISTS(SELECT 1 FROM dataset WHERE id = ?), COALESCE((SELECT active FROM dataset WHERE id = ?), false)", datasetID, datasetID).Scan(&exists, &active)
if err != nil {
return fmt.Errorf("failed to query dataset: %w", err)
}
if !exists {
return fmt.Errorf("dataset not found: %s", datasetID)
}
if !active {
return fmt.Errorf("dataset '%s' is not active (cannot update inactive datasets)", datasetID)
}
return nil
}