const std = @import("std");

const path = "data/day12/input.txt";

const RetType = u18;
const Str = []const u8;

const Routes = std.StringHashMap(std.ArrayList(Str));

const Seen = std.StringHashMap(u2);

fn parseInput(a: std.mem.Allocator) anyerror!Routes {
    const input = @embedFile(path);
    var lines = std.mem.tokenize(u8, input, "\n");

    var ret = Routes.init(a);

    while (lines.next()) |line| {
        var from_to = std.mem.tokenize(u8, line, "-");
        const left = from_to.next().?;
        const right = from_to.next().?;

        // discard invalid routes
        if (!std.mem.eql(u8, right, "start") and (!std.mem.eql(u8, left, "end"))) {
            var left_array = try ret.getOrPut(left);
            if (!left_array.found_existing) {
                left_array.value_ptr.* = std.ArrayList(Str).init(a);
            }
            try left_array.value_ptr.append(right);
        }

        if (!std.mem.eql(u8, left, "start") and (!std.mem.eql(u8, right, "end"))) {
            var right_array = try ret.getOrPut(right);
            if (!right_array.found_existing) {
                right_array.value_ptr.* = std.ArrayList(Str).init(a);
            }
            try right_array.value_ptr.append(left);
        }
    }

    return ret;
}

pub fn second() anyerror!RetType {
    var gpa = std.heap.GeneralPurposeAllocator(.{}){};
    const allocator = gpa.allocator();

    var cave = try parseInput(allocator);
    var seen = Seen.init(allocator);
    defer {
        var it = cave.valueIterator();
        while (it.next()) |arr_list| {
            arr_list.deinit();
        }
        cave.deinit();
        seen.deinit();
    }

    return try genRoutes(&seen, &cave, "start");
}

fn genRoutes(seen: *Seen, routes: *Routes, root: Str) anyerror!RetType {
    var sum: RetType = 0;

    // reached "end", return
    if (std.mem.eql(u8, root, "end")) {
        sum += 1;
        return sum;
    }

    // already visited, no-go
    if (seenTooMuch(seen, root)) {
        return 0;
    }

    // add lowercase to seen
    if ((root[0] > 96) and (!std.mem.eql(u8, root, "end"))) {
        const ptr = seen.getPtr(root);
        if (ptr == null) {
            try seen.put(root, 1);
        } else {
            ptr.?.* +|= 1;
        }
    }

    // take all the routes
    for (routes.get(root).?.items) |item| {
        sum += try genRoutes(seen, routes, item);
    }

    // reuse caves
    const ptr = seen.getPtr(root);
    if (ptr != null) {
        ptr.?.* -= 1;
    }

    return sum;
}

fn seenTooMuch(seen: *Seen, root: Str) bool {
    const v = seen.get(root) orelse return false;
    if (v >= 2) {
        return true;
    }

    if (v == 0) {
        return false;
    }

    if (v == 1) {
        var it = seen.iterator();
        while (it.next()) |entry| {
            if (std.mem.eql(u8, entry.key_ptr.*, root)) continue;
            if (entry.value_ptr.* > 1) {
                return true;
            }
        }
    }
    return false;
}

pub fn main() anyerror!void {
    var timer = try std.time.Timer.start();
    const ret = try second();
    const t = timer.lap() / 1000;

    try std.testing.expectEqual(@as(RetType, 134862), ret);

    std.debug.print("Day 12b result: {d} \t\ttime: {d}us\n", .{ ret, t });
}

test "day12b" {
    try std.testing.expectEqual(@as(RetType, 134862), try second());
}