10000 Use in-place tensor operators where possible. · dotnet/TorchSharpExamples@4c7dbb0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4c7dbb0

Browse files
Use in-place tensor operators where possible.
1 parent 0b757b3 commit 4c7dbb0

File tree

1 file changed

+105
-21
lines changed

1 file changed

+105
-21
lines changed

tutorials/CSharp/synthetic_data.ipynb

Lines changed: 105 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
},
1515
{
1616
"cell_type": "code",
17-
"execution_count": null,
17+
"execution_count": 41,
1818
"metadata": {
1919
"dotnet_interactive": {
2020
"language": "csharp"
@@ -26,7 +26,17 @@
2626
"languageId": "polyglot-notebook"
2727
}
2828
},
29-
"outputs": [],
29+
"outputs": [
30+
{
31+
"data": {
32+
"text/html": [
33+
"<div><div></div><div></div><div><strong>Installed Packages</strong><ul><li><span>TorchSharp-cpu, 0.100.3</span></li></ul></div></div>"
34+
]
35+
},
36+
"metadata": {},
37+
"output_type": "display_data"
38+
}
39+
],
3040
"source": [
3141
"#r \"nuget: TorchSharp-cpu\"\n",
3242
"\n",
@@ -46,7 +56,7 @@
4656
},
4757
{
4858
"cell_type": "code",
49-
"execution_count": null,
59+
"execution_count": 42,
5060
"metadata": {
5161
"vscode": {
5262
"languageId": "polyglot-notebook"
@@ -78,8 +88,8 @@
7888
" public override torch.Tensor forward(torch.Tensor input)\n",
7989
" {\n",
8090
" using var _ = torch.NewDisposeScope();\n",
81-
" var z = torch.tanh(hid1.call(input));\n",
82-
" z = torch.sigmoid(oupt.call(z));\n",
91+
" var z = hid1.call(input).tanh_();\n",
92+
" z = oupt.call(z).sigmoid_();\n",
8393
" return z.MoveToOuterDisposeScope();\n",
8494
" }\n",
8595
"}"
@@ -95,7 +105,7 @@
95105
},
96106
{
97107
"cell_type": "code",
98-
"execution_count": 32,
108+
"execution_count": 43,
99109
"metadata": {
100110
"dotnet_interactive": {
101111
"language": "csharp"
@@ -142,7 +152,7 @@
142152
},
143153
{
144154
"cell_type": "code",
145-
"execution_count": null,
155+
"execution_count": 44,
146156
"metadata": {
147157
"dotnet_interactive": {
148158
"language": "csharp"
@@ -169,7 +179,7 @@
169179
},
170180
{
171181
"cell_type": "code",
172-
"execution_count": null,
182+
"execution_count": 45,
173183
"metadata": {
174184
"dotnet_interactive": {
175185
"language": "csharp"
@@ -197,7 +207,7 @@
197207
},
198208
{
199209
"cell_type": "code",
200-
"execution_count": null,
210+
"execution_count": 46,
201211
"metadata": {
202212
"dotnet_interactive": {
203213
"language": "csharp"
@@ -225,7 +235,7 @@
225235
},
226236
{
227237
"cell_type": "code",
228-
"execution_count": null,
238+
"execution_count": 47,
229239
"metadata": {
230240
"dotnet_interactive": {
231241
"language": "csharp"
@@ -254,8 +264,8 @@
254264
" public override torch.Tensor forward(torch.Tensor input)\n",
255265
" {\n",
256266
" using var _ = torch.NewDisposeScope();\n",
257-
" var z = torch.nn.functional.relu(hid1.call(input));\n",
258-
" z = torch.sigmoid(oupt.call(z));\n",
267+
" var z = hid1.call(input).relu_();\n",
268+
" z = oupt.call(z).sigmoid_();\n",
259269
" return z.MoveToOuterDisposeScope();\n",
260270
" }\n",
261271
"}"
@@ -271,7 +281,7 @@
271281
},
272282
{
273283
"cell_type": "code",
274-
"execution_count": null,
284+
"execution_count": 48,
275285
"metadata": {
276286
"dotnet_interactive": {
277287
"language": "csharp"
@@ -295,12 +305,12 @@
295305
"cell_type": "markdown",
296306
"metadata": {},
297307
"source": [
298-
"A standard training loop. It ends with evaluating the trained model on the training set."
308+
"We need an optimizer."
299309
]
300310
},
301311
{
302312
"cell_type": "code",
303-
"execution_count": null,
313+
"execution_count": 52,
304314
"metadata": {
305315
"dotnet_interactive": {
306316
"language": "csharp"
@@ -315,11 +325,44 @@
315325
"outputs": [],
316326
"source": [
317327
"var learning_rate = 0.01f;\n",
318-
"\n",
328+
"var optimizer = torch.optim.SGD(model.parameters(), learning_rate);"
329+
]
330+
},
331+
{
332+
"attachments": {},
333+
"cell_type": "markdown",
334+
"metadata": {},
335+
"source": [
336+
"A standard training loop. It ends with evaluating the trained model on the training set."
337+
]
338+
},
339+
{
340+
"cell_type": "code",
341+
"execution_count": 59,
342+
"metadata": {
343+
"dotnet_interactive": {
344+
"language": "csharp"
345+
},
346+
"polyglot_notebook": {
347+
"kernelName": "csharp"
348+
},
349+
"vscode": {
350+
"languageId": "polyglot-notebook"
351+
}
352+
},
353+
"outputs": [
354+
{
355+
"name": "stdout",
356+
"output_type": "stream",
357+
"text": [
358+
" initial loss = 0.023704259\n",
359+
" final loss = 0.023490703\n"
360+
]
361+
}
362+
],
363+
"source": [
319364
"Console.WriteLine(\" initial loss = \" + loss.forward(model.forward(X_train), y_train).item<float>().ToString());\n",
320365
"\n",
321-
"var optimizer = torch.optim.SGD(model.parameters(), learning_rate);\n",
322-
"\n",
323366
"for (int i = 0; i < 10000; i++) {\n",
324367
" // Compute the loss\n",
325368
" using var output = loss.forward(model.forward(X_train), y_train);\n",
@@ -346,7 +389,7 @@
346389
},
347390
{
348391
"cell_type": "code",
349-
"execution_count": null,
392+
"execution_count": 60,
350393
"metadata": {
351394
"dotnet_interactive": {
352395
"language": "csharp"
@@ -358,14 +401,55 @@
358401
"languageId": "polyglot-notebook"
359402
}
360403
},
361-
"outputs": [],
404+
"outputs": [
405+
{
406+
"data": {
407+
"text/html": [
408+
"<div class=\"dni-plaintext\"><pre>0.021710658</pre></div><style>\r\n",
409+
".dni-code-hint {\r\n",
410+
" font-style: italic;\r\n",
411+
" overflow: hidden;\r\n",
412+
" white-space: nowrap;\r\n",
413+
"}\r\n",
414+
".dni-treeview {\r\n",
415+
" white-space: nowrap;\r\n",
< 10000 /td>416+
"}\r\n",
417+
".dni-treeview td {\r\n",
418+
" vertical-align: top;\r\n",
419+
" text-align: start;\r\n",
420+
"}\r\n",
421+
"details.dni-treeview {\r\n",
422+
" padding-left: 1em;\r\n",
423+
"}\r\n",
424+
"table td {\r\n",
425+
" text-align: start;\r\n",
426+
"}\r\n",
427+
"table tr { \r\n",
428+
" vertical-align: top; \r\n",
429+
" margin: 0em 0px;\r\n",
430+
"}\r\n",
431+
"table tr td pre \r\n",
432+
"{ \r\n",
433+
" vertical-align: top !important; \r\n",
434+
" margin: 0em 0px !important;\r\n",
435+
"} \r\n",
436+
"table th {\r\n",
437+
" text-align: start;\r\n",
438+
"}\r\n",
439+
"</style>"
440+
]
441+
},
442+
"metadata": {},
443+
"output_type": "display_data"
444+
}
445+
],
362446
"source": [
363447
"loss.forward(model.forward(X_test), y_test).item<float>()"
364448
]
365449
},
366450
{
367451
"cell_type": "code",
368-
"execution_count": null,
452+
"execution_count": 51,
369453
"metadata": {
370454
"dotnet_interactive": {
371455
"language": "csharp"

0 commit comments

Comments
 (0)
0