package main

import (
	"log"
)

const (
	p1s       = 7
	p2s       = 9
	maxScore1 = 1000
	maxScore2 = 21
)

type player struct {
	pos   int
	score int
}

func main() {
	if err := myMain(); err != nil {
		log.Println(err)
	}
}

func myMain() error {
	log.Println(solveFirst())

	log.Println(solveSecond())

	return nil
}

type gameState struct {
	p1 player
	p2 player
}

var cache = map[gameState][2]int{}

func solveSecond() int {
	p1 := player{p1s - 1, 0}
	p2 := player{p2s - 1, 0}
	res := countWin(gameState{p1, p2})

	if res[0] > res[1] {
		return res[0]
	}
	return res[1]
}

// https://github.com/jonathanpaulson/AdventOfCode/blob/master/2021/21.py
// This one is the go version on Jonathan's code. I am speechless how elegant this
// solution is, and amazed by how fast he has written it.
func countWin(s gameState) [2]int {
	if s.p1.score >= maxScore2 {
		return [2]int{1, 0}
	}
	if s.p2.score >= maxScore2 {
		return [2]int{0, 1}
	}

	if ret, ok := cache[s]; ok {
		return ret
	}

	var ret [2]int
	for _, d1 := range []int{1, 2, 3} {
		for _, d2 := range []int{1, 2, 3} {
			for _, d3 := range []int{1, 2, 3} {
				nextPos := (s.p1.pos + d1 + d2 + d3) % 10
				nextScore := s.p1.score + nextPos + 1

				p2Turn := countWin(gameState{s.p2, player{nextPos, nextScore}})
				ret = [2]int{ret[0] + p2Turn[1], ret[1] + p2Turn[0]}
			}
		}
	}
	cache[s] = ret
	return ret
}

func solveFirst() int {
	p1 := player{pos: p1s}
	p2 := player{pos: p2s}

	dice := 1

	var round int
	for round = 1; ; round++ {
		// payer 1
		p1.addScore(dice)
		dice += 3
		if p1.score >= maxScore1 {
			break
		}

		// player 2
		p2.addScore(dice)
		dice += 3
		if p2.score >= maxScore1 {
			break
		}
	}

	if p1.score >= maxScore1 {
		return p2.score * (6*round - 3)
	} else {
		return p1.score * 6 * round
	}
}

func (p *player) addScore(dice int) {
	p.pos += dice + dice + 1 + dice + 2
	if p.pos%10 == 0 {
		p.score += 10
	} else {
		p.score += p.pos % 10
	}

}