Learning Python with Advent of Code Walkthroughs

Dazbo's Advent of Code solutions, written in Python

ALU

Advent of Code 2021 - Day 24

Day 24: Arithmetic Logic Unit

Useful Links

Concepts and Packages Demonstrated

introspection

tqdmprogress bar

Problem Intro

Oh, this one looks easy enough. Wrong!
I just need to write an ALU simulator that knows how to process some instructions. Wrong!

Overview

We’re told that the sub runs off an arithmetic logic unit that takes four integer variables (w, x, y, and z), and is capable of performing dix different instructions with these variables.

The ALU processes programs, which are sets of instructions. The instructions are procesed in order, from beginning to end.

We’re given a few sample input programs. Like this one:

inp w
add z w
mod z 2
div w 2
add y w
mod y 2
div w 2
add x w
mod x 2
div w 2
mod w 2

We’re told we need to use our ALU to validate the sum’s model number. We’re given a program called MONAD which takes any 14-digit number (where the digits must be 1 to 9, inclusive), and processes the number. My actual MONAD data, for example, 14 inp w instructions.

We’re told a given model number is only valid if, after processing all the instructions in the MONAD program, variable z is set to 0.

Part 1

What is the largest model number accepted by MONAD?

So the goal is to find the largest possible 14-digit number which results in a z value of 0, after running the number through our program.

The ALU Simulator

Having done a few AOCs before, I jumped straight to writing an ALU simulator. (Spoiler alert: this was a mistake!)

class ALU():
    """ Simulate processor with four registers and six instructions """
    def __init__(self) -> None:
        self._vars = {'w': 0, 'x': 0, 'y': 0, 'z': 0}
        
        self._input = None
        self._input_posn = 0    # which digit of the input value we're currently on
        
        self._instructions: list[tuple[str, list[str]]] = []     # list of instructions in the format [instr, [parms]]
        self._ip = 0
        
    @property
    def vars(self):
        return self._vars
    
    def _set_input(self, value: str):
        """ Take a number and store as a str representation """
        assert value.isdigit, "Must be number"
        assert len(value) == 14, "Must be 14 digit input"
        self._input = value
        self._input_posn = 0        
     
    def _set_var(self, var, value):
        """ Sets the specified var to the specified value. """
        if var not in self._vars:
            raise KeyError(f"No such var '{var}'")
        
        self._vars[var] = value
    
    def _reset(self):
        for var in self._vars:
            self._vars[var] = 0

        self._input = None
        self._input_posn = 0
        self._ip = 0
        
    def run_program(self, input_str: str):
        """ Process instructions in the program. """
        self._reset()      
        self._set_input(input_str)
        
        for instruction in self._instructions:
            self._execute_instruction(instruction)
            self._ip += 1

    def set_program(self, instructions_input: list[str]):
        """ Create a list of instructions, 
        where each instruction is of the format: (str, list[str]) """
        self._instructions = []
        
        for line in instructions_input:
            instr_parts = line.split()
            instr = instr_parts[0]
            instr_parms = instr_parts[1:]
        
            self._instructions.append((instr, instr_parms))
        
    def _execute_instruction(self, instruction:tuple[str, list[str]]):
        """ Takes an instruction, and calls the appropriate implementation method.

        Args:
            instr_and_parms (list): The instruction, in the format (instr, [parms])
            
        Raises:
            AttributeError if instruction is not understood
        """
        # logger.debug("Instruction: %s", instruction)
        instr = instruction[0]
        instr_parms = instruction[1]
        
        # call the appropriate instruction method
        try:
            self.__getattribute__(f"_op_{instr}")(instr_parms)         
        except AttributeError as err:
            raise AttributeError(f"Bad instruction {instr} at {self._ip}") from err

    def int_or_reg_val(self, x) -> int:
        """ Determine if the variable is an int value, or the value is a register """
        if x in self._vars:
            return self._vars[x]
        else:
            return int(x)
        
    def _op_inp(self, parms:list[str]):
        var = parms[0]
        assert self._input, "Input value not set"
        assert self._input_posn < len(self._input), "Too many input digits!"
        input_digit = int(self._input[self._input_posn])
        self._vars[var] = input_digit
        self._input_posn += 1
    
    def _op_add(self, parms:list[str]):
        """ Add a to b and store in a. Param b could be a var or a number. """
        self._vars[parms[0]] += self.int_or_reg_val(parms[1])
    
    def _op_mul(self, parms:list[str]):
        """ Multiply a by b and store in a. Param b could be a var or a number. """
        self._vars[parms[0]] *= self.int_or_reg_val(parms[1])
    
    def _op_div(self, parms:list[str]):
        """ Divide a by b and store in a. Param b could be a var or a number. """
        parm_b = self.int_or_reg_val(parms[1])
        assert parm_b != 0, "Integer division by 0 is bad."
        self._vars[parms[0]] //= parm_b
        
    def _op_mod(self, parms:list[str]):
        """ Modulo a by b and store in a. Param b could be a var or a number. """
        parm_a = self._vars[parms[0]]
        parm_b = self.int_or_reg_val(parms[1])
        try:
            assert parm_a >= 0 and parm_b != 0, "Integer division by 0 is bad." 
            self._vars[parms[0]] %= parm_b     
        except AssertionError as err:
            raise AttributeError(f"Bad instruction: {parm_a} mod {parm_b}") from err

    def _op_eql(self, parms:list[str]):
        """ Chec if a and b are equal. Store 1 if equal. Param b could be a var or a number. """
        self._vars[parms[0]] = 1 if self._vars[parms[0]] == self.int_or_reg_val(parms[1]) else 0
                 
    def __repr__(self):
        return f"{self.__class__.__name__}{self._vars}"    

So what does this do?

Great! So now all we have to do is create the ALU, initialise the program to our set of instructions, and then run it with every possible 14 digit number that doesn’t contain a 0. We could do something like this:

input_file = os.path.join(SCRIPT_DIR, INPUT_FILE)
with open(input_file, mode="rt") as f:
    data = f.read().splitlines()

alu = ALU()
alu.set_program(data)
for int_val in tqdm(range(99999999999999, 11111111111110, -1)):
    val = str(int_val)
    if '0' in val:
        continue
    
    alu.run_program(val)
    if alu.vars['z'] == 0:
        logger.info("%s verified.", val) 

As I’ve done with an earlier program, I’ve wrapped my for loop with tqdm, in order to give me a progress bar for this long running loop. Usefully, the progress bar includes an estimated completion time. So yeah… It’s going to take about 200 years. We’re going to need a bigger boat. Or a better solution.

TIME TO THROW ALL THAT AWAY!

So, what have we learned? We’ve learned that running each instruction in MONAD is going to take too long. So instead…

Maybe we need to determine what MONAD is actually trying to do, and come up with a more efficient way to do that?

Understanding What MONAD Does

When I examine my MONAD input, it turns out that the program is made up of 14 repeating blocks of 18 nearly identical lines. Here I’m showing the first 5 repeats, side by side…

     1          2          3          4          5
     --------   --------   --------   --------   -------- 
 1   inp w      inp w      inp w      inp w      inp w
 2   mul x 0    mul x 0    mul x 0    mul x 0    mul x 0
 3   add x z    add x z    add x z    add x z    add x z
 4   mod x 26   mod x 26   mod x 26   mod x 26   mod x 26
 5   div z 1    div z 1    div z 1    div z 1    div z 26   
 6   add x 12   add x 12   add x 13   add x 12   add x -3
 7   eql x w    eql x w    eql x w    eql x w    eql x w
 8   eql x 0    eql x 0    eql x 0    eql x 0    eql x 0
 9   mul y 0    mul y 0    mul y 0    mul y 0    mul y 0
10   add y 25   add y 25   add y 25   add y 25   add y 25
11   mul y x    mul y x    mul y x    mul y x    mul y x
12   add y 1    add y 1    add y 1    add y 1    add y 1
13   mul z y    mul z y    mul z y    mul z y    mul z y
14   mul y 0    mul y 0    mul y 0    mul y 0    mul y 0
15   add y w    add y w    add y w    add y w    add y w
16   add y 7    add y 8    add y 2    add y 11   add y 6
17   mul y x    mul y x    mul y x    mul y x    mul y x
18   add z y    add z y    add z y    add z y    add z y

Here’s what we know:

So, what do the blocks do?

I’ve rewritten these instructions, along with their net effect:

     1          2          3          4                     Result?
     --------   --------   --------   --------   --------   ------------
 1   inp w      inp w      inp w      inp w      inp w      Input w (Any number 1 through 9)
 2   mul x 0    mul x 0    mul x 0    mul x 0    mul x 0    x = 0 (Reset x)
 3   add x z    add x z    add x z    add x z    add x z    x = z0
 4   mod x 26   mod x 26   mod x 26   mod x 26   mod x 26   x = z0 % 26
 5   div z 1    div z 1    div z 1    div z 1    div z 26   z1 = z0 // var1
 6   add x 12   add x 12   add x 13   add x 12   add x -3   x = z0 % 26 + var2
 7   eql x w    eql x w    eql x w    eql x w    eql x w    x = 1 if x == w (input), else 0
 8   eql x 0    eql x 0    eql x 0    eql x 0    eql x 0    x = 0 if x == 1, else 1
 9   mul y 0    mul y 0    mul y 0    mul y 0    mul y 0    y = 0 (Reset y)
10   add y 25   add y 25   add y 25   add y 25   add y 25   y = 25
11   mul y x    mul y x    mul y x    mul y x    mul y x    y = 25 if x == 1, else 0
12   add y 1    add y 1    add y 1    add y 1    add y 1    y = 26 if x == 1, else 1
13   mul z y    mul z y    mul z y    mul z y    mul z y    z2 = 26(z0 // var1) if x == 1, else z0 // var1
14   mul y 0    mul y 0    mul y 0    mul y 0    mul y 0    y = 0 (Reset y)
15   add y w    add y w    add y w    add y w    add y w    y = w
16   add y 7    add y 8    add y 2    add y 11   add y 6    y = w + var3
17   mul y x    mul y x    mul y x    mul y x    mul y x    y = w + var3 if x == 1, else 0
18   add z y    add z y    add z y    add z y    add z y    z = 26(z0 // var1) + w + var3 if x == 1, else 26(z0 // var1)

Some important observations:

Thus z is the only variable that persists between blocks.

That’s handy, since z is the value we ultimately care about. Recall that our goal is for z to be 0 when MONAD has finished.

Let’s look at the 14 possible values of the 3 variables:

var1  var2  var3
----  ----  ----  
   1    12     8
   1    13     2
   1    12    11
   1    12     7
  26    -3     6
   1    10    12
   1    14    14
  26   -16    13
   1    12    15
  26    -8    10
  26   -12     6
  26    -7    10
  26    -6     8
  26   -11     5

Some observations:

Thus, there appear to be two types of block.

Type 1 Block: var1 is 1, var2 > 9

     Block      Result?
     --------   ------------
 1   inp w      Input w (Any number 1 through 9)
 2   mul x 0    x = 0 (Reset x)
 3   add x z    x = z0
 4   mod x 26   x = z0 % 26
 5   div z 1    z1 = z0
 6   add x 14   x = z0 % 26 + var2
 7   eql x w    x = 0, since w is always <= 9, but var2 is always > 9.
 8   eql x 0    x = 1
 9   mul y 0    y = 0 (Reset y)
10   add y 25   y = 25
11   mul y x    y = 25
12   add y 1    y = 26
13   mul z y    z2 = 26z0
14   mul y 0    y = 0 (Reset y)
15   add y w    y = w
16   add y 7    y = w + var3
17   mul y x    y = w + var3
18   add z y    z = 26z0 + w + var3

In summary:

\(z_{next} = 26z_{prev} + w + a\)

Where \(a\) is the variable from instruction 16.

Type 2 Block: var1 is 26, var2 is negative

     Block      Result?
     --------   --------
 1   inp w      Input w (Any number 1 through 9)
 2   mul x 0    x = 0 (Reset x)
 3   add x z    x = z0
 4   mod x 26   x = z0 % 26
 5   div z 26   z1 = z0 // 26
 6   add x -3   x = z0 % 26 + var2
 7   eql x w    x = 1 if w == z0 % 26 + var2, else 0
 8   eql x 0    x = 0 if x == 1, else 1
 9   mul y 0    y = 0 (Reset y)
10   add y 25   y = 25
11   mul y x    y = 0 if x == 0, else 25
12   add y 1    y = 1 if x == 0, else 26
13   mul z y    z0 // 26 if x == 0, else z2 = z0
14   mul y 0    y = 0 (Reset y)
15   add y w    y = w
16   add y 6    y = w + var3
17   mul y x    y = 0 if x == 0, else y = w + var3
18   add z y    z = z0 // 26 if x == 0, else z = z0 + w + var3

Thus, there are two possible outcomes of a Type 2 block:

When final \(x\) == 0: \(z_{next} = z_{prev} // 26\)

When final \(x\) == 1: \(z_{next} = z_{prev} + w + a\)

What Have We Learned?

Half the blocks are type 1. Each block results in a new value of z, according to the equation:

\(z_{next} = 26z_{prev} + w + a\)

I.e. z gets multiplied by 26, plus a constant. Thus, type 1 blocks result in z getting much larger.

Half the blocks are type 2. Each block results in a new value of z, according to the equations:

  1. \(z_{next} = z_{prev} // 26\), when \(x\) == 0

  2. \(z_{next} = z_{prev} + w + a\), when \(x\) == 1

The first equation results in z getting much smaller. Whereas the second equations results in z getting larger.

So, in order for z to be 0 at the end of MONAD, we need all the type 2 blocks to result in z getting smaller. And thus, all type 2 blocks require x to be equal to 0 when instruction 17 (mul y x) is run in each block.

How do we ensure that x == 0 at instruction 17?

So now we have enough information to determine what value of w we need, in order for z to shrink in any block 2:

\(w = x = (z \;\mathrm{mod}\; 26) + var2\)

The Solution

First, let’s read the data, and split it into our 14 blocks:

input_file = os.path.join(SCRIPT_DIR, INPUT_FILE)
with open(input_file, mode="rt") as f:
    data = f.read().splitlines()

COUNT_INSTRUCTION_BLOCKS = 14
EXPECTED_INSTRUCTIONS_PER_BLOCK = 18

instruction_block_size = len(data) // COUNT_INSTRUCTION_BLOCKS
assert instruction_block_size == EXPECTED_INSTRUCTIONS_PER_BLOCK

# Split all instructions into repeating blocks of instructions
instruction_blocks: list[list[str]] = []
for i in range(COUNT_INSTRUCTION_BLOCKS):
    instruction_blocks.append(data[i*instruction_block_size:(i+1)*instruction_block_size])

alu = ALU()
alu.set_program(data)
valid_vals = compute_valid_inputs(instruction_blocks)

This code:

We then pass all 14 blocks into the compute_valid_inputs() function:

def compute_valid_inputs(instruction_blocks: list[list[str]]) -> list[int]:
    """ Our goal is determine valid values of w, 
    where w is each successive digit of the 14-digit input value.
    The 14 input values are used in the 14 blocks of instructions. """

    # instruction types, based on "div z" instruction parameter
    SHRINKAGE = 26
    
    div_z_instructions = []
    add_x_instructions = []
    add_y_instructions = []
    
    for block in instruction_blocks:
        # Retrieve the param value from each instruction
        # The instructions we care about are at specific locations in the block
        div_z_instructions.append(int(block[4].split()[-1]))
        add_x_instructions.append(int(block[5].split()[-1]))
        add_y_instructions.append(int(block[15].split()[-1]))
    
    # Values of these variables in our input data
    # z [1,   1,  1,  1, 26,  1,  1,  26,  1, 26,  26, 26, 26,  26]
    # x [12, 12, 13, 12, -3, 10, 14, -16, 12, -8, -12, -7, -6, -11]
    # y [7,   8,  2, 11,  6, 12, 14,  13, 15, 10,   6, 10,  8,   5]
    
    # E.g. [False, False, False, False, True...]
    shrink_instructions = [z == SHRINKAGE for z in div_z_instructions]
    shrink_count = sum(x for x in shrink_instructions)
    assert shrink_count == 7, "We expect 7 shrink types for our input"
    
    # list of tuples by index, e.g. (False, 12, 7)
    instruction_vars = list(zip(shrink_instructions, add_x_instructions, add_y_instructions))

    # Get the cartesian product of all digits where any digit is possible
    # E.g. 9999999, 9999998, 9999997, etc
    any_digits = list(product(range(9, 0, -1), repeat=shrink_count))
    assert len(any_digits) == 9**shrink_count, "Cartesian product messed up"
        
    valid: list[int] = []    # Store valid 14-digit input values
    for digits_candidate in tqdm(any_digits):
        num_blocks = len(instruction_blocks)
        z = 0
        digit = [0] * num_blocks
    
        early_exit = False
        digits_idx = 0
    
        for block_idx in range(num_blocks):
            is_shrink, add_x, add_y = instruction_vars[block_idx]
                  
            if is_shrink:
                # We want to compute a value w, where w = (z % 26) + a 
                digit[block_idx] = ((z % 26) + add_x)   # digit[block_idx] = w
                z //= 26    # New z is given by z = z//26
                if not (1 <= digit[block_idx] <= 9):
                    early_exit = True
                    break
            
            else:   # expansion type, so just use the candidate digit
                z = (26 * z) + digits_candidate[digits_idx] + add_y
                digit[block_idx] = digits_candidate[digits_idx]  
                digits_idx += 1
        
        if not early_exit:
            valid.append(int("".join(str(i) for i in digit)))
     
    return valid

This code:

Here’s the clever bit…

Finally, return the valid list; i.e. all the valid serial numbers, as integer values.

Recall that Part 1 has asked us to determine the largest model number accepted by MONAD. So, we just need to determine the largest value of our valid values. That’s easy…

    if valid_vals:
        max_input_val = max(valid_vals)

Part 2

What is the smallest model number accepted by MONAD?

I can’t tell you how relieved I am about that!!

All we need to do is add one line…

    if valid_vals:
        max_input_val = max(valid_vals)
        min_input_val = max(valid_vals)

But for completeness, and so that my ALU simulator was not completely wasted, I’ve run my min and max values through the ALU, to verify it produces a 0 result.

So the final code looks like this:

alu = ALU()
alu.set_program(data)
valid_vals = compute_valid_inputs(instruction_blocks)
if valid_vals:
    max_input_val = max(valid_vals)
    min_input_val = min(valid_vals)
    
    # check them by running them through the ALU
    for val in (min_input_val, max_input_val):
        alu.run_program(str(val))
        if alu.vars['z'] == 0:
            logger.info("%s verified.", val) 
        else:
            logger.info("%s does not work??")
else:
    logger.info("Fail bus, all aboard.")

And the output looks like this:

100%|███████████████████████████████████████████████████████████| 4782969/4782969 [00:05<00:00, 908376.41it/s]
08:27:28.010:INFO:__main__:     51619131181131 verified.
08:27:28.010:INFO:__main__:     97919997299495 verified.
08:27:28.010:INFO:__main__:     Execution time: 5.7623 seconds

Wow. What a horrendous day.