Advent of Code 2025 in Lean 4 – Day 8

Today’s problem was all about connected components where the nodes were 3D positions, so the first thing I did was to define a data structure for 3D positions:

structure Position3D where
  x : Int
  y : Int
  z : Int

To avoid leaving the beautiful, simple world of integers, I defined a method for computing the square distance between positions, hoping that I wouldn’t need to find the actual distances:

def Position3D.squaredDistance (p1 p2 : Position3D) : Int :=
  (p1.x - p2.x)^2 + (p1.y - p2.y)^2 + (p1.z - p2.z)^2

I then parsed each position and was on my merry way. The first thing I did was to build an array with each pair of positions and their squared distance, sorted by that distance. I then created a list of HashSets to represent each connected component, manually removing components as I inserted a merged one:

-- part 1
let (connected, _) := distances.foldl (fun acc dist =>
  match acc with
  | ⟨_,0 => acc
  | circuits,cords =>
     let ⟨⟨b1,b2⟩, d := dist
      let b1Circuit := circuits.find? (b1  ·) |>.get!
      let b2Circuit := circuits.find? (b2  ·) |>.get!
      let circuits' := circuits.filter  c => b1  c && b2  c)
        |>.push (b1Circuit  b2Circuit)
      (circuits', cords - 1)
) (circuits, 1000)

let largestThree := connected.map (·.size) |>.qsort |>.reverse.toList.take 3

However, this approach is both ugly and inefficient. And this fold is begging for an early termination, so why not write imperative code with Lean’s imperative features? Behold! Mutation and a for loop:

let mut circuits' := circuits
let mut cords := 1000
for ⟨⟨b1, b2⟩, _⟩ in distances do
      let b1Circuit := circuits'.find? (b1  ·) |>.get!
      let b2Circuit := circuits'.find? (b2  ·) |>.get!
      circuits' := circuits'.filter  c => b1  c && b2  c)
        |>.push (b1Circuit  b2Circuit)
      cords := cords - 1
      if cords == 0 then
        break

This linear scan to find which component a box belongs to for each iteration is still an eyesore, though, so let’s use the good old union-find. It turns out there is a UF implementation in the Batteries library so we don’t even have to write our own. However, I ended up adding two helpers I needed to my utility library:

def Batteries.UnionFind.clusterSizes (self : Batteries.UnionFind) : Array (Nat × Nat) :=
  let allRoots := Array.range self.size |>.map self.rootD
  let uniqueRoots := allRoots.toList.eraseDups.toArray
  uniqueRoots.map  root => (root, allRoots.filter (· == root) |>.size))
    |>.qsort (·.snd > ·.snd)

def Batteries.UnionFind.numClusters (self : Batteries.UnionFind) : Nat :=
  Array.range self.size |>.foldl  acc i =>
    acc.insert (self.rootD i)
  ) (: Std.HashSet Nat)
    |>.size

With that in place, our imperative solution for part 1 is much simpler and more efficient:

let mut circuits' := circuits
let mut cords := 1000
for ⟨⟨b1, b2⟩, _⟩ in distances do
      circuits' := circuits'.union! b1 b2
      cords := cords - 1
      if cords == 0 then
        break

let largestThree := circuits'.clusterSizes.map (·.snd) |>.toList.take 3

Things I (re-)learned today

  • Using mutable variables.
  • Using imperative-style for loops.

Solution

import Batteries.Data.List
import Batteries.Data.UnionFind
import Aoc
open Aoc

structure Position3D where
  x : Int
  y : Int
  z : Int
  deriving BEq, Hashable, Repr, Inhabited

def Position3D.parse (s : String) : Except String Position3D :=
  match s.splitOn "," with
  | [x, y, z] => { x:= x.toInt!, y:= y.toInt!, z:= z.toInt! } |> Except.ok
  | _ => Except.error s!"Invalid Position3D: {s}"

def Position3D.squaredDistance (p1 p2 : Position3D) : Int :=
  (p1.x - p2.x)^2 + (p1.y - p2.y)^2 + (p1.z - p2.z)^2

def main : IO Unit := do
  let input <- readLines "Day08/input.txt"
  let boxes <- input |> Array.mapM Position3D.parse |> IO.ofExcept

  -- a union-find data structure representing the circuits
  let circuits := boxes.foldl  uf _ => uf.push) (Batteries.UnionFind.empty)

  -- a Nat-index map of junction boxes for looking up by index
  let index := boxes.toList.zipIdx
    |>.map  (b, i) => (i, b)) |> Std.HashMap.ofList

  -- the squared distance between pairs of boxes, from smallest to largest
  let distances := (List.range boxes.size).tails.foldl (
    λ acc boxes =>
      match boxes with
      | [] => acc
      | b1i :: bs =>
        let newDists := bs.map  b2i =>
          let b1 := index.get! b1i
          let b2 := index.get! b2i
          ((b1i, b2i), b1.squaredDistance b2)
        )
        acc ++ newDists
  )|> Array.mk |>.qsort (·.snd < ·.snd)

  -- part 1
  let mut circuits' := circuits
  let mut cords := 1000
  for ⟨⟨b1, b2⟩, _⟩ in distances do
        circuits' := circuits'.union! b1 b2
        cords := cords - 1
        if cords == 0 then
          break

  let largestThree := circuits'.clusterSizes.map (·.snd) |>.toList.take 3
  IO.println s!"Part 1: {largestThree.prod}"

  -- part 2
  circuits' := circuits
  let mut lastPair : Option (Nat × Nat) := none
  for ⟨⟨b1, b2⟩, _⟩ in distances do
    circuits' := circuits'.union! b1 b2
    if circuits'.numClusters == 1 then
        lastPair := some (b1, b2)
        break

  let b1i,b2i := lastPair.get!
  let b1 := index.get! b1i
  let b2 := index.get! b2i
  IO.println s!"Part 2: {b1.x * b2.x}"