use std::collections::HashMap;
use std::convert::TryFrom;
use std::env;
use std::str::FromStr;
#[derive(Clone)]
struct Mask {
content: Vec<(u8, bool)>,
}
impl Mask {
fn apply(&self, arg: u64) -> u64 {
self.content.iter().fold(arg, |arg, (b, s)| {
if *s {
arg | 1_u64.wrapping_shl(*b as u32)
} else {
arg & !1_u64.wrapping_shl(*b as u32)
}
})
}
}
impl FromStr for Mask {
type Err = <u8 as TryFrom<usize>>::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Mask {
content: s
.chars()
.enumerate()
.filter_map(|(ix, c)| {
let ix = <u8 as TryFrom<usize>>::try_from(ix).ok()?;
match c {
'0' => Some((35 - ix, false)),
'1' => Some((35 - ix, true)),
_ => None,
}
})
.collect(),
})
}
}
#[derive(Debug)]
struct AddressMask {
floating: Vec<u8>,
set: Vec<u8>,
}
impl From<&Mask> for AddressMask {
fn from(mask: &Mask) -> Self {
AddressMask {
floating: (0..36)
.into_iter()
.filter(|n| !mask.content.iter().any(|(m, _)| m == n))
.collect(),
set: mask
.content
.iter()
.filter_map(|(b, s)| if *s { Some(*b) } else { None })
.collect(),
}
}
}
struct AddressMaskIter {
address_mask: AddressMask,
last: u64,
}
impl IntoIterator for AddressMask {
type Item = <Self::IntoIter as Iterator>::Item;
type IntoIter = AddressMaskIter;
fn into_iter(self) -> Self::IntoIter {
AddressMaskIter {
address_mask: self,
last: 0,
}
}
}
impl Iterator for AddressMaskIter {
type Item = Mask;
fn next(&mut self) -> Option<Self::Item> {
if self.last >= 1_u64.wrapping_shl(self.address_mask.floating.len() as u32) {
None
} else {
let result = self
.address_mask
.set
.iter()
.map(|n| (*n, true))
.chain(
self.address_mask
.floating
.iter()
.enumerate()
.map(|(ix, n)| (*n, 1_u64.wrapping_shl(ix as u32) & self.last != 0)),
)
.collect();
self.last += 1;
Some(Mask { content: result })
}
}
}
enum Instruction {
SetMask(Mask),
SetMem(u64, u64),
}
#[derive(Debug)]
struct ParseInstructionError {}
impl std::fmt::Display for ParseInstructionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(f, "ParseInstructionError")
}
}
impl std::error::Error for ParseInstructionError {}
impl FromStr for Instruction {
type Err = Box<dyn std::error::Error>;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut parts = s.split_whitespace();
let key = parts.next().ok_or(ParseInstructionError {})?;
match parts.next() {
Some("=") => Ok(()),
_ => Err(ParseInstructionError {}),
}?;
if key == "mask" {
Ok(Instruction::SetMask(Mask::from_str(
parts.next().ok_or(ParseInstructionError {})?,
)?))
} else {
let mut key_parts = key.split(|c| c == '[' || c == ']');
match key_parts.next() {
Some("mem") => Ok(()),
_ => Err(ParseInstructionError {}),
}?;
Ok(Instruction::SetMem(
u64::from_str(key_parts.next().ok_or(ParseInstructionError {})?)?,
u64::from_str(parts.next().ok_or(ParseInstructionError {})?)?,
))
}
}
}
impl Instruction {
fn run(self, mask: &mut Mask, mem: &mut HashMap<u64, u64>, version: u8) {
match self {
Instruction::SetMask(m) => {
*mask = m;
}
Instruction::SetMem(k, v) => match version {
1 => {
mem.insert(k, mask.apply(v));
}
2 => {
for mask in AddressMask::from(mask as &Mask) {
mem.insert(mask.apply(k), v);
}
}
_ => panic!(),
},
}
}
}
fn main() {
let args: Vec<_> = env::args().collect();
let version = u8::from_str(args[2].as_ref()).unwrap();
let mut mask = Mask {
content: Vec::new(),
};
let mut mem = HashMap::new();
for line in std::fs::read_to_string(&args[1]).unwrap().lines() {
match Instruction::from_str(line) {
Ok(i) => i.run(&mut mask, &mut mem, version),
_ => panic!(),
}
}
println!("{}", mem.values().map(|v| *v as u128).sum::<u128>())
}