Half-completed crypto experiments in Haskell.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PackageImports #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-- __FIXME__: See "Sponge" for the justification for these.
{-# OPTIONS_GHC -Wno-missing-safe-haskell-mode -Wno-unsafe #-}
-- Because of big `Fin`s.
{-# OPTIONS_GHC -freduction-depth=0 #-}

-- | This is an implementation of [Keccak](https://keccak.team/keccak.html) with
--   a lot of static guarantees.
--
--   There are some deviations from the spec:
-- * `keccak_f`, rather than taking the value of @b@ takes the value of @l@, so
--   where the spec says "Keccak-/f/[1600]", we would write
--   @`keccak_f` (Proxy \@(`Nat.fromGHC` 6))@.
-- * `keccak_rc` and `keccak_c` each take an extra `Proxy` (@l@ and @r@,
--   respectively), but these are uniquely determined by the other parameters,
--   so it's a minor annoyance.
module Keccak
  ( keccak_rc,
    keccak_c,
    keccak,
  )
where

import Control.Applicative (Applicative (..))
import Control.Category (Category (..))
import Control.Lens ((&), (.~))
import Data.Bit (Bit (..), Vector)
import Data.Bits (Bits (..))
import Data.Bool (Bool (..), bool)
import Data.Eq (Eq (..))
import Data.Fin (Fin)
import Data.Foldable (Foldable (..))
import Data.Function (flip)
import Data.Maybe (Maybe)
import Data.Ord (Ord (..))
import Data.Proxy (Proxy (..))
import Data.Semigroup (Semigroup (..))
import Data.Type.Nat (Nat (..))
import Data.Type.Nat qualified as Nat
import Data.Type.Nat.LE qualified as Nat
import Data.Vec.Lazy (Vec (..), (!))
import Data.Vec.Lazy qualified as Vec
import Data.Vec.Lazy.Lens qualified as Vec
import Data.Vector.Generic qualified as V
import Data.Word (Word64)
import Numeric (Floating (..))
import "this" Sponge (B, Block, State, W, WBits, sponge)
import Prelude
  ( Double,
    Integral (..),
    Num (..),
    RealFrac (truncate),
    fromIntegral,
    ($),
  )

imap2 :: (Fin m -> Fin n -> a -> b) -> Vec m (Vec n a) -> Vec m (Vec n b)
imap2 f = Vec.imap $ Vec.imap . f

-- = Keccak-specific definitions

theta :: Nat.SNatI (W l) => Proxy l -> State l -> State l
theta Proxy a =
  let c =
        Vec.map
          ( \ax ->
              Vec.tabulate (\z -> foldr1 xor (Vec.tabulate (\y -> (ax ! y) ! z)))
          )
          a
      d =
        Vec.tabulate
          ( \x ->
              Vec.tabulate
                ( \i ->
                    let q = (c ! (x - 1)) ! i
                        u = (c ! (x + 1)) ! (i - 1)
                     in q `xor` u
                )
          )
   in imap2 (\x _ -> xor (d ! x)) a

-- |
--  __FIXME__: I’m pretty sure this isn’t doing right thing.
rho_pi :: Nat.SNatI (W l) => Proxy l -> State l -> State l
rho_pi Proxy a =
  Vec.tabulate
    ( \y -> do
        Vec.ifoldr
          ( \x ax by ->
              by & Vec.ix (2 * x + 3 * y)
                .~ Vec.tabulate (\i -> (ax ! y) ! (i - rot_tbl x y))
          )
          (pure $ pure False)
          a
    )

chi :: Bits a => Vec Block a -> Vec Block a
chi a = Vec.imap (\x -> xor $ complement (a ! (x + 1)) .&. (a ! (x + 2))) a

iota :: Nat.SNatI (W l) => Proxy l -> Word64 -> State l -> State l
iota Proxy rc =
  Vec.imap
    ( \x ax ->
        bool
          ax
          ( Vec.imap
              ( \y axy ->
                  bool axy (Vec.imap (xor . get_round_bit rc) axy) $ y == 0
              )
              ax
          )
          $ x == 0
    )

round ::
  (Nat.SNatI (W l), Bits (WBits l)) => Proxy l -> Word64 -> State l -> State l
round l rc = iota l rc . chi . rho_pi l . theta l

-- Should calculate this rather than building an explicit `Vec`.
round_consts :: Nat.LE n_r (Nat.FromGHC 24) => Vec n_r Word64
round_consts =
  Vec.take
    ( 0x00000001
        ::: 0x00008082
        ::: 0x0000808a
        ::: 0x80008000
        ::: 0x0000808b
        ::: 0x80000001
        ::: 0x80008081
        ::: 0x00008009
        ::: 0x0000008a
        ::: 0x00000088
        ::: 0x80008009
        ::: 0x8000000a
        ::: 0x8000808b
        ::: 0x800000000000008b
        ::: 0x8000000000008089
        ::: 0x8000000000008003
        ::: 0x8000000000008002
        ::: 0x8000000000000080
        ::: 0x800000000000800a
        ::: 0x800000008000000a
        ::: 0x8000000080008081
        ::: 0x8000000080008080
        ::: 0x0000000080000001
        ::: 0x8000000080008008
        ::: VNil
    )

-- where
--   rc t = (x ** t `mod` x ** 8 + x ** 6 + x ** 5 + x ** 4 + 1) `mod` x

rot_tbl :: forall w. Nat.SNatI w => Fin Block -> Fin Block -> Fin w
rot_tbl x y =
  let m :: Vec Block (Vec Block (Fin w))
      m =
        (0 ::: 36 ::: 3 ::: 105 ::: 210 ::: VNil)
          ::: (1 ::: 300 ::: 10 ::: 45 ::: 66 ::: VNil)
          ::: (190 ::: 6 ::: 171 ::: 15 ::: 253 ::: VNil)
          ::: (28 ::: 55 ::: 153 ::: 21 ::: 120 ::: VNil)
          ::: (91 ::: 276 ::: 231 ::: 136 ::: 78 ::: VNil)
          ::: VNil
   in (m ! x) ! y

get_round_bit :: Nat.SNatI w => Word64 -> Fin w -> Bool
get_round_bit round_c bit_i =
  let the_bit =
        round_c .&. truncate (2 ** fromIntegral bit_i :: Double)
   in the_bit > 0

-- | Ideally, this function would accept @b@ rather than @l@ as a type
--   parameter, but it's hard to define the type that way, so this accepts @l@,
--   which is required to be in the range [0..6], producing the following
--   values:
--
--   +---+-----+----+------+
--   | l | n_r | w  |  b   |
--   +===+=====+====+======+
--   | 0 |  12 |  1 |   25 |
--   +---+-----+----+------+
--   | 1 |  14 |  2 |   50 |
--   +---+-----+----+------+
--   | 2 |  16 |  4 |  100 |
--   +---+-----+----+------+
--   | 3 |  18 |  8 |  200 |
--   +---+-----+----+------+
--   | 4 |  20 | 16 |  400 |
--   +---+-----+----+------+
--   | 5 |  22 | 32 |  800 |
--   +---+-----+----+------+
--   | 6 |  24 | 64 | 1600 |
--   +---+-----+----+------+
keccak_f ::
  forall l n_r.
  ( Nat.SNatI (W l),
    Bits (WBits l),
    n_r ~ Nat.Plus (Nat.FromGHC 12) (Nat.Mult2 l),
    -- This implies @l <= 6@
    Nat.LE n_r (Nat.FromGHC 24)
  ) =>
  Proxy l ->
  State l ->
  State l
keccak_f l a = foldl (flip $ round l) a $ round_consts @n_r

pad10x1 :: Nat.SNatI x => Proxy ('S x) -> Nat -> Vector Bit
pad10x1 x l =
  -- __FIXME__: This is flagged as non-exhaustive, but because the `Proxy` is
  --            explicitly `'S`, it _is_ exhaustive.
  let (S x') = Nat.reflect x
   in V.cons (Bit True) $
        V.replicate
          ( fromIntegral $
              case x' - (l `mod` S x') of
                Z -> x'
                S neededZeros -> neededZeros
          )
          (Bit False)
          <> V.singleton (Bit True)

-- | Ideally this wouldn't require @l@, but until `keccak_f` accepts @b@ instead
--   of @l@, this is necessary.
keccak_rc ::
  forall l' l n_r r c n.
  ( Nat.SNatI r,
    Nat.SNatI c,
    Nat.SNatI n,
    Nat.SNatI (W l),
    Bits (WBits l),
    Nat.SNatI (Nat.Mult Block (W l)),
    B l ~ Nat.Plus ('S r) ('S c),
    Nat.LE ('S r) (B l),
    Nat.LE ('S l') (Nat.Mult ('S n) ('S r)),
    n_r ~ Nat.Plus (Nat.FromGHC 12) (Nat.Mult2 l),
    -- This implies @l <= 6@
    Nat.LE n_r (Nat.FromGHC 24)
  ) =>
  Proxy l ->
  Proxy ('S r) ->
  Proxy ('S c) ->
  -- | This must be at least ⌈l' / r⌉
  Proxy ('S n) ->
  Vector Bit ->
  Maybe (Vec ('S l') Bool)
keccak_rc l = sponge l (keccak_f l) pad10x1

-- | This taking @r@ seems a bit silly, but since @b@ is fixed, it means that
--   @r@ is uniquely determined by @c@ (and vice versa).
keccak_c ::
  forall l' l r c n.
  ( l ~ Nat.FromGHC 6,
    Nat.SNatI r,
    Nat.SNatI c,
    Nat.SNatI n,
    B l ~ Nat.Plus ('S r) ('S c),
    Nat.LE ('S r) (B l),
    Nat.LE ('S l') (Nat.Mult ('S n) ('S r))
  ) =>
  Proxy ('S r) ->
  Proxy ('S c) ->
  -- | This must be at least ⌈l' / r⌉
  Proxy ('S n) ->
  Vector Bit ->
  Maybe (Vec ('S l') Bool)
keccak_c = keccak_rc (Proxy @l)

keccak ::
  forall l' l r c n.
  ( l ~ Nat.FromGHC 6,
    r ~ Nat.FromGHC 1024,
    c ~ Nat.FromGHC 576,
    Nat.SNatI n,
    Nat.LE r (B l),
    Nat.LE ('S l') (Nat.Mult ('S n) r)
  ) =>
  -- | This must be at least ⌈l' / 1024⌉
  Proxy ('S n) ->
  Vector Bit ->
  Maybe (Vec ('S l') Bool)
keccak = keccak_c (Proxy @r) (Proxy @c)