package main

import (
	"bufio"
	"log"
	"os"
	"sort"
	"strconv"
)

type JointDiff [3]int

func main() {
	if err := realMain(); err != nil {
		log.Println(err)
	}
}

func realMain() error {
	joltOut, err := parseInput("input.txt")
	if err != nil {
		return err
	}

	joltOut = addDeviceJolt(joltOut)
	joltOut = append(joltOut, 0)
	sort.Ints(joltOut)

	diffs := solveFirst(joltOut)
	log.Println("Part1:", diffs[0]*diffs[2])

	log.Println("Part2:", solveSecond(joltOut))

	return nil
}

func solveSecond(jolts []int) int {
	sum := 1
	count := 0
	for i, jt := range jolts {
		if i == 0 {
			continue
		}
		if jt-jolts[i-1] == 1 {
			count++
		} else {
			switch count {
			case 4:
				sum *= 7 // ways to arrange 4 item
			case 3:
				sum *= 4
			case 2:
				sum *= 2
			}
			count = 0
		}
	}
	return sum
}

func solveFirst(jolts []int) JointDiff {
	sum := JointDiff{}
	for i, jt := range jolts {
		if i == 0 {
			continue
		}
		switch jt - jolts[i-1] {
		case 1:
			sum[0]++
		case 2:
			sum[1]++
		case 3:
			sum[2]++
		default:
			log.Panicln("invalid diff")
		}
	}
	return sum
}

func addDeviceJolt(jolts []int) []int {
	var max int
	for _, j := range jolts {
		if j > max {
			max = j
		}
	}
	return append(jolts, max+3)
}

func parseInput(fileName string) ([]int, error) {
	fd, err := os.Open(fileName)
	if err != nil {
		return nil, err
	}
	defer fd.Close()

	var ret []int

	buf := bufio.NewScanner(fd)
	for buf.Scan() {
		tmp, err := strconv.Atoi(buf.Text())
		if err != nil {
			return nil, err
		}
		ret = append(ret, tmp)
	}

	return ret, nil
}