Dazbo's Advent of Code solutions, written in Python
Recursion (Wikipedia)Recursion Introduction (@ RealPython)
I have to admit, I used to struggle with recursion. A lot of people do. It can be a bit mind-bending. But it’s a pretty simple concept and can be very useful.
In short: a recursive function is a function that calls itself. Thus, the code that defines the function will include a call to the same function.
As an anology, take a look at examples of recursive acronyms. See how the acronym definition includes the acronym itself!
Acronym | Definition |
---|---|
GNU | = GNU’s not Linux |
LAME | = LAME Ain’t an MP3 Encoder |
YAML | YAML Ain’t Markup Lanugage |
Typical use cases include:
We’ll look at example of these in a bit.
When creating a recursive function, there are only two rules you need to know:
We want to create recursive function that counts down from an arbitary number n
to 0. We can do it like this:
def countdown(n):
print(n)
if n == 0:
return # Terminate recursion
else:
countdown(n - 1) # Recursive call, one closer to the base case
As per the rules:
n
each time by 1.We can simplify this code:
def countdown(n):
print(n)
if n > 0:
countdown(n - 1) # Recursive call, one closer to the base case
Let’s try it. I’ve added the above code to a file called scratch.py, in my snippets folder. I’ll now execute it from the Python REPL:
>>> from snippets.scratch import *
>>> countdown(5)
5
4
3
2
1
0
Recall the defition of factorial:
\(k! = k * (k-1)\)
This is slightly tricker than the previous example, since we’re not just printing a value with each iteration. Instead, we’re always multiplying the current iteration by the result of the previous iteration.
So we can code it like this:
def factorial(n):
return 1 if n <= 1 else n * factorial(n - 1)
n == 1
. In this situation, factorial
should always return 1.n
is decremented by 1.Note that it’s common for any recusive function that calculates a product to have an exit condition that returns 1.
We can see how function works by adding some debugging statements:
def factorial(n):
print(f"factorial() called with n = {n}")
return_value = 1 if n <= 1 else n * factorial(n -1)
print(f"-> factorial({n}) returns {return_value}")
return return_value
Let’s run it from the REPL:
>>> from snippets.scratch import *
>>> factorial(4)
factorial() called with n = 4
factorial() called with n = 3
factorial() called with n = 2
factorial() called with n = 1
-> factorial(1) returns 1
-> factorial(2) returns 2
-> factorial(3) returns 6
-> factorial(4) returns 24
24
Note how each return
is the product of n
and the previous return value.
Here we create a recursive function that counts all the individual elements in a list. If the list is nested, the function recurses into each sub list, adding the elements of that list to the overall count.
def count_leaf_items(item_list):
"""Recursively counts and returns the number of leaf items
in a (potentially nested) list. """
count = 0
for item in item_list:
if isinstance(item, list): # if the element is itself a list, recurse...
count += count_leaf_items(item)
else: # count the item
# this is the exit condition, i.e. when we've reached a leaf (element) rather than a nested list
count += 1
return count
Let’s try this…
nested_list = [2, [3,5], [[10,20],30]]
print(nested_list)
res = count_leaf_items(nested_list)
print(res)
Output:
[2, [3, 5], [[10, 20], 30]]
6
The Fibonacci sequence is an infinite sequence that generates the next number by adding the two preceding numbers.
1, 1, 2, 3, 5, 8, 13, 21...
I.e. to determine the nth
value in the sequence:
\(f(n) = f(n-2) + f(n-1)\)
The base case is where n
is 1
, which returns a value of 1
.
def fib(num: int) -> int:
""" Recursive function to determine nth value of Fibonacci sequence.
I.e. 1, 1, 2, 3, 5, 8, 13, 21...
fib(n) = fib(n-2) + fib(n-1)
Args:
num (int): value of n, i.e. to determine nth value
Returns:
int: The nth value of the Fibonacci sequence
"""
if num > 2:
return fib(num-2) + fib(num-1)
else:
return 1
while True:
try:
input_val = input("Enter the value of n, or q to quit: ")
if input_val.upper() == "Q":
break
print(fib(int(input_val)))
except ValueError as err:
print("Invalid input")
Note: this isn’t a particularly efficient function. It doesn’t scale well!
An arithmetic progression (AP) is a sequence of numbers in which the difference of any two successive members is a constant. This difference is commonly referred to as the “common difference”. For example:
Progression: 0 3 6 9 12 15 18
Common diff: 3 3 3 3 3 3
A second-degree arithmetic progression is one in which the differences between terms is growing, but growing by a constant amount. Thus, differences of differences are common:
Triangle numbers are a common example:
Progression: 1 3 6 10 15 21
First diff: 2 3 4 5 6
Second (common) diff: 1 1 1 1
We can extrapolate this to the Nth Degree. I.e. the number of times you have to determine differences, before the differences are common. If you determine the number of degrees after which the differences are common, you can bubble the results back up to the top, in order to determine the next term in the sequence.
So this is a good candidate for a recursive function:
def recurse_diffs(sequence: np.ndarray, forwards=True) -> int:
"""
Calculate the next value in a numeric sequence based on the pattern of differences.
Recursively analyses the differences between consecutive elements of the sequence. Recurses until the differences remain constant. It then calculates the next value in the sequence based on this constant difference.
Parameters:
sequence (np.ndarray): A NumPy array representing the sequence.
forwards (bool, optional): A flag to determine the direction of progression.
If True (default), the function calculates the next value.
If False, it calculates the previous value in the sequence.
Returns:
int: The next (or previous) value in the sequence
"""
diffs = np.diff(sequence)
op = operator.add if forwards else operator.sub
term = sequence[-1] if forwards else sequence[0]
# Check if all the diffs are constant
# If they are, we've reached the deepest point in our recursion, and we know the constant diff
if np.all(diffs == diffs[0]):
next_val = op(term, diffs[0])
else: # if the diffs are not constant, then we need to recurse
diff = recurse_diffs(diffs, forwards)
next_val = op(term, diff)
return int(next_val)
__lt__
compare - 2022 day 13