import datetime
import os
import random
import string
import sys
import plotly
import numpy as np
import plotly.express as px
import torch
import torchvision
import PIL
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import v2
= 20240622
SEED = "_autoencoder_state_dict.pt"
TRAINED_NN_PATH = "_loss.pt"
LOSS_DATA_PATH = 50
N_PASSTHROUGHS
# Seed RNGs for reproducibility
np.random.seed(SEED)
random.seed(SEED) torch.manual_seed(SEED)
Playing Visual Telephone with Autoencoders
A convoluted way to distort images with convolution
Alvin Lucier’s I am Sitting in a Room is an interesting contemporary sound art piece that I like. As Lucier describes in its introduction, he plays tape-recorded audio of his voice into an empty room, and records the sound echoed from the walls. Each iteration of the piece plays back the recording from the previous one, and this process repeats until Lucier’s voice distorts into eerie chime-like sounds with pitches relating to the room’s fundamental frequency.
Inspired by this, the concept of model collapse in artificial intelligence, and the game of telephone, I wondered: can I do the same thing with image processing in artificial intelligence? More specifically, how quickly will an AI distort a recognizable image until it becomes unrecognizable? What would this look like at the end? I decided to try seeing how a convolutional autoencoder can corrupt handwritten letters to see the results for myself.
Setup
I’m going to use the PyTorch machine learning library1 to build my autoencoder because I like it more than other machine learning libraries. It’s notoriously bad for reproducibility, though, so the results for this article might change over time.
Before we pass my own handwriting through the autoencoder, we’ll first need to train it on preexisting handwriting data. Let’s use the Extended MNIST database since PyTorch already has utilities for downloading and extracting it. Each element in the EMNIST dataset is a \(28 \times 28\) black-and-white image representing one of 62 different characters: lowercase letters, uppercase letters, and numbers. In this case, let’s focus on the letters.2
We also need to pre-process the dataset images for properly inputting them into the AI. Each pixel in the original image corresponds to an integer in the range \([0, 255]\); we’ll convert them to floating-point numbers and scale them to the range \([0, 1]\) so that the AI can better interpret them. The raw EMNIST dataset data is also transposed, so we’ll have to transpose it back to make it easier to view when plotted.3
= v2.Compose(
emnist_preprocess
[
v2.ToImage(),float, scale=True),
v2.ToDtype(torch.lambda x: torch.transpose(x, 1, 2)),
v2.Lambda(
]
)
= torchvision.datasets.EMNIST(
emnist_train ="_emnist",
root="letters",
split=True,
train=True,
download=emnist_preprocess,
transform
)= torchvision.datasets.EMNIST(
emnist_test ="_emnist",
root="letters",
split=False,
train=True,
download=emnist_preprocess,
transform
)
= DataLoader(emnist_train, batch_size=512, shuffle=False)
train_dataloader = DataLoader(emnist_test, batch_size=512, shuffle=True)
test_dataloader
= string.ascii_uppercase[emnist_train[0][1] - 1]
example_label = px.imshow(
fig 0][0][0, :, :],
emnist_train[={"x": "x", "y": "y", "color": "value"},
labels=f"EMNIST Data Example 0: {example_label}",
title
) fig.show()
Downloading https://biometrics.nist.gov/cs_links/EMNIST/gzip.zip to _emnist/EMNIST/raw/gzip.zip
0%| | 0/561753746 [00:00<?, ?it/s] 0%| | 196608/561753746 [00:00<04:48, 1949350.45it/s] 0%| | 1146880/561753746 [00:00<01:28, 6361699.16it/s] 0%| | 2359296/561753746 [00:00<01:03, 8870858.93it/s] 1%| | 3571712/561753746 [00:00<00:55, 10032870.80it/s] 1%| | 4751360/561753746 [00:00<00:52, 10609171.99it/s] 1%| | 5963776/561753746 [00:00<00:50, 11014641.41it/s] 1%|▏ | 7176192/561753746 [00:00<00:49, 11270584.04it/s] 1%|▏ | 8388608/561753746 [00:00<00:48, 11437104.37it/s] 2%|▏ | 9568256/561753746 [00:00<00:48, 11499775.45it/s] 2%|▏ | 10780672/561753746 [00:01<00:47, 11590500.05it/s] 2%|▏ | 11960320/561753746 [00:01<00:47, 11646945.50it/s] 2%|▏ | 13139968/561753746 [00:01<00:47, 11644246.57it/s] 3%|▎ | 14352384/561753746 [00:01<00:46, 11697156.15it/s] 3%|▎ | 15532032/561753746 [00:01<00:46, 11721754.09it/s] 3%|▎ | 16744448/561753746 [00:01<00:46, 11743604.19it/s] 3%|▎ | 17924096/561753746 [00:01<00:46, 11711664.13it/s] 3%|▎ | 19103744/561753746 [00:01<00:46, 11731855.93it/s] 4%|▎ | 20316160/561753746 [00:01<00:46, 11757508.83it/s] 4%|▍ | 21495808/561753746 [00:01<00:46, 11714659.57it/s] 4%|▍ | 22675456/561753746 [00:02<00:45, 11734129.22it/s] 4%|▍ | 23887872/561753746 [00:02<00:45, 11758907.39it/s] 4%|▍ | 25067520/561753746 [00:02<00:45, 11764883.89it/s] 5%|▍ | 26247168/561753746 [00:02<00:45, 11726631.87it/s] 5%|▍ | 27459584/561753746 [00:02<00:45, 11746783.39it/s] 5%|▌ | 28639232/561753746 [00:02<00:45, 11756153.54it/s] 5%|▌ | 29818880/561753746 [00:02<00:45, 11721136.75it/s] 6%|▌ | 31031296/561753746 [00:02<00:45, 11750001.19it/s] 6%|▌ | 32210944/561753746 [00:02<00:45, 11758533.51it/s] 6%|▌ | 33423360/561753746 [00:02<00:44, 11769239.33it/s] 6%|▌ | 34603008/561753746 [00:03<00:44, 11729897.74it/s] 6%|▋ | 35782656/561753746 [00:03<00:44, 11744206.07it/s] 7%|▋ | 36995072/561753746 [00:03<00:44, 11759361.65it/s] 7%|▋ | 38174720/561753746 [00:03<00:44, 11722711.90it/s] 7%|▋ | 39387136/561753746 [00:03<00:44, 11744344.32it/s] 7%|▋ | 40599552/561753746 [00:03<00:44, 11758710.96it/s] 7%|▋ | 41811968/561753746 [00:03<00:44, 11776374.46it/s] 8%|▊ | 42991616/561753746 [00:03<00:44, 11728154.45it/s] 8%|▊ | 44204032/561753746 [00:03<00:44, 11747789.20it/s] 8%|▊ | 45416448/561753746 [00:03<00:43, 11761304.11it/s] 8%|▊ | 46596096/561753746 [00:04<00:43, 11724910.83it/s] 9%|▊ | 47808512/561753746 [00:04<00:43, 11745255.97it/s] 9%|▊ | 49020928/561753746 [00:04<00:43, 11759317.55it/s] 9%|▉ | 50233344/561753746 [00:04<00:43, 11769588.29it/s] 9%|▉ | 51412992/561753746 [00:04<00:43, 11730447.99it/s] 9%|▉ | 52625408/561753746 [00:04<00:43, 11749392.36it/s] 10%|▉ | 53837824/561753746 [00:04<00:43, 11762638.86it/s] 10%|▉ | 55017472/561753746 [00:04<00:43, 11725558.61it/s] 10%|█ | 56197120/561753746 [00:04<00:43, 11741672.42it/s] 10%|█ | 57409536/561753746 [00:04<00:42, 11764247.78it/s] 10%|█ | 58589184/561753746 [00:05<00:42, 11769009.70it/s] 11%|█ | 59768832/561753746 [00:05<00:42, 11728934.28it/s] 11%|█ | 60981248/561753746 [00:05<00:42, 11748918.49it/s] 11%|█ | 62160896/561753746 [00:05<00:42, 11757888.95it/s] 11%|█▏ | 63340544/561753746 [00:05<00:42, 11721798.56it/s] 11%|█▏ | 64552960/561753746 [00:05<00:42, 11743786.15it/s] 12%|█▏ | 65765376/561753746 [00:05<00:42, 11758478.75it/s] 12%|█▏ | 66977792/561753746 [00:05<00:42, 11726688.84it/s] 12%|█▏ | 68157440/561753746 [00:05<00:42, 11742678.28it/s] 12%|█▏ | 69369856/561753746 [00:06<00:41, 11757863.73it/s] 13%|█▎ | 70582272/561753746 [00:06<00:41, 11768018.11it/s] 13%|█▎ | 71761920/561753746 [00:06<00:41, 11729512.14it/s] 13%|█▎ | 72974336/561753746 [00:06<00:41, 11748675.51it/s] 13%|█▎ | 74153984/561753746 [00:06<00:41, 11757802.92it/s] 13%|█▎ | 75333632/561753746 [00:06<00:41, 11721890.62it/s] 14%|█▎ | 76546048/561753746 [00:06<00:41, 11750814.31it/s] 14%|█▍ | 77725696/561753746 [00:06<00:41, 11759010.42it/s] 14%|█▍ | 78938112/561753746 [00:06<00:41, 11769759.52it/s] 14%|█▍ | 80117760/561753746 [00:06<00:41, 11730173.92it/s] 14%|█▍ | 81330176/561753746 [00:07<00:40, 11749206.79it/s] 15%|█▍ | 82542592/561753746 [00:07<00:40, 11762223.26it/s] 15%|█▍ | 83722240/561753746 [00:07<00:40, 11725462.26it/s] 15%|█▌ | 84934656/561753746 [00:07<00:40, 11745727.23it/s] 15%|█▌ | 86114304/561753746 [00:07<00:40, 11755639.09it/s] 16%|█▌ | 87326720/561753746 [00:07<00:40, 11774054.86it/s] 16%|█▌ | 88506368/561753746 [00:07<00:40, 11733349.16it/s] 16%|█▌ | 89686016/561753746 [00:07<00:40, 11746560.34it/s] 16%|█▌ | 90898432/561753746 [00:07<00:40, 11760949.73it/s] 16%|█▋ | 92078080/561753746 [00:07<00:40, 11724198.49it/s] 17%|█▋ | 93257728/561753746 [00:08<00:39, 11741051.59it/s] 17%|█▋ | 94470144/561753746 [00:08<00:39, 11756112.47it/s] 17%|█▋ | 95682560/561753746 [00:08<00:39, 11767285.52it/s] 17%|█▋ | 96862208/561753746 [00:08<00:39, 11729150.45it/s] 17%|█▋ | 98074624/561753746 [00:08<00:39, 11748431.69it/s] 18%|█▊ | 99287040/561753746 [00:08<00:39, 11761960.16it/s] 18%|█▊ | 100466688/561753746 [00:08<00:39, 11722961.64it/s] 18%|█▊ | 101679104/561753746 [00:08<00:39, 11746344.94it/s] 18%|█▊ | 102891520/561753746 [00:08<00:39, 11759798.35it/s] 19%|█▊ | 104103936/561753746 [00:08<00:38, 11770116.81it/s] 19%|█▊ | 105283584/561753746 [00:09<00:38, 11730628.79it/s] 19%|█▉ | 106496000/561753746 [00:09<00:38, 11749620.31it/s] 19%|█▉ | 107675648/561753746 [00:09<00:38, 11758364.11it/s] 19%|█▉ | 108855296/561753746 [00:09<00:38, 11722199.61it/s] 20%|█▉ | 110067712/561753746 [00:09<00:38, 11750213.22it/s] 20%|█▉ | 111247360/561753746 [00:09<00:38, 11759365.24it/s] 20%|██ | 112459776/561753746 [00:09<00:38, 11769716.01it/s] 20%|██ | 113639424/561753746 [00:09<00:38, 11730238.52it/s] 20%|██ | 114819072/561753746 [00:09<00:38, 11744869.21it/s] 21%|██ | 116031488/561753746 [00:09<00:37, 11759081.99it/s] 21%|██ | 117211136/561753746 [00:10<00:37, 11722923.12it/s] 21%|██ | 118423552/561753746 [00:10<00:37, 11751346.66it/s] 21%|██▏ | 119603200/561753746 [00:10<00:37, 11759864.78it/s] 22%|██▏ | 120815616/561753746 [00:10<00:37, 11769833.05it/s] 22%|██▏ | 121995264/561753746 [00:10<00:37, 11730650.18it/s] 22%|██▏ | 123174912/561753746 [00:10<00:37, 11745117.58it/s] 22%|██▏ | 124387328/561753746 [00:10<00:37, 11766538.12it/s] 22%|██▏ | 125566976/561753746 [00:10<00:37, 11720700.58it/s] 23%|██▎ | 126779392/561753746 [00:10<00:37, 11742769.70it/s] 23%|██▎ | 127991808/561753746 [00:10<00:36, 11758068.86it/s] 23%|██▎ | 129204224/561753746 [00:11<00:36, 11768643.36it/s] 23%|██▎ | 130383872/561753746 [00:11<00:36, 11729785.67it/s] 23%|██▎ | 131596288/561753746 [00:11<00:36, 11748798.19it/s] 24%|██▎ | 132808704/561753746 [00:11<00:36, 11762050.77it/s] 24%|██▍ | 133988352/561753746 [00:11<00:36, 11724958.11it/s] 24%|██▍ | 135168000/561753746 [00:11<00:36, 11741493.92it/s] 24%|██▍ | 136380416/561753746 [00:11<00:36, 11763516.59it/s] 24%|██▍ | 137560064/561753746 [00:11<00:36, 11768670.20it/s] 25%|██▍ | 138739712/561753746 [00:11<00:36, 11729305.73it/s] 25%|██▍ | 139952128/561753746 [00:12<00:35, 11748948.15it/s] 25%|██▌ | 141131776/561753746 [00:12<00:35, 11757785.76it/s] 25%|██▌ | 142311424/561753746 [00:12<00:35, 11721824.55it/s] 26%|██▌ | 143523840/561753746 [00:12<00:35, 11743326.11it/s] 26%|██▌ | 144736256/561753746 [00:12<00:35, 11758083.43it/s] 26%|██▌ | 145948672/561753746 [00:12<00:35, 11769009.40it/s] 26%|██▌ | 147128320/561753746 [00:12<00:35, 11729768.63it/s] 26%|██▋ | 148340736/561753746 [00:12<00:35, 11749773.36it/s] 27%|██▋ | 149553152/561753746 [00:12<00:35, 11764068.42it/s] 27%|██▋ | 150732800/561753746 [00:12<00:35, 11726213.78it/s] 27%|██▋ | 151912448/561753746 [00:13<00:34, 11745990.64it/s] 27%|██▋ | 153092096/561753746 [00:13<00:34, 11757032.47it/s] 27%|██▋ | 154304512/561753746 [00:13<00:34, 11775135.64it/s] 28%|██▊ | 155484160/561753746 [00:13<00:34, 11726999.11it/s] 28%|██▊ | 156663808/561753746 [00:13<00:34, 11742298.43it/s] 28%|██▊ | 157876224/561753746 [00:13<00:34, 11765232.64it/s] 28%|██▊ | 159055872/561753746 [00:13<00:34, 11726561.97it/s] 29%|██▊ | 160235520/561753746 [00:13<00:34, 11742409.81it/s] 29%|██▊ | 161447936/561753746 [00:13<00:34, 11758191.59it/s] 29%|██▉ | 162660352/561753746 [00:13<00:33, 11768426.78it/s] 29%|██▉ | 163840000/561753746 [00:14<00:33, 11729857.82it/s] 29%|██▉ | 165052416/561753746 [00:14<00:33, 11748649.12it/s] 30%|██▉ | 166232064/561753746 [00:14<00:33, 11757481.74it/s] 30%|██▉ | 167411712/561753746 [00:14<00:33, 11722095.45it/s] 30%|███ | 168624128/561753746 [00:14<00:33, 11743591.92it/s] 30%|███ | 169836544/561753746 [00:14<00:33, 11757879.80it/s] 30%|███ | 171048960/561753746 [00:14<00:33, 11768598.29it/s] 31%|███ | 172228608/561753746 [00:14<00:33, 11730193.70it/s] 31%|███ | 173408256/561753746 [00:14<00:33, 11744634.02it/s] 31%|███ | 174620672/561753746 [00:14<00:32, 11766136.49it/s] 31%|███▏ | 175800320/561753746 [00:15<00:32, 11720948.13it/s] 32%|███▏ | 177012736/561753746 [00:15<00:32, 11749642.46it/s] 32%|███▏ | 178192384/561753746 [00:15<00:32, 11758551.32it/s] 32%|███▏ | 179372032/561753746 [00:15<00:54, 7018477.81it/s] 33%|███▎ | 182943744/561753746 [00:15<00:29, 12749694.20it/s] 33%|███▎ | 184647680/561753746 [00:15<00:35, 10708176.50it/s] 33%|███▎ | 186777600/561753746 [00:16<00:29, 12795652.04it/s] 34%|███▎ | 188416000/561753746 [00:16<00:29, 12477463.42it/s] 34%|███▍ | 189923328/561753746 [00:16<00:30, 12287868.45it/s] 34%|███▍ | 191332352/561753746 [00:16<00:30, 12158918.20it/s] 34%|███▍ | 192675840/561753746 [00:16<00:30, 12023322.04it/s] 35%|███▍ | 193986560/561753746 [00:16<00:30, 11959656.53it/s] 35%|███▍ | 195264512/561753746 [00:16<00:30, 11914978.17it/s] 35%|███▍ | 196509696/561753746 [00:16<00:30, 11840743.07it/s] 35%|███▌ | 197722112/561753746 [00:16<00:30, 11827437.96it/s] 35%|███▌ | 198934528/561753746 [00:17<00:30, 11818061.98it/s] 36%|███▌ | 200146944/561753746 [00:17<00:30, 11810930.08it/s] 36%|███▌ | 201359360/561753746 [00:17<00:30, 11758398.50it/s] 36%|███▌ | 202571776/561753746 [00:17<00:30, 11768623.60it/s] 36%|███▋ | 203784192/561753746 [00:17<00:30, 11775313.79it/s] 36%|███▋ | 204996608/561753746 [00:17<00:30, 11732952.82it/s] 37%|███▋ | 206209024/561753746 [00:17<00:30, 11750532.50it/s] 37%|███▋ | 207421440/561753746 [00:17<00:30, 11763333.43it/s] 37%|███▋ | 208633856/561753746 [00:17<00:29, 11779011.91it/s] 37%|███▋ | 209813504/561753746 [00:17<00:30, 11730471.32it/s] 38%|███▊ | 211025920/561753746 [00:18<00:29, 11749515.35it/s] 38%|███▊ | 212238336/561753746 [00:18<00:29, 11762296.40it/s] 38%|███▊ | 213417984/561753746 [00:18<00:29, 11725233.31it/s] 38%|███▊ | 214597632/561753746 [00:18<00:29, 11741578.20it/s] 38%|███▊ | 215810048/561753746 [00:18<00:29, 11763640.93it/s] 39%|███▊ | 216989696/561753746 [00:18<00:29, 11768828.28it/s] 39%|███▉ | 218169344/561753746 [00:18<00:29, 11729438.32it/s] 39%|███▉ | 219381760/561753746 [00:18<00:29, 11749107.24it/s] 39%|███▉ | 220594176/561753746 [00:18<00:29, 11761412.22it/s] 39%|███▉ | 221773824/561753746 [00:19<00:29, 11718191.52it/s] 40%|███▉ | 222986240/561753746 [00:19<00:28, 11747691.41it/s] 40%|███▉ | 224165888/561753746 [00:19<00:28, 11756679.83it/s] 40%|████ | 225378304/561753746 [00:19<00:28, 11768606.87it/s] 40%|████ | 226557952/561753746 [00:19<00:28, 11729303.02it/s] 41%|████ | 227737600/561753746 [00:19<00:28, 11744343.79it/s] 41%|████ | 228950016/561753746 [00:19<00:28, 11766196.71it/s] 41%|████ | 230129664/561753746 [00:19<00:28, 11720845.34it/s] 41%|████ | 231309312/561753746 [00:19<00:34, 9455845.02it/s] 42%|████▏ | 233472000/561753746 [00:20<00:26, 12492143.60it/s] 42%|████▏ | 234848256/561753746 [00:20<00:26, 12245235.19it/s] 42%|████▏ | 236158976/561753746 [00:20<00:26, 12114215.95it/s] 42%|████▏ | 237436928/561753746 [00:20<00:26, 12023228.32it/s] 42%|████▏ | 238682112/561753746 [00:20<00:27, 11915516.84it/s] 43%|████▎ | 239927296/561753746 [00:20<00:27, 11877552.59it/s] 43%|████▎ | 241139712/561753746 [00:20<00:27, 11856691.38it/s] 43%|████▎ | 242352128/561753746 [00:20<00:26, 11835041.01it/s] 43%|████▎ | 243564544/561753746 [00:20<00:27, 11774741.19it/s] 44%|████▎ | 244776960/561753746 [00:20<00:26, 11787412.42it/s] 44%|████▍ | 245989376/561753746 [00:21<00:26, 11788803.29it/s] 44%|████▍ | 247201792/561753746 [00:21<00:26, 11741582.13it/s] 44%|████▍ | 248381440/561753746 [00:21<00:26, 11753044.45it/s] 44%|████▍ | 249593856/561753746 [00:21<00:26, 11771843.86it/s] 45%|████▍ | 250773504/561753746 [00:21<00:26, 11774527.00it/s] 45%|████▍ | 251953152/561753746 [00:21<00:26, 11733456.18it/s] 45%|████▌ | 253165568/561753746 [00:21<00:26, 11751317.80it/s] 45%|████▌ | 254345216/561753746 [00:21<00:26, 11759631.69it/s] 45%|████▌ | 255524864/561753746 [00:21<00:26, 11723276.95it/s] 46%|████▌ | 256737280/561753746 [00:21<00:25, 11751334.49it/s] 46%|████▌ | 257916928/561753746 [00:22<00:25, 11759404.04it/s] 46%|████▌ | 259129344/561753746 [00:22<00:25, 11770417.00it/s] 46%|████▋ | 260308992/561753746 [00:22<00:25, 11730329.96it/s] 47%|████▋ | 261488640/561753746 [00:22<00:25, 11744894.74it/s] 47%|████▋ | 262701056/561753746 [00:22<00:25, 11759382.85it/s] 47%|████▋ | 263880704/561753746 [00:22<00:25, 11723072.28it/s] 47%|████▋ | 265093120/561753746 [00:22<00:25, 11744304.81it/s] 47%|████▋ | 266305536/561753746 [00:22<00:25, 11759463.74it/s] 48%|████▊ | 267517952/561753746 [00:22<00:25, 11720345.07it/s] 48%|████▊ | 268730368/561753746 [00:23<00:24, 11749026.42it/s] 48%|████▊ | 269910016/561753746 [00:23<00:24, 11757955.33it/s] 48%|████▊ | 271122432/561753746 [00:23<00:24, 11768631.57it/s] 48%|████▊ | 272302080/561753746 [00:23<00:24, 11730043.35it/s] 49%|████▊ | 273481728/561753746 [00:23<00:24, 11744768.97it/s] 49%|████▉ | 274694144/561753746 [00:23<00:24, 11766289.70it/s] 49%|████▉ | 275873792/561753746 [00:23<00:24, 11727713.49it/s] 49%|████▉ | 277053440/561753746 [00:23<00:24, 11741792.90it/s] 50%|████▉ | 278265856/561753746 [00:23<00:24, 11758584.88it/s] 50%|████▉ | 279445504/561753746 [00:23<00:23, 11764984.07it/s] 50%|████▉ | 280625152/561753746 [00:24<00:23, 11727007.84it/s] 50%|█████ | 281837568/561753746 [00:24<00:23, 11747003.77it/s] 50%|█████ | 283017216/561753746 [00:24<00:23, 11756355.53it/s] 51%|█████ | 284229632/561753746 [00:24<00:23, 11725552.57it/s] 51%|█████ | 285442048/561753746 [00:24<00:23, 11745973.07it/s] 51%|█████ | 286621696/561753746 [00:24<00:23, 11755517.95it/s] 51%|█████ | 287801344/561753746 [00:24<00:27, 10014103.49it/s] 52%|█████▏ | 289669120/561753746 [00:24<00:22, 12289139.58it/s] 52%|█████▏ | 290979840/561753746 [00:24<00:22, 12146996.61it/s] 52%|█████▏ | 292257792/561753746 [00:25<00:22, 12037896.77it/s] 52%|█████▏ | 293502976/561753746 [00:25<00:22, 11924619.66it/s] 52%|█████▏ | 294715392/561753746 [00:25<00:22, 11886770.62it/s] 53%|█████▎ | 295927808/561753746 [00:25<00:22, 11859952.06it/s] 53%|█████▎ | 297140224/561753746 [00:25<00:22, 11799397.54it/s] 53%|█████▎ | 298352640/561753746 [00:25<00:22, 11790687.07it/s] 53%|█████▎ | 299565056/561753746 [00:25<00:22, 11798458.00it/s] 54%|█████▎ | 300777472/561753746 [00:25<00:22, 11789753.63it/s] 54%|█████▍ | 301989888/561753746 [00:25<00:22, 11749293.32it/s] 54%|█████▍ | 303169536/561753746 [00:25<00:21, 11758115.78it/s] 54%|█████▍ | 304381952/561753746 [00:26<00:21, 11768513.88it/s] 54%|█████▍ | 305561600/561753746 [00:26<00:21, 11729723.13it/s] 55%|█████▍ | 306774016/561753746 [00:26<00:21, 11749030.99it/s] 55%|█████▍ | 307986432/561753746 [00:26<00:21, 11762459.17it/s] 55%|█████▌ | 309198848/561753746 [00:26<00:21, 11771450.56it/s] 55%|█████▌ | 310378496/561753746 [00:26<00:21, 11731973.91it/s] 55%|█████▌ | 311590912/561753746 [00:26<00:21, 11750123.89it/s] 56%|█████▌ | 312770560/561753746 [00:26<00:21, 11758831.31it/s] 56%|█████▌ | 313950208/561753746 [00:26<00:21, 11722808.91it/s] 56%|█████▌ | 315162624/561753746 [00:26<00:20, 11750922.35it/s] 56%|█████▋ | 316342272/561753746 [00:27<00:20, 11759382.15it/s] 57%|█████▋ | 317554688/561753746 [00:27<00:20, 11769966.56it/s] 57%|█████▋ | 318734336/561753746 [00:27<00:20, 11732388.86it/s] 57%|█████▋ | 319946752/561753746 [00:27<00:20, 11750504.98it/s] 57%|█████▋ | 321159168/561753746 [00:27<00:20, 11763318.98it/s] 57%|█████▋ | 322338816/561753746 [00:27<00:20, 11724002.69it/s] 58%|█████▊ | 323551232/561753746 [00:27<00:20, 11744828.59it/s] 58%|█████▊ | 324730880/561753746 [00:27<00:20, 11754972.22it/s] 58%|█████▊ | 325943296/561753746 [00:27<00:20, 11773621.35it/s] 58%|█████▊ | 327122944/561753746 [00:27<00:19, 11732807.67it/s] 58%|█████▊ | 328302592/561753746 [00:28<00:19, 11746386.35it/s] 59%|█████▊ | 329515008/561753746 [00:28<00:19, 11761365.75it/s] 59%|█████▉ | 330694656/561753746 [00:28<00:19, 11723928.14it/s] 59%|█████▉ | 331874304/561753746 [00:28<00:19, 11740744.96it/s] 59%|█████▉ | 333086720/561753746 [00:28<00:19, 11756609.86it/s] 60%|█████▉ | 334299136/561753746 [00:28<00:19, 11767575.07it/s] 60%|█████▉ | 335478784/561753746 [00:28<00:19, 11727803.94it/s] 60%|█████▉ | 336691200/561753746 [00:28<00:19, 11748798.71it/s] 60%|██████ | 337903616/561753746 [00:28<00:19, 11760621.89it/s] 60%|██████ | 339083264/561753746 [00:28<00:18, 11725164.52it/s] 61%|██████ | 340295680/561753746 [00:29<00:18, 11746958.21it/s] 61%|██████ | 341508096/561753746 [00:29<00:18, 11759266.62it/s] 61%|██████ | 342720512/561753746 [00:29<00:18, 11727534.73it/s] 61%|██████ | 343900160/561753746 [00:29<00:18, 11744153.36it/s] 61%|██████▏ | 345112576/561753746 [00:29<00:18, 11758350.70it/s] 62%|██████▏ | 346292224/561753746 [00:29<00:18, 11765959.27it/s] 62%|██████▏ | 347471872/561753746 [00:29<00:18, 11725406.58it/s] 62%|██████▏ | 348684288/561753746 [00:29<00:18, 11754016.06it/s] 62%|██████▏ | 349863936/561753746 [00:29<00:18, 11760200.58it/s] 62%|██████▏ | 351076352/561753746 [00:30<00:17, 11721354.04it/s] 63%|██████▎ | 352288768/561753746 [00:30<00:17, 11745940.39it/s] 63%|██████▎ | 353468416/561753746 [00:30<00:17, 11755300.00it/s] 63%|██████▎ | 354680832/561753746 [00:30<00:17, 11772483.77it/s] 63%|██████▎ | 355860480/561753746 [00:30<00:17, 11732312.27it/s] 64%|██████▎ | 357040128/561753746 [00:30<00:17, 11746422.72it/s] 64%|██████▍ | 358252544/561753746 [00:30<00:17, 11760379.04it/s] 64%|██████▍ | 359432192/561753746 [00:30<00:17, 11723591.46it/s] 64%|██████▍ | 360644608/561753746 [00:30<00:17, 11744543.66it/s] 64%|██████▍ | 361824256/561753746 [00:30<00:17, 11755435.80it/s] 65%|██████▍ | 363036672/561753746 [00:31<00:16, 11773024.85it/s] 65%|██████▍ | 364216320/561753746 [00:31<00:16, 11733206.47it/s] 65%|██████▌ | 365395968/561753746 [00:31<00:16, 11746857.89it/s] 65%|██████▌ | 366608384/561753746 [00:31<00:16, 11760790.94it/s] 65%|██████▌ | 367788032/561753746 [00:31<00:16, 11725315.32it/s] 66%|██████▌ | 368967680/561753746 [00:31<00:16, 11739649.57it/s] 66%|██████▌ | 370180096/561753746 [00:31<00:16, 11764862.57it/s] 66%|██████▌ | 371359744/561753746 [00:31<00:16, 11766214.93it/s] 66%|██████▋ | 372539392/561753746 [00:31<00:16, 11729465.53it/s] 67%|██████▋ | 373751808/561753746 [00:31<00:16, 11748548.57it/s] 67%|██████▋ | 374964224/561753746 [00:32<00:15, 11761775.24it/s] 67%|██████▋ | 376143872/561753746 [00:32<00:15, 11724986.11it/s] 67%|██████▋ | 377323520/561753746 [00:32<00:15, 11742774.74it/s] 67%|██████▋ | 378535936/561753746 [00:32<00:15, 11765287.08it/s] 68%|██████▊ | 379715584/561753746 [00:32<00:15, 11769264.09it/s] 68%|██████▊ | 380895232/561753746 [00:32<00:15, 11728176.20it/s] 68%|██████▊ | 382074880/561753746 [00:32<00:15, 11743351.64it/s] 68%|██████▊ | 383287296/561753746 [00:32<00:15, 11758573.95it/s] 68%|██████▊ | 384466944/561753746 [00:32<00:15, 11764884.33it/s] 69%|██████▊ | 385646592/561753746 [00:32<00:15, 11726176.29it/s] 69%|██████▉ | 386859008/561753746 [00:33<00:14, 11746782.54it/s] 69%|██████▉ | 388071424/561753746 [00:33<00:14, 11760512.89it/s] 69%|██████▉ | 389251072/561753746 [00:33<00:14, 11724013.24it/s] 70%|██████▉ | 390430720/561753746 [00:33<00:14, 11742257.76it/s] 70%|██████▉ | 391643136/561753746 [00:33<00:14, 11756187.85it/s] 70%|██████▉ | 392855552/561753746 [00:33<00:14, 11717705.73it/s] 70%|███████ | 394067968/561753746 [00:33<00:14, 11748198.12it/s] 70%|███████ | 395247616/561753746 [00:33<00:14, 11757203.39it/s] 71%|███████ | 396460032/561753746 [00:33<00:14, 11775260.83it/s] 71%|███████ | 397639680/561753746 [00:33<00:13, 11727317.95it/s] 71%|███████ | 398852096/561753746 [00:34<00:13, 11746663.09it/s] 71%|███████ | 400064512/561753746 [00:34<00:13, 11760635.81it/s] 71%|███████▏ | 401244160/561753746 [00:34<00:13, 11724608.25it/s] 72%|███████▏ | 402456576/561753746 [00:34<00:13, 11743969.59it/s] 72%|███████▏ | 403668992/561753746 [00:34<00:13, 11759957.14it/s] 72%|███████▏ | 404881408/561753746 [00:34<00:13, 11770292.71it/s] 72%|███████▏ | 406061056/561753746 [00:34<00:13, 11729826.68it/s] 72%|███████▏ | 407240704/561753746 [00:34<00:13, 11745179.11it/s] 73%|███████▎ | 408453120/561753746 [00:34<00:13, 11766948.26it/s] 73%|███████▎ | 409632768/561753746 [00:34<00:12, 11727817.09it/s] 73%|███████▎ | 410812416/561753746 [00:35<00:12, 11743715.87it/s] 73%|███████▎ | 412024832/561753746 [00:35<00:12, 11758646.75it/s] 74%|███████▎ | 413204480/561753746 [00:35<00:12, 11764776.49it/s] 74%|███████▍ | 414384128/561753746 [00:35<00:12, 11726539.15it/s] 74%|███████▍ | 415596544/561753746 [00:35<00:12, 11746459.19it/s] 74%|███████▍ | 416808960/561753746 [00:35<00:12, 11760560.20it/s] 74%|███████▍ | 417988608/561753746 [00:35<00:12, 11719837.21it/s] 75%|███████▍ | 419168256/561753746 [00:35<00:12, 11734117.54it/s] 75%|███████▍ | 420347904/561753746 [00:35<00:12, 11712254.10it/s] 75%|███████▌ | 421527552/561753746 [00:36<00:11, 11731224.11it/s] 75%|███████▌ | 422707200/561753746 [00:36<00:11, 11704756.01it/s] 75%|███████▌ | 423919616/561753746 [00:36<00:11, 11726140.93it/s] 76%|███████▌ | 425099264/561753746 [00:36<00:11, 11746872.79it/s] 76%|███████▌ | 426278912/561753746 [00:36<00:11, 11713724.87it/s] 76%|███████▌ | 427491328/561753746 [00:36<00:11, 11745379.47it/s] 76%|███████▋ | 428670976/561753746 [00:36<00:11, 11754870.62it/s] 77%|███████▋ | 429883392/561753746 [00:36<00:11, 11767500.22it/s] 77%|███████▋ | 431063040/561753746 [00:36<00:11, 11728607.56it/s] 77%|███████▋ | 432242688/561753746 [00:36<00:11, 11743333.74it/s] 77%|███████▋ | 433422336/561753746 [00:37<00:10, 11732929.76it/s] 77%|███████▋ | 434601984/561753746 [00:37<00:10, 11739948.18it/s] 78%|███████▊ | 435781632/561753746 [00:37<00:10, 11709299.13it/s] 78%|███████▊ | 436961280/561753746 [00:37<00:10, 11730203.14it/s] 78%|███████▊ | 438173696/561753746 [00:37<00:10, 11749075.55it/s] 78%|███████▊ | 439353344/561753746 [00:37<00:10, 11715272.17it/s] 78%|███████▊ | 440532992/561753746 [00:37<00:10, 11734701.91it/s] 79%|███████▊ | 441712640/561753746 [00:37<00:10, 11749980.76it/s] 79%|███████▉ | 442892288/561753746 [00:37<00:10, 11749640.42it/s] 79%|███████▉ | 444071936/561753746 [00:37<00:10, 11694478.82it/s] 79%|███████▉ | 445251584/561753746 [00:38<00:09, 11711547.89it/s] 79%|███████▉ | 446431232/561753746 [00:38<00:09, 11722469.91it/s] 80%|███████▉ | 447610880/561753746 [00:38<00:09, 11692530.33it/s] 80%|███████▉ | 448790528/561753746 [00:38<00:09, 11697804.82it/s] 80%|████████ | 449970176/561753746 [00:38<00:09, 11699476.00it/s] 80%|████████ | 451182592/561753746 [00:38<00:09, 11735105.22it/s] 81%|████████ | 452362240/561753746 [00:38<00:09, 11665194.96it/s] 81%|████████ | 453541888/561753746 [00:38<00:09, 11691091.75it/s] 81%|████████ | 454721536/561753746 [00:38<00:09, 11692945.63it/s] 81%|████████ | 455901184/561753746 [00:38<00:09, 11680039.08it/s] 81%|████████▏ | 457080832/561753746 [00:39<00:08, 11697790.12it/s] 82%|████████▏ | 458260480/561753746 [00:39<00:08, 11721185.63it/s] 82%|████████▏ | 459472896/561753746 [00:39<00:08, 11747857.25it/s] 82%|████████▏ | 460652544/561753746 [00:39<00:08, 11715370.25it/s] 82%|████████▏ | 461832192/561753746 [00:39<00:08, 11733902.58it/s] 82%|████████▏ | 463044608/561753746 [00:39<00:08, 11753159.18it/s] 83%|████████▎ | 464224256/561753746 [00:39<00:08, 11718548.88it/s] 83%|████████▎ | 465403904/561753746 [00:39<00:08, 11735802.54it/s] 83%|████████▎ | 466616320/561753746 [00:39<00:08, 11761077.05it/s] 83%|████████▎ | 467795968/561753746 [00:39<00:07, 11766326.21it/s] 83%|████████▎ | 468975616/561753746 [00:40<00:07, 11727630.15it/s] 84%|████████▎ | 470155264/561753746 [00:40<00:07, 11741329.83it/s] 84%|████████▍ | 471367680/561753746 [00:40<00:07, 11766686.07it/s] 84%|████████▍ | 472547328/561753746 [00:40<00:07, 11727804.56it/s] 84%|████████▍ | 473726976/561753746 [00:40<00:07, 11740154.39it/s] 85%|████████▍ | 474939392/561753746 [00:40<00:07, 11758797.45it/s] 85%|████████▍ | 476119040/561753746 [00:40<00:07, 11765355.81it/s] 85%|████████▍ | 477298688/561753746 [00:40<00:07, 11726811.68it/s] 85%|████████▌ | 478511104/561753746 [00:40<00:07, 11754474.67it/s] 85%|████████▌ | 479690752/561753746 [00:40<00:06, 11763993.46it/s] 86%|████████▌ | 480870400/561753746 [00:41<00:06, 11712476.32it/s] 86%|████████▌ | 482082816/561753746 [00:41<00:06, 11740520.54it/s] 86%|████████▌ | 483295232/561753746 [00:41<00:06, 11757709.45it/s] 86%|████████▌ | 484507648/561753746 [00:41<00:06, 11767606.21it/s] 86%|████████▋ | 485687296/561753746 [00:41<00:06, 11713349.42it/s] 87%|████████▋ | 486899712/561753746 [00:41<00:06, 11738341.19it/s] 87%|████████▋ | 488079360/561753746 [00:41<00:06, 11638064.61it/s] 87%|████████▋ | 489259008/561753746 [00:41<00:06, 11557248.94it/s] 87%|████████▋ | 490438656/561753746 [00:41<00:06, 11515033.62it/s] 88%|████████▊ | 491618304/561753746 [00:42<00:06, 11203439.23it/s] 88%|████████▊ | 492765184/561753746 [00:42<00:06, 11046919.51it/s] 88%|████████▊ | 493879296/561753746 [00:42<00:06, 10741369.34it/s] 88%|████████▊ | 495058944/561753746 [00:42<00:06, 11037734.19it/s] 88%|████████▊ | 496271360/561753746 [00:42<00:05, 11260536.73it/s] 89%|████████▊ | 497451008/561753746 [00:42<00:05, 11364239.70it/s] 89%|████████▉ | 498630656/561753746 [00:42<00:05, 11485817.27it/s] 89%|████████▉ | 499810304/561753746 [00:42<00:05, 11570301.76it/s] 89%|████████▉ | 501022720/561753746 [00:42<00:05, 11637572.42it/s] 89%|████████▉ | 502202368/561753746 [00:42<00:05, 11638949.40it/s] 90%|████████▉ | 503414784/561753746 [00:43<00:04, 11686317.53it/s] 90%|████████▉ | 504594432/561753746 [00:43<00:04, 11713460.03it/s] 90%|█████████ | 505774080/561753746 [00:43<00:04, 11690146.02it/s] 90%|█████████ | 506986496/561753746 [00:43<00:04, 11722497.98it/s] 90%|█████████ | 508198912/561753746 [00:43<00:04, 11750720.09it/s] 91%|█████████ | 509378560/561753746 [00:43<00:04, 11759212.42it/s] 91%|█████████ | 510558208/561753746 [00:43<00:04, 11722813.29it/s] 91%|█████████ | 511770624/561753746 [00:43<00:04, 11741997.11it/s] 91%|█████████▏| 512950272/561753746 [00:43<00:04, 11755152.79it/s] 92%|█████████▏| 514129920/561753746 [00:43<00:04, 11720160.82it/s] 92%|█████████▏| 515342336/561753746 [00:44<00:03, 11742449.74it/s] 92%|█████████▏| 516554752/561753746 [00:44<00:03, 11757535.30it/s] 92%|█████████▏| 517767168/561753746 [00:44<00:03, 11767144.46it/s] 92%|█████████▏| 518946816/561753746 [00:44<00:03, 11729600.74it/s] 93%|█████████▎| 520159232/561753746 [00:44<00:03, 11748529.23it/s] 93%|█████████▎| 521371648/561753746 [00:44<00:03, 11762006.27it/s] 93%|█████████▎| 522551296/561753746 [00:44<00:03, 11725354.34it/s] 93%|█████████▎| 523763712/561753746 [00:44<00:03, 11745659.36it/s] 93%|█████████▎| 524943360/561753746 [00:44<00:03, 11755833.51it/s] 94%|█████████▎| 526155776/561753746 [00:44<00:03, 11766497.32it/s] 94%|█████████▍| 527335424/561753746 [00:45<00:02, 11726554.82it/s] 94%|█████████▍| 528515072/561753746 [00:45<00:02, 11746664.81it/s] 94%|█████████▍| 529694720/561753746 [00:45<00:02, 11761429.34it/s] 95%|█████████▍| 530874368/561753746 [00:45<00:02, 11722746.83it/s] 95%|█████████▍| 532054016/561753746 [00:45<00:02, 11734077.38it/s] 95%|█████████▍| 533233664/561753746 [00:45<00:02, 11747397.23it/s] 95%|█████████▌| 534446080/561753746 [00:45<00:02, 11761030.78it/s] 95%|█████████▌| 535625728/561753746 [00:45<00:02, 11688853.52it/s] 96%|█████████▌| 536838144/561753746 [00:45<00:02, 11720851.15it/s] 96%|█████████▌| 538050560/561753746 [00:45<00:02, 11742394.27it/s] 96%|█████████▌| 539230208/561753746 [00:46<00:01, 11710866.15it/s] 96%|█████████▌| 540409856/561753746 [00:46<00:01, 11731395.93it/s] 96%|█████████▋| 541622272/561753746 [00:46<00:01, 11749908.93it/s] 97%|█████████▋| 542801920/561753746 [00:46<00:01, 11759163.93it/s] 97%|█████████▋| 543981568/561753746 [00:46<00:01, 11722565.26it/s] 97%|█████████▋| 545193984/561753746 [00:46<00:01, 11751411.69it/s] 97%|█████████▋| 546373632/561753746 [00:46<00:01, 11744247.94it/s] 97%|█████████▋| 547553280/561753746 [00:46<00:01, 11705843.49it/s] 98%|█████████▊| 548732928/561753746 [00:46<00:01, 11730791.35it/s] 98%|█████████▊| 549912576/561753746 [00:46<00:01, 11720927.61it/s] 98%|█████████▊| 551092224/561753746 [00:47<00:00, 11743361.18it/s] 98%|█████████▊| 552271872/561753746 [00:47<00:00, 11706564.15it/s] 99%|█████████▊| 553451520/561753746 [00:47<00:00, 11730742.67it/s] 99%|█████████▊| 554663936/561753746 [00:47<00:00, 11754878.04it/s] 99%|█████████▉| 555843584/561753746 [00:47<00:00, 11719038.43it/s] 99%|█████████▉| 557056000/561753746 [00:47<00:00, 11742302.46it/s] 99%|█████████▉| 558235648/561753746 [00:47<00:00, 11753462.68it/s]100%|█████████▉| 559448064/561753746 [00:47<00:00, 11773185.08it/s]100%|█████████▉| 560627712/561753746 [00:47<00:00, 11733002.14it/s]100%|██████████| 561753746/561753746 [00:48<00:00, 11701183.24it/s]
Extracting _emnist/EMNIST/raw/gzip.zip to _emnist/EMNIST/raw
The first image in the training dataset, a capital letter ‘W’.
Designing the Autoencoder
An autoencoder is an hourglass-shaped neural network (NN) that learns how to replicate input data. It’s built from two smaller neural networks chained back-to-back: the encoder compresses the input data into a smaller space, and the decoder maps the smaller space back into the original input size. By training both the encoder and the decoder at the same time, the autoencoder learns how to efficiently store the most important features of its input data in the smaller space with minimal distortion.4
Since we’re working with two-dimensional image data, we can use convolutional layers in our autoencoder to help it better understand spatial relationships between adjacent cells in a grid.5 To keep the encoding and decoding halves of the autoencoder symmetrical, we’ll use transpose convolution layers to bring the data back to its original shape.
class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()
self.encoder = nn.Sequential(
# 1 x 28 x 28
1, 32, kernel_size=3, stride=2, padding=1),
nn.Conv2d(
nn.ReLU(),# 32 x 14 x 14
32, 64, kernel_size=3, stride=2, padding=1),
nn.Conv2d(
nn.ReLU(),# 64 x 7 x 7
=-3),
nn.Flatten(start_dim3136, 256),
nn.Linear(
nn.ReLU(),256, 100),
nn.Linear(
)self.decoder = nn.Sequential(
100, 256),
nn.Linear(
nn.ReLU(),256, 64 * 7 * 7),
nn.Linear(
nn.ReLU(),-1, (64, 7, 7)),
nn.Unflatten(# 64 x 7 x 7
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1
),# 32 x 14 x 14
nn.ReLU(),
nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1
),0, 1),
nn.Hardtanh(
)
def forward(self, x):
= self.encoder(x)
z = self.decoder(z)
output return output
Convolutional layers are a bit more difficult to deal with than the standard linear layers in multi-layer perceptron neural networks since their input & output sizes aren’t fixed. I had to step through each layer of my NN to double-check that all my layers connected nicely, and added padding to augment the output.
Training the Autoencoder
Training an autoencoder is an unsupervised learning task since there is no need for a reference output data that the NN should match. So, we can discard the training dataset’s target vectors since it’s not needed for understanding an efficient representation.
Here’s how our training loop works:
- Pass a batch of images through the autoencoder.
- Use the mean square error (MSE) loss metric to rate the NN’s performance. This loss function gives us an at-a-glance look for seeing how well the autoencoder can replicate the handwritten digits.
- Using the
Adam
optimizer, determine how to tweak the autoencoder’s neurons to decrease this error. - Repeat steps 1-3 until the NN has viewed all images.
In a normal machine learning training loop, the NN trains over the entire dataset multiple times in different epochs. In this case, I’m OK with using just one epoch since the autoencoder is good enough at that point.
= Autoencoder()
autoencoder
# NN Training setup
= nn.MSELoss()
criterion = torch.optim.Adam(autoencoder.parameters())
optimizer = 3.0e38
best_loss = []
losses
# Run NN training if not cached
if not (os.path.exists(TRAINED_NN_PATH) and os.path.exists(LOSS_DATA_PATH)):
# Train the NN
for batch_idx, data in enumerate(train_dataloader):
= data
inputs, _
optimizer.zero_grad()= autoencoder(inputs)
outputs = criterion(outputs, inputs)
loss
loss.backward()
optimizer.step()= loss.item()
loss_item
losses.append(loss_item)# Save the NN
torch.save(autoencoder.state_dict(), TRAINED_NN_PATH)= torch.tensor(losses)
losses
torch.save(losses, LOSS_DATA_PATH)else:
# Load the already-trained NN and its loss results
autoencoder.load_state_dict(torch.load(TRAINED_NN_PATH))= torch.load(LOSS_DATA_PATH)
losses
# Set the autoencoder in evaluation mode
eval()
autoencoder.
= px.line(
fig "training loss": losses.numpy()}, title="Loss Over Training Batches"
{="batch number", yaxis_title="MSE Loss [unitless]")
).update_layout(xaxis_title fig.show()
Demonstrating the Autoencoder
With the autoencoder fully trained, let’s take a look at how quickly it takes to corrupt some test data. We’ll need to apply some extra pre-processing steps on the letters I drew with my computer before we can repeatedly pass them through the autoencoder NN.
# Load my handwriting image
= Image.open("my_handwriting.png").convert("L")
img # Pre-process & reshape to match MNIST input data
= np.reshape((255 - np.array(img)) / 255, [17, 1, 28, 28]).astype(
my_handwriting
np.float32
)# Remove grid borders
-1, :] = 0.0
my_handwriting[:, :, -1] = 0.0
my_handwriting[:, :, :, # Prepare plotting data
= []
handwriting_data # Rearrange letters to appear from left to right; append to plotting data
handwriting_data.append(np.hstack(my_handwriting.squeeze()))
for passthrough in range(N_PASSTHROUGHS):
= (
my_handwriting =torch.float32)).detach().numpy()
autoencoder(torch.tensor(my_handwriting, dtype
)
handwriting_data.append(np.hstack(my_handwriting.squeeze()))
= px.imshow(
fig
np.array(handwriting_data),=0,
animation_frame={"animation_frame": "n_passthroughs"},
labels
) fig.show()
It’s pretty satisfying watching the AI corrupt my handwritten letters. Interestingly, some of the intermediate shapes resemble symbols that it was not trained on.
For a more quantitative approach, let’s also see how our loss metric (mean square error) from a test data batch changes over each autoencoder passthrough.
= []
corruption_loss
for inputs, _ in test_dataloader:
= inputs.clone()
corrupted for passthrough in range(N_PASSTHROUGHS):
= criterion(corrupted, inputs)
loss
corruption_loss.append(loss.item())= autoencoder(corrupted)
corrupted break
= px.line(
fig "testing loss": corruption_loss},
{="Loss over Autoencoder Passthroughs",
title="Passthroughs", yaxis_title="MSE Loss [unitless]")
).update_layout(xaxis_title fig.show()
The repeated compression and decompression undoes the NN’s progress minimizing the mean square error loss; it eventually plateaus towards a loss value close to the AI’s initial iteration loss.
Conclusion
I’m obviously not the first person inspired by I am Sitting in a Room and doing similar things with technology. Marques Brownlee’s YouTube reuploading test showed some interesting video artifacts. His original video used a similar script to I am Sitting in a Room, which I thought was neat.
Package Versions & Last Run Date
Python 3.11.2 (main, Aug 26 2024, 07:20:54) [GCC 12.2.0]
NumPy 2.1.2
PIL 10.4.0
Plotly 5.24.1
PyTorch 2.4.1+cu121
Torchvision 0.19.1+cu121
Last Run 2024-10-13 21:39:13.414318