PLFA agda exercises
------------------------------------------------------------------------
-- The Agda standard library
--
-- A solver that uses reflection to automatically obtain and solve
-- equations over rings.
------------------------------------------------------------------------

{-# OPTIONS --cubical-compatible --safe #-}

module Tactic.RingSolver where

open import Algebra
open import Data.Fin.Base   as Fin     using (Fin)
open import Data.Vec.Base   as Vec     using (Vec; _∷_; [])
open import Data.List.Base  as List    using (List; _∷_; [])
open import Data.Maybe.Base as Maybe   using (Maybe; just; nothing; fromMaybe)
open import Data.Nat.Base              using (ℕ; suc; zero; _<ᵇ_)
open import Data.Bool.Base             using (Bool; if_then_else_; true; false)
open import Data.Unit.Base             using (⊤)
open import Data.String.Base as String using (String; _++_; parens)
open import Data.Product.Base          using (_,_; proj₁)
open import Function.Base
open import Relation.Nullary.Decidable

open import Reflection
open import Reflection.AST.Argument
open import Reflection.AST.Term as Term
open import Reflection.AST.AlphaEquality
open import Reflection.AST.Name as Name
open import Reflection.TCM.Syntax
open import Data.Nat.Reflection
open import Data.List.Reflection
import Data.Vec.Reflection as Vec

open import Tactic.RingSolver.NonReflective renaming (solve to solver)
open import Tactic.RingSolver.Core.AlmostCommutativeRing
open import Tactic.RingSolver.Core.NatSet as NatSet

open AlmostCommutativeRing

------------------------------------------------------------------------
-- Utilities

private
  VarMap : Set
  VarMap = ℕ → Maybe Term

  getVisible : Arg Term → Maybe Term
  getVisible (arg (arg-info visible _) x) = just x
  getVisible _                            = nothing

  getVisibleArgs : ∀ n → Term → Maybe (Vec Term n)
  getVisibleArgs n (def _ xs) = Maybe.map Vec.reverse
    (List.foldl f c (List.mapMaybe getVisible xs) n)
    where
    f : (∀ n → Maybe (Vec Term n)) → Term → ∀ n → Maybe (Vec Term n)
    f xs x zero    = just []
    f xs x (suc n) = Maybe.map (x ∷_) (xs n)

    c : ∀ n → Maybe (Vec Term n)
    c zero     = just []
    c (suc _ ) = nothing
  getVisibleArgs _ _ = nothing

  curriedTerm : NatSet → Term
  curriedTerm = List.foldr go Vec.`[] ∘ NatSet.toList
    where
    go : ℕ → Term → Term
    go x xs = var x [] Vec.`∷ xs

------------------------------------------------------------------------
-- Reflection utilities for rings

`AlmostCommutativeRing : Term
`AlmostCommutativeRing = def (quote AlmostCommutativeRing) (2 ⋯⟨∷⟩ [])

record RingOperatorTerms : Set where
  constructor add⇒_mul⇒_pow⇒_neg⇒_sub⇒_
  field
    add mul pow neg sub : Term

checkIsRing : Term → TC Term
checkIsRing ring = checkType ring `AlmostCommutativeRing

module RingReflection (`ring : Term) where

  -- Takes the name of a function that takes the ring as it's first
  -- explicit argument and the terms of it's arguments and inserts
  -- the required ring arguments
  --   e.g. "_+_" $ʳ xs = "_+_ {_} {_} ring xs"
  infixr 6 _$ʳ_
  _$ʳ_ : Name → Args Term → Term
  nm $ʳ args = def nm (2 ⋯⟅∷⟆ `ring ⟨∷⟩ args)

  `Carrier : Term
  `Carrier = quote Carrier $ʳ []

  `refl : Term
  `refl = quote refl $ʳ (1 ⋯⟅∷⟆ [])

  `sym : Term → Term
  `sym x≈y = quote sym $ʳ (2 ⋯⟅∷⟆ x≈y ⟨∷⟩ [])

  `trans : Term → Term → Term
  `trans x≈y y≈z = quote trans $ʳ (3 ⋯⟅∷⟆ x≈y ⟨∷⟩ y≈z ⟨∷⟩ [])

  -- Normalises each of the fields of the ring operator so we can
  -- compare the result against the normalised definitions we come
  -- across when converting the term passed to the macro.
  getRingOperatorTerms : TC RingOperatorTerms
  getRingOperatorTerms = ⦇
    add⇒ normalise (quote _+_  $ʳ [])
    mul⇒ normalise (quote _*_  $ʳ [])
    pow⇒ normalise (quote _^_  $ʳ [])
    neg⇒ normalise (quote (-_) $ʳ [])
    sub⇒ normalise (quote _-_  $ʳ [])
    ⦈

------------------------------------------------------------------------
-- Reflection utilities for ring solver

module RingSolverReflection (ring : Term) (numberOfVariables : ℕ) where
  open RingReflection ring

  `numberOfVariables : Term
  `numberOfVariables = toTerm numberOfVariables

  -- This function applies the hidden arguments that the constructors
  -- that Expr needs. The first is the universe level, the second is the
  -- type it contains, and the third is the number of variables it's
  -- indexed by. All three of these could likely be inferred, but to
  -- make things easier we supply the third because we know it.
  infix -1 _$ᵉ_
  _$ᵉ_ : Name → List (Arg Term) → Term
  e $ᵉ xs = con e (1 ⋯⟅∷⟆ `Carrier ⟅∷⟆ `numberOfVariables ⟅∷⟆ xs)

  -- A constant expression.
  `Κ : Term → Term
  `Κ x = quote Κ $ᵉ (x ⟨∷⟩ [])

  `I : Term → Term
  `I x = quote Ι $ᵉ (x ⟨∷⟩ [])

  infixl 6 _`⊜_

  _`⊜_ : Term → Term → Term
  x `⊜ y = quote _⊜_  $ʳ (`numberOfVariables ⟅∷⟆ x ⟨∷⟩ y ⟨∷⟩ [])

  `correct : Term → Term → Term
  `correct x ρ = quote Ops.correct $ʳ (1 ⋯⟅∷⟆ x ⟨∷⟩ ρ ⟨∷⟩ [])

  `solver : Term → Term → Term
  `solver `f `eq = quote solver $ʳ (`numberOfVariables ⟨∷⟩ `f ⟨∷⟩ `eq ⟨∷⟩ [])

  -- Converts the raw terms provided by the macro into the `Expr`s
  -- used internally by the solver.
  --
  -- When trying to figure out the shape of an expression, one of
  -- the difficult tasks is recognizing where constants in the
  -- underlying ring are used. If we were only dealing with ℕ, we
  -- might look for its constructors: however, we want to deal with
  -- arbitrary types which implement AlmostCommutativeRing. If the
  -- Term type contained type information we might be able to
  -- recognize it there, but it doesn't.
  --
  -- We're in luck, though, because all other cases in the following
  -- function *are* recognizable. As a result, the "catch-all" case
  -- will just assume that it has a constant expression.
  convertTerm : RingOperatorTerms → VarMap → Term → TC Term
  convertTerm operatorTerms varMap = convert
    where
    open RingOperatorTerms operatorTerms

    mutual
      convert : Term → TC Term
      -- First try and match directly against the fields
      convert (def (quote _+_) xs) = convertOp₂ (quote _⊕_) xs
      convert (def (quote _*_) xs) = convertOp₂ (quote _⊗_) xs
      convert (def (quote  -_) xs) = convertOp₁ (quote  ⊝_) xs
      convert (def (quote _^_) xs) = convertExp xs
      convert (def (quote _-_) xs) = convertSub xs
      -- Other definitions the underlying implementation of the ring's fields
      convert (def nm          xs) = convertUnknownName nm xs
      -- Variables
      convert v@(var x _)          = pure $ fromMaybe (`Κ v) (varMap x)
      -- Special case to recognise "suc" for naturals
      convert (`suc x)             = convertSuc x
      -- Otherwise we're forced to treat it as a constant
      convert t                    = pure $ `Κ t

      -- Application of a ring operator often doesn't have a type as
      -- simple as "Carrier → Carrier → Carrier": there may be hidden
      -- arguments, etc. Here, we do our best to handle those cases,
      -- by just taking the last two explicit arguments.
      convertOp₂ : Name → Args Term → TC Term
      convertOp₂ nm (x ⟨∷⟩ y ⟨∷⟩ []) = do
        x' ← convert x
        y' ← convert y
        pure (nm $ᵉ (x' ⟨∷⟩ y' ⟨∷⟩ []))
      convertOp₂ nm (x ∷ xs)         = convertOp₂ nm xs
      convertOp₂ _  _                = pure unknown

      convertOp₁ : Name → Args Term → TC Term
      convertOp₁ nm (x ⟨∷⟩ []) = do
        x' ← convert x
        pure (nm $ᵉ (x' ⟨∷⟩ []))
      convertOp₁ nm (x ∷ xs)   = convertOp₁ nm xs
      convertOp₁ _  _          = pure unknown

      convertExp : Args Term → TC Term
      convertExp (x ⟨∷⟩ y ⟨∷⟩ []) = do
        x' ← convert x
        pure (quote _⊛_ $ᵉ (x' ⟨∷⟩ y ⟨∷⟩ []))
      convertExp (x ∷ xs)         = convertExp xs
      convertExp _                = pure unknown

      convertSub : Args Term → TC Term
      convertSub (x ⟨∷⟩ y ⟨∷⟩ []) = do
        x'  ← convert x
        -y' ← convertOp₁ (quote (⊝_)) (y ⟨∷⟩ [])
        pure (quote _⊕_ $ᵉ x' ⟨∷⟩ -y' ⟨∷⟩ [])
      convertSub (x ∷ xs)         = convertSub xs
      convertSub _                = pure unknown

      convertUnknownName : Name → Args Term → TC Term
      convertUnknownName nm xs = do
        nameTerm ← normalise (def nm [])
        if (nameTerm =α= add) then convertOp₂ (quote _⊕_) xs else
          if (nameTerm =α= mul) then convertOp₂ (quote _⊗_) xs else
            if (nameTerm =α= neg) then convertOp₁ (quote ⊝_)  xs else
              if (nameTerm =α= pow) then convertExp             xs else
                if (nameTerm =α= sub) then convertSub            xs else
                  pure (`Κ (def nm xs))

      convertSuc : Term → TC Term
      convertSuc x = do x' ← convert x; pure (quote _⊕_ $ᵉ (`Κ (toTerm 1) ⟨∷⟩ x' ⟨∷⟩ []))

------------------------------------------------------------------------
-- Macros
------------------------------------------------------------------------
-- Quantified macro

open RingReflection
open RingSolverReflection

malformedForallTypeError : ∀ {a} {A : Set a} → Term → TC A
malformedForallTypeError found = typeError
  ( strErr "Malformed call to solve."
  ∷ strErr "Expected target type to be like: ∀ x y → x + y ≈ y + x."
  ∷ strErr "Instead: "
  ∷ termErr found
  ∷ [])

quantifiedVarMap : ℕ → VarMap
quantifiedVarMap numVars i =
  if i <ᵇ numVars
    then just (var i [])
    else nothing

constructCallToSolver : Term → RingOperatorTerms → List String → Term → Term → TC Term
constructCallToSolver `ring opNames variables `lhs `rhs = do
  `lhsExpr ← conv `lhs
  `rhsExpr ← conv `rhs

  pure $ `solver `ring numVars
                    (prependVLams variables (_`⊜_ `ring numVars `lhsExpr `rhsExpr))
                    (prependHLams variables (`refl `ring))
  where
  numVars : ℕ
  numVars = List.length variables

  conv : Term → TC Term
  conv = convertTerm `ring numVars opNames (quantifiedVarMap numVars)

-- This is the main macro which solves for equations in which the
-- variables are universally quantified over:
--
--   lemma : ∀ x y → x + y ≈ y + x
--   lemma = solve-∀ ring
--
-- where ring is your implementation of AlmostCommutativeRing.
-- (Find some example implementations in
-- Polynomial.Solver.Ring.AlmostCommutativeRing.Instances).
solve-∀-macro : Name → Term → TC ⊤
solve-∀-macro ring hole = do
  `ring ← checkIsRing (def ring [])
  commitTC
  operatorTerms ← getRingOperatorTerms `ring

  -- Obtain and sanitise the goal type
  `hole ← inferType hole >>= reduce
  let variablesAndTypes , equation = stripPis `hole

  let variables = List.map proj₁ variablesAndTypes
  just (lhs ∷ rhs ∷ []) ← pure (getVisibleArgs 2 equation)
    where nothing → malformedForallTypeError `hole

  solverCall ← constructCallToSolver `ring operatorTerms variables lhs rhs
  unify hole solverCall

macro
  solve-∀ : Name → Term → TC ⊤
  solve-∀ = solve-∀-macro

------------------------------------------------------------------------
-- Unquantified macro

malformedArgumentListError : ∀ {a} {A : Set a} → Term → TC A
malformedArgumentListError found = typeError
  ( strErr "Malformed call to solve."
  ∷ strErr "First argument should be a list of free variables."
  ∷ strErr "Instead: "
  ∷ termErr found
  ∷ [])

malformedGoalError : ∀ {a} {A : Set a} → Term → TC A
malformedGoalError found = typeError
  ( strErr "Malformed call to solve."
  ∷ strErr "Goal type should be of the form: LHS ≈ RHS"
  ∷ strErr "Instead: "
  ∷ termErr found
  ∷ [])

checkIsListOfVariables : Term → Term → TC Term
checkIsListOfVariables `ring `xs = checkType `xs (`List (`Carrier `ring)) >>= normalise

-- Extracts the deBruijn indices from a list of variables
getVariableIndices : Term → Maybe NatSet
getVariableIndices = go []
  where
  go : NatSet → Term → Maybe NatSet
  go t (var i [] `∷` xs) = go (insert i t) xs
  go t `[]`              = just t
  go _ _                 = nothing

constructSolution : Term → RingOperatorTerms → NatSet → Term → Term → TC Term
constructSolution `ring opTerms variables `lhs `rhs = do
  `lhsExpr ← conv `lhs
  `rhsExpr ← conv `rhs
  pure $ `trans `ring (`sym `ring `lhsExpr) `rhsExpr
  where
  numVars = List.length variables

  varMap : VarMap
  varMap i = Maybe.map (λ x → `I `ring numVars (toFinTerm x)) (lookup variables i)

  ρ : Term
  ρ = curriedTerm variables

  conv = λ t → do
    t' ← convertTerm `ring numVars opTerms varMap t
    pure $ `correct `ring numVars t' ρ

-- Use this macro when you want to solve something *under* a lambda.
-- For example: say you have a long proof, and you just want the solver
-- to deal with an intermediate step. Call it like so:
--
--   lemma₃ : ∀ x y → x + y * 1 + 3 ≈ 2 + 1 + y + x
--   lemma₃ x y = begin
--     x + y * 1 + 3 ≈⟨ +-comm x (y * 1) ⟨ +-cong ⟩ refl ⟩
--     y * 1 + x + 3 ≈⟨ solve (x ∷ y ∷ []) Int.ring ⟩
--     3 + y + x     ≡⟨ refl ⟩
--     2 + 1 + y + x ∎
--
-- The first argument is the free variables, and the second is the
-- ring implementation (as before).
solve-macro : Term → Name → Term → TC ⊤
solve-macro variables ring hole = do
  `ring ← checkIsRing (def ring [])
  commitTC
  operatorTerms ← getRingOperatorTerms `ring

  -- Obtain and sanitise the list of variables
  listOfVariables′ ← checkIsListOfVariables `ring variables
  commitTC
  just variableIndices ← pure (getVariableIndices listOfVariables′)
    where nothing → malformedArgumentListError listOfVariables′

  -- Obtain and santise the goal type
  hole′ ← inferType hole >>= reduce
  just (lhs ∷ rhs ∷ []) ← pure (getVisibleArgs 2 hole′)
    where nothing → malformedGoalError hole′

  solution ← constructSolution `ring operatorTerms variableIndices lhs rhs
  unify hole solution

macro
  solve : Term → Name → Term → TC ⊤
  solve = solve-macro