module STLC where

open import Data.Nat using (ℕ; zero; suc; _+_; _≤_)
open import Data.Nat.Properties using (m≤m+n)
open import Data.List using (List; []; _∷_; length)
open import Data.Product using (_×_; uncurry; proj₁; proj₂) renaming (_,_ to ⟨_,_⟩)
open import Relation.Nullary using (Dec; yes; no; ¬_)
open import Relation.Nullary.Product using (_×-dec_)
open import Relation.Nullary.Decidable using (map′)
open import Data.List.Relation.Unary.Any using (Any; here; there)
open import Relation.Binary.PropositionalEquality using (_≡_; _≢_; refl; cong; cong₂)

infix  4 _⊢_
infix  4 _∋_
infixl 5 _,_
infixl 5 `λ_
infixr 7 _`⇒_
infixl 7 _`$_
infix  8 `S_
infix  9 `_

Id : Set
Id = ℕ

data Ty : Set where
 `ℕ   : Ty
 _`⇒_ : Ty → Ty → Ty

module _ where
  -- without map′, cong₂ and `⇒-dec
  `⇒-inj₁ : ∀ {τ τ' σ σ'} → τ `⇒ σ ≡ τ' `⇒ σ' → τ ≡ τ'
  `⇒-inj₁ refl = refl
  
  `⇒-inj₂ : ∀ {τ τ' σ σ'} → τ `⇒ σ ≡ τ' `⇒ σ' → σ ≡ σ'
  `⇒-inj₂ refl = refl
  
  _≡?_ :  (τ : Ty) → (σ : Ty) → Dec (τ ≡ σ)
  `ℕ ≡? `ℕ = yes refl
  `ℕ ≡? (_ `⇒ _) = no (λ ())
  (_ `⇒ _) ≡? `ℕ = no (λ ())
  (τ `⇒ σ) ≡? (τ' `⇒ σ') with τ ≡? τ'
  ((τ `⇒ σ) ≡? (τ' `⇒ σ')) | yes τ≡τ' with σ ≡? σ'  
  ((τ `⇒ σ) ≡? (τ' `⇒ σ')) | yes τ≡τ'  | yes σ≡σ' rewrite τ≡τ' rewrite σ≡σ' = yes refl
  ((τ `⇒ σ) ≡? (τ' `⇒ σ')) | yes τ≡τ'  | no ¬σ≡σ' = no (λ p → ¬σ≡σ' (`⇒-inj₂ p))
  ((τ `⇒ σ) ≡? (τ' `⇒ σ')) | no ¬τ≡τ'  = no (λ p → ¬τ≡τ' (`⇒-inj₁ p))

`⇒-inj : ∀ {τ τ' σ σ'} → τ `⇒ σ ≡ τ' `⇒ σ' → τ ≡ τ' × σ ≡ σ'
`⇒-inj refl = ⟨ refl , refl ⟩

{--
map′ : (P → Q) → (Q → P) → Dec P → Dec Q
cong₂ : x ≡ y → u ≡ v → f x u ≡ f y v
--}
`⇒-dec :  ∀ {τ τ' σ σ'} → Dec (τ ≡ τ' × σ ≡ σ') → Dec (τ `⇒ σ ≡ τ' `⇒ σ')
`⇒-dec prf = map′ (uncurry (cong₂ _`⇒_)) `⇒-inj prf

_≟_ :  (τ : Ty) → (σ : Ty) → Dec (τ ≡ σ)
`ℕ ≟ `ℕ                 = yes refl
`ℕ ≟ (_ `⇒ _)           = no (λ ())
(_ `⇒ _) ≟ `ℕ           = no (λ ())
(t `⇒ t₁) ≟ (t' `⇒ t'') = `⇒-dec (t ≟ t' ×-dec t₁ ≟ t'')

Ctx : Set
Ctx = List Ty

_,_ : ∀ {A : Set} → List A → A → List A
Γ , τ = τ ∷ Γ

_∋_ : ∀ {A} → List A → A -> Set
Γ ∋ τ = Any (λ σ -> σ ≡ τ) Γ

private
  variable
    Γ Δ : Ctx
    τ τ' σ σ' : Ty
    n m : ℕ

data Arrow : Ty → Set where
  arrow : Arrow (τ `⇒ σ)
  
arrow? : (τ : Ty) → Dec (Arrow τ)
arrow? `ℕ       = no (λ ())
arrow? (τ `⇒ σ) = yes arrow

index : Γ ∋ τ → ℕ
index (here px) = zero
index (there p) = suc (index p)

data Lookup (Γ : Ctx) : ℕ → Set where
  inside : (τ : Ty) → (p : Γ ∋ τ) → Lookup Γ (index p)
  outside : (m : ℕ) → Lookup Γ (length Γ + m)

_!_ : (Γ : Ctx) → (m : ℕ) → Lookup Γ m
[] ! m         = outside m
(x ∷ Γ) ! zero = inside x (here refl)
(x ∷ Γ) ! suc m with Γ ! m
... | inside τ p = inside τ (there p)
... | outside m  = outside m

data UTerm : Set where
  `_     : Id → UTerm
  `λ_∙_  : Ty → UTerm → UTerm
  `let_`in_ : UTerm → UTerm → UTerm
  `Z     : UTerm
  `S_    : UTerm → UTerm
  `ucase : UTerm → UTerm → UTerm → UTerm
  _`$_   : UTerm → UTerm → UTerm

data _⊢_ : Ctx → Ty → Set where

  `_   : Γ ∋ τ
         --------
         → Γ ⊢ τ

  `λ_  : Γ , τ ⊢ σ
         --------------
         → Γ ⊢ τ `⇒ σ  

  `let_`in_ : Γ ⊢ τ
              → Γ , τ ⊢ σ
              ------------
              → Γ ⊢ σ
              
  `Z   : Γ ⊢ `ℕ

  `S_  : Γ ⊢ `ℕ
         ----------
         →  Γ ⊢ `ℕ
         
  `case_[Z⇒_|S⇒_] :
          Γ ⊢ `ℕ
          → Γ ⊢ τ
          → Γ , `ℕ ⊢ τ
          --------------
          → Γ ⊢ τ
         

  _`$_ : Γ ⊢ τ `⇒ σ
         → Γ ⊢ τ
         ---------
         → Γ ⊢ σ

erase : Γ ⊢ τ → UTerm
erase (` x)             = ` index x
erase {Γ}{τ `⇒ σ}(`λ t) = `λ τ ∙ erase t
erase (`let t `in t₁) = `let erase t `in erase t₁
erase `Z                = `Z
erase (`S t)            = `S erase t
erase (`case n [Z⇒ t₁ |S⇒ t₂ ]) = `ucase (erase n) (erase t₁) (erase t₂)
erase (t `$ t₁)         = erase t `$ erase t₁

data _⊬_ : Ctx → UTerm → Set where
  ill-scoped       : length Γ ≤ n →  Γ ⊬ (` n)
  not-a-nat-S      : (t : Γ ⊢ τ) → τ ≢ `ℕ → Γ ⊬ (`S erase t)
  not-a-nat-case   : ∀ {a b} → (n : Γ ⊢ τ) → τ ≢ `ℕ → Γ ⊬ (`ucase (erase n) a b)
  not-a-function   : ∀ {a} → (f : Γ ⊢ τ) → ¬ Arrow τ → Γ ⊬ (erase f `$ a)
  ty-mismatch-app  : (f : Γ ⊢ τ `⇒ σ) → (a : Γ ⊢ τ') → τ ≢ τ' → Γ ⊬ ((erase f) `$ (erase a))
  ty-mismatch-case : ∀ {n} → (a :  Γ ⊢ τ) → (b : Γ , `ℕ ⊢ σ) → τ ≢ σ → Γ ⊬ (`ucase n (erase a) (erase b))
  propagate-λ      : ∀ {t} → (Γ , τ) ⊬ t → Γ ⊬ (`λ τ ∙ t) 
  propagate-let₀   : ∀ {t t₁} → Γ ⊬ t → Γ ⊬ (`let t `in t₁)
  propagate-let₁   : ∀ {t t₁}{τ} → (Γ , τ) ⊬ t₁ → Γ ⊬ (`let t `in t₁)
  propagate-S      : ∀ {t} → Γ ⊬ t → Γ ⊬ (`S t)
  propagate-left   : ∀ {f a} → Γ ⊬ f → Γ ⊬ (f `$ a) 
  propagate-right  : ∀ {f a} → Γ ⊬ a → Γ ⊬ (f `$ a) 
  propagate-case-n : ∀ {n a b} → Γ ⊬ n → Γ ⊬ (`ucase n a b)
  propagate-case-Z : ∀ {n a b} → Γ ⊬ a → Γ ⊬ (`ucase n a b)
  propagate-case-S : ∀ {n a b} → (Γ , `ℕ) ⊬ b → Γ ⊬ (`ucase n a b)
  
data Infer (Γ : Ctx) : UTerm → Set where
  inf : (τ : Ty) → (t : Γ ⊢ τ) → Infer Γ (erase t)
  bad : {t : UTerm} → Γ ⊬ t → Infer Γ t

infer : (Γ : Ctx) → (t : UTerm) → Infer Γ t
infer Γ (` x)              with Γ ! x
infer Γ (` .(index p))      | inside τ p = inf τ (` p)
infer Γ (` .(length Γ + m)) | outside m  = bad (ill-scoped (m≤m+n (length Γ) m))
infer Γ (`λ τ ∙ t)         with infer (Γ , τ) t
infer Γ (`λ τ ∙ .(erase t)) | inf σ t = inf (τ `⇒ σ) (`λ t)
infer Γ (`λ τ ∙ t)          | bad e   = bad (propagate-λ e)
infer Γ (`let t `in t₁)                  with infer Γ t
infer Γ (`let .(erase t) `in t₁)          | inf τ t with infer (Γ , τ) t₁
infer Γ (`let .(erase t) `in .(erase t₁)) | inf τ t  | inf σ t₁ = inf σ (`let t `in t₁)
infer Γ (`let .(erase t) `in t₁)          | inf τ t  | bad x = bad (propagate-let₁ x)
infer Γ (`let t `in t₁)                   | bad x = bad (propagate-let₀ x)
infer Γ `Z = inf `ℕ `Z
infer Γ (`S t)         with infer Γ t
infer Γ (`S .(erase t)) | inf `ℕ t       = inf `ℕ (`S t)
infer Γ (`S .(erase t)) | inf (τ `⇒ τ₁) t = bad (not-a-nat-S t (λ ()))
infer Γ (`S t)          | bad e           = bad (propagate-S e)
infer Γ (`ucase n t₁ t₂)                             with infer Γ n
infer Γ (`ucase .(erase n) t₁ t₂)                     | inf `ℕ n with infer Γ t₁
infer Γ (`ucase .(erase n) .(erase t₁) t₂)            | inf `ℕ n  | inf τ t₁ with infer (Γ , `ℕ) t₂
infer Γ (`ucase .(erase n) .(erase t₁) .(erase t₂))   | inf `ℕ n  | inf τ t₁  | inf σ t₂ with τ ≟ σ
infer Γ (`ucase .(erase n) _ .(erase t₂))             | inf `ℕ n  | inf τ t₁  | inf σ t₂  | yes τ≡σ rewrite τ≡σ = inf σ `case n [Z⇒ t₁ |S⇒ t₂ ]
infer Γ (`ucase .(erase n) .(erase t₁) .(erase t₂))   | inf `ℕ n  | inf τ t₁  | inf σ t₂  | no ¬τ≡σ = bad (ty-mismatch-case t₁ t₂ ¬τ≡σ)
infer Γ (`ucase .(erase n) .(erase t₁) t₂)            | inf `ℕ n | inf τ t₁ | bad e = bad (propagate-case-S e)
infer Γ (`ucase .(erase n) t₁ t₂)                     | inf `ℕ n | bad e = bad (propagate-case-Z e)
infer Γ (`ucase .(erase n) t₁ t₂)                     | inf (τ `⇒ τ₁) n = bad (not-a-nat-case n (λ ()))
infer Γ (`ucase n t₁ t₂)                              | bad x = bad (propagate-case-n x)
infer Γ (f `$ a)                   with infer Γ f
infer Γ (.(erase t) `$ a)           | inf `ℕ t = bad (not-a-function t (λ ()))
infer Γ (.(erase t) `$ a)           | inf (τ `⇒ σ) t with infer Γ a
infer Γ (.(erase t) `$ .(erase t₁)) | inf (τ `⇒ σ) t  | inf τ₁ t₁ with τ ≟ τ₁
infer Γ (_ `$ .(erase t₁))          | inf (τ `⇒ σ) t  | inf τ₁ t₁  | yes τ≡τ₁ rewrite τ≡τ₁ = inf σ (t `$ t₁)
infer Γ (.(erase t) `$ .(erase t₁)) | inf (τ `⇒ σ) t  | inf τ₁ t₁  | no ¬τ≡τ₁ = bad (ty-mismatch-app t t₁ ¬τ≡τ₁)
infer Γ (.(erase t) `$ a)           | inf (τ `⇒ σ) t  | bad e = bad (propagate-right e)
infer Γ (f `$ a)                    | bad e = bad (propagate-left e)