{-# LANGUAGE DerivingVia, GeneralizedNewtypeDeriving, StandaloneDeriving #-}
module Math.ParetoFront (
  Comparison(..),
  Debatable(..),
  Front,
  Strata,
  singleton,
  stratum,
  getStrata,
  getFront,
  quota,
  nestedFold
 ) where

import Data.Foldable
import Data.List (partition, transpose)
import Data.Ord (Down(..))
import Data.Semigroup(Arg(..), Max(..), Min(..))

-- | The outcome of comparing two items by possibly multiple contradicting
-- criteria.
data Comparison =
  -- | Where the second item is preferred by all used criteria
  Dominated |
  -- | Where no item is preferred by any used criterion
  WeakTie |
  -- | Where each item is preferred by at least one criterion
  StrongTie |
  -- | Where the first item is preferred by all used critera
  Dominates
 deriving(Ord,Eq,Show,Read)

instance Semigroup Comparison where
  WeakTie <> b = b
  a <> WeakTie = a
  Dominates <> Dominates = Dominates
  Dominated <> Dominated = Dominated
  _ <> _ = StrongTie

instance Monoid Comparison where
  mempty = WeakTie

-- | Items which can be compared by possibly multiple criteria contradicting
-- criteria
class Debatable a where
  weigh :: a -> a -> Comparison

instance Ord a => Debatable (Min a) where
  weigh (Min a) (Min b) = case compare a b of
    LT -> Dominates
    GT -> Dominated
    EQ -> WeakTie

deriving via Min Int instance Debatable Int
deriving via Min Integer instance Debatable Integer
deriving via Min Double instance Debatable Double
deriving via Min Float instance Debatable Float
deriving via Min (Down a) instance Ord a => Debatable (Max a)

instance Debatable a => Debatable (Down a) where
  weigh (Down a) (Down b) = case weigh a b of
    Dominates -> Dominated
    Dominated -> Dominates
    r -> r

instance Debatable a => Debatable (Arg a b) where
  weigh (Arg a _) (Arg b _) = weigh a b

instance (Debatable a, Debatable b) => Debatable (a,b) where
  weigh ~(a1,a2) ~(b1,b2) = weigh a1 b1 <> weigh a2 b2

instance (Debatable a, Debatable b, Debatable c) => Debatable (a,b,c) where
  weigh ~(a1,a2,a3) ~(b1,b2,b3) = weigh a1 b1 <> weigh a2 b2 <> weigh a3 b3

instance (Debatable a, Debatable b, Debatable c, Debatable d) =>
  Debatable (a,b,c,d) where
  weigh ~(a1,a2,a3,a4) ~(b1,b2,b3,b4) = weigh a1 b1 <>
    weigh a2 b2 <>
    weigh a3 b3 <>
    weigh a4 b4

-- | A collection of items where no item is preferred by all criteria.
newtype Front a = Front [a] deriving (Foldable, Show)

-- | A series of 'Front's such that each subsequent 'Front' consists of
-- items for which some item in the previous front is preferable by all
-- criteria.
newtype Strata a = Strata  [Front a] deriving (Show)

instance Foldable Strata where
  foldMap f (Strata l) = foldMap (foldMap f) l

singleton :: a -> Front a
singleton a = Front [a]

stratum :: a -> Strata a
stratum a = Strata [singleton a]

fuse :: Debatable a => Front a -> Front a -> (Front a, Front a, Front a)
fuse (Front a) (Front b) = let
  m = map (flip map b . weigh) a
  m' = transpose m
  s = map (not . any (== Dominated)) m
  s' = map (not . any (== Dominates)) m'
  (f1, t1) = partition snd $ zip a s
  (f2, t2) = partition snd $ zip b s'
  in (Front $ map fst (f1 ++ f2), Front $ map fst t1, Front $ map fst t2)

-- | Where two 'Front's are combined, all items are retained except those for
-- which at least one other item is preferred by all criteria.
instance Debatable a => Semigroup (Front a) where
  a <> b = let
    (r, _, _) = fuse a b
    in r

instance Debatable a => Monoid (Front a) where
  mempty = Front []

instance Debatable a => Semigroup (Strata a) where
  a <> b = mconcat [a,b]
  
instance Debatable a => Monoid (Strata a) where
  mempty = Strata []
  mconcat = Strata . rebuild . transpose . map getStrata where
    rebuild [] = []
    rebuild ([] : r) = rebuild r
    rebuild ([x] : r) = x : rebuild r
    rebuild ((x : y : s) : r) = let
      (f, p, q) = fuse x y
      r' = push p $ push q r
      in rebuild ((f: s): r')
    push (Front []) r = r
    push q [] = [[q]]
    push q (s : r) = ((q:s) : r)


getFront :: Front a -> [a]
getFront (Front l) = l

getStrata :: Strata a -> [Front a]
getStrata (Strata l) = l

-- | Drop fronts after those accounting for the first n items.
quota :: Int -> Strata a -> Strata a
quota _ (Strata []) = Strata []
quota n (Strata (a:r))
  | n > 0 = let
    Strata r' = quota (n - length (getFront a)) (Strata r)
    in Strata (a : r')
  | otherwise = Strata []

-- | 'foldMap' each front separately with one function, then 'foldMap' the
-- results.
nestedFold :: (Monoid m, Monoid n) => (a -> m) -> (m -> n) -> Strata a -> n
nestedFold f g (Strata l) = foldMap (g . foldMap f) l