package tools
import (
"context"
"database/sql"
"encoding/base64"
"fmt"
"regexp"
"strings"
"time"
"skraak/db"
)
var dbPath string
func SetDBPath(path string) {
dbPath = path
}
type ExecuteSQLInput struct {
Query string `json:"query"`
Parameters []any `json:"parameters,omitempty"`
Limit *int `json:"limit,omitempty"`
}
type ColumnInfo struct {
Name string `json:"name"`
DatabaseType string `json:"database_type"`
}
type ExecuteSQLOutput struct {
Rows []map[string]any `json:"rows"`
RowCount int `json:"row_count"`
Columns []ColumnInfo `json:"columns"`
Limited bool `json:"limited"`
Query string `json:"query_executed"`
}
var (
selectPattern = regexp.MustCompile(`(?i)^\s*(SELECT|WITH)\s+`)
forbiddenPattern = regexp.MustCompile(`(?i)\b(INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|TRUNCATE|GRANT|REVOKE)\b`)
limitPattern = regexp.MustCompile(`(?i)\bLIMIT\s+\d+`)
)
const (
defaultLimit = 1000
maxLimit = 10000
)
func ExecuteSQL(
ctx context.Context,
input ExecuteSQLInput,
) (ExecuteSQLOutput, error) {
if err := validateSQLQuery(input.Query, input.Limit); err != nil {
return ExecuteSQLOutput{}, err
}
limit := resolveLimit(input.Limit)
query, autoAddedLimit := applyLimit(input.Query, limit)
var output ExecuteSQLOutput
err := db.WithReadDB(dbPath, func(database *sql.DB) error {
rows, rerr := executeSQLQuery(ctx, database, query, input.Parameters)
if rerr != nil {
return rerr
}
defer rows.Close()
columnInfo, columns, cerr := buildColumnInfo(rows)
if cerr != nil {
return cerr
}
results, serr := scanResultRows(rows, columns)
if serr != nil {
return serr
}
if results == nil {
results = []map[string]any{}
}
limited := false
if autoAddedLimit && len(results) > limit {
limited = true
results = results[:limit]
}
queryReported := buildQueryReported(input.Query, autoAddedLimit, limit)
output = ExecuteSQLOutput{
Rows: results,
RowCount: len(results),
Columns: columnInfo,
Limited: limited,
Query: queryReported,
}
return nil
})
return output, err
}
func validateSQLQuery(query string, limit *int) error {
if strings.TrimSpace(query) == "" {
return fmt.Errorf("query cannot be empty")
}
if !selectPattern.MatchString(query) {
return fmt.Errorf("only SELECT and WITH queries are allowed")
}
if forbiddenPattern.MatchString(query) {
return fmt.Errorf("query contains forbidden keywords (INSERT/UPDATE/DELETE/DROP/CREATE/ALTER)")
}
if limit != nil {
if *limit < 1 || *limit > maxLimit {
return fmt.Errorf("limit must be between 1 and %d", maxLimit)
}
}
return nil
}
func resolveLimit(limit *int) int {
if limit != nil {
return *limit
}
return defaultLimit
}
func applyLimit(query string, limit int) (string, bool) {
if !limitPattern.MatchString(query) {
return fmt.Sprintf("%s LIMIT %d", strings.TrimSpace(query), limit+1), true
}
return query, false
}
func executeSQLQuery(ctx context.Context, database *sql.DB, query string, params []any) (*sql.Rows, error) {
if len(params) > 0 {
return database.QueryContext(ctx, query, params...)
}
return database.QueryContext(ctx, query)
}
func buildColumnInfo(rows *sql.Rows) ([]ColumnInfo, []string, error) {
columns, err := rows.Columns()
if err != nil {
return nil, nil, fmt.Errorf("failed to get columns: %w", err)
}
columnTypes, err := rows.ColumnTypes()
if err != nil {
return nil, nil, fmt.Errorf("failed to get column types: %w", err)
}
columnInfo := make([]ColumnInfo, len(columns))
for i, col := range columns {
columnInfo[i] = ColumnInfo{
Name: col,
DatabaseType: columnTypes[i].DatabaseTypeName(),
}
}
return columnInfo, columns, nil
}
func scanResultRows(rows *sql.Rows, columns []string) ([]map[string]any, error) {
var results []map[string]any
for rows.Next() {
values := make([]any, len(columns))
valuePtrs := make([]any, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, fmt.Errorf("row scan failed: %w", err)
}
rowMap := make(map[string]any)
for i, col := range columns {
rowMap[col] = convertValue(values[i])
}
results = append(results, rowMap)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("row iteration failed: %w", err)
}
return results, nil
}
func buildQueryReported(originalQuery string, autoAddedLimit bool, limit int) string {
if autoAddedLimit {
return fmt.Sprintf("%s LIMIT %d", strings.TrimSpace(originalQuery), limit)
}
return originalQuery
}
func convertValue(val any) any {
if val == nil {
return nil
}
switch v := val.(type) {
case time.Time:
return v.Format(time.RFC3339)
case []byte:
return base64.StdEncoding.EncodeToString(v)
case int64, float64, string, bool:
return v
default:
return fmt.Sprintf("%v", v)
}
}