{-# LANGUAGE OverloadedStrings #-}

module Tproxy where

import Ip
import System.Process (callProcess, readProcess)
import Data.List (isPrefixOf)
import Control.Exception (try, SomeException)
import System.IO (hPutStrLn, stderr)
import Control.Monad (forM_, void, when)

import System.Directory (findExecutable)


-- Run a system command with arguments
runCmd :: IO FilePath -> [String] -> IO ()
runCmd cmd args = do
  command <- cmd
  res <- try (callProcess command args) :: IO (Either SomeException ())
  case res of
    Left e  -> hPutStrLn stderr $ "[ERROR] " ++ command ++ " " ++ unwords args ++ ": " ++ show e
    Right _ -> putStrLn $ "[OK] " ++ command ++ " " ++ unwords args

iptablesBin :: IO FilePath
iptablesBin = findExecutable "iptables" >>= maybe (fail "iptables not found") pure
ipBin :: IO FilePath
ipBin = findExecutable "ip" >>= maybe (fail "ip not found") pure

-- | Run iptables with arguments.
--   Automatically prefixes "iptables" when using xtables-nft-multi.
iptables :: [String] -> IO ()
iptables args = runCmd iptablesBin args
-- | Run ip (iproute2) with arguments.
ip :: [String] -> IO String
ip args = do
  path <- ipBin
  readProcess path args ""


-- Detect the first global IPv4 address (like awk '{print $4}' in the script)
detectLocalCIDR :: IO (Maybe String)
detectLocalCIDR = do
  output <- ip ["-o", "-4", "addr", "show", "scope", "global"]
  let parts = [ w | line <- lines output, w <- words line, '/' `elem` w ]
  pure $ case parts of
    []    -> Nothing
    (x:_) -> Just x

-- Create all the XRAY and XRAY_MASK chains and rules
applyRules :: IO ()
applyRules = do
  mCIDR <- detectLocalCIDR
  let localCIDR = maybe "192.168.2.184/24" id mCIDR

  putStrLn $ "Using local CIDR: " ++ localCIDR
  iptables ["-t", "mangle", "-N", "XRAY"]
  iptables ["-t", "mangle", "-A", "XRAY", "-d", localCIDR, "-j", "RETURN"]

  let returns =
        [ "224.0.0.0/3", "0.0.0.0/8", "10.0.0.0/8", "100.64.0.0/10", "127.0.0.0/8"
        , "169.254.0.0/16", "172.16.0.0/12", "192.168.0.0/16", "198.18.0.0/15"
        , "224.0.0.0/4", "240.0.0.0/4" ]
  forM_ returns $ \r -> iptables ["-t", "mangle", "-A", "XRAY", "-d", r, "-j", "RETURN"]

  iptables ["-t", "mangle", "-A", "XRAY", "!", "-s", localCIDR, "-j", "RETURN"]
  iptables ["-t", "mangle", "-A", "XRAY", "-p", "tcp", "-j", "TPROXY", "--on-port", "12345", "--tproxy-mark", "1"]
  iptables ["-t", "mangle", "-A", "XRAY", "-p", "udp", "-j", "TPROXY", "--on-port", "12345", "--tproxy-mark", "1"]
  iptables ["-t", "mangle", "-A", "PREROUTING", "-j", "XRAY"]

  iptables ["-t", "mangle", "-N", "XRAY_MASK"]
  iptables ["-t", "mangle", "-A", "XRAY_MASK", "-m", "owner", "--gid-owner", "988", "-j", "RETURN"]

  let maskReturns =
        [ "0.0.0.0/8", "10.0.0.0/8", "127.0.0.0/8", "169.254.0.0/16", "172.16.0.0/12"
        , localCIDR, "224.0.0.0/4", "240.0.0.0/4" ]
  forM_ maskReturns $ \r -> iptables ["-t", "mangle", "-A", "XRAY_MASK", "-d", r, "-j", "RETURN"]

  iptables ["-t", "mangle", "-A", "XRAY_MASK", "-j", "MARK", "--set-mark", "1"]
  iptables ["-t", "mangle", "-A", "OUTPUT", "-p", "tcp", "-j", "XRAY_MASK"]
  iptables ["-t", "mangle", "-A", "OUTPUT", "-p", "udp", "-j", "XRAY_MASK"]

  runCmd ipBin ["route", "add", "local", "0.0.0.0/0", "dev", "lo", "table", "100"]
  runCmd ipBin ["rule", "add", "fwmark", "1", "table", "100"]

-- Remove all XRAY rules and chains
-- Think of using ip set to clear the udp2raw rules
clearRules :: IO ()
clearRules = do
  putStrLn "Clearing XRAY and XRAY_MASK chains..."
  iptables ["-t", "mangle", "-D", "PREROUTING", "-j", "XRAY"]
  iptables ["-t", "mangle", "-F", "XRAY"]
  iptables ["-t", "mangle", "-X", "XRAY"]
  iptables ["-t", "mangle", "-F", "XRAY_MASK"]
  iptables ["-t", "mangle", "-X", "XRAY_MASK"]

  void $ ip ["route", "del", "local", "default", "dev", "lo", "table", "100"]
  void $ ip ["rule", "del", "table", "100"]


applyFakeTCP :: String -> Int -> IO ()
applyFakeTCP server port = do
  iptables  ["-I", "INPUT", "-s", server, "-p", "tcp", "-m", "tcp", "--sport", show port, "-j", "DROP"]

clearFakeTCP :: String -> Int -> IO ()
clearFakeTCP server port = do
  iptables  ["-D", "INPUT", "-s", server, "-p", "tcp", "-m", "tcp", "--sport", show port, "-j", "DROP"]


applyICMP :: String -> IO ()
applyICMP server = do
  iptables  ["-I", "INPUT", "-s", server, "-p", "icmp", "--icmp-type", "0", "-j", "DROP"]

clearICMP :: String -> IO ()
clearICMP server = do
  iptables  ["-D", "INPUT", "-s", server, "-p", "icmp", "--icmp-type", "0", "-j", "DROP"]