Half-completed crypto experiments in Haskell.
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE Safe #-}
{-# LANGUAGE StandaloneDeriving #-}

module PLONKish
  ( OrderedPair,
    orderedPair,
    MultivarPoly (..),
    MultivariatePolynomial (..),
    evaluateMultivariatePolynomial,
    ColumnSpecification (..),
    Configuration (..),
    Circuit (configuration, values),
    circuit,
    EqualityConstraintFailure (..),
    addEqualityConstraint,
    -- generateProvingKey,
    -- generateVerificationKey,
  )
where

import Control.Applicative (Applicative (..))
import Control.Category (Category (..))
import Control.Exception (Exception (..))
import Data.Data (Data)
import Data.Eq (Eq (..))
import Data.Fin (Fin)
import Data.Foldable (Foldable (..))
import Data.Functor (Functor (..))
import Data.List.NonEmpty (NonEmpty, nonEmpty)
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Maybe (Maybe (..), catMaybes, maybe)
import Data.Ord (Ord (..))
import Data.Proxy (Proxy (..))
import Data.Semigroup (Semigroup (..), Sum (..))
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Traversable (Traversable)
import Data.Tuple qualified as Tuple
import Data.Type.Nat (Nat (..))
import Data.Type.Nat qualified as Nat
import Data.Typeable (Typeable)
import Data.Vec.Lazy (Vec)
import Data.Vec.Lazy qualified as Vec
import GHC.Generics (Generic, Generic1)
import Prelude (Bool, Either (..), Integer, Num (..), Show (..), flip, ($))

-- | A pair that maintains keeps the elements in increasing order
data OrderedPair a = OrderedPair a a
  deriving stock (Data, Eq, Generic, Ord, Foldable, Functor, Generic1, Traversable)

orderedPair :: Ord a => a -> a -> OrderedPair a
orderedPair a b = if a < b then OrderedPair a b else OrderedPair b a

-- | A sequence with /at most/ @n@ elements.
data BoundedList n a where
  BNil :: BoundedList n a
  BCons :: a -> BoundedList n a -> BoundedList ('S n) a

deriving stock instance Eq a => Eq (BoundedList n a)

deriving stock instance Foldable (BoundedList n)

-- | Make a smaller `BoundedList` fit into a space that requires a larger one.
weaken ::
  forall m n a. Proxy m -> BoundedList n a -> BoundedList (Nat.Plus n m) a
weaken = go
  where
    go :: Proxy m -> BoundedList o a -> BoundedList (Nat.Plus o m) a
    go Proxy BNil = BNil
    go m (BCons x xs) = BCons x (go m xs)

-- | A `BoundedList` whose elements are maintained in order.
newtype OrderedBoundedList n a = OrderedBoundedList
  { unorderedBoundedList :: BoundedList n a
  }
  deriving stock (Foldable)

bcons :: a -> OrderedBoundedList n a -> OrderedBoundedList ('S n) a
bcons a = OrderedBoundedList . BCons a . unorderedBoundedList

type Field = Num

-- | Represents a multivariate polynomial with a maximum degree.
--
--   This breaks down the exponents into repetitions of the variables. E.g..,
--   @7 * x^2 * y * z^3@ has the variables mapped to x=0, y=1, z=2 and produces
--   a `Map` entry that looks like @([0, 0, 1, 2, 2, 2], 7)`. The list is
--   maintained in order and guaranteed (by the type system) to never have more
--   than @maximumDegree@ elements.
type MVP vars maximumDegree field =
  Map (OrderedBoundedList maximumDegree (Fin vars)) field

-- | Lowers a multivariate polynomial to a function.
evaluateMVP ::
  Field field => MVP vars maximumDegree field -> Vec vars field -> field
evaluateMVP mvp values =
  getSum $
    Map.foldMapWithKey
      ( \k coefficient ->
          Sum $ foldr (\i acc -> (values Vec.! i) * acc) coefficient k
      )
      mvp

data MultivariatePolynomial f maximumDegree c vars = MultivariatePolynomial
  { -- | The row here is `Integer`, because in the `Configuration` we don't know
    --   how many rows we have, and so the actual row is calculated later,
    --   relative to the current row, modulo the number of rows.
    variables :: Vec vars (Fin c, Integer),
    polynomial :: MVP vars maximumDegree f
  }

evaluateMultivariatePolynomial ::
  (Field f, Nat.SNatI r) =>
  MultivariatePolynomial f maximumDegree c vars ->
  Circuit f c r maximumDegree ->
  Fin r ->
  f
evaluateMultivariatePolynomial mvp circuit currentRow =
  evaluateMVP (polynomial mvp)
    . fmap
      ( \(c, r) ->
          (values circuit Vec.! c) Vec.! (currentRow + fromInteger r)
      )
    $ variables mvp

data ColumnSpecification
  = Fixed
  | Advice
  | Instance

data MultivarPoly f maximumConstraintDegree c
  = forall v. MultivarPoly (MultivariatePolynomial f maximumConstraintDegree c v)

-- @f@ should be a `FiniteField`, @c@ is the number of columns.
data Configuration f c maximumConstraintDegree = Configuration
  { columns :: Vec c ColumnSpecification,
    equalityConstraintColumns :: Set (Fin c),
    polynomialConstraints :: Set (MultivarPoly f maximumConstraintDegree c),
    -- |
    --
    --  __TODO__: Determine if a `Map` is the right structure here. I.e., do we
    --            ever need to look up by column and do we need to support
    --            duplicate polynomials? (If the latter, we need to decide
    --            between a `MultiMap` and a @`Map` (`MultivariatePolynomial` f)
    --            [`Fin` c]@).
    lookupArguments :: Map (MultivarPoly f maximumConstraintDegree c) (Fin c)
  }

checkEqualityColumn :: Configuration f c mcd -> Fin c -> Bool
checkEqualityColumn = flip Set.member . equalityConstraintColumns

-- | @r@ is the number of rows in the `Circuit`.
data Circuit f c r maximumConstraintDegree = Circuit
  { configuration :: Configuration f c maximumConstraintDegree,
    -- |
    --
    --  __NB__: `Set` maintains its elements in ascending order, so we do have
    --          determinism for generating keys.
    equalityConstraints' :: Set (OrderedPair (Fin c, Fin r)),
    values :: Vec c (Vec r f)
  }

circuit ::
  Configuration f c maximumConstraintDegree ->
  Vec c (Vec r f) ->
  Circuit f c r maximumConstraintDegree
circuit conf = Circuit conf Set.empty

data EqualityConstraintFailure c r
  = IllegalEqualityColumn (Fin c) (Set (Fin c))
  | -- |
    --
    --  __TODO__: This should actually be a warning, not an error.
    TautologicalEqualityConstraint (Fin c, Fin r)
  deriving stock (Eq, Ord)

-- |
--
--  __TODO__: This should not be defined here, and the output should be
--            rephrased to make sense in the context in which it's forced.
instance Show (EqualityConstraintFailure c r) where
  show = \case
    IllegalEqualityColumn column allowedColumns ->
      "Column "
        <> show column
        <> " is not in the set of columns that can participate in equality constraints (`Halo2.PLONKish.Configuration.equalityConstraintColumns`): "
        <> show allowedColumns
        <> "."
    -- Documenting this as if it were a warning, even though it currently causes
    -- a failure.
    TautologicalEqualityConstraint cell ->
      "The equality constraint "
        <> show cell
        <> " == "
        <> show cell
        <> "is tautological, because both elements are the same. Refusing to add it."

instance (Typeable c, Typeable r) => Exception (EqualityConstraintFailure c r)

addEqualityConstraint ::
  (Fin c, Fin r) ->
  (Fin c, Fin r) ->
  Circuit f c r mcd ->
  Either (NonEmpty (EqualityConstraintFailure c r)) (Circuit f c r mcd)
addEqualityConstraint a b circuit =
  maybe
    ( pure $
        circuit
          { equalityConstraints' =
              Set.insert (orderedPair a b) (equalityConstraints' circuit)
          }
    )
    Left
    . nonEmpty
    $ validateEqualityConstraint (configuration circuit) a b

validateEqualityConstraint ::
  forall f c mcd r.
  Configuration f c mcd ->
  (Fin c, Fin r) ->
  (Fin c, Fin r) ->
  [EqualityConstraintFailure c r]
validateEqualityConstraint config a b =
  let invalidColumns =
        toList . Set.fromList $
          catMaybes [reifyEqualityCheck config a, reifyEqualityCheck config b]
   in if a == b
        then TautologicalEqualityConstraint a : invalidColumns
        else invalidColumns
  where
    reifyEqualityCheck ::
      Configuration f c mcd ->
      (Fin c, b) ->
      Maybe (EqualityConstraintFailure c r)
    reifyEqualityCheck config x =
      if checkEqualityColumn config (Tuple.fst x)
        then Nothing
        else
          Just . IllegalEqualityColumn (Tuple.fst x) $
            equalityConstraintColumns config

-- -- | This builds an elliptic curve from the two coefficients, returning a
-- --   function that, given the x component of a point will return the positive y
-- --   component. The negative component can be found via point inversion.
-- ellipticCurve :: Field f => f -> f -> f -> f
-- ellipticCurve a b x = sqrt $ x ** 3 + a * x + b

-- type F_101 = ()

-- meh :: F_101 -> F_101
-- meh = ellipticCurve 0 3

-- type ProvingKey = ()

-- type VerificationKey = ()

-- generateProvingKey :: Circuit f c r mcd -> ProvingKey
-- generateProvingKey = _

-- generateVerificationKey :: Circuit f c r mcd -> VerificationKey
-- generateVerificationKey = _