8000 Add code for the session on the 25th of march. · mathspp/training@e595215 · GitHub
[go: up one dir, main page]

8000
Skip to content

Commit e595215

Browse files
Add code for the session on the 25th of march.
1 parent eaa2177 commit e595215

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

sessions/nn_18032021/mnist.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#import csv
2+
import numpy as np
3+
from nn import NeuralNetwork, Layer, LeakyReLU, MSELoss
4+
5+
def load_data(path):
6+
with open(path, "r") as f:
7+
contents = f.read()
8+
# list comprehension
9+
rows = contents.split("\n")[:-1]
10+
10000 data = [list(map(int, row.split(","))) for row in rows]
11+
return np.array(data)
12+
#return np.genfromtxt(path, delimiter=',')
13+
14+
def test(net, data):
15+
"""Test the network on the rows of the data."""
16+
17+
correct = 0
18+
for i, row in enumerate(data):
19+
if i % 1000 == 0:
20+
print(i)
21+
digit = row[0]
22+
x = row[1:].reshape((784, 1))
23+
out = net.forward_pass(x)
24+
if digit == np.argmax(out):
25+
correct += 1
26+
return correct/data.shape[0]
27+
28+
def train_student(teacher, student, data):
29+
"""Traverse the data and teach the student to act like the teacher."""
30+
31+
for i, row in enumerate(data):
32+
if i % 1000 == 0:
33+
print(i)
34+
digit = row[0]
35+
x = row[1:].reshape((784, 1))
36+
t = teacher.forward_pass(x)
37+
student.train(x, t)
38+
39+
def train(net, data):
40+
"""Train the network on the rows of the data."""
41+
# Precompute the target column vectors.
42+
ts = {}
43+
for digit in range(10):
44+
t = np.zeros((10, 1))
45+
t[digit] = 1
46+
ts[digit] = t
47+
48+
for i, row in enumerate(data):
49+
if i % 1000 == 0:
50+
print(i)
51+
digit = row[0]
52+
x = row[1:].reshape((784, 1))
53+
net.train(x, ts[digit])
54+
55+
if __name__ == "__main__":
56+
layers = [
57+
Layer(784, 16, LeakyReLU()),
58+
Layer(16, 16, LeakyReLU()),
59+
Layer(16, 10, LeakyReLU()),
60+
]
61+
net = NeuralNetwork(layers, MSELoss(), 0.001)
62+
# CrossEntropyLoss ← um pouco mais chato
63+
# Sigmoid ← wikipedia
64+
65+
print("Loading data...")
66+
train_data = load_data("mnistdata/mnist_train.csv")
67+
print("Done.")
68+
69+
print("Training network...")
70+
train(net, train_data)
71+
print("Done.")
72+
73+
print("Loading data...")
74+
test_data = load_data("mnistdata/mnist_test.csv")
75+
print("Done.")
76+
77+
print("Testing network...")
78+
accuracy = test(net, test_data)
79+
print(round(100*accuracy, 2))
80+
81+
student = NeuralNetwork(
82+
[Layer(784, 10, LeakyReLU())], MSELoss(), 0.005
83+
)
84+
85+
print("Training the student...")
86+
train_student(net, student, train_data)
87+
print("Done.")
88+
89+
accuracy = test(student, test_data)
90+
print(round(100*accuracy, 2))

0 commit comments

Comments
 (0)
0