nn

Lightweight neural network implementation on top of Value autograd framework

source

Module

 Module ()

Initialize self. See help(type(self)) for accurate signature.

Module()
<__main__.Module>

Neuron Implementation

image.png

source

Neuron

 Neuron (nin)

Initialize self. See help(type(self)) for accurate signature.

xs = [1, 2, 3, 4]

n = Neuron(len(xs))
n(xs)
Value(data=-0.9999816874198546)
n.parameters()
[Value(data=0.2996475661487965),
 Value(data=-0.6267630844346634),
 Value(data=-0.9246264014404468),
 Value(data=-0.4674702685842562),
 Value(data=-0.20289127806601925)]

Layer Implementation

image.png

source

Layer

 Layer (nin, nout)

Initialize self. See help(type(self)) for accurate signature.

xs = [1, 2, 3, 4]

l = Layer(len(xs), 3)
l(xs)
[Value(data=0.7048653358992089),
 Value(data=0.7624645201923977),
 Value(data=-0.5361441307353372)]
# len(l.parameters()), l.parameters()
l.num_params(), l.parameters()
(15,
 [Value(data=-0.735292639179753),
  Value(data=-0.46621362338222494),
  Value(data=0.31216478566074635),
  Value(data=0.32439718817347507),
  Value(data=0.31054160829965216),
  Value(data=0.5389804840056651),
  Value(data=0.49574071178176227),
  Value(data=-0.52360851024096),
  Value(data=0.26985245436198113),
  Value(data=-0.03697049273643227),
  Value(data=-0.20940467223829673),
  Value(data=0.6855711092076937),
  Value(data=-0.2123166216168204),
  Value(data=-0.1822824061529016),
  Value(data=-0.39438649023126526)])

Multilayer Perceptron implementation


source

MLP

 MLP (nin, nouts)

Initialize self. See help(type(self)) for accurate signature.

nin = len(xs)
nouts = [3,4,4,1]
# for lin,lout in zip([nin]+nouts[:-1], nouts):
#     print(lin, lout)

mlp = MLP(nin, nouts)
mlp(xs), mlp.num_params()
(Value(data=-0.07189402118195996), 56)
# for lin,lout in zip([nin]+nouts[:-1], nouts):
#     print(lin, lout)
xs = [
    [2.0, 3.0, -1.0],
    [3.0, -1.0, 0.5],
    [0.5, 1.0, 0.0],
    [1.0, 1.0, -1.0]
]

ys = [1.0, -1.0, -1.0, 1.0]
nin = len(xs[0])
nouts = [4,4,1]
mlp = MLP(nin, nouts)
y_preds = [mlp(x) for x in xs]; y_preds
[Value(data=-0.9092726328584118),
 Value(data=-0.8971560386646303),
 Value(data=-0.8972030311657623),
 Value(data=-0.918073509910156)]
loss = sum((y-y_pred)**2 for y, y_pred in zip(ys, y_preds))
loss
Value(data=7.345472073185816)
loss.backward()
draw_dot(loss)

step_size = 0.01
# for p in mlp.parameters():
p = mlp.parameters()[0]
p.grad, p.data
(0.0003160544492504312, -0.20452920628415744)
mlp.layers[0].neurons[0].weights[0].grad, mlp.layers[0].neurons[0].weights[0]
(0.0003160544492504312, Value(data=-0.20452920628415744))
step_size = 0.1
for k in range(200):
    y_preds = [mlp(x) for x in xs]; y_preds
    loss = sum((y-y_pred)**2 for y, y_pred in zip(ys, y_preds))
    mlp.zero_grad()
    loss.backward()
    for p in mlp.parameters():
        p.data -= step_size*p.grad
    print(k, loss.data)
0 0.0013144825642295425
1 0.0013065773994759486
2 0.0012987650979127516
3 0.0012910440439571425
4 0.0012834126591459902
5 0.0012758694010784333
6 0.0012684127623942536
7 0.001261041269786802
8 0.0012537534830490676
9 0.0012465479941515204
10 0.0012394234263506476
11 0.0012323784333267843
12 0.0012254116983503624
13 0.0012185219334753645
14 0.001211707878758976
15 0.0012049683015063265
16 0.0011983019955397005
17 0.001191707780490962
18 0.0011851845011165507
19 0.001178731026634071
20 0.001172346250079836
21 0.0011660290876864667
22 0.0011597784782798942
23 0.0011535933826950608
24 0.0011474727832095593
25 0.0011414156829947766
26 0.0011354211055836336
27 0.0011294880943545928
28 0.0011236157120311791
29 0.0011178030401966382
30 0.0011120491788230093
31 0.001106353245814232
32 0.0011007143765629024
33 0.001095131723519913
34 0.0010896044557768638
35 0.0010841317586605796
36 0.0010787128333394369
37 0.0010733468964410527
38 0.001068033179681
39 0.0010627709295021642
40 0.0010575594067242992
41 0.001052397886203644
42 0.0010472856565020721
43 0.0010422220195655046
44 0.0010372062904114698
45 0.0010322377968251485
46 0.0010273158790640878
47 0.0010224398895708389
48 0.001017609192693632
49 0.0010128231644146
50 0.0010080811920854943
51 0.0010033826741705134
52 0.0009987270199959923
53 0.0009941136495069342
54 0.000989541993030058
55 0.0009850114910430138
56 0.0009805215939498553
57 0.0009760717618623543
58 0.000971661464387156
59 0.0009672901804183503
60 0.0009629573979355989
61 0.0009586626138074349
62 0.0009544053335995889
63 0.000950185071388379
64 0.0009460013495787001
65 0.0009418536987268519
66 0.0009377416573676927
67 0.0009336647718463008
68 0.0009296225961537306
69 0.0009256146917670592
70 0.0009216406274932613
71 0.0009176999793170925
72 0.0009137923302526909
73 0.0009099172701987749
74 0.0009060743957975801
75 0.0009022633102969983
76 0.0008984836234163037
77 0.0008947349512150158
78 0.0008910169159649467
79 0.0008873291460253709
80 0.0008836712757211151
81 0.000880042945223655
82 0.0008764438004349868
83 0.0008728734928742112
84 0.0008693316795668931
85 0.0008658180229369572
86 0.0008623321907010905
87 0.0008588738557657298
88 0.0008554426961262763
89 0.0008520383947688007
90 0.000848660639573885
91 0.0008453091232227803
92 0.00084198354310567
93 0.0008386836012319776
94 0.0008354090041427868
95 0.0008321594628252636
96 0.0008289346926288292
97 0.000825734413183521
98 0.0008225583483198497
99 0.0008194062259906968
100 0.0008162777781948332
101 0.0008131727409021486
102 0.0008100908539805008
103 0.0008070318611242684
104 0.0008039955097843385
105 0.0008009815510997533
106 0.0007979897398306936
107 0.0007950198342931154
108 0.0007920715962945789
109 0.0007891447910715675
110 0.0007862391872282258
111 0.0007833545566762403
112 0.0007804906745760965
113 0.0007776473192796074
114 0.0007748242722735751
115 0.0007720213181247004
116 0.0007692382444256329
117 0.0007664748417421036
118 0.0007637309035612887
119 0.0007610062262410678
120 0.000758300608960478
121 0.0007556138536710872
122 0.0007529457650494999
123 0.0007502961504506618
124 0.0007476648198622622
125 0.0007450515858599953
126 0.0007424562635637399
127 0.0007398786705946248
128 0.0007373186270329898
129 0.0007347759553771366
130 0.0007322504805029422
131 0.000729742029624277
132 0.0007272504322541954
133 0.0007247755201669193
134 0.0007223171273605169
135 0.0007198750900203924
136 0.000717449246483458
137 0.0007150394372029725
138 0.000712645504714129
139 0.0007102672936002208
140 0.0007079046504596337
141 0.0007055574238732472
142 0.0007032254643726573
143 0.0007009086244088784
144 0.0006986067583217559
145 0.0006963197223098371
146 0.0006940473744009292
147 0.0006917895744231622
148 0.0006895461839765717
149 0.0006873170664053042
150 0.0006851020867702303
151 0.0006829011118221537
152 0.0006807140099755039
153 0.0006785406512825029
154 0.0006763809074078174
155 0.0006742346516036669
156 0.0006721017586854316
157 0.000669982105007664
158 0.0006678755684405401
159 0.000665782028346798
160 0.0006637013655589932
161 0.0006616334623573101
162 0.0006595782024476249
163 0.000657535470940069
164 0.0006555051543279252
165 0.000653487140466935
166 0.0006514813185549407
167 0.0006494875791119498
168 0.0006475058139604537
169 0.0006455359162062066
170 0.0006435777802192888
171 0.0006416313016154834
172 0.0006396963772380257
173 0.000637772905139629
174 0.0006358607845649011
175 0.0006339599159329587
176 0.0006320702008204078
177 0.0006301915419446457
178 0.0006283238431473647
179 0.0006264670093784374
180 0.0006246209466800182
181 0.0006227855621708926
182 0.0006209607640312016
183 0.0006191464614873015
184 0.0006173425647969686
185 0.0006155489852347961
186 0.0006137656350778798
187 0.0006119924275917459
188 0.0006102292770164558
189 0.0006084760985530041
190 0.0006067328083499474
191 0.00060499932349022
192 0.0006032755619781668
193 0.0006015614427268049
194 0.0005998568855453456
195 0.0005981618111268101
196 0.0005964761410359543
197 0.000594799797697332
198 0.0005931327043835735
199 0.0005914747852038544
y_preds
[Value(data=0.9894351129108765),
 Value(data=-0.998879606224376),
 Value(data=-0.9873883237450028),
 Value(data=0.982124086433846)]