8000 Train a network on the MNIST data. · mathspp/training@14157e0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 14157e0

Browse files
Train a network on the MNIST data.
1 parent 048461f commit 14157e0

File tree

1 file changed

+61
-0
lines changed
  • neural-networks-fundamentals-with-python

1 file changed

+61
-0
lines chang 10000 ed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from nn import LeakyReLU, MSELoss, Layer, NeuralNetwork
2+
import numpy as np
3+
4+
def load_data(filepath, *args, **kwargs):
5+
print(f"Loading {filepath}...")
6+
data = np.genfromtxt(filepath, *args, **kwargs)
7+
print("Done.")
8+
return data
9+
10+
def to_col(vec):
11+
return vec.reshape((vec.size, 1))
12+
13+
def test(net, test_data):
14+
correct = 0
15+
for i, test_row in enumerate(test_data):
16+
if not i%1000:
17+
print(i)
18+
19+
t = test_row[0]
20+
x = to_col(test_row[1:])
21+
out = net.forward_pass(x)
22+
guess = np.argmax(out)
23+
if t == guess:
24+
correct += 1
25+
26+
return correct
27+
28+
def train(net, train_data):
29+
# Precompute all target vectors.
30+
ts = {}
31+
for t in range(10):
32+
tv = np.zeros((10, 1))
33+
tv[t] = 1
34+
ts[t] = tv
35+
36+
for i, train_row in enumerate(train_data):
37+
if not i%1000:
38+
print(i)
39+
40+
t = ts[train_row[0]]
41+
x = to_col(train_row[1:])
42+
net.train(x, t)
43+
44+
45+
layers = [
46+
Layer(784, 16, LeakyReLU()),
47+
Layer(16, 16, LeakyReLU()),
48+
Layer(16, 10, LeakyReLU()),
49+
]
50+
net = NeuralNetwork(layers, MSELoss(), 0.001)
51+
52+
test_data = load_data("mnistdata/mnist_test.csv", delimiter=",", dtype=int)
53+
54+
correct = test(net, test_data)
55+
print(f"Accuracy is {100*correct/test_data.shape[0]:.2f}%") # Expected to be around 10%
56+
57+
train_data = load_data("mnistdata/mnist_train.csv", delimiter=",", dtype=int)
58+
train(net, train_data)
59+
60+
correct = test(net, test_data)
61+
print(f"Accuracy is {100*correct/test_data.shape[0]:.2f}%")

0 commit comments

Comments
 (0)
0