Learning Python with Advent of Code Walkthroughs

Dazbo's Advent of Code solutions, written in Python

Wizard Fight

Advent of Code 2015 - Day 22

Day 22: Wizard Simulator 20XX

Useful Links

Concepts and Packages Demonstrated

DataclassCacheEnumComprehensionClassmethodClass FactoryRaising Exceptionsvarsgenerator

Problem Intro

Following on from Day 21, we’re asked to turn our RPG simulator into a Wizard Fight Simulator!

This was the worst!

It took me hours to write write. It works, but it takes a few hours to run! I think I should probably have used a depth-first search to minimise the solution space. But anyway, here goes…


Spell details:

Spell Mana Cost Description
Magic missiles 53 Does 4 instant damage
Drain 73 Does 2 instant damage and adds 2 hit points
Shield 113 Effect: effective armor is increased by 7 for 6 turns
Poison 173 Effect: deals 3 damage at the start of the turn, for 6 turns
Recharge 229 Effect: adds 101 mana at the start of each turn, for 5 turns

Part 1

Boss stats are given in the input.

What is the least amount of mana you can spend and still win the fight?

(The mana recharge effect does not count as “spending” negative mana.)

First, I can re-use my Player class from day 21:

class Player:
    """A player has three key attributes:
      hit_points (life) - When this reaches 0, the player has been defeated
      damage - Attack strength
      armor - Attack defence

    Damage done per attack = this player's damage - opponent's armor.  (With a min of 1.)
    Hit_points are decremented by an enemy attack.
    def __init__(self, name: str, hit_points: int, damage: int, armor: int):
        self._name = name
        self._hit_points = hit_points
        self._damage = damage
        self._armor = armor

    def name(self) -> str:
        return self._name

    def hit_points(self) -> int:
        return self._hit_points

    def armor(self) -> int:
        return self._armor

    def damage(self) -> int:
        return self._damage
    def take_hit(self, loss: int):
        """ Remove this hit from the current hit points """
        self._hit_points -= loss

    def is_alive(self) -> bool:
        return self._hit_points > 0

    def _damage_inflicted_on_opponent(self, other_player: Player) -> int:
        """Damage inflicted in an attack.  Given by this player's damage minus other player's armor.
        Returns: damage inflicted per attack """
        return max(self._damage - other_player.armor, 1)

    def get_attacks_needed(self, other_player: Player) -> int:
        """ The number of attacks needed for this player to defeat the other player. """
        return ceil(other_player.hit_points / self._damage_inflicted_on_opponent(other_player))

    def will_defeat(self, other_player: Player) -> bool:
        """ Determine if this player will win a fight with an opponent.
        I.e. if this player needs fewer (or same) attacks than the opponent.
        Assumes this player always goes first. """
        return (self.get_attacks_needed(other_player) 
                <= other_player.get_attacks_needed(self))

    def attack(self, other_player: Player):
        """ Perform an attack on another player, inflicting damage """
        attack_damage = self._damage_inflicted_on_opponent(other_player)
    def __str__(self):
        return self.__repr__()
    def __repr__(self):
        return f"Player: {self._name}, hit points={self._hit_points}, damage={self._damage}, armor={self._armor}"

Nothing more to say about that!

Next, a bunch of useful spell stuff:

class SpellAttributes:
    """ Define the attributes of a Spell """
    name: str
    mana_cost: int
    effect_duration: int
    is_effect: bool
    heal: int
    damage: int
    armor: int
    mana_regen: int
    delay_start: int
class SpellType(Enum):
    """ Possible spell types. 
    Any given spell_type.value will return an instance of SpellAttributes. """
    MAGIC_MISSILES = SpellAttributes('MAGIC_MISSILES', 53, 0, False, 0, 4, 0, 0, 0)
    DRAIN = SpellAttributes('DRAIN', 73, 0, False, 2, 2, 0, 0, 0)
    SHIELD = SpellAttributes('SHIELD', 113, 6, True, 0, 0, 7, 0, 0)
    POISON = SpellAttributes('POISON', 173, 6, True, 0, 3, 0, 0, 0)
    RECHARGE = SpellAttributes('RECHARGE', 229, 5, True, 0, 0, 0, 101, 0)

spell_key_lookup = {
    0: SpellType.MAGIC_MISSILES, # 53
    1: SpellType.DRAIN, # 73
    2: SpellType.SHIELD, # 113
    3: SpellType.POISON, # 173
    4: SpellType.RECHARGE # 229

spell_costs = {spell_key: spell_key_lookup[spell_key].value.mana_cost 
               for spell_key, spell_type in spell_key_lookup.items()}

The SpellAttributes class is simply a dataclass that allows me to define the various properties that make up any given Spell. Think of SpellAttributes as the schematic for a given spell. But it is not an instance of a spell.

Then, I use a SpellTypes Enum to create a set of constants, where each SpellType constant is mapped to an instance of SpellAttributes, with the required properties for that spell. I use this later to make it easier to cast spells of a specific type. So I use SpellTypes to simply map each SpellType Enum to the five SpellAttributes.

Then, a couple of useful variables:

Now I go ahead create the Spell class:

class Spell:
    """ Spells should be created using create_spell_by_type() factory method.

    Spells have a number of attributes.  Of note:
    - effects last for multiple turns, and apply on both player and opponent turns.
    - duration is the number of turns an effect lasts for
    - mana is the cost of the spell
    name: str
    mana_cost: int
    effect_duration: int
    is_effect: bool
    heal: int = 0
    damage: int = 0
    armor: int = 0
    mana_regen: int = 0
    delay_start: int = 0
    effect_applied_count = 0

    def check_spell_castable(cls, spell_type: SpellType, wiz: Wizard):
        """ Determine if this Wizard can cast this spell.
        Spell can only be cast if the wizard has sufficient mana, and if the spell is not already active.

            ValueError: If the spell is not castable

            [bool]: True if castable

        # not enough mana
        if wiz.mana < spell_type.value.mana_cost:
            raise ValueError(f"Not enough mana for {spell_type}. " \
                                f"Need {spell_type.value.mana_cost}, have {wiz.mana}.")

        # spell already active
        if spell_type in wiz.get_active_effects():
            raise ValueError(f"Spell {spell_type} already active.")
        return True
    def create_spell_by_type(cls, spell_type: SpellType):
        # Unpack the spell_type.value, which will be a SpellAttributes class
        # Get all the values, and unpack them, to pass into the factory method.
        attrs_dict = vars(spell_type.value)
        return cls(*attrs_dict.values())
    def __repr__(self) -> str:
        return f"Spell: {self.name}, cost: {self.mana_cost}, " \
                    f"is effect: {self.is_effect}, remaining duration: {self.duration}"

    def increment_effect_applied_count(self):
        self.effect_applied_count += 1

Some interesting things to say about this:

Now the tricky bit: the Wizard class. It overrides the Player class with wizard-specific behaviour. There’s nothing complicated about this class. The tricky bit is making sure that spells and effects apply at the right time during a turn. The rules are very specific!

class Wizard(Player):
    """ Extends Player.
    Also has attribute 'mana', which powers spells.
    Wizard has no armor (except when provided by spells) and no inherent damage (except from spells).

    For each wizard turn, we must cast_spell() and apply_effects().
    On each opponent's turn, we must apply_effects().
    def __init__(self, name: str, hit_points: int, mana: int, damage: int = 0, armor: int = 0):
        """ Wizards have 0 mundane armor or damage.

            name (str): Wizard name
            hit_points (int): Total life.
            mana (int): Used to power spells.
            damage (int, optional): mundane damage. Defaults to 0.
            armor (int, optional): mundane armor. Defaults to 0.
        super().__init__(name, hit_points, damage, armor)
        self._mana = mana

        # store currently active effects, where key = spell constant, and value = spell
        self._active_effects: dict[str, Spell] = {}

    def mana(self):
        return self._mana

    def use_mana(self, mana_used: int):
        if mana_used > self._mana:
            raise ValueError("Not enough mana!")
        self._mana -= mana_used

    def get_active_effects(self):
        return self._active_effects

    def take_turn(self, spell_key, other_player: Player) -> int:
        """ This player takes a turn.
        This means: casting a spell, applying any effects, and fading any expired effects

            spell_key (str): The spell key, from SpellFactory.SpellConstants
            other_player (Player): The opponent

            int: The mana consumed by this turn
        mana_consumed = self.cast_spell(spell_key, other_player)

        return mana_consumed

    def _turn(self, other_player: Player):
    def opponent_takes_turn(self, other_player: Player):
        """ An opponent takes their turn.  (Not the wizard.)
        We must apply any Wizard effects on their turn (and fade), before their attack.
        This method does not include their attack.

            other_player (Player): [description]

    def cast_spell(self, spell_type: SpellType, other_player: Player) -> int:
        """ Casts a spell.
        - If spell is not an effect, it applies once.
        - Otherwise, it applies for the spell's duration, on both player and opponent turns.

            spell_type (SpellType): a SpellType constant.
            other_player (Player): The player to cast against

            [int]: Mana consumed
        Spell.check_spell_castable(spell_type, self) # can this wizard cast this spell?
        spell = Spell.create_spell_by_type(spell_type)
        except ValueError as err:
            raise ValueError(f"Unable to cast {spell_type}: Not enough mana! " \
                             f"Needed {spell.mana_cost}; have {self._mana}.") from err

        logger.debug("%s casted %s", self._name, spell)

        if spell.is_effect:
            # add to active effects, apply later
            # this might replace a fading effect
            self._active_effects[spell_type.name] = spell
            # apply now.
            # opponent's armor counts for nothing against a magical attack
            attack_damage = spell.damage
            if attack_damage:
                logger.debug("%s attack. Inflicting damage: %s.", self._name, attack_damage)

            heal = spell.heal
            if heal:
                logger.debug("%s: healing by %s.", self._name, heal) 
                self._hit_points += heal

        return spell.mana_cost                        
    def fade_effects(self):
        effects_to_remove = []
        for effect_name, effect in self._active_effects.items():
            if effect.effect_applied_count >= effect.effect_duration:
                logger.debug("%s: fading effect %s", self._name, effect_name)
                if effect.armor:
                    # restore armor to pre-effect levels
                    self._armor -= effect.armor

                # Now we've faded the effect, flag it for removal
        # now remove any effects flagged for removal
        for effect_name in effects_to_remove:

    def apply_effects(self, other_player: Player):
        """ Apply effects in the active_effects dict.

            other_player (Player): The opponent
        for effect_name, effect in self._active_effects.items():
            # if effect should be active if we've used it fewer times than the duration
            if effect.effect_applied_count < effect.effect_duration:
                if logger.getEffectiveLevel() == logging.DEBUG:
                    logger.debug("%s: applying effect %s, leaving %d turns.", 
                            self._name, effect_name, effect.effect_duration - effect.effect_applied_count)

                if effect.armor:
                    if effect.effect_applied_count == 1:
                        # increment armor on first use, and persist this level until the effect fades
                        self._armor += effect.armor

                if effect.damage:
                if effect.mana_regen:
                    self._mana += effect.mana_regen
    def attack(self, other_player: Player):
        """ A Wizard cannot perform a mundane attack.
        Use cast_spell() instead.
        raise NotImplementedError("Wizards cast spells")

    def __repr__(self):
        return f"{self._name} (Wizard): hit points={self._hit_points}, " \
                        f"damage={self._damage}, armor={self._armor}, mana={self._mana}"

Things to note:

Most of the hard work is done. Now we’re ready to solve the problem!

Next, I’ll create a function that generates successive attack combinations to try. It returns each attack combination as a string of digits, where each digit is the key lookup for an attack, e.g. “4013” would mean:

  1. 4 = RECHARGE
  3. 1 = DRAIN
  4. 3 = POISON
def attack_combos_generator(count_different_attacks: int) -> Iterable[str]:
    """ Generator that returns the next attack combo. Pass in the number of different attack types.
    E.g. with 5 different attacks, it will generate...
    0, 1, 2, 3, 4, 10, 11, 12, 13, 14, 20, 21, 22, 23, 24, etc
    i = 0
    while True:
        # convert i to base-n (where n is the number of attacks we can choose from) 
        yield td.to_base_n(i, count_different_attacks)
        i += 1

Note that this is a generator. It increments the value of i with each call. And it is infinite. Later, we’ll use this generator in a loop, and we’ll need an exit condition, otherwise this will loop forever.

My generator also calls a new to_base_n() function, which I’ve moved to my common type_defs.py module. It converts any supplied number to a supplied base, and then returns the str representation of that number. Here, I’m using it convert to base-5. Why? Because I have 5 unique attack types, and I want my attack lookup string to contain only the digits 0 to 4. So my to_base_n() function does this conversion:

Decimal Number In Base-5
0 0
1 1
2 2
3 3
4 4
5 10
6 11
7 12
8 13
9 14
10 20
11 21
12 22

Now I create a method that calculates the overall mana cost for any given attack sequence. Why? Because if we’re testing an attack sequence that has a higher cost than a previous winning attack sequence, then there’s no point in even trying it. It saves us playing the game with this sequence.

@cache # I think there are only about 3000 different sorted attacks
def get_combo_mana_cost(attack_combo_lookup: str) -> int:
    """ Pass in attack combo lookup str, and return the cost of this attack combo.
    Ideally, the attack combo lookup should be sorted, because cost doesn't care about attack order;
    and providing a sorted value, we can use a cache. """
    return sum(spell_costs[int(attack)] for attack in attack_combo_lookup)

It works by using a list comprehension to obtain the integer value of each digit in the attack sequence, and then summing them.

The interesting thing about this get_combo_mana_cost() is that it caches the calculated cost for a given attack sequence. But every attack sequence is unique, so why bother caching? Well, although every attack sequence is unique, the cost of a given sequence only depends on the attacks contained, not in the order of those attacks. So, if I sort each attack before checking its cost, it turns out that there are only a few thousand unique costs, and it’s very efficient to cache these.

Now let’s read in the boss stats from the input data:

def main():
    # boss stats are determined by an input file
    with open(path.join(locations.input_dir, BOSS_FILE), mode="rt") as f:
        boss_hit_points, boss_damage = process_boss_input(f.read().splitlines())
        actual_boss = Player("Actual Boss", hit_points=boss_hit_points, damage=boss_damage, armor=0)

    player = Wizard("Bob", hit_points=50, mana=500)

    winning_games, least_winning_mana = try_combos(actual_boss, player)

    message = "Winning solutions:\n" + "\n".join(f"Mana: {k}, Attack: {v}" for k, v in winning_games.items())
    logger.info("We found %d winning solutions. Lowest mana cost was %d.", len(winning_games), least_winning_mana)

def process_boss_input(data:list[str]) -> tuple:
    """ Process boss file input and return tuple of hit_points, damage

        tuple: hit_points, damage
    boss = {}
    for line in data:
        key, val = line.strip().split(":")
        boss[key] = int(val)

    return boss['Hit Points'], boss['Damage']

Recall that the input data looks something like this:

Hit Points: 71
Damage: 10

So we read each line, split the line at the : to return the key and its value. Convert the value to an int, and store it in a dictionary called boss. Finally, return the two values in this dictionary, as a tuple. We use these two values to construct a new Player called Actual Boss.

And, of course, we create a Wizard to represent our player.

Then we need to try out the attack sequence combinations:

def try_combos(boss_stats: Player, plyr_stats: Wizard):

    winning_games = {}
    least_winning_mana = 2500 # ball-park of what will likely be larger than winning solution
    ignore_combo = "9999999"
    player_has_won = False
    last_attack_len = 0
    # This is an infinite generator, so we need an exit condition
    for attack_combo_lookup in attack_combos_generator(len(spell_key_lookup)): 
        # play the game with this attack combo
        # since attack combos are returned sequentially, 
        # we can ignore any that start with the same attacks as the last failed combo.
        if attack_combo_lookup.startswith(ignore_combo):
        # determine if the cost of the current attack is going to be more than an existing
        # winning solution. (Sort it, so we can cache the attack cost.)
        sorted_attack = ''.join(sorted(attack_combo_lookup))
        if get_combo_mana_cost(sorted_attack) >= least_winning_mana:
        # Much faster than a deep copy
        boss = Player(boss_stats.name, boss_stats.hit_points, boss_stats.damage, boss_stats.armor)
        player = Wizard(plyr_stats.name, plyr_stats.hit_points, plyr_stats.mana)
        if player_has_won and logger.getEffectiveLevel() == logging.DEBUG:
            logger.debug("Best winning attack: %s. Total mana: %s. Current attack: %s", 
                        winning_games[least_winning_mana], least_winning_mana, attack_combo_lookup)
            logger.debug("Current attack: %s", attack_combo_lookup)

        player_won, mana_consumed, rounds_started = play_game(
                attack_combo_lookup, player, boss, mana_target=least_winning_mana)
        if player_won:
            player_has_won = True
            winning_games[mana_consumed] = attack_combo_lookup
            least_winning_mana = min(mana_consumed, least_winning_mana)
            logger.info("Found a winning solution, with attack %s consuming %d", attack_combo_lookup, mana_consumed)
        attack_len = len(attack_combo_lookup)
        if (attack_len > last_attack_len):
            if player_has_won:
                # We can't play forever. Assume that if the last attack length didn't yield a better result
                # then we're not going to find a better solution.
                if len(attack_combo_lookup) > len(winning_games[least_winning_mana]) + 1:
                    logger.info("Probably not getting any better. Exiting.")
                    break # We're done!
            logger.info("Trying attacks of length %d", attack_len)
        last_attack_len = attack_len

        # we can ingore any attacks that start with the same attacks as what we tried last time
        ignore_combo = attack_combo_lookup[0:rounds_started]
    return winning_games, least_winning_mana

Things to say about this…

Finally, this is how we actually play the game:

def play_game(attack_combo_lookup: str, player: Wizard, boss: Player, **kwargs) -> tuple[bool, int, int]:
    """ Play a game, given a player (Wizard) and an opponent (boss)

        attacks (list[str]): List of spells to cast, from SpellFactory.SpellConstants
        player (Wizard): A Wizard
        boss (Player): A mundane opponent
        hard_mode (Bool): Whether each player turn automatically loses 1 hit point
        mana_target (int): optional arg, that specifies a max mana consumed value which triggers a return

        tuple[bool, int, int]: player won, mana consumed, number of rounds
    # Convert the attack combo to a list of spells. E.g. convert '00002320'
    # to [<SpellType.MAGIC_MISSILES: ..., <SpellType.MAGIC_MISSILES: ..., 
    #    ... <SpellType.SHIELD: ..., <SpellType.MAGIC_MISSILES: ... >]
    attacks = [spell_key_lookup[int(attack)] for attack in attack_combo_lookup]    

    game_round = 1
    current_player = player
    other_player = boss    

    mana_consumed: int = 0
    mana_target = kwargs.get('mana_target', None)

    while (player.hit_points > 0 and boss.hit_points > 0):
        if current_player == player:
            # player (wizard) attack
            if logger.getEffectiveLevel() == logging.DEBUG:
                logger.debug("Round %s...", game_round)
                logger.debug("%s's turn:", current_player.name)
                mana_consumed += player.take_turn(attacks[game_round-1], boss)
                if mana_target and mana_consumed > mana_target:
                    logger.debug('Mana target %s exceeded; mana consumed=%s.', mana_target, mana_consumed)
                    return False, mana_consumed, game_round
            except ValueError as err:
                return False, mana_consumed, game_round
            except IndexError:
                logger.debug("No more attacks left.")
                return False, mana_consumed, game_round

            logger.debug("%s's turn:", current_player.name)
            # effects apply before opponent attacks
            if boss.hit_points <= 0:
                logger.debug("Effects killed %s!", boss.name)

            game_round += 1
        if logger.getEffectiveLevel() == logging.DEBUG:
            logger.debug("End of turn: %s", player)
            logger.debug("End of turn: %s", boss)

        # swap players
        current_player, other_player = other_player, current_player

    player_won = player.hit_points > 0
    return player_won, mana_consumed, game_round

This works by letting each player take a turn, and then swapping which player is the current player. It loops until one of the players no longer has any hit points. If we exit the loop and the player still has hit points, then the player has won.

And that’s it!

Part 2

Now we’re told to play the game in _hard_mode. The player loses 1 hit point with each turn.

As before:

What is the least amount of mana you can spend and still win the fight?

Fortunately, the changes required here are trivial.

In play_game(), I just add this before the player takes their turn:

            if hard_mode:
                logger.debug("Hard mode hit. Player hit points reduced by 1.")
                if player.hit_points <= 0:
                    logger.debug("Hard mode killed %s", boss.name)

And add a hard_mode parameter to the function signature:

def play_game(attack_combo_lookup: str, player: Wizard, boss: Player, hard_mode=False, **kwargs) -> tuple[bool, int, int]:


There’s so much that can wrong in this code. I thought it was sensible to have unit tests that would allow me to validate the code is working correctly, but also to check that I’m not breaking anything when refactoring.

So here’s my test:

""" Unit testing for Spell_Casting 
    0: SpellType.MAGIC_MISSILES, # 53
    1: SpellType.DRAIN, # 73
    2: SpellType.SHIELD, # 113
    3: SpellType.POISON, # 173
    4: SpellType.RECHARGE # 229
import unittest
import logging
from spell_casting import (
        Player, Wizard, 
        play_game, try_combos, get_combo_mana_cost)

class TestPlayGame(unittest.TestCase):
    """ Test single game, and combos """
    def setUp(self):
    def run(self, result=None):
        """ Override run method so we can include method name in output """
        method_name = self._testMethodName
        logger.info("Running test: %s", method_name)
    def test_play_game_42130(self):
        """ Test a simple game, in _normal_ diffulty.
        As supplied in the game instructions. """  
        logger.setLevel(logging.DEBUG) # So we can look at each turn and compare to the instructions
        player = Wizard("Bob", hit_points=10, mana=250)
        boss = Player("Boss", hit_points=14, damage=8, armor=0)
        player_won, mana_consumed, rounds_started = play_game("42130", player, boss)
        self.assertEqual(player_won, True)
        self.assertEqual(mana_consumed, 641)
        self.assertEqual(rounds_started, 5)

    def test_play_game_42130_hard_mode(self):
        """ Test a simple game, in _hard_ diffulty. I.e. player loses 1 hitpoint per turn. """        
        player = Wizard("Bob", hit_points=10, mana=250)
        boss = Player("Boss", hit_points=14, damage=8, armor=0)
        player_won, mana_consumed, rounds_started = play_game("42130", player, boss, hard_mode=True)
        self.assertEqual(player_won, False)
        self.assertEqual(mana_consumed, 229)
        self.assertEqual(rounds_started, 2)
    def test_play_game_304320_hard_mode(self):
        player = Wizard("Bob", hit_points=50, mana=500)
        boss = Player("Boss", hit_points=40, damage=10, armor=0)
        player_won, mana_consumed, rounds_started = play_game("304320", player, boss, hard_mode=True)
        self.assertEqual(player_won, True)
        self.assertEqual(mana_consumed, 794)
        self.assertEqual(rounds_started, 6)
    def test_play_game_34230000_hard_mode(self):
        player = Wizard("Bob", hit_points=50, mana=500)
        boss = Player("Boss", hit_points=45, damage=10, armor=0)
        player_won, mana_consumed, rounds_started = play_game("34230000", player, boss, hard_mode=True)
        self.assertEqual(player_won, True)
        self.assertEqual(mana_consumed, 847) # only uses 7 of the 8 attacks, otherwise would be 900
        self.assertEqual(rounds_started, 7)
    def test_play_game_224304300300_hard_mode(self):
        player = Wizard("Bob", hit_points=50, mana=500)
        boss = Player("Boss", hit_points=71, damage=10, armor=0)
        player_won, mana_consumed, rounds_started = play_game("224304300300", player, boss, hard_mode=True)
        self.assertEqual(player_won, True)
        self.assertEqual(mana_consumed, 1468)
        self.assertEqual(rounds_started, 12)        
    def test_calculate_mana_cost(self):
        self.assertEqual(get_combo_mana_cost("42130"), 641)
        self.assertEqual(get_combo_mana_cost("34230000"), 900)
    def test_try_combos(self):
        """ Try multiple games, testing combos to find the winning combo that consumes the least mana """
        boss = Player("Boss", hit_points=40, damage=10, armor=0) # Use hp 50 to see improving solutions
        player = Wizard("Bob", hit_points=50, mana=500)
        winning_games, least_winning_mana = try_combos(boss, player)
        logger.info("We found %d winning solutions. Lowest mana cost was %d.", len(winning_games), least_winning_mana)
        message = "Winning solutions:\n" + "\n".join(f"Mana: {k}, Attack: {v}" for k, v in winning_games.items())
        self.assertEqual(least_winning_mana, 794) # with 40, 10, 8

if __name__ == '__main__':

Some things to say about this…

It’s quick to run, and easy to change.


It’s not short. It’s not even fast.

Boss Fight Output