const std = @import("std");

const path = "data/day20/input.txt";
const max_rounds = 50;

const PointType = isize;
const Point = [2]PointType;

const Pixels = std.AutoHashMap(Point, void);

const bitset_size = 512;
const BitSet = std.StaticBitSet(bitset_size);

const Image = struct {
    IEA: BitSet,
    lit: Pixels,
    size: isize,
    flippy: bool = false,
};

pub fn main() !void {
    var timer = try std.time.Timer.start();
    const ret = try second();
    const t = timer.lap() / 1000;

    try std.testing.expectEqual(@as(usize, 19156), ret);

    std.debug.print("Day 20b result: {d} \t\ttime: {d}us\n", .{ ret, t });
}

pub fn second() !usize {
    const allocator = std.testing.allocator;

    var image = try parseInput(allocator);
    defer image.lit.deinit();

    var round: isize = 1;
    while (round <= max_rounds) : (round += 1) {
        var next = Pixels.init(allocator);

        // printImage(&image, round);

        var row: isize = -round;
        while (row < @intCast(isize, image.size) + round) : (row += 1) {
            var col: isize = -round;
            while (col < @intCast(isize, image.size) + round) : (col += 1) {
                const p = Point{ row, col };
                if (getValue(&image, p, round)) try next.put(p, {});
            }
        }

        // std.debug.print("lit: {d}\n", .{next.count()});

        image.lit.deinit();
        image.lit = next;

        if (image.IEA.isSet(0) and !image.IEA.isSet(bitset_size - 1)) {
            image.flippy = !image.flippy;
        }
    }

    return image.lit.count();
}

const directions = [9][2]isize{
    // order is important!
    .{ 1, 1 },
    .{ 1, 0 },
    .{ 1, -1 },
    .{ 0, 1 },
    .{ 0, 0 },
    .{ 0, -1 },
    .{ -1, 1 },
    .{ -1, 0 },
    .{ -1, -1 },
};

fn getValue(image: *Image, pt: Point, round: isize) bool {
    var index: u9 = 0;

    for (directions) |d, idx| {
        const p: Point = .{ pt[0] + d[0], pt[1] + d[1] };

        if (image.flippy and
            (p[0] < -round + 1 or p[0] >= image.size + round - 1 or
            p[1] < -round + 1 or p[1] >= image.size + round - 1))
        {
            index |= @as(u9, 1) <<| idx;
            continue;
        }
        if (image.lit.contains(p)) index |= @as(u9, 1) <<| idx;
    }

    // std.debug.print("{d} {any}\n", .{index, image.IEA.isSet(index)});

    std.debug.assert(index < bitset_size);
    return image.IEA.isSet(index);
}

fn parseInput(a: std.mem.Allocator) !Image {
    const input = @embedFile(path);
    var lines = std.mem.split(u8, input, "\n");

    var i: Image = undefined;

    i.IEA = BitSet.initEmpty();

    for (lines.next().?) |item, idx| {
        if (item == '#') i.IEA.set(idx);
    }
    _ = lines.next(); // drop empty line

    i.lit = Pixels.init(a);

    var row: isize = 0;
    while (lines.next()) |line| : (row += 1) {
        for (line) |ch, col| {
            if (ch == '#') try i.lit.put(.{ row, @intCast(isize, col) }, {});
        }
    }

    i.size = row - 1;

    return i;
}

fn printImage(image: *Image, round: isize) void {
    std.debug.print("\n", .{});
    var row: isize = -round;
    while (row < @intCast(isize, image.size) + round) : (row += 1) {
        var col: isize = -round;
        while (col < @intCast(isize, image.size) + round) : (col += 1) {
            const p = Point{ row, col };
            if (image.lit.contains(p)) {
                std.debug.print("#", .{});
            } else {
                std.debug.print(".", .{});
            }
        }
        std.debug.print("\n", .{});
    }
}

test "day20b" {
    try std.testing.expectEqual(@as(usize, 19156), try second());
}