Today’s problem was all about connected components where the nodes were 3D positions, so the first thing I did was to define a data structure for 3D positions:
structure Position3D where
x : Int
y : Int
z : IntTo avoid leaving the beautiful, simple world of integers, I defined a method for computing the square distance between positions, hoping that I wouldn’t need to find the actual distances:
def Position3D.squaredDistance (p1 p2 : Position3D) : Int :=
(p1.x - p2.x)^2 + (p1.y - p2.y)^2 + (p1.z - p2.z)^2I then parsed each position and was on my merry way. The first thing I did was
to build an array with each pair of positions and their squared distance, sorted
by that distance. I then created a list of HashSets to represent each
connected component, manually removing components as I inserted a merged one:
-- part 1
let (connected, _) := distances.foldl (fun acc dist =>
match acc with
| ⟨_,0⟩ => acc
| ⟨circuits,cords⟩ =>
let ⟨⟨b1,b2⟩, d⟩ := dist
let b1Circuit := circuits.find? (b1 ∈ ·) |>.get!
let b2Circuit := circuits.find? (b2 ∈ ·) |>.get!
let circuits' := circuits.filter (λ c => b1 ∉ c && b2 ∉ c)
|>.push (b1Circuit ∪ b2Circuit)
(circuits', cords - 1)
) (circuits, 1000)
let largestThree := connected.map (·.size) |>.qsort |>.reverse.toList.take 3However, this approach is both ugly and inefficient. And this fold is
begging for an early termination, so why not write imperative code with Lean’s
imperative features? Behold! Mutation and a for loop:
let mut circuits' := circuits
let mut cords := 1000
for ⟨⟨b1, b2⟩, _⟩ in distances do
let b1Circuit := circuits'.find? (b1 ∈ ·) |>.get!
let b2Circuit := circuits'.find? (b2 ∈ ·) |>.get!
circuits' := circuits'.filter (λ c => b1 ∉ c && b2 ∉ c)
|>.push (b1Circuit ∪ b2Circuit)
cords := cords - 1
if cords == 0 then
breakThis linear scan to find which component a box belongs to for each iteration is
still an eyesore, though, so let’s use the good old
union-find. It
turns out there is a UF implementation in the Batteries library
so we don’t even have to write our own. However, I ended up adding two helpers I
needed to my utility
library:
def Batteries.UnionFind.clusterSizes (self : Batteries.UnionFind) : Array (Nat × Nat) :=
let allRoots := Array.range self.size |>.map self.rootD
let uniqueRoots := allRoots.toList.eraseDups.toArray
uniqueRoots.map (λ root => (root, allRoots.filter (· == root) |>.size))
|>.qsort (·.snd > ·.snd)
def Batteries.UnionFind.numClusters (self : Batteries.UnionFind) : Nat :=
Array.range self.size |>.foldl (λ acc i =>
acc.insert (self.rootD i)
) (∅ : Std.HashSet Nat)
|>.sizeWith that in place, our imperative solution for part 1 is much simpler and more efficient:
let mut circuits' := circuits
let mut cords := 1000
for ⟨⟨b1, b2⟩, _⟩ in distances do
circuits' := circuits'.union! b1 b2
cords := cords - 1
if cords == 0 then
break
let largestThree := circuits'.clusterSizes.map (·.snd) |>.toList.take 3Things I (re-)learned today
- Using mutable variables.
- Using imperative-style
forloops.
Solution
import Batteries.Data.List
import Batteries.Data.UnionFind
import Aoc
open Aoc
structure Position3D where
x : Int
y : Int
z : Int
deriving BEq, Hashable, Repr, Inhabited
def Position3D.parse (s : String) : Except String Position3D :=
match s.splitOn "," with
| [x, y, z] => { x:= x.toInt!, y:= y.toInt!, z:= z.toInt! } |> Except.ok
| _ => Except.error s!"Invalid Position3D: {s}"
def Position3D.squaredDistance (p1 p2 : Position3D) : Int :=
(p1.x - p2.x)^2 + (p1.y - p2.y)^2 + (p1.z - p2.z)^2
def main : IO Unit := do
let input <- readLines "Day08/input.txt"
let boxes <- input |> Array.mapM Position3D.parse |> IO.ofExcept
-- a union-find data structure representing the circuits
let circuits := boxes.foldl (λ uf _ => uf.push) (Batteries.UnionFind.empty)
-- a Nat-index map of junction boxes for looking up by index
let index := boxes.toList.zipIdx
|>.map (λ (b, i) => (i, b)) |> Std.HashMap.ofList
-- the squared distance between pairs of boxes, from smallest to largest
let distances := (List.range boxes.size).tails.foldl (
λ acc boxes =>
match boxes with
| [] => acc
| b1i :: bs =>
let newDists := bs.map (λ b2i =>
let b1 := index.get! b1i
let b2 := index.get! b2i
((b1i, b2i), b1.squaredDistance b2)
)
acc ++ newDists
) ∅ |> Array.mk |>.qsort (·.snd < ·.snd)
-- part 1
let mut circuits' := circuits
let mut cords := 1000
for ⟨⟨b1, b2⟩, _⟩ in distances do
circuits' := circuits'.union! b1 b2
cords := cords - 1
if cords == 0 then
break
let largestThree := circuits'.clusterSizes.map (·.snd) |>.toList.take 3
IO.println s!"Part 1: {largestThree.prod}"
-- part 2
circuits' := circuits
let mut lastPair : Option (Nat × Nat) := none
for ⟨⟨b1, b2⟩, _⟩ in distances do
circuits' := circuits'.union! b1 b2
if circuits'.numClusters == 1 then
lastPair := some (b1, b2)
break
let ⟨b1i,b2i⟩ := lastPair.get!
let b1 := index.get! b1i
let b2 := index.get! b2i
IO.println s!"Part 2: {b1.x * b2.x}"