{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module Rule (
    -- * Types
    Rule (..),

    -- * Classes
    IntoRule (..),

    -- * Constructors
    defaultRule,

    -- * Conversion
    fromTuple,
    fromRule,
    toSNats,
) where

import Data.Proxy

import Unsafe.Coerce (unsafeCoerce)

import GHC.Exts (withDict)
import GHC.TypeLits

-- | The distribution to divide another value by the three different categories.
data Rule t where
    MkRule ::
        forall a b c.
        -- (KnownNat a, KnownNat b, KnownNat c) =>
        SNat a
        -> SNat b
        -> SNat c
        -> Rule '(a, b, c)

class IntoRule a t | a -> t where
    toRule :: a -> Rule t

instance IntoRule (Rule t) t where
    toRule = id

-- | Convert a tuple kind to the term level
fromTuple ::
    forall a b c proxy.
    (KnownNat a, KnownNat b, KnownNat c) =>
    proxy '(a, b, c)
    -> (Integer, Integer, Integer)
fromTuple _ = (natVal (Proxy @a), natVal (Proxy @b), natVal (Proxy @c))

-- | Convert a Distribution type to the term level
fromRule ::
    Rule '(a, b, c)
    -> (Integer, Integer, Integer)
fromRule (MkRule a b c) = (fromSNat a, fromSNat b, fromSNat c)

defaultRule :: Rule '(50, 30, 20)
defaultRule = MkRule (SNat @50) (SNat @30) (SNat @20)

fromKnownNats :: forall n w s. (KnownNat n, KnownNat w, KnownNat s) => Rule '(n, w, s)
fromKnownNats = MkRule (SNat @n) (SNat @w) (SNat @s)

fromSTuple :: forall n w s. SNat n -> SNat w -> SNat s -> Rule '(n, w, s)
fromSTuple n w s = withKnownNat n $ withKnownNat w $ withKnownNat s $ fromKnownNats

toSNats :: forall n w s. Rule '(n, w, s) -> (SNat n, SNat w, SNat s)
toSNats (MkRule n w s) = (n, w, s)

toProxy :: forall n w s. Rule '(n, w, s) -> Proxy '(n, w, s)
toProxy _ = Proxy @'(n, w, s)