{-# LANGUAGE OverloadedStrings #-}

module Ip (waitForIP) where

import Control.Concurrent (threadDelay)
import Control.Exception (IOException, try)
import System.IO (hPutStrLn, stderr)
import qualified Data.ByteString.Char8 as B
import Network.DNS
import Network.Socket
import Data.List (isPrefixOf)
import Data.Maybe (listToMaybe)

waitForIP :: String -> String -> IO String
waitForIP ns domain = do
  result <- resolveWithHostname ns domain
  case result of
    [] -> do
      hPutStrLn stderr $ "Waiting for 2 seconds for " ++ domain
      threadDelay (2 * 1000000)
      waitForIP ns domain
    (ip:_) -> pure ip

resolveWithHostname :: String -> String -> IO [String]
resolveWithHostname nsHost domain = do
  nsIP <- resolveNameServer nsHost
  case nsIP of
    Left err -> do
      hPutStrLn stderr $ show err
      pure []
    Right (Just ipStr) -> do
      hPutStrLn stderr $ "Querying " ++ domain ++ " using nameserver " ++ ipStr
      -- old network-dns only supports RCFilePath or RCHostName
      let conf = defaultResolvConf { resolvInfo = RCHostName ipStr }
      rs <- makeResolvSeed conf
      withResolver rs $ \resolver -> do
        r <- lookupA resolver (B.pack domain)
        case r of
          Left err -> do
            hPutStrLn stderr $ "DNS query failed: " ++ show err
            pure []
          Right ips -> do
            hPutStrLn stderr $ "DNS reply: " ++ show ips
            pure (map show ips)
    Right Nothing -> do
      hPutStrLn stderr $ "No server found for " ++ domain
      pure []

resolveNameServer :: String -> IO (Either IOException (Maybe String))
resolveNameServer host
  | all (`elem` (".0123456789" :: String)) host =
      pure (Right (Just host))
  | otherwise = do
      result <- try $
        getAddrInfo (Just defaultHints) (Just host) (Just "53")
      case result of
        Left err ->
          pure (Left (err))

        Right infos ->
          case listToMaybe infos of
            Nothing ->
              pure (Right (Nothing))
            Just info ->
              let ip = stripPort . show . addrAddress $ info
              in pure (Right (Just ip))

stripPort :: String -> String
stripPort s
  | ":" `isPrefixOf` dropWhile (/=':') s = takeWhile (/=':') s
  | otherwise = s