|
105 | 105 | },
|
106 | 106 | {
|
107 | 107 | "cell_type": "code",
|
108 |
| - "execution_count": 63, |
| 108 | + "execution_count": 80, |
109 | 109 | "metadata": {
|
110 | 110 | "dotnet_interactive": {
|
111 | 111 | "language": "csharp"
|
|
137 | 137 | " y += torch.randn(y.shape) * 0.01;\n",
|
138 | 138 | "\n",
|
139 | 139 | " // Make sure that the output isn't negative.\n",
|
140 |
| - " y += torch.where(y < 0.0, 0.01 * torch.randn(y.shape) + 0.01, torch.zeros(y.shape));\n", |
| 140 | + " //y += torch.where(y < 0.0, 0.01 * torch.randn(y.shape) + 0.01, torch.zeros(y.shape));\n", |
| 141 | + " y = y.relu_();\n", |
141 | 142 | "\n",
|
142 | 143 | " // Save the data in two separate, binary files.\n",
|
143 | 144 | " X.save(fileName + \".x\");\n",
|
|
152 | 153 | },
|
153 | 154 | {
|
154 | 155 | "cell_type": "code",
|
155 |
| - "execution_count": 64, |
| 156 | + "execution_count": 81, |
156 | 157 | "metadata": {
|
157 | 158 | "dotnet_interactive": {
|
158 | 159 | "language": "csharp"
|
|
179 | 180 | },
|
180 | 181 | {
|
181 | 182 | "cell_type": "code",
|
182 |
| - "execution_count": 65, |
| 183 | + "execution_count": 82, |
183 | 184 | "metadata": {
|
184 | 185 | "dotnet_interactive": {
|
185 | 186 | "language": "csharp"
|
|
207 | 208 | },
|
208 | 209 | {
|
209 | 210 | "cell_type": "code",
|
210 |
| - "execution_count": 66, |
| 211 | + "execution_count": 83, |
211 | 212 | "metadata": {
|
212 | 213 | "dotnet_interactive": {
|
213 | 214 | "language": "csharp"
|
|
235 | 236 | },
|
236 | 237 | {
|
237 | 238 | "cell_type": "code",
|
238 |
| - "execution_count": 67, |
| 239 | + "execution_count": 84, |
239 | 240 | "metadata": {
|
240 | 241 | "dotnet_interactive": {
|
241 | 242 | "language": "csharp"
|
|
281 | 282 | },
|
282 | 283 | {
|
283 | 284 | "cell_type": "code",
|
284 |
| - "execution_count": 68, |
| 285 | + "execution_count": 103, |
285 | 286 | "metadata": {
|
286 | 287 | "dotnet_interactive": {
|
287 | 288 | "language": "csharp"
|
|
310 | 311 | },
|
311 | 312 | {
|
312 | 313 | "cell_type": "code",
|
313 |
| - "execution_count": 69, |
| 314 | + "execution_count": 104, |
314 | 315 | "metadata": {
|
315 | 316 | "dotnet_interactive": {
|
316 | 317 | "language": "csharp"
|
|
325 | 326 | "outputs": [],
|
326 | 327 | "source": [
|
327 | 328 | "var learning_rate = 0.01f;\n",
|
328 |
| - "var optimizer = torch.optim.SGD(model.parameters(), learning_rate);" |
| 329 | + "var optimizer = torch.optim.Rprop(model.parameters(), learning_rate);" |
329 | 330 | ]
|
330 | 331 | },
|
331 | 332 | {
|
|
338 | 339 | },
|
339 | 340 | {
|
340 | 341 | "cell_type": "code",
|
341 |
| - "execution_count": 78, |
| 342 | + "execution_count": 115, |
342 | 343 | "metadata": {
|
343 | 344 | "dotnet_interactive": {
|
344 | 345 | "language": "csharp"
|
|
355 | 356 | "name": "stdout",
|
356 | 357 | "output_type": "stream",
|
357 | 358 | "text": [
|
358 | | - " initial loss = 0.0063750837\n", |
359 |
| - " final loss = 0.007656585\n" |
| 359 | + " initial loss = 0.00047994562\n", |
| 360 | + " final loss = 0.0004698771\n" |
360 | 361 | ]
|
361 | 362 | }
|
362 | 363 | ],
|
|
389 | 390 | },
|
390 | 391 | {
|
391 | 392 | "cell_type": "code",
|
392 |
| - "execution_count": 79, |
| 393 | + "execution_count": 116, |
393 | 394 | "metadata": {
|
394 | 395 | "dotnet_interactive": {
|
395 | 396 | "language": "csharp"
|
|
405 | 406 | {
|
406 | 407 | "data": {
|
407 | 408 | "text/html": [
|
408 |
| - "<div class=\"dni-plaintext\"><pre>0.008805012</pre></div><style>\r\n", |
| 409 | + "<div class=\"dni-plaintext\"><pre>0.00061915093</pre></div><style>\r\n", |
409 | 410 | ".dni-code-hint {\r\n",
|
410 | 411 | " font-style: italic;\r\n",
|
411 | 412 | " overflow: hidden;\r\n",
|
|
0 commit comments