[go: up one dir, main page]

Playing Visual Telephone with Autoencoders

A convoluted way to distort images with convolution

ai
interactive
programming
python
Author

Vincent “VM” Mercator

Published

June 24, 2024

Modified

June 25, 2024

Alvin Lucier’s I am Sitting in a Room is an interesting contemporary sound art piece that I like. As Lucier describes in its introduction, he plays tape-recorded audio of his voice into an empty room, and records the sound echoed from the walls. Each iteration of the piece plays back the recording from the previous one, and this process repeats until Lucier’s voice distorts into eerie chime-like sounds with pitches relating to the room’s fundamental frequency.

Inspired by this, the concept of model collapse in artificial intelligence, and the game of telephone, I wondered: can I do the same thing with image processing in artificial intelligence? More specifically, how quickly will an AI distort a recognizable image until it becomes unrecognizable? What would this look like at the end? I decided to try seeing how a convolutional autoencoder can corrupt handwritten letters to see the results for myself.

Setup

I’m going to use the PyTorch machine learning library1 to build my autoencoder because I like it more than other machine learning libraries. It’s notoriously bad for reproducibility, though, so the results for this article might change over time.

import datetime
import os
import random
import string
import sys

import plotly
import numpy as np
import plotly.express as px
import torch
import torchvision
import PIL
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import v2

SEED = 20240622
TRAINED_NN_PATH = "_autoencoder_state_dict.pt"
LOSS_DATA_PATH = "_loss.pt"
N_PASSTHROUGHS = 50

# Seed RNGs for reproducibility
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)

Before we pass my own handwriting through the autoencoder, we’ll first need to train it on preexisting handwriting data. Let’s use the Extended MNIST database since PyTorch already has utilities for downloading and extracting it. Each element in the EMNIST dataset is a \(28 \times 28\) black-and-white image representing one of 62 different characters: lowercase letters, uppercase letters, and numbers. In this case, let’s focus on the letters.2

We also need to pre-process the dataset images for properly inputting them into the AI. Each pixel in the original image corresponds to an integer in the range \([0, 255]\); we’ll convert them to floating-point numbers and scale them to the range \([0, 1]\) so that the AI can better interpret them. The raw EMNIST dataset data is also transposed, so we’ll have to transpose it back to make it easier to view when plotted.3

emnist_preprocess = v2.Compose(
    [
        v2.ToImage(),
        v2.ToDtype(torch.float, scale=True),
        v2.Lambda(lambda x: torch.transpose(x, 1, 2)),
    ]
)

emnist_train = torchvision.datasets.EMNIST(
    root="_emnist",
    split="letters",
    train=True,
    download=True,
    transform=emnist_preprocess,
)
emnist_test = torchvision.datasets.EMNIST(
    root="_emnist",
    split="letters",
    train=False,
    download=True,
    transform=emnist_preprocess,
)

train_dataloader = DataLoader(emnist_train, batch_size=512, shuffle=False)
test_dataloader = DataLoader(emnist_test, batch_size=512, shuffle=True)

example_label = string.ascii_uppercase[emnist_train[0][1] - 1]
fig = px.imshow(
    emnist_train[0][0][0, :, :],
    labels={"x": "x", "y": "y", "color": "value"},
    title=f"EMNIST Data Example 0: {example_label}",
)
fig.show()
Downloading https://biometrics.nist.gov/cs_links/EMNIST/gzip.zip to _emnist/EMNIST/raw/gzip.zip
  0%|          | 0/561753746 [00:00<?, ?it/s]  0%|          | 196608/561753746 [00:00<04:48, 1949350.45it/s]  0%|          | 1146880/561753746 [00:00<01:28, 6361699.16it/s]  0%|          | 2359296/561753746 [00:00<01:03, 8870858.93it/s]  1%|          | 3571712/561753746 [00:00<00:55, 10032870.80it/s]  1%|          | 4751360/561753746 [00:00<00:52, 10609171.99it/s]  1%|          | 5963776/561753746 [00:00<00:50, 11014641.41it/s]  1%|▏         | 7176192/561753746 [00:00<00:49, 11270584.04it/s]  1%|▏         | 8388608/561753746 [00:00<00:48, 11437104.37it/s]  2%|▏         | 9568256/561753746 [00:00<00:48, 11499775.45it/s]  2%|▏         | 10780672/561753746 [00:01<00:47, 11590500.05it/s]  2%|▏         | 11960320/561753746 [00:01<00:47, 11646945.50it/s]  2%|▏         | 13139968/561753746 [00:01<00:47, 11644246.57it/s]  3%|▎         | 14352384/561753746 [00:01<00:46, 11697156.15it/s]  3%|▎         | 15532032/561753746 [00:01<00:46, 11721754.09it/s]  3%|▎         | 16744448/561753746 [00:01<00:46, 11743604.19it/s]  3%|▎         | 17924096/561753746 [00:01<00:46, 11711664.13it/s]  3%|▎         | 19103744/561753746 [00:01<00:46, 11731855.93it/s]  4%|▎         | 20316160/561753746 [00:01<00:46, 11757508.83it/s]  4%|▍         | 21495808/561753746 [00:01<00:46, 11714659.57it/s]  4%|▍         | 22675456/561753746 [00:02<00:45, 11734129.22it/s]  4%|▍         | 23887872/561753746 [00:02<00:45, 11758907.39it/s]  4%|▍         | 25067520/561753746 [00:02<00:45, 11764883.89it/s]  5%|▍         | 26247168/561753746 [00:02<00:45, 11726631.87it/s]  5%|▍         | 27459584/561753746 [00:02<00:45, 11746783.39it/s]  5%|▌         | 28639232/561753746 [00:02<00:45, 11756153.54it/s]  5%|▌         | 29818880/561753746 [00:02<00:45, 11721136.75it/s]  6%|▌         | 31031296/561753746 [00:02<00:45, 11750001.19it/s]  6%|▌         | 32210944/561753746 [00:02<00:45, 11758533.51it/s]  6%|▌         | 33423360/561753746 [00:02<00:44, 11769239.33it/s]  6%|▌         | 34603008/561753746 [00:03<00:44, 11729897.74it/s]  6%|▋         | 35782656/561753746 [00:03<00:44, 11744206.07it/s]  7%|▋         | 36995072/561753746 [00:03<00:44, 11759361.65it/s]  7%|▋         | 38174720/561753746 [00:03<00:44, 11722711.90it/s]  7%|▋         | 39387136/561753746 [00:03<00:44, 11744344.32it/s]  7%|▋         | 40599552/561753746 [00:03<00:44, 11758710.96it/s]  7%|▋         | 41811968/561753746 [00:03<00:44, 11776374.46it/s]  8%|▊         | 42991616/561753746 [00:03<00:44, 11728154.45it/s]  8%|▊         | 44204032/561753746 [00:03<00:44, 11747789.20it/s]  8%|▊         | 45416448/561753746 [00:03<00:43, 11761304.11it/s]  8%|▊         | 46596096/561753746 [00:04<00:43, 11724910.83it/s]  9%|▊         | 47808512/561753746 [00:04<00:43, 11745255.97it/s]  9%|▊         | 49020928/561753746 [00:04<00:43, 11759317.55it/s]  9%|▉         | 50233344/561753746 [00:04<00:43, 11769588.29it/s]  9%|▉         | 51412992/561753746 [00:04<00:43, 11730447.99it/s]  9%|▉         | 52625408/561753746 [00:04<00:43, 11749392.36it/s] 10%|▉         | 53837824/561753746 [00:04<00:43, 11762638.86it/s] 10%|▉         | 55017472/561753746 [00:04<00:43, 11725558.61it/s] 10%|█         | 56197120/561753746 [00:04<00:43, 11741672.42it/s] 10%|█         | 57409536/561753746 [00:04<00:42, 11764247.78it/s] 10%|█         | 58589184/561753746 [00:05<00:42, 11769009.70it/s] 11%|█         | 59768832/561753746 [00:05<00:42, 11728934.28it/s] 11%|█         | 60981248/561753746 [00:05<00:42, 11748918.49it/s] 11%|█         | 62160896/561753746 [00:05<00:42, 11757888.95it/s] 11%|█▏        | 63340544/561753746 [00:05<00:42, 11721798.56it/s] 11%|█▏        | 64552960/561753746 [00:05<00:42, 11743786.15it/s] 12%|█▏        | 65765376/561753746 [00:05<00:42, 11758478.75it/s] 12%|█▏        | 66977792/561753746 [00:05<00:42, 11726688.84it/s] 12%|█▏        | 68157440/561753746 [00:05<00:42, 11742678.28it/s] 12%|█▏        | 69369856/561753746 [00:06<00:41, 11757863.73it/s] 13%|█▎        | 70582272/561753746 [00:06<00:41, 11768018.11it/s] 13%|█▎        | 71761920/561753746 [00:06<00:41, 11729512.14it/s] 13%|█▎        | 72974336/561753746 [00:06<00:41, 11748675.51it/s] 13%|█▎        | 74153984/561753746 [00:06<00:41, 11757802.92it/s] 13%|█▎        | 75333632/561753746 [00:06<00:41, 11721890.62it/s] 14%|█▎        | 76546048/561753746 [00:06<00:41, 11750814.31it/s] 14%|█▍        | 77725696/561753746 [00:06<00:41, 11759010.42it/s] 14%|█▍        | 78938112/561753746 [00:06<00:41, 11769759.52it/s] 14%|█▍        | 80117760/561753746 [00:06<00:41, 11730173.92it/s] 14%|█▍        | 81330176/561753746 [00:07<00:40, 11749206.79it/s] 15%|█▍        | 82542592/561753746 [00:07<00:40, 11762223.26it/s] 15%|█▍        | 83722240/561753746 [00:07<00:40, 11725462.26it/s] 15%|█▌        | 84934656/561753746 [00:07<00:40, 11745727.23it/s] 15%|█▌        | 86114304/561753746 [00:07<00:40, 11755639.09it/s] 16%|█▌        | 87326720/561753746 [00:07<00:40, 11774054.86it/s] 16%|█▌        | 88506368/561753746 [00:07<00:40, 11733349.16it/s] 16%|█▌        | 89686016/561753746 [00:07<00:40, 11746560.34it/s] 16%|█▌        | 90898432/561753746 [00:07<00:40, 11760949.73it/s] 16%|█▋        | 92078080/561753746 [00:07<00:40, 11724198.49it/s] 17%|█▋        | 93257728/561753746 [00:08<00:39, 11741051.59it/s] 17%|█▋        | 94470144/561753746 [00:08<00:39, 11756112.47it/s] 17%|█▋        | 95682560/561753746 [00:08<00:39, 11767285.52it/s] 17%|█▋        | 96862208/561753746 [00:08<00:39, 11729150.45it/s] 17%|█▋        | 98074624/561753746 [00:08<00:39, 11748431.69it/s] 18%|█▊        | 99287040/561753746 [00:08<00:39, 11761960.16it/s] 18%|█▊        | 100466688/561753746 [00:08<00:39, 11722961.64it/s] 18%|█▊        | 101679104/561753746 [00:08<00:39, 11746344.94it/s] 18%|█▊        | 102891520/561753746 [00:08<00:39, 11759798.35it/s] 19%|█▊        | 104103936/561753746 [00:08<00:38, 11770116.81it/s] 19%|█▊        | 105283584/561753746 [00:09<00:38, 11730628.79it/s] 19%|█▉        | 106496000/561753746 [00:09<00:38, 11749620.31it/s] 19%|█▉        | 107675648/561753746 [00:09<00:38, 11758364.11it/s] 19%|█▉        | 108855296/561753746 [00:09<00:38, 11722199.61it/s] 20%|█▉        | 110067712/561753746 [00:09<00:38, 11750213.22it/s] 20%|█▉        | 111247360/561753746 [00:09<00:38, 11759365.24it/s] 20%|██        | 112459776/561753746 [00:09<00:38, 11769716.01it/s] 20%|██        | 113639424/561753746 [00:09<00:38, 11730238.52it/s] 20%|██        | 114819072/561753746 [00:09<00:38, 11744869.21it/s] 21%|██        | 116031488/561753746 [00:09<00:37, 11759081.99it/s] 21%|██        | 117211136/561753746 [00:10<00:37, 11722923.12it/s] 21%|██        | 118423552/561753746 [00:10<00:37, 11751346.66it/s] 21%|██▏       | 119603200/561753746 [00:10<00:37, 11759864.78it/s] 22%|██▏       | 120815616/561753746 [00:10<00:37, 11769833.05it/s] 22%|██▏       | 121995264/561753746 [00:10<00:37, 11730650.18it/s] 22%|██▏       | 123174912/561753746 [00:10<00:37, 11745117.58it/s] 22%|██▏       | 124387328/561753746 [00:10<00:37, 11766538.12it/s] 22%|██▏       | 125566976/561753746 [00:10<00:37, 11720700.58it/s] 23%|██▎       | 126779392/561753746 [00:10<00:37, 11742769.70it/s] 23%|██▎       | 127991808/561753746 [00:10<00:36, 11758068.86it/s] 23%|██▎       | 129204224/561753746 [00:11<00:36, 11768643.36it/s] 23%|██▎       | 130383872/561753746 [00:11<00:36, 11729785.67it/s] 23%|██▎       | 131596288/561753746 [00:11<00:36, 11748798.19it/s] 24%|██▎       | 132808704/561753746 [00:11<00:36, 11762050.77it/s] 24%|██▍       | 133988352/561753746 [00:11<00:36, 11724958.11it/s] 24%|██▍       | 135168000/561753746 [00:11<00:36, 11741493.92it/s] 24%|██▍       | 136380416/561753746 [00:11<00:36, 11763516.59it/s] 24%|██▍       | 137560064/561753746 [00:11<00:36, 11768670.20it/s] 25%|██▍       | 138739712/561753746 [00:11<00:36, 11729305.73it/s] 25%|██▍       | 139952128/561753746 [00:12<00:35, 11748948.15it/s] 25%|██▌       | 141131776/561753746 [00:12<00:35, 11757785.76it/s] 25%|██▌       | 142311424/561753746 [00:12<00:35, 11721824.55it/s] 26%|██▌       | 143523840/561753746 [00:12<00:35, 11743326.11it/s] 26%|██▌       | 144736256/561753746 [00:12<00:35, 11758083.43it/s] 26%|██▌       | 145948672/561753746 [00:12<00:35, 11769009.40it/s] 26%|██▌       | 147128320/561753746 [00:12<00:35, 11729768.63it/s] 26%|██▋       | 148340736/561753746 [00:12<00:35, 11749773.36it/s] 27%|██▋       | 149553152/561753746 [00:12<00:35, 11764068.42it/s] 27%|██▋       | 150732800/561753746 [00:12<00:35, 11726213.78it/s] 27%|██▋       | 151912448/561753746 [00:13<00:34, 11745990.64it/s] 27%|██▋       | 153092096/561753746 [00:13<00:34, 11757032.47it/s] 27%|██▋       | 154304512/561753746 [00:13<00:34, 11775135.64it/s] 28%|██▊       | 155484160/561753746 [00:13<00:34, 11726999.11it/s] 28%|██▊       | 156663808/561753746 [00:13<00:34, 11742298.43it/s] 28%|██▊       | 157876224/561753746 [00:13<00:34, 11765232.64it/s] 28%|██▊       | 159055872/561753746 [00:13<00:34, 11726561.97it/s] 29%|██▊       | 160235520/561753746 [00:13<00:34, 11742409.81it/s] 29%|██▊       | 161447936/561753746 [00:13<00:34, 11758191.59it/s] 29%|██▉       | 162660352/561753746 [00:13<00:33, 11768426.78it/s] 29%|██▉       | 163840000/561753746 [00:14<00:33, 11729857.82it/s] 29%|██▉       | 165052416/561753746 [00:14<00:33, 11748649.12it/s] 30%|██▉       | 166232064/561753746 [00:14<00:33, 11757481.74it/s] 30%|██▉       | 167411712/561753746 [00:14<00:33, 11722095.45it/s] 30%|███       | 168624128/561753746 [00:14<00:33, 11743591.92it/s] 30%|███       | 169836544/561753746 [00:14<00:33, 11757879.80it/s] 30%|███       | 171048960/561753746 [00:14<00:33, 11768598.29it/s] 31%|███       | 172228608/561753746 [00:14<00:33, 11730193.70it/s] 31%|███       | 173408256/561753746 [00:14<00:33, 11744634.02it/s] 31%|███       | 174620672/561753746 [00:14<00:32, 11766136.49it/s] 31%|███▏      | 175800320/561753746 [00:15<00:32, 11720948.13it/s] 32%|███▏      | 177012736/561753746 [00:15<00:32, 11749642.46it/s] 32%|███▏      | 178192384/561753746 [00:15<00:32, 11758551.32it/s] 32%|███▏      | 179372032/561753746 [00:15<00:54, 7018477.81it/s]  33%|███▎      | 182943744/561753746 [00:15<00:29, 12749694.20it/s] 33%|███▎      | 184647680/561753746 [00:15<00:35, 10708176.50it/s] 33%|███▎      | 186777600/561753746 [00:16<00:29, 12795652.04it/s] 34%|███▎      | 188416000/561753746 [00:16<00:29, 12477463.42it/s] 34%|███▍      | 189923328/561753746 [00:16<00:30, 12287868.45it/s] 34%|███▍      | 191332352/561753746 [00:16<00:30, 12158918.20it/s] 34%|███▍      | 192675840/561753746 [00:16<00:30, 12023322.04it/s] 35%|███▍      | 193986560/561753746 [00:16<00:30, 11959656.53it/s] 35%|███▍      | 195264512/561753746 [00:16<00:30, 11914978.17it/s] 35%|███▍      | 196509696/561753746 [00:16<00:30, 11840743.07it/s] 35%|███▌      | 197722112/561753746 [00:16<00:30, 11827437.96it/s] 35%|███▌      | 198934528/561753746 [00:17<00:30, 11818061.98it/s] 36%|███▌      | 200146944/561753746 [00:17<00:30, 11810930.08it/s] 36%|███▌      | 201359360/561753746 [00:17<00:30, 11758398.50it/s] 36%|███▌      | 202571776/561753746 [00:17<00:30, 11768623.60it/s] 36%|███▋      | 203784192/561753746 [00:17<00:30, 11775313.79it/s] 36%|███▋      | 204996608/561753746 [00:17<00:30, 11732952.82it/s] 37%|███▋      | 206209024/561753746 [00:17<00:30, 11750532.50it/s] 37%|███▋      | 207421440/561753746 [00:17<00:30, 11763333.43it/s] 37%|███▋      | 208633856/561753746 [00:17<00:29, 11779011.91it/s] 37%|███▋      | 209813504/561753746 [00:17<00:30, 11730471.32it/s] 38%|███▊      | 211025920/561753746 [00:18<00:29, 11749515.35it/s] 38%|███▊      | 212238336/561753746 [00:18<00:29, 11762296.40it/s] 38%|███▊      | 213417984/561753746 [00:18<00:29, 11725233.31it/s] 38%|███▊      | 214597632/561753746 [00:18<00:29, 11741578.20it/s] 38%|███▊      | 215810048/561753746 [00:18<00:29, 11763640.93it/s] 39%|███▊      | 216989696/561753746 [00:18<00:29, 11768828.28it/s] 39%|███▉      | 218169344/561753746 [00:18<00:29, 11729438.32it/s] 39%|███▉      | 219381760/561753746 [00:18<00:29, 11749107.24it/s] 39%|███▉      | 220594176/561753746 [00:18<00:29, 11761412.22it/s] 39%|███▉      | 221773824/561753746 [00:19<00:29, 11718191.52it/s] 40%|███▉      | 222986240/561753746 [00:19<00:28, 11747691.41it/s] 40%|███▉      | 224165888/561753746 [00:19<00:28, 11756679.83it/s] 40%|████      | 225378304/561753746 [00:19<00:28, 11768606.87it/s] 40%|████      | 226557952/561753746 [00:19<00:28, 11729303.02it/s] 41%|████      | 227737600/561753746 [00:19<00:28, 11744343.79it/s] 41%|████      | 228950016/561753746 [00:19<00:28, 11766196.71it/s] 41%|████      | 230129664/561753746 [00:19<00:28, 11720845.34it/s] 41%|████      | 231309312/561753746 [00:19<00:34, 9455845.02it/s]  42%|████▏     | 233472000/561753746 [00:20<00:26, 12492143.60it/s] 42%|████▏     | 234848256/561753746 [00:20<00:26, 12245235.19it/s] 42%|████▏     | 236158976/561753746 [00:20<00:26, 12114215.95it/s] 42%|████▏     | 237436928/561753746 [00:20<00:26, 12023228.32it/s] 42%|████▏     | 238682112/561753746 [00:20<00:27, 11915516.84it/s] 43%|████▎     | 239927296/561753746 [00:20<00:27, 11877552.59it/s] 43%|████▎     | 241139712/561753746 [00:20<00:27, 11856691.38it/s] 43%|████▎     | 242352128/561753746 [00:20<00:26, 11835041.01it/s] 43%|████▎     | 243564544/561753746 [00:20<00:27, 11774741.19it/s] 44%|████▎     | 244776960/561753746 [00:20<00:26, 11787412.42it/s] 44%|████▍     | 245989376/561753746 [00:21<00:26, 11788803.29it/s] 44%|████▍     | 247201792/561753746 [00:21<00:26, 11741582.13it/s] 44%|████▍     | 248381440/561753746 [00:21<00:26, 11753044.45it/s] 44%|████▍     | 249593856/561753746 [00:21<00:26, 11771843.86it/s] 45%|████▍     | 250773504/561753746 [00:21<00:26, 11774527.00it/s] 45%|████▍     | 251953152/561753746 [00:21<00:26, 11733456.18it/s] 45%|████▌     | 253165568/561753746 [00:21<00:26, 11751317.80it/s] 45%|████▌     | 254345216/561753746 [00:21<00:26, 11759631.69it/s] 45%|████▌     | 255524864/561753746 [00:21<00:26, 11723276.95it/s] 46%|████▌     | 256737280/561753746 [00:21<00:25, 11751334.49it/s] 46%|████▌     | 257916928/561753746 [00:22<00:25, 11759404.04it/s] 46%|████▌     | 259129344/561753746 [00:22<00:25, 11770417.00it/s] 46%|████▋     | 260308992/561753746 [00:22<00:25, 11730329.96it/s] 47%|████▋     | 261488640/561753746 [00:22<00:25, 11744894.74it/s] 47%|████▋     | 262701056/561753746 [00:22<00:25, 11759382.85it/s] 47%|████▋     | 263880704/561753746 [00:22<00:25, 11723072.28it/s] 47%|████▋     | 265093120/561753746 [00:22<00:25, 11744304.81it/s] 47%|████▋     | 266305536/561753746 [00:22<00:25, 11759463.74it/s] 48%|████▊     | 267517952/561753746 [00:22<00:25, 11720345.07it/s] 48%|████▊     | 268730368/561753746 [00:23<00:24, 11749026.42it/s] 48%|████▊     | 269910016/561753746 [00:23<00:24, 11757955.33it/s] 48%|████▊     | 271122432/561753746 [00:23<00:24, 11768631.57it/s] 48%|████▊     | 272302080/561753746 [00:23<00:24, 11730043.35it/s] 49%|████▊     | 273481728/561753746 [00:23<00:24, 11744768.97it/s] 49%|████▉     | 274694144/561753746 [00:23<00:24, 11766289.70it/s] 49%|████▉     | 275873792/561753746 [00:23<00:24, 11727713.49it/s] 49%|████▉     | 277053440/561753746 [00:23<00:24, 11741792.90it/s] 50%|████▉     | 278265856/561753746 [00:23<00:24, 11758584.88it/s] 50%|████▉     | 279445504/561753746 [00:23<00:23, 11764984.07it/s] 50%|████▉     | 280625152/561753746 [00:24<00:23, 11727007.84it/s] 50%|█████     | 281837568/561753746 [00:24<00:23, 11747003.77it/s] 50%|█████     | 283017216/561753746 [00:24<00:23, 11756355.53it/s] 51%|█████     | 284229632/561753746 [00:24<00:23, 11725552.57it/s] 51%|█████     | 285442048/561753746 [00:24<00:23, 11745973.07it/s] 51%|█████     | 286621696/561753746 [00:24<00:23, 11755517.95it/s] 51%|█████     | 287801344/561753746 [00:24<00:27, 10014103.49it/s] 52%|█████▏    | 289669120/561753746 [00:24<00:22, 12289139.58it/s] 52%|█████▏    | 290979840/561753746 [00:24<00:22, 12146996.61it/s] 52%|█████▏    | 292257792/561753746 [00:25<00:22, 12037896.77it/s] 52%|█████▏    | 293502976/561753746 [00:25<00:22, 11924619.66it/s] 52%|█████▏    | 294715392/561753746 [00:25<00:22, 11886770.62it/s] 53%|█████▎    | 295927808/561753746 [00:25<00:22, 11859952.06it/s] 53%|█████▎    | 297140224/561753746 [00:25<00:22, 11799397.54it/s] 53%|█████▎    | 298352640/561753746 [00:25<00:22, 11790687.07it/s] 53%|█████▎    | 299565056/561753746 [00:25<00:22, 11798458.00it/s] 54%|█████▎    | 300777472/561753746 [00:25<00:22, 11789753.63it/s] 54%|█████▍    | 301989888/561753746 [00:25<00:22, 11749293.32it/s] 54%|█████▍    | 303169536/561753746 [00:25<00:21, 11758115.78it/s] 54%|█████▍    | 304381952/561753746 [00:26<00:21, 11768513.88it/s] 54%|█████▍    | 305561600/561753746 [00:26<00:21, 11729723.13it/s] 55%|█████▍    | 306774016/561753746 [00:26<00:21, 11749030.99it/s] 55%|█████▍    | 307986432/561753746 [00:26<00:21, 11762459.17it/s] 55%|█████▌    | 309198848/561753746 [00:26<00:21, 11771450.56it/s] 55%|█████▌    | 310378496/561753746 [00:26<00:21, 11731973.91it/s] 55%|█████▌    | 311590912/561753746 [00:26<00:21, 11750123.89it/s] 56%|█████▌    | 312770560/561753746 [00:26<00:21, 11758831.31it/s] 56%|█████▌    | 313950208/561753746 [00:26<00:21, 11722808.91it/s] 56%|█████▌    | 315162624/561753746 [00:26<00:20, 11750922.35it/s] 56%|█████▋    | 316342272/561753746 [00:27<00:20, 11759382.15it/s] 57%|█████▋    | 317554688/561753746 [00:27<00:20, 11769966.56it/s] 57%|█████▋    | 318734336/561753746 [00:27<00:20, 11732388.86it/s] 57%|█████▋    | 319946752/561753746 [00:27<00:20, 11750504.98it/s] 57%|█████▋    | 321159168/561753746 [00:27<00:20, 11763318.98it/s] 57%|█████▋    | 322338816/561753746 [00:27<00:20, 11724002.69it/s] 58%|█████▊    | 323551232/561753746 [00:27<00:20, 11744828.59it/s] 58%|█████▊    | 324730880/561753746 [00:27<00:20, 11754972.22it/s] 58%|█████▊    | 325943296/561753746 [00:27<00:20, 11773621.35it/s] 58%|█████▊    | 327122944/561753746 [00:27<00:19, 11732807.67it/s] 58%|█████▊    | 328302592/561753746 [00:28<00:19, 11746386.35it/s] 59%|█████▊    | 329515008/561753746 [00:28<00:19, 11761365.75it/s] 59%|█████▉    | 330694656/561753746 [00:28<00:19, 11723928.14it/s] 59%|█████▉    | 331874304/561753746 [00:28<00:19, 11740744.96it/s] 59%|█████▉    | 333086720/561753746 [00:28<00:19, 11756609.86it/s] 60%|█████▉    | 334299136/561753746 [00:28<00:19, 11767575.07it/s] 60%|█████▉    | 335478784/561753746 [00:28<00:19, 11727803.94it/s] 60%|█████▉    | 336691200/561753746 [00:28<00:19, 11748798.71it/s] 60%|██████    | 337903616/561753746 [00:28<00:19, 11760621.89it/s] 60%|██████    | 339083264/561753746 [00:28<00:18, 11725164.52it/s] 61%|██████    | 340295680/561753746 [00:29<00:18, 11746958.21it/s] 61%|██████    | 341508096/561753746 [00:29<00:18, 11759266.62it/s] 61%|██████    | 342720512/561753746 [00:29<00:18, 11727534.73it/s] 61%|██████    | 343900160/561753746 [00:29<00:18, 11744153.36it/s] 61%|██████▏   | 345112576/561753746 [00:29<00:18, 11758350.70it/s] 62%|██████▏   | 346292224/561753746 [00:29<00:18, 11765959.27it/s] 62%|██████▏   | 347471872/561753746 [00:29<00:18, 11725406.58it/s] 62%|██████▏   | 348684288/561753746 [00:29<00:18, 11754016.06it/s] 62%|██████▏   | 349863936/561753746 [00:29<00:18, 11760200.58it/s] 62%|██████▏   | 351076352/561753746 [00:30<00:17, 11721354.04it/s] 63%|██████▎   | 352288768/561753746 [00:30<00:17, 11745940.39it/s] 63%|██████▎   | 353468416/561753746 [00:30<00:17, 11755300.00it/s] 63%|██████▎   | 354680832/561753746 [00:30<00:17, 11772483.77it/s] 63%|██████▎   | 355860480/561753746 [00:30<00:17, 11732312.27it/s] 64%|██████▎   | 357040128/561753746 [00:30<00:17, 11746422.72it/s] 64%|██████▍   | 358252544/561753746 [00:30<00:17, 11760379.04it/s] 64%|██████▍   | 359432192/561753746 [00:30<00:17, 11723591.46it/s] 64%|██████▍   | 360644608/561753746 [00:30<00:17, 11744543.66it/s] 64%|██████▍   | 361824256/561753746 [00:30<00:17, 11755435.80it/s] 65%|██████▍   | 363036672/561753746 [00:31<00:16, 11773024.85it/s] 65%|██████▍   | 364216320/561753746 [00:31<00:16, 11733206.47it/s] 65%|██████▌   | 365395968/561753746 [00:31<00:16, 11746857.89it/s] 65%|██████▌   | 366608384/561753746 [00:31<00:16, 11760790.94it/s] 65%|██████▌   | 367788032/561753746 [00:31<00:16, 11725315.32it/s] 66%|██████▌   | 368967680/561753746 [00:31<00:16, 11739649.57it/s] 66%|██████▌   | 370180096/561753746 [00:31<00:16, 11764862.57it/s] 66%|██████▌   | 371359744/561753746 [00:31<00:16, 11766214.93it/s] 66%|██████▋   | 372539392/561753746 [00:31<00:16, 11729465.53it/s] 67%|██████▋   | 373751808/561753746 [00:31<00:16, 11748548.57it/s] 67%|██████▋   | 374964224/561753746 [00:32<00:15, 11761775.24it/s] 67%|██████▋   | 376143872/561753746 [00:32<00:15, 11724986.11it/s] 67%|██████▋   | 377323520/561753746 [00:32<00:15, 11742774.74it/s] 67%|██████▋   | 378535936/561753746 [00:32<00:15, 11765287.08it/s] 68%|██████▊   | 379715584/561753746 [00:32<00:15, 11769264.09it/s] 68%|██████▊   | 380895232/561753746 [00:32<00:15, 11728176.20it/s] 68%|██████▊   | 382074880/561753746 [00:32<00:15, 11743351.64it/s] 68%|██████▊   | 383287296/561753746 [00:32<00:15, 11758573.95it/s] 68%|██████▊   | 384466944/561753746 [00:32<00:15, 11764884.33it/s] 69%|██████▊   | 385646592/561753746 [00:32<00:15, 11726176.29it/s] 69%|██████▉   | 386859008/561753746 [00:33<00:14, 11746782.54it/s] 69%|██████▉   | 388071424/561753746 [00:33<00:14, 11760512.89it/s] 69%|██████▉   | 389251072/561753746 [00:33<00:14, 11724013.24it/s] 70%|██████▉   | 390430720/561753746 [00:33<00:14, 11742257.76it/s] 70%|██████▉   | 391643136/561753746 [00:33<00:14, 11756187.85it/s] 70%|██████▉   | 392855552/561753746 [00:33<00:14, 11717705.73it/s] 70%|███████   | 394067968/561753746 [00:33<00:14, 11748198.12it/s] 70%|███████   | 395247616/561753746 [00:33<00:14, 11757203.39it/s] 71%|███████   | 396460032/561753746 [00:33<00:14, 11775260.83it/s] 71%|███████   | 397639680/561753746 [00:33<00:13, 11727317.95it/s] 71%|███████   | 398852096/561753746 [00:34<00:13, 11746663.09it/s] 71%|███████   | 400064512/561753746 [00:34<00:13, 11760635.81it/s] 71%|███████▏  | 401244160/561753746 [00:34<00:13, 11724608.25it/s] 72%|███████▏  | 402456576/561753746 [00:34<00:13, 11743969.59it/s] 72%|███████▏  | 403668992/561753746 [00:34<00:13, 11759957.14it/s] 72%|███████▏  | 404881408/561753746 [00:34<00:13, 11770292.71it/s] 72%|███████▏  | 406061056/561753746 [00:34<00:13, 11729826.68it/s] 72%|███████▏  | 407240704/561753746 [00:34<00:13, 11745179.11it/s] 73%|███████▎  | 408453120/561753746 [00:34<00:13, 11766948.26it/s] 73%|███████▎  | 409632768/561753746 [00:34<00:12, 11727817.09it/s] 73%|███████▎  | 410812416/561753746 [00:35<00:12, 11743715.87it/s] 73%|███████▎  | 412024832/561753746 [00:35<00:12, 11758646.75it/s] 74%|███████▎  | 413204480/561753746 [00:35<00:12, 11764776.49it/s] 74%|███████▍  | 414384128/561753746 [00:35<00:12, 11726539.15it/s] 74%|███████▍  | 415596544/561753746 [00:35<00:12, 11746459.19it/s] 74%|███████▍  | 416808960/561753746 [00:35<00:12, 11760560.20it/s] 74%|███████▍  | 417988608/561753746 [00:35<00:12, 11719837.21it/s] 75%|███████▍  | 419168256/561753746 [00:35<00:12, 11734117.54it/s] 75%|███████▍  | 420347904/561753746 [00:35<00:12, 11712254.10it/s] 75%|███████▌  | 421527552/561753746 [00:36<00:11, 11731224.11it/s] 75%|███████▌  | 422707200/561753746 [00:36<00:11, 11704756.01it/s] 75%|███████▌  | 423919616/561753746 [00:36<00:11, 11726140.93it/s] 76%|███████▌  | 425099264/561753746 [00:36<00:11, 11746872.79it/s] 76%|███████▌  | 426278912/561753746 [00:36<00:11, 11713724.87it/s] 76%|███████▌  | 427491328/561753746 [00:36<00:11, 11745379.47it/s] 76%|███████▋  | 428670976/561753746 [00:36<00:11, 11754870.62it/s] 77%|███████▋  | 429883392/561753746 [00:36<00:11, 11767500.22it/s] 77%|███████▋  | 431063040/561753746 [00:36<00:11, 11728607.56it/s] 77%|███████▋  | 432242688/561753746 [00:36<00:11, 11743333.74it/s] 77%|███████▋  | 433422336/561753746 [00:37<00:10, 11732929.76it/s] 77%|███████▋  | 434601984/561753746 [00:37<00:10, 11739948.18it/s] 78%|███████▊  | 435781632/561753746 [00:37<00:10, 11709299.13it/s] 78%|███████▊  | 436961280/561753746 [00:37<00:10, 11730203.14it/s] 78%|███████▊  | 438173696/561753746 [00:37<00:10, 11749075.55it/s] 78%|███████▊  | 439353344/561753746 [00:37<00:10, 11715272.17it/s] 78%|███████▊  | 440532992/561753746 [00:37<00:10, 11734701.91it/s] 79%|███████▊  | 441712640/561753746 [00:37<00:10, 11749980.76it/s] 79%|███████▉  | 442892288/561753746 [00:37<00:10, 11749640.42it/s] 79%|███████▉  | 444071936/561753746 [00:37<00:10, 11694478.82it/s] 79%|███████▉  | 445251584/561753746 [00:38<00:09, 11711547.89it/s] 79%|███████▉  | 446431232/561753746 [00:38<00:09, 11722469.91it/s] 80%|███████▉  | 447610880/561753746 [00:38<00:09, 11692530.33it/s] 80%|███████▉  | 448790528/561753746 [00:38<00:09, 11697804.82it/s] 80%|████████  | 449970176/561753746 [00:38<00:09, 11699476.00it/s] 80%|████████  | 451182592/561753746 [00:38<00:09, 11735105.22it/s] 81%|████████  | 452362240/561753746 [00:38<00:09, 11665194.96it/s] 81%|████████  | 453541888/561753746 [00:38<00:09, 11691091.75it/s] 81%|████████  | 454721536/561753746 [00:38<00:09, 11692945.63it/s] 81%|████████  | 455901184/561753746 [00:38<00:09, 11680039.08it/s] 81%|████████▏ | 457080832/561753746 [00:39<00:08, 11697790.12it/s] 82%|████████▏ | 458260480/561753746 [00:39<00:08, 11721185.63it/s] 82%|████████▏ | 459472896/561753746 [00:39<00:08, 11747857.25it/s] 82%|████████▏ | 460652544/561753746 [00:39<00:08, 11715370.25it/s] 82%|████████▏ | 461832192/561753746 [00:39<00:08, 11733902.58it/s] 82%|████████▏ | 463044608/561753746 [00:39<00:08, 11753159.18it/s] 83%|████████▎ | 464224256/561753746 [00:39<00:08, 11718548.88it/s] 83%|████████▎ | 465403904/561753746 [00:39<00:08, 11735802.54it/s] 83%|████████▎ | 466616320/561753746 [00:39<00:08, 11761077.05it/s] 83%|████████▎ | 467795968/561753746 [00:39<00:07, 11766326.21it/s] 83%|████████▎ | 468975616/561753746 [00:40<00:07, 11727630.15it/s] 84%|████████▎ | 470155264/561753746 [00:40<00:07, 11741329.83it/s] 84%|████████▍ | 471367680/561753746 [00:40<00:07, 11766686.07it/s] 84%|████████▍ | 472547328/561753746 [00:40<00:07, 11727804.56it/s] 84%|████████▍ | 473726976/561753746 [00:40<00:07, 11740154.39it/s] 85%|████████▍ | 474939392/561753746 [00:40<00:07, 11758797.45it/s] 85%|████████▍ | 476119040/561753746 [00:40<00:07, 11765355.81it/s] 85%|████████▍ | 477298688/561753746 [00:40<00:07, 11726811.68it/s] 85%|████████▌ | 478511104/561753746 [00:40<00:07, 11754474.67it/s] 85%|████████▌ | 479690752/561753746 [00:40<00:06, 11763993.46it/s] 86%|████████▌ | 480870400/561753746 [00:41<00:06, 11712476.32it/s] 86%|████████▌ | 482082816/561753746 [00:41<00:06, 11740520.54it/s] 86%|████████▌ | 483295232/561753746 [00:41<00:06, 11757709.45it/s] 86%|████████▌ | 484507648/561753746 [00:41<00:06, 11767606.21it/s] 86%|████████▋ | 485687296/561753746 [00:41<00:06, 11713349.42it/s] 87%|████████▋ | 486899712/561753746 [00:41<00:06, 11738341.19it/s] 87%|████████▋ | 488079360/561753746 [00:41<00:06, 11638064.61it/s] 87%|████████▋ | 489259008/561753746 [00:41<00:06, 11557248.94it/s] 87%|████████▋ | 490438656/561753746 [00:41<00:06, 11515033.62it/s] 88%|████████▊ | 491618304/561753746 [00:42<00:06, 11203439.23it/s] 88%|████████▊ | 492765184/561753746 [00:42<00:06, 11046919.51it/s] 88%|████████▊ | 493879296/561753746 [00:42<00:06, 10741369.34it/s] 88%|████████▊ | 495058944/561753746 [00:42<00:06, 11037734.19it/s] 88%|████████▊ | 496271360/561753746 [00:42<00:05, 11260536.73it/s] 89%|████████▊ | 497451008/561753746 [00:42<00:05, 11364239.70it/s] 89%|████████▉ | 498630656/561753746 [00:42<00:05, 11485817.27it/s] 89%|████████▉ | 499810304/561753746 [00:42<00:05, 11570301.76it/s] 89%|████████▉ | 501022720/561753746 [00:42<00:05, 11637572.42it/s] 89%|████████▉ | 502202368/561753746 [00:42<00:05, 11638949.40it/s] 90%|████████▉ | 503414784/561753746 [00:43<00:04, 11686317.53it/s] 90%|████████▉ | 504594432/561753746 [00:43<00:04, 11713460.03it/s] 90%|█████████ | 505774080/561753746 [00:43<00:04, 11690146.02it/s] 90%|█████████ | 506986496/561753746 [00:43<00:04, 11722497.98it/s] 90%|█████████ | 508198912/561753746 [00:43<00:04, 11750720.09it/s] 91%|█████████ | 509378560/561753746 [00:43<00:04, 11759212.42it/s] 91%|█████████ | 510558208/561753746 [00:43<00:04, 11722813.29it/s] 91%|█████████ | 511770624/561753746 [00:43<00:04, 11741997.11it/s] 91%|█████████▏| 512950272/561753746 [00:43<00:04, 11755152.79it/s] 92%|█████████▏| 514129920/561753746 [00:43<00:04, 11720160.82it/s] 92%|█████████▏| 515342336/561753746 [00:44<00:03, 11742449.74it/s] 92%|█████████▏| 516554752/561753746 [00:44<00:03, 11757535.30it/s] 92%|█████████▏| 517767168/561753746 [00:44<00:03, 11767144.46it/s] 92%|█████████▏| 518946816/561753746 [00:44<00:03, 11729600.74it/s] 93%|█████████▎| 520159232/561753746 [00:44<00:03, 11748529.23it/s] 93%|█████████▎| 521371648/561753746 [00:44<00:03, 11762006.27it/s] 93%|█████████▎| 522551296/561753746 [00:44<00:03, 11725354.34it/s] 93%|█████████▎| 523763712/561753746 [00:44<00:03, 11745659.36it/s] 93%|█████████▎| 524943360/561753746 [00:44<00:03, 11755833.51it/s] 94%|█████████▎| 526155776/561753746 [00:44<00:03, 11766497.32it/s] 94%|█████████▍| 527335424/561753746 [00:45<00:02, 11726554.82it/s] 94%|█████████▍| 528515072/561753746 [00:45<00:02, 11746664.81it/s] 94%|█████████▍| 529694720/561753746 [00:45<00:02, 11761429.34it/s] 95%|█████████▍| 530874368/561753746 [00:45<00:02, 11722746.83it/s] 95%|█████████▍| 532054016/561753746 [00:45<00:02, 11734077.38it/s] 95%|█████████▍| 533233664/561753746 [00:45<00:02, 11747397.23it/s] 95%|█████████▌| 534446080/561753746 [00:45<00:02, 11761030.78it/s] 95%|█████████▌| 535625728/561753746 [00:45<00:02, 11688853.52it/s] 96%|█████████▌| 536838144/561753746 [00:45<00:02, 11720851.15it/s] 96%|█████████▌| 538050560/561753746 [00:45<00:02, 11742394.27it/s] 96%|█████████▌| 539230208/561753746 [00:46<00:01, 11710866.15it/s] 96%|█████████▌| 540409856/561753746 [00:46<00:01, 11731395.93it/s] 96%|█████████▋| 541622272/561753746 [00:46<00:01, 11749908.93it/s] 97%|█████████▋| 542801920/561753746 [00:46<00:01, 11759163.93it/s] 97%|█████████▋| 543981568/561753746 [00:46<00:01, 11722565.26it/s] 97%|█████████▋| 545193984/561753746 [00:46<00:01, 11751411.69it/s] 97%|█████████▋| 546373632/561753746 [00:46<00:01, 11744247.94it/s] 97%|█████████▋| 547553280/561753746 [00:46<00:01, 11705843.49it/s] 98%|█████████▊| 548732928/561753746 [00:46<00:01, 11730791.35it/s] 98%|█████████▊| 549912576/561753746 [00:46<00:01, 11720927.61it/s] 98%|█████████▊| 551092224/561753746 [00:47<00:00, 11743361.18it/s] 98%|█████████▊| 552271872/561753746 [00:47<00:00, 11706564.15it/s] 99%|█████████▊| 553451520/561753746 [00:47<00:00, 11730742.67it/s] 99%|█████████▊| 554663936/561753746 [00:47<00:00, 11754878.04it/s] 99%|█████████▉| 555843584/561753746 [00:47<00:00, 11719038.43it/s] 99%|█████████▉| 557056000/561753746 [00:47<00:00, 11742302.46it/s] 99%|█████████▉| 558235648/561753746 [00:47<00:00, 11753462.68it/s]100%|█████████▉| 559448064/561753746 [00:47<00:00, 11773185.08it/s]100%|█████████▉| 560627712/561753746 [00:47<00:00, 11733002.14it/s]100%|██████████| 561753746/561753746 [00:48<00:00, 11701183.24it/s]
Extracting _emnist/EMNIST/raw/gzip.zip to _emnist/EMNIST/raw

The first image in the training dataset, a capital letter ‘W’.

Designing the Autoencoder

An autoencoder is an hourglass-shaped neural network (NN) that learns how to replicate input data. It’s built from two smaller neural networks chained back-to-back: the encoder compresses the input data into a smaller space, and the decoder maps the smaller space back into the original input size. By training both the encoder and the decoder at the same time, the autoencoder learns how to efficiently store the most important features of its input data in the smaller space with minimal distortion.4

Since we’re working with two-dimensional image data, we can use convolutional layers in our autoencoder to help it better understand spatial relationships between adjacent cells in a grid.5 To keep the encoding and decoding halves of the autoencoder symmetrical, we’ll use transpose convolution layers to bring the data back to its original shape.

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            # 1 x 28 x 28
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            # 32 x 14 x 14
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            # 64 x 7 x 7
            nn.Flatten(start_dim=-3),
            nn.Linear(3136, 256),
            nn.ReLU(),
            nn.Linear(256, 100),
        )
        self.decoder = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 64 * 7 * 7),
            nn.ReLU(),
            nn.Unflatten(-1, (64, 7, 7)),
            # 64 x 7 x 7
            nn.ConvTranspose2d(
                64, 32, kernel_size=3, stride=2, padding=1, output_padding=1
            ),
            # 32 x 14 x 14
            nn.ReLU(),
            nn.ConvTranspose2d(
                32, 1, kernel_size=3, stride=2, padding=1, output_padding=1
            ),
            nn.Hardtanh(0, 1),
        )

    def forward(self, x):
        z = self.encoder(x)
        output = self.decoder(z)
        return output

An autoencoder architecture diagram.

Here is the layout of my convolutional autoencoder. Parts of this image were made with NN-SVG.

Convolutional layers are a bit more difficult to deal with than the standard linear layers in multi-layer perceptron neural networks since their input & output sizes aren’t fixed. I had to step through each layer of my NN to double-check that all my layers connected nicely, and added padding to augment the output.

Training the Autoencoder

Training an autoencoder is an unsupervised learning task since there is no need for a reference output data that the NN should match. So, we can discard the training dataset’s target vectors since it’s not needed for understanding an efficient representation.

Here’s how our training loop works:

  1. Pass a batch of images through the autoencoder.
  2. Use the mean square error (MSE) loss metric to rate the NN’s performance. This loss function gives us an at-a-glance look for seeing how well the autoencoder can replicate the handwritten digits.
  3. Using the Adam optimizer, determine how to tweak the autoencoder’s neurons to decrease this error.
  4. Repeat steps 1-3 until the NN has viewed all images.

In a normal machine learning training loop, the NN trains over the entire dataset multiple times in different epochs. In this case, I’m OK with using just one epoch since the autoencoder is good enough at that point.

autoencoder = Autoencoder()

# NN Training setup
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(autoencoder.parameters())
best_loss = 3.0e38
losses = []

# Run NN training if not cached
if not (os.path.exists(TRAINED_NN_PATH) and os.path.exists(LOSS_DATA_PATH)):
    # Train the NN
    for batch_idx, data in enumerate(train_dataloader):
        inputs, _ = data
        optimizer.zero_grad()
        outputs = autoencoder(inputs)
        loss = criterion(outputs, inputs)
        loss.backward()
        optimizer.step()
        loss_item = loss.item()
        losses.append(loss_item)
    # Save the NN
    torch.save(autoencoder.state_dict(), TRAINED_NN_PATH)
    losses = torch.tensor(losses)
    torch.save(losses, LOSS_DATA_PATH)
else:
    # Load the already-trained NN and its loss results
    autoencoder.load_state_dict(torch.load(TRAINED_NN_PATH))
    losses = torch.load(LOSS_DATA_PATH)

# Set the autoencoder in evaluation mode
autoencoder.eval()

fig = px.line(
    {"training loss": losses.numpy()}, title="Loss Over Training Batches"
).update_layout(xaxis_title="batch number", yaxis_title="MSE Loss [unitless]")
fig.show()

Demonstrating the Autoencoder

With the autoencoder fully trained, let’s take a look at how quickly it takes to corrupt some test data. We’ll need to apply some extra pre-processing steps on the letters I drew with my computer before we can repeatedly pass them through the autoencoder NN.

A vertical column of handwritten letters spelling the phrase 'I am sitting in a room'.

Here are some handwritten letters I made for the AI to corrupt.
# Load my handwriting image
img = Image.open("my_handwriting.png").convert("L")
# Pre-process & reshape to match MNIST input data
my_handwriting = np.reshape((255 - np.array(img)) / 255, [17, 1, 28, 28]).astype(
    np.float32
)
# Remove grid borders
my_handwriting[:, :, -1, :] = 0.0
my_handwriting[:, :, :, -1] = 0.0
# Prepare plotting data
handwriting_data = []
# Rearrange letters to appear from left to right; append to plotting data
handwriting_data.append(np.hstack(my_handwriting.squeeze()))

for passthrough in range(N_PASSTHROUGHS):
    my_handwriting = (
        autoencoder(torch.tensor(my_handwriting, dtype=torch.float32)).detach().numpy()
    )
    handwriting_data.append(np.hstack(my_handwriting.squeeze()))


fig = px.imshow(
    np.array(handwriting_data),
    animation_frame=0,
    labels={"animation_frame": "n_passthroughs"},
)
fig.show()

It’s pretty satisfying watching the AI corrupt my handwritten letters. Interestingly, some of the intermediate shapes resemble symbols that it was not trained on.

For a more quantitative approach, let’s also see how our loss metric (mean square error) from a test data batch changes over each autoencoder passthrough.

corruption_loss = []

for inputs, _ in test_dataloader:
    corrupted = inputs.clone()
    for passthrough in range(N_PASSTHROUGHS):
        loss = criterion(corrupted, inputs)
        corruption_loss.append(loss.item())
        corrupted = autoencoder(corrupted)
    break

fig = px.line(
    {"testing loss": corruption_loss},
    title="Loss over Autoencoder Passthroughs",
).update_layout(xaxis_title="Passthroughs", yaxis_title="MSE Loss [unitless]")
fig.show()

The repeated compression and decompression undoes the NN’s progress minimizing the mean square error loss; it eventually plateaus towards a loss value close to the AI’s initial iteration loss.

Conclusion

I’m obviously not the first person inspired by I am Sitting in a Room and doing similar things with technology. Marques Brownlee’s YouTube reuploading test showed some interesting video artifacts. His original video used a similar script to I am Sitting in a Room, which I thought was neat.

Package Versions & Last Run Date

Python 3.11.2 (main, Aug 26 2024, 07:20:54) [GCC 12.2.0]
NumPy 2.1.2
PIL 10.4.0
Plotly 5.24.1
PyTorch 2.4.1+cu121
Torchvision 0.19.1+cu121
Last Run 2024-10-13 21:39:13.414318

References

Cohen, Gregory, Saeed Afshar, Jonathan Tapson, and André van Schaik. EMNIST: An Extension of MNIST to Handwritten Letters.” arXiv, February 2017. https://doi.org/10.48550/arXiv.1702.05373.
Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep Learning. Adaptive Computation and Machine Learning. Cambridge, Massachusetts: The MIT Press, 2016.
Paszke, Adam, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, et al. PyTorch: An Imperative Style, High-Performance Deep Learning Library.” arXiv.org, December 2019. https://arxiv.org/abs/1912.01703v1.

Footnotes

  1. Paszke et al., PyTorch.↩︎

  2. Cohen et al., EMNIST.↩︎

  3. Cohen et al.↩︎

  4. Goodfellow, Bengio, and Courville, Deep Learning.↩︎

  5. Goodfellow, Bengio, and Courville.↩︎

Reuse

CC BY SA 4.0 International License(View License)