Learning Python with Advent of Code Walkthroughs

Dazbo's Advent of Code solutions, written in Python

Blackboard math

Advent of Code 2021 - Day 18

Day 18: Snailfish

Useful Links

Concepts and Packages Demonstrated

Regular expressionsreducepermutations

mathliteral_evaldequerecursionstaticmethodbinary treeDepth-First Search (DFS)

Problem Intro

There are some people in the world - out of the roughly 200,000 that took part in the 2021 AoC - that managed to complete the solutions to this challenge in under 30 minutes!

AoC 2021 Day 18 Leaderboard

I’m not one of those people. This was tricky, and it took me hours!

I created two solutions to this problem:

So, we need to help a snailfish with its math homework. Snailfish math is pretty darn weird. The input looks like this:


Note how each snailfish number has the exact same structure as a Python (nested) list. This will come in handy!!

The thing which caught me out for a while is that I hadn’t fully understood the explode instructions. In particular:

Part 1

We have to add all the snailfish numbers in our input, and then determine the magnitude of the resulting number?

We’re told that:

Solution 1

This solution takes our list and converts it to a str in order to check if it can be exploded or split, and to enable us to modify the str according to the rules.


from __future__ import annotations
import logging
from pathlib import Path
import time
import re
from functools import reduce
from math import ceil, floor
from itertools import permutations
from ast import literal_eval

SCRIPT_DIR = Path(__file__).parent
INPUT_FILE = Path(SCRIPT_DIR, "input/input.txt")
# INPUT_FILE = Path(SCRIPT_DIR, "input/sample_input.txt")

logger = logging.getLogger(__name__)

The only new imports here are:

We’ll talk about those when we get to them.


First, we read in the data:

with open(INPUT_FILE, mode="rt") as f:
    # Each input line is a nested list. Use literal_eval to convert to Python lists.
    data = [FishNumber(literal_eval(line)) for line in f.read().splitlines()]

Now for the FishNumber class:

class FishNumber():
    """ FishNumber stores a snailfish number internally. This class knows how to:
    - Add two FishNumbers to create a new FishNumber. 
    - Reduce snailfish numbers according to rules. """
    SPLIT_MIN = 10
    def __init__(self, fish_list: list) -> None:
        self._number = fish_list # internal representation as a list
    def number(self):
        return self._number
    def add(self, other: FishNumber) -> FishNumber:
        """ Creates a new FishNumber by concatenating two FishNumbers.
        Effectively, this is list extension. """
        return FishNumber([self.number] + [other.number])
    def reduce(self):
        """ Performs 'reduction' logic. I.e. explode and split, as required. """
        while True:
            if self._can_explode():
                self._number = self._explode()
            elif self._can_split():
                self._number = self._split()

    def __repr__(self) -> str:
        return str(self.number)

Things to say about this class:

Now let’s add the FishNumber methods that do the hard work. First, exploding…

    def _can_explode(self) -> bool:
        """ Checks if we can explode by counting brackets. """
        str_repr = str(self._number)
        depth_count = 0
        for char in str_repr:
            if char == "[":
                depth_count += 1
            if char == "]":
                depth_count -= 1
            if depth_count > FishNumber.EXPLODE_BRACKETS:
                return True
        return False

    def _explode(self) -> list:
        """ Explodes the current list.
        Looks for first opening bracket that is sufficiently nested. Takes the pair of digits within.  
        Adds LH to first digit on the left. (If there is one.)
        Adds RH to the first digit on the right. (If there is one.)
        Then replaces the entire bracket with 0. """
        str_repr = str(self._number)    # convert list to str
        depth_count = 0
        for i, char in enumerate(str_repr):
            if char == "[":
                depth_count += 1
            if char == "]":
                depth_count -= 1
            if depth_count > FishNumber.EXPLODE_BRACKETS:
                assert str_repr[i+1].isdigit(), "Should have been a digit here"
                left_bracket_posn = i
                comma_posn = i+1 + str_repr[i+1:].find(",")
                right_bracket_posn = comma_posn + str_repr[comma_posn:].find("]")
                left_num = int(str_repr[i+1: comma_posn])
                right_num = int(str_repr[comma_posn+1:right_bracket_posn])
                # process left of pair
                # This regex looks for the first matching digits at the end
                if match := re.match(r".*\D+(\d+).*$", str_repr[:left_bracket_posn]):
                    # match first group, i.e. (\d+)
                    num_start, num_end = match.span(1)[0], match.span(1)[1]
                    new_num = int(str_repr[num_start:num_end]) + left_num
                    # We might be inserting a bigger number
                    l_increase = len(str(new_num)) - (num_end-num_start)            
                    str_repr = str_repr[:num_start] + str(new_num) + str_repr[num_end:]
                    left_bracket_posn += l_increase
                    comma_posn += l_increase
                    right_bracket_posn += l_increase
                # process right of pair
                if match := re.search(r"\d+", str_repr[right_bracket_posn:]):
                    # match whole group
                    num_start = right_bracket_posn + match.span(0)[0]
                    num_end = right_bracket_posn + match.span(0)[1]
                    new_num = int(str_repr[num_start:num_end]) + right_num
                    str_repr = str_repr[:num_start] + str(new_num) + str_repr[num_end:]
                # replace the original pair with 0
                str_repr = str_repr[:left_bracket_posn] + "0" + str_repr[right_bracket_posn+1:]
        new_num = literal_eval(str_repr)    # convert back to list
        return new_num

Next, splitting…

    def _can_split(self) -> bool:
        """ We can split if there is a number >= 10 """
        str_repr = str(self._number)
        if re.search(r"(\d{2,})", str_repr):
            return True
        return False
    def _split(self) -> list:
        """ Split our fish number by taking the first n >= 10,
        and replacing with [floor(n/2), ceil(n/2)] """
        str_repr = str(self._number)
        if match := re.search(r"(\d{2,})", str_repr):
            num = int(match.groups()[0])
            if (num >= FishNumber.SPLIT_MIN):
                new_left_num = floor(num/2)
                new_right_num = ceil(num/2)
                new_str = "[" + str(new_left_num) + ", " + str(new_right_num) + "]"
                str_repr = re.sub(r"(\d{2,})", new_str, str_repr, count=1)

            new_num = literal_eval(str_repr)     # convert back to list
            return new_num
        assert False, "We should never get here since we're checking if we can split"
        return []

The last thing we need to add to our FishNumber class is a way to determine its magnitude:

    def magnitude(fish_num) -> int:
        """ Magnitude is given by 3*LHS + 2*RHS for any pair of values. 
        If the values are themselves lists, we must recurse.
        If the values are themselves ints, we return the int value. 
        If the value is not part of a pair, simply return the value. """
        mag = 0
        # First check if this is a pair (list)
        if isinstance(fish_num, list):
            mag = 3*FishNumber.magnitude(fish_num[0]) + 2*FishNumber.magnitude(fish_num[1])
        elif isinstance(fish_num, int): # must be a single value
            mag = fish_num
        return mag

Since every snailfish number is a pair, we need to determine the magnitude of that pair. Since each element in the pair can be another pair, we know that recursion is going to be a good way to get the magnitude.

Thus, this method works by checking if the input parameter is itself an int, or a list. If it’s an int, we just return that value. If it’s a list, we know it represents a pair, so we need to return 3*left + 2*right. And we recurse to get the values of left and right.

Now let’s run it:

# Part 1
result = reduce(fish_add, data)  # Reduce to add n to n+1, then the sum to n+2, etc
logger.info("Result = %s", result)
mag = FishNumber.magnitude(result.number)
logger.info("Part 1 magnitude = %d", mag)

Just to avoid any potential confusion: in this code snippet, I’m using functools.reduce(), not FishNumber.reduce(). We’ve come across functools.reduce() before. It applies the specified function (the first parameter) to the first two items in the iterable (the second parameter). This generates a result, and it applies the function to this result and the third parameter. And then to the result of that and the fourth parameter. And so on.

In this way, we can use functools.reduce() to perform the fish_add() method against every number in the data supplied.

The fish_add() method looks like this:

def fish_add(left: FishNumber, right: FishNumber) -> FishNumber:
    """ Create new FishNumber by concatenating left and right.
    Then reduce the resulting number and return it """
    new_fish_num = left.add(right)
    return new_fish_num  

This just uses the add() method from our FishNumber class, and then uses the FishNumber’s reduce() method on the resulting FishNumber.

Part 2

What is the largest magnitude you can get from adding only two of the snailfish numbers?

Very little additional code required here, since we’ve done all the hard work.

# Part 2
mags = []
for perm in permutations(data, 2): # All permutations of 2 fish numbers
    result = fish_add(perm[0], perm[1])
logger.info("Part 2: max magnitude = %d", max(mags))  

We use itertools.permutations() to get all permutations of two of the fish numbers, given the list of all the fish nubmers. Note that unlike itertools.combinations(), permutations considers order. I.e. a,b is different to b,a. Then add each pair of fish numbers, and determine the one with the largest magnitude.

Phew, that part was easy!

The final output looks like this:

21:34:16.431:INFO:__main__:     Result = [[[[7, 7], [7, 7]], [[7, 8], [0, 8]]], [[[8, 9], [9, 9]], [7, 7]]]
21:34:16.432:INFO:__main__:     Part 1 magnitude = 3869
21:34:45.502:INFO:__main__:     Part 2: max magnitude = 4671
21:34:45.504:INFO:__main__:     Execution time: 15.5432 seconds

Yay, it works! But it was a little slow. All that converting between list and str takes its toll. And all that str manipulation is quite slow.

We can do better!

Solution 2

This solution doesn’t do any manipulation as strings. Inatead, we create a binary tree from the list.

A tree is defined as a finite set of nodes, made up of a single root node, and one or more child nodes that are themselves leaf nodes, or are themselves trees. A binary tree is a special type of tree where:

Thus, a binary tree looks something like this:

Binary Tree

Our FishNumber is a special type of binary tree, with these properties:

Let’s take a look at a bit of this solution’s FishNumber class:

class FishNumber:
    """ A FishNumber is either a leaf node or a pair of FishNumbers """
    SPLIT_MIN = 10
    def __init__(self, val=None):
        """ Create a new FishNumber. 
        If val is an int, then this is a leaf, and left/right will be None. """
        self.val: Optional[int] = val  # leaf node value
        self.left: Optional[FishNumber] = None
        self.right: Optional[FishNumber] = None
        self.parent: Optional[FishNumber] = None
    def __str__(self):
        if isinstance(self.val, int):
            return str(self.val)
        assert isinstance(self.left, FishNumber) and isinstance(self.right, FishNumber)
        return f"[{str(self.left)},{str(self.right)}]" # print recursively
    def __repr__(self):
        msg = str(self.val) if isinstance(self.val, int) else f"[{str(self.left)},{str(self.right)}]"           
        return msg if self.parent else "FishNumber(" + msg + ")"

    def fish_reduce(self):
        """ Reduce a FishNumber 
        - Explode any pairs that are more than four deep. Repeat explode until no more explosions possible.
        - Split any numbers that are > 10. Repeat split until no more splits are possible. """
        still_reducing = True
        while still_reducing:
            still_reducing = False  # assume nothing more to do
            # DFS through the tree, starting at the root, to see if we have pairs to explode
            stack = deque()
            stack.append((self, 0))    # (tree, depth)
            while len(stack) > 0:
                node, depth = stack.pop()

                # If we're at sufficient depth and this we're dealing with a pair
                if depth >= FishNumber.EXPLODE_BRACKETS and node.val is None:
                    still_reducing = True
                    break   # we've just exploded, so start loop again

                # otherwise, add children to the DFS frontier, ensuring left is always popped first
                if node.right and node.left: 
                    stack.append((node.right, depth + 1))
                    stack.append((node.left, depth + 1))

            if still_reducing:   # We've just exploded
                continue  # So loop
            # No explosions, so now try splitting
            assert not still_reducing, "Done exploding"
            assert len(stack) == 0, "Stack should be empty"
            stack.append(self)    # Add root node. We don't care about depth now.
            while len(stack) > 0:
                node = stack.pop()
                if node.val is not None:    # we've found our leaf
                    assert node.left is None and node.right is None
                    if node.val >= FishNumber.SPLIT_MIN:
                        still_reducing = True
                        break   # back to the top
                else:   # not a leaf node, so must have children

    def parse(parse_input: list|int) -> FishNumber:
        """ Parse a list and convert to a FishNumber. 
        Recurses any nested lists, including leaf values. """
        node = FishNumber()
        if isinstance(parse_input, int):   # If a leaf node with no children
            node.val = parse_input
            return node

        assert len(parse_input) == 2, "Must be a pair in a list"
        node.left = FishNumber.parse(parse_input[0])
        node.right = FishNumber.parse(parse_input[1])
        node.left.parent = node
        node.right.parent = node

        return node   

A FishNumber is a node, and has four properties:

We use the recursive static method parse() to create a FishNumber from a top-level list; it recurses into each nested item. This method doesn’t actually need to be part of the FishNumber class; it is static, meaning it doesn’t actually use or modify any FishNumber instance attributes; rather, it creates FishNumber instances. I could have created it as a separate function, independent of the FishNumber class. However, the creation of FishNumber is conceptually related to the FishNumber class. And for that reason, I’ve elected to make it a static method of the class.

We then use a Depth-First Search (DFS) to parse our tree, starting at the root node, and traversing all the way down to the bottom of the tree, from left to right. Note that the DFS is basically the same as the BFS that we’ve used before, but with one key difference: instead of popping FIFO (as we for BFS), or based on priority (as we do for Dijkstra), we’re popping last-in, first out (LIFO). I.e. the last thing we discovered in the frontier is the first thing we now explore further.

This is how the code works:

Now let’s look at how splitting works. This is the easy reduce operation. The objective is to remove a given node value, and replace it with a new pair. Thus, the current node becomes the parent of a new pair of leaf nodes.

    def _split(self, node):
        """ Split a single value into a pair of two halves.
        (Rounding down on the left, and rounding up on the right.)
        The current node becomes the parent of new left/right nodes. """
        assert node.val >= 10, "We can only split numbers >= 10"
        node.left = FishNumber(node.val//2) # new left val
        node.right = FishNumber(node.val - (node.val//2)) # new right val
        node.left.parent = node   # left node parent is current node
        node.right.parent = node  # right node parent is current node
        node.val = None  # current node value is cleared

The method does this:

Exploding is much more difficult.

We start from a node that is sufficiently nested and contains a pair of regular numbers. The goal is:

Let’s use this diagram to help explain it:

Navigate Tree

In this example:

The strategy is:

    def _explode(self, node: FishNumber):
        """ Split a pair. The node passed to this method itself contains a pair of leaf values.
        Left node value is added to first value on the left, if there is one.
        Right node value is added to first value on the right, if there is one.
        Current node value is then set to 0. 

        Args: node ([FishNumber]): The node containing a pair we need to explode
        # First explode the left side
        prev_node = node.left
        current_node = node  # the parent of our pair of leaf values
        # Move UP the tree until we identify a node with a left (different) child
        # or until we can go no further
        while (current_node is not None and 
               (current_node.left == prev_node or current_node.left is None)):
            prev_node = current_node  # prev node moves up one
            current_node = current_node.parent  # current node now points to parent

        # Current node will be None if we previously reached the root from the left.
        # Otherwise, we must have identified a left node, so come back DOWN the left
        if current_node is not None:
            assert current_node.left is not None, "There must be a left node"
            current_node = current_node.left
            while current_node.val is None: # must have two children; keep going down until we reach a leaf
                if current_node.right is not None:
                    current_node = current_node.right   # if there's a number on the right of this node, it's nearest
                    current_node = current_node.left

            assert current_node.val is not None, "We've reached the value on the left"
            current_node.val += node.left.val   # add to the left

        # Now explode the right side
        prev_node = node.right
        current_node = node
        # traverse up the tree until we identify a node with a right (different) child
        # or until we can go no further
        while (current_node is not None and 
                (current_node.right == prev_node or current_node.right is None)):
            prev_node = current_node
            current_node = current_node.parent

        # current node will be null if we previously reached the root (so no right value)
        # otherwise, we must have identified a right node, so come back down the right
        if current_node is not None: 
            current_node = current_node.right
            while current_node.val is None:
                if current_node.left is not None:
                    current_node = current_node.left
                    current_node = current_node.right

            current_node.val += node.right.val  # add to the right

        # Final explode updates - set original node value to 0 and clear the children
        node.val = 0 
        node.left = None
        node.right = None

Finally, we need to be able to determine the magnitude. We can do this with recursion, just like before:

    def magnitude(self):
        """ Magnitude is given by 3*LHS + 2*RHS for any pair of values. 
        If the values are themselves lists, we must recurse.
        If the values are themselves ints, we return the int value. """
        if isinstance(self.val, int):
            return self.val

        assert self.left and self.right, "Must have children"
        return 3 * self.left.magnitude() + 2 * self.right.magnitude()

We run it like this:

with open(INPUT_FILE, mode="rt") as f:
    # Each input line is a nested list. 
    # Use literal_eval to convert each to a Python list.
    data = [literal_eval(line) for line in f.read().splitlines()]
# Part 1 - Sum all numbers and report magnitude
result = reduce(add, map(FishNumber.parse, data))  # Reduce to add n to n+1, then to n+2, etc
logger.info("Result = %s", result)
logger.info("Part 1 magnitude = %d", result.magnitude())

def add(left_tree: FishNumber, right_tree: FishNumber) -> FishNumber:
    """ Add two FishNumbers together.
    Creates a new parent node, with the supplied left and right set to its children. """
    new_root = FishNumber()
    new_root.left = left_tree
    new_root.right = right_tree
    new_root.left.parent = new_root
    new_root.right.parent = new_root
    new_root.fish_reduce()  # Note that this modifies the roiginal supplied FishNumbers
    return new_root

This is basically the same as Solution 1. I.e. we read in each FishNumber using literal_eval(). Then we use functools.reduce() to add each FishNumber to the next.

Part 2

This is basically the same as Part 2 for Solution #1.

mags = []
for perm in permutations(data, 2): # All permutations of 2 fish numbers
    # Quicker to parse the input data each time than deepcopy a FishNumber
    result = add(FishNumber.parse(perm[0]), FishNumber.parse(perm[1]))
logger.info("Part 2: max magnitude = %d", max(mags))

This solution runs about 8x quicker than Solution #1:

2022-01-25 21:56:40.267:INFO:__main__:  Result = [[[[7,7],[7,7]],[[7,8],[0,8]]],[[[8,9],[9,9]],[7,7]]]
2022-01-25 21:56:40.268:INFO:__main__:  Part 1 magnitude = 3869
2022-01-25 21:56:47.176:INFO:__main__:  Part 2: max magnitude = 4671
2022-01-25 21:56:47.179:INFO:__main__:  Execution time: 2.0004 seconds