{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PackageImports #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -Wno-missing-safe-haskell-mode -Wno-unsafe #-}
{-# OPTIONS_GHC -freduction-depth=0 #-}
module Keccak
( keccak_rc,
keccak_c,
keccak,
)
where
import Control.Applicative (Applicative (..))
import Control.Category (Category (..))
import Control.Lens ((&), (.~))
import Data.Bit (Bit (..), Vector)
import Data.Bits (Bits (..))
import Data.Bool (Bool (..), bool)
import Data.Eq (Eq (..))
import Data.Fin (Fin)
import Data.Foldable (Foldable (..))
import Data.Function (flip)
import Data.Maybe (Maybe)
import Data.Ord (Ord (..))
import Data.Proxy (Proxy (..))
import Data.Semigroup (Semigroup (..))
import Data.Type.Nat (Nat (..))
import Data.Type.Nat qualified as Nat
import Data.Type.Nat.LE qualified as Nat
import Data.Vec.Lazy (Vec (..), (!))
import Data.Vec.Lazy qualified as Vec
import Data.Vec.Lazy.Lens qualified as Vec
import Data.Vector.Generic qualified as V
import Data.Word (Word64)
import Numeric (Floating (..))
import "this" Sponge (B, Block, State, W, WBits, sponge)
import Prelude
( Double,
Integral (..),
Num (..),
RealFrac (truncate),
fromIntegral,
($),
)
imap2 :: (Fin m -> Fin n -> a -> b) -> Vec m (Vec n a) -> Vec m (Vec n b)
imap2 f = Vec.imap $ Vec.imap . f
theta :: Nat.SNatI (W l) => Proxy l -> State l -> State l
theta Proxy a =
let c =
Vec.map
( \ax ->
Vec.tabulate (\z -> foldr1 xor (Vec.tabulate (\y -> (ax ! y) ! z)))
)
a
d =
Vec.tabulate
( \x ->
Vec.tabulate
( \i ->
let q = (c ! (x - 1)) ! i
u = (c ! (x + 1)) ! (i - 1)
in q `xor` u
)
)
in imap2 (\x _ -> xor (d ! x)) a
rho_pi :: Nat.SNatI (W l) => Proxy l -> State l -> State l
rho_pi Proxy a =
Vec.tabulate
( \y -> do
Vec.ifoldr
( \x ax by ->
by & Vec.ix (2 * x + 3 * y)
.~ Vec.tabulate (\i -> (ax ! y) ! (i - rot_tbl x y))
)
(pure $ pure False)
a
)
chi :: Bits a => Vec Block a -> Vec Block a
chi a = Vec.imap (\x -> xor $ complement (a ! (x + 1)) .&. (a ! (x + 2))) a
iota :: Nat.SNatI (W l) => Proxy l -> Word64 -> State l -> State l
iota Proxy rc =
Vec.imap
( \x ax ->
bool
ax
( Vec.imap
( \y axy ->
bool axy (Vec.imap (xor . get_round_bit rc) axy) $ y == 0
)
ax
)
$ x == 0
)
round ::
(Nat.SNatI (W l), Bits (WBits l)) => Proxy l -> Word64 -> State l -> State l
round l rc = iota l rc . chi . rho_pi l . theta l
round_consts :: Nat.LE n_r (Nat.FromGHC 24) => Vec n_r Word64
round_consts =
Vec.take
( 0x00000001
::: 0x00008082
::: 0x0000808a
::: 0x80008000
::: 0x0000808b
::: 0x80000001
::: 0x80008081
::: 0x00008009
::: 0x0000008a
::: 0x00000088
::: 0x80008009
::: 0x8000000a
::: 0x8000808b
::: 0x800000000000008b
::: 0x8000000000008089
::: 0x8000000000008003
::: 0x8000000000008002
::: 0x8000000000000080
::: 0x800000000000800a
::: 0x800000008000000a
::: 0x8000000080008081
::: 0x8000000080008080
::: 0x0000000080000001
::: 0x8000000080008008
::: VNil
)
rot_tbl :: forall w. Nat.SNatI w => Fin Block -> Fin Block -> Fin w
rot_tbl x y =
let m :: Vec Block (Vec Block (Fin w))
m =
(0 ::: 36 ::: 3 ::: 105 ::: 210 ::: VNil)
::: (1 ::: 300 ::: 10 ::: 45 ::: 66 ::: VNil)
::: (190 ::: 6 ::: 171 ::: 15 ::: 253 ::: VNil)
::: (28 ::: 55 ::: 153 ::: 21 ::: 120 ::: VNil)
::: (91 ::: 276 ::: 231 ::: 136 ::: 78 ::: VNil)
::: VNil
in (m ! x) ! y
get_round_bit :: Nat.SNatI w => Word64 -> Fin w -> Bool
get_round_bit round_c bit_i =
let the_bit =
round_c .&. truncate (2 ** fromIntegral bit_i :: Double)
in the_bit > 0
keccak_f ::
forall l n_r.
( Nat.SNatI (W l),
Bits (WBits l),
n_r ~ Nat.Plus (Nat.FromGHC 12) (Nat.Mult2 l),
-- This implies @l <= 6@
Nat.LE n_r (Nat.FromGHC 24)
) =>
Proxy l ->
State l ->
State l
keccak_f l a = foldl (flip $ round l) a $ round_consts @n_r
pad10x1 :: Nat.SNatI x => Proxy ('S x) -> Nat -> Vector Bit
pad10x1 x l =
let (S x') = Nat.reflect x
in V.cons (Bit True) $
V.replicate
( fromIntegral $
case x' - (l `mod` S x') of
Z -> x'
S neededZeros -> neededZeros
)
(Bit False)
<> V.singleton (Bit True)
keccak_rc ::
forall l' l n_r r c n.
( Nat.SNatI r,
Nat.SNatI c,
Nat.SNatI n,
Nat.SNatI (W l),
Bits (WBits l),
Nat.SNatI (Nat.Mult Block (W l)),
B l ~ Nat.Plus ('S r) ('S c),
Nat.LE ('S r) (B l),
Nat.LE ('S l') (Nat.Mult ('S n) ('S r)),
n_r ~ Nat.Plus (Nat.FromGHC 12) (Nat.Mult2 l),
-- This implies @l <= 6@
Nat.LE n_r (Nat.FromGHC 24)
) =>
Proxy l ->
Proxy ('S r) ->
Proxy ('S c) ->
Proxy ('S n) ->
Vector Bit ->
Maybe (Vec ('S l') Bool)
keccak_rc l = sponge l (keccak_f l) pad10x1
keccak_c ::
forall l' l r c n.
( l ~ Nat.FromGHC 6,
Nat.SNatI r,
Nat.SNatI c,
Nat.SNatI n,
B l ~ Nat.Plus ('S r) ('S c),
Nat.LE ('S r) (B l),
Nat.LE ('S l') (Nat.Mult ('S n) ('S r))
) =>
Proxy ('S r) ->
Proxy ('S c) ->
Proxy ('S n) ->
Vector Bit ->
Maybe (Vec ('S l') Bool)
keccak_c = keccak_rc (Proxy @l)
keccak ::
forall l' l r c n.
( l ~ Nat.FromGHC 6,
r ~ Nat.FromGHC 1024,
c ~ Nat.FromGHC 576,
Nat.SNatI n,
Nat.LE r (B l),
Nat.LE ('S l') (Nat.Mult ('S n) r)
) =>
Proxy ('S n) ->
Vector Bit ->
Maybe (Vec ('S l') Bool)
keccak = keccak_c (Proxy @r) (Proxy @c)