|
27 | 27 | "from operator import itemgetter\n",
|
28 | 28 | "\n",
|
29 | 29 | "\n",
|
| 30 | + "class BaseAlmanac[T: \"AlmanacMap\"]:\n", |
| 31 | + " map_cls: type[T]\n", |
| 32 | + " maps: dict[str, T]\n", |
| 33 | + "\n", |
| 34 | + " @classmethod\n", |
| 35 | + " def from_entries(cls, *entries: str) -> t.Self:\n", |
| 36 | + " seeds_line, *table = entries\n", |
| 37 | + " seeds = [int(seed) for seed in seeds_line.partition(\": \")[-1].split()]\n", |
| 38 | + " maps = {map_.from_: map_ for map_ in map(cls.map_cls.from_entry, table)}\n", |
| 39 | + " return cls(seeds, maps)\n", |
| 40 | + "\n", |
| 41 | + " def seed_location(self, seed: int) -> int:\n", |
| 42 | + " current = \"seed\"\n", |
| 43 | + " value = seed\n", |
| 44 | + " while current != \"location\":\n", |
| 45 | + " map_ = self.maps[current]\n", |
| 46 | + " current = map_.to_\n", |
| 47 | + " value, _ = map_.lookup(value)\n", |
| 48 | + " return value\n", |
| 49 | + "\n", |
| 50 | + "\n", |
30 | 51 | "@dataclass\n",
|
31 | 52 | "class AlmanacMap:\n",
|
32 | 53 | " from_: str\n",
|
|
43 | 64 | " ]\n",
|
44 | 65 | " return cls(from_, to_, sorted(ranges, key=itemgetter(0)))\n",
|
45 | 66 | "\n",
|
46 |
| - " def __getitem__(self, value: int) -> tuple[int, int | None]:\n", |
| 67 | + " def lookup(self, value: int) -> tuple[int, int | None]:\n", |
47 | 68 | " \"\"\"Map a value through the almanac table\n",
|
48 | 69 | "\n",
|
49 | 70 | " Returns the new value, and the remaining length of the source section it\n",
|
|
61 | 82 | "\n",
|
62 | 83 | "\n",
|
63 | 84 | "@dataclass\n",
|
64 |
| - "class Almanac:\n", |
| 85 | + "class Almanac(BaseAlmanac[AlmanacMap]):\n", |
| 86 | + " map_cls = AlmanacMap\n", |
65 | 87 | " seeds: list[int]\n",
|
66 | 88 | " maps: dict[str, AlmanacMap]\n",
|
67 | 89 | "\n",
|
68 |
| - " @classmethod\n", |
69 |
| - " def from_entries(cls, *entries: str) -> t.Self:\n", |
70 |
| - " seeds_line, *entries = entries\n", |
71 |
| - " seeds = [int(seed) for seed in seeds_line.partition(\": \")[-1].split()]\n", |
72 |
| - " maps = {map_.from_: map_ for map_ in map(AlmanacMap.from_entry, entries)}\n", |
73 |
| - " return cls(seeds, maps)\n", |
74 |
| - "\n", |
75 |
| - " def __getitem__(self, seed: int) -> int:\n", |
76 |
| - " current = \"seed\"\n", |
77 |
| - " value = seed\n", |
78 |
| - " while current != \"location\":\n", |
79 |
| - " map_ = self.maps[current]\n", |
80 |
| - " current = map_.to_\n", |
81 |
| - " value, _ = map_[value]\n", |
82 |
| - " return value\n", |
83 |
| - "\n", |
84 | 90 | "\n",
|
85 | 91 | "test_almanac_text = \"\"\"\\\n",
|
86 | 92 | "seeds: 79 14 55 13\n",
|
|
118 | 124 | "56 93 4\n",
|
119 | 125 | "\"\"\"\n",
|
120 | 126 | "test_almanac = Almanac.from_entries(*test_almanac_text.split(\"\\n\\n\"))\n",
|
121 |
| - "assert min(test_almanac[seed] for seed in test_almanac.seeds) == 35" |
| 127 | + "assert min(map(test_almanac.seed_location, test_almanac.seeds)) == 35" |
122 | 128 | ]
|
123 | 129 | },
|
124 | 130 | {
|
|
137 | 143 | "source": [
|
138 | 144 | "import aocd\n",
|
139 | 145 | "\n",
|
140 |
| - "almanac = Almanac.from_entries(*aocd.get_data(day=5, year=2023).split(\"\\n\\n\"))\n", |
141 |
| - "print(\"Part 1:\", min(almanac[seed] for seed in almanac.seeds))" |
| 146 | + "almanac_entries = aocd.get_data(day=5, year=2023).split(\"\\n\\n\")\n", |
| 147 | + "almanac = Almanac.from_entries(*almanac_entries)\n", |
| 148 | + "print(\"Part 1:\", min(map(almanac.seed_location, almanac.seeds)))" |
142 | 149 | ]
|
143 | 150 | },
|
144 | 151 | {
|
|
166 | 173 | "\n",
|
167 | 174 | "\n",
|
168 | 175 | "class RangeAlmanacMap(AlmanacMap):\n",
|
169 |
| - " def __getitem__(self, values: tuple[range, ...]) -> tuple[range, ...]:\n", |
170 |
| - " results = []\n", |
| 176 | + " def lookup_range(self, values: tuple[range, ...]) -> tuple[range, ...]:\n", |
| 177 | + " results: list[range] = []\n", |
171 | 178 | " queue = deque(values)\n",
|
172 | 179 | " while queue:\n",
|
173 | 180 | " value = queue.popleft()\n",
|
174 | 181 | " size = len(value)\n",
|
175 |
| - " dst, remainder = super().__getitem__(value.start)\n", |
| 182 | + " dst, remainder = self.lookup(value.start)\n", |
176 | 183 | " if remainder and size > remainder:\n",
|
177 | 184 | " # process the section that doesn't fit\n",
|
178 | 185 | " queue.append(value[remainder:])\n",
|
|
183 | 190 | "\n",
|
184 | 191 | "\n",
|
185 | 192 | "@dataclass\n",
|
186 |
| - "class RangeAlmanac(Almanac):\n", |
| 193 | + "class RangeAlmanac(BaseAlmanac[RangeAlmanacMap]):\n", |
| 194 | + " map_cls = RangeAlmanacMap\n", |
| 195 | + " seeds: list[int]\n", |
187 | 196 | " maps: dict[str, RangeAlmanacMap]\n",
|
188 | 197 | "\n",
|
189 |
| - " @classmethod\n", |
190 |
| - " def from_entries(cls, *entries: str) -> t.Self:\n", |
191 |
| - " inst = super().from_entries(*entries)\n", |
192 |
| - " inst.maps = {\n", |
193 |
| - " to_: RangeAlmanacMap(**vars(map_)) for to_, map_ in inst.maps.items()\n", |
194 |
| - " }\n", |
195 |
| - " return inst\n", |
196 |
| - "\n", |
197 |
| - " def __getitem__(self, values: tuple[range, ...]) -> int:\n", |
| 198 | + " def seed_locations(self, values: tuple[range, ...]) -> int:\n", |
198 | 199 | " current = \"seed\"\n",
|
199 | 200 | " while current != \"location\":\n",
|
200 | 201 | " map_ = self.maps[current]\n",
|
201 | 202 | " current = map_.to_\n",
|
202 |
| - " values = map_[values]\n", |
| 203 | + " values = map_.lookup_range(values)\n", |
203 | 204 | " return min(v.start for v in values)\n",
|
204 | 205 | "\n",
|
205 | 206 | " @property\n",
|
|
209 | 210 | "\n",
|
210 | 211 | "\n",
|
211 | 212 | "test_almanac = RangeAlmanac.from_entries(*test_almanac_text.split(\"\\n\\n\"))\n",
|
212 |
| - "assert test_almanac[test_almanac.seed_ranges] == 46" |
| 213 | + "assert test_almanac.seed_locations(test_almanac.seed_ranges) == 46" |
213 | 214 | ]
|
214 | 215 | },
|
215 | 216 | {
|
|
226 | 227 | }
|
227 | 228 | ],
|
228 | 229 | "source": [
|
229 |
| - "almanac = RangeAlmanac.from_entries(*aocd.get_data(day=5, year=2023).split(\"\\n\\n\"))\n", |
230 |
| - "print(\"Part 2:\", almanac[almanac.seed_ranges])" |
| 230 | + "almanac = RangeAlmanac.from_entries(*almanac_entries)\n", |
| 231 | + "print(\"Part 2:\", almanac.seed_locations(almanac.seed_ranges))" |
231 | 232 | ]
|
232 | 233 | }
|
233 | 234 | ],
|
|
247 | 248 | "name": "python",
|
248 | 249 | "nbconvert_exporter": "python",
|
249 | 250 | "pygments_lexer": "ipython3",
|
250 |
| - "version": "3.12.0" |
| 251 | + "version": "3.12.1" |
251 | 252 | }
|
252 | 253 | },
|
253 | 254 | "nbformat": 4,
|
|
0 commit comments