const std = @import("std");
const Str = []const u8;
pub fn Graph(comptime T: type) type {
if (!@hasDecl(T, "neighborFn")) {
@compileError("Given type (" ++ @typeName(T) ++
") must contain a function called `neighborFn`");
}
return struct {
pub const QueueType = struct {
state: T,
cost: usize,
};
fn lessThan(context: void, a: QueueType, b: QueueType) std.math.Order {
_ = context;
if (a.cost == b.cost) {
return .eq;
} else if (a.cost < b.cost) {
return .lt;
} else {
return .gt;
}
}
arena: std.heap.ArenaAllocator,
pub fn init(allocator: std.mem.Allocator) @This() {
return @This(){
.arena = std.heap.ArenaAllocator.init(allocator),
};
}
pub fn deinit(self: *@This()) void {
self.arena.deinit();
}
pub fn dijkstra(self: *@This(), start: T, goal: T) !usize {
var visited = std.AutoHashMap(T, void).init(self.arena.allocator());
defer visited.deinit();
var queue = std.PriorityQueue(QueueType, void, comptime lessThan)
.init(self.arena.allocator(), {});
defer queue.deinit();
try queue.add(.{ .state = start, .cost = 0 });
while (queue.count() > 0) {
const curr = queue.remove();
if (visited.contains(curr.state)) continue;
try visited.put(curr.state, {});
if (std.meta.eql(curr.state, goal)) {
return curr.cost;
}
const nb = try curr.state.neighborFn(self.arena.allocator());
defer self.arena.allocator().free(nb);
for (nb) |*n| {
n.cost += curr.cost;
try queue.add(n.*);
}
}
unreachable;
}
};
}