8000 Merge pull request #388 from DiffSharp/gunes/fix21 · DiffSharp/DiffSharp@5a80ff1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5a80ff1

Browse files
authored
Merge pull request #388 from DiffSharp/gunes/fix21
Differentiable programs
2 parents 336a22d + df42259 commit 5a80ff1

File tree

14 files changed

+278
-114
lines changed

14 files changed

+278
-114
lines changed

examples/classifier.fsx

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88
// Libtorch binaries
99
// Option A: you can use a platform-specific nuget package
10-
// #r "nuget: libtorch-cuda-11.1-win-x64, 1.8.0.7"
11-
// #r "nuget: libtorch-cuda-11.1-linux-x64, 1.9.0.10"
10+
#r "nuget: TorchSharp-cpu, 0.93.5"
11+
// #r "nuget: TorchSharp-cuda-linux, 0.93.5"
12+
// #r "nuget: TorchSharp-cuda-windows, 0.93.5"
1213
// Option B: you can use a local libtorch installation
13-
System.Runtime.InteropServices.NativeLibrary.Load("/home/gunes/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch.so")
14+
// System.Runtime.InteropServices.NativeLibrary.Load("/home/gunes/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch.so")
1415

1516

1617
open DiffSharp
@@ -77,7 +78,7 @@ let validSet = MNIST("../data", urls=urls, train=false)
7778
let validLoader = validSet.loader(batchSize=batchSize, shuffle=false)
7879

7980

80-
printfn "Model: %A" classifier
81+
printfn "Model:\n%s" (classifier.summary())
8182

8283
let optimizer = Adam(classifier, lr=dsharp.tensor(0.001))
8384

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#!/usr/bin/env -S dotnet fsi
2+
3+
#I "../tests/DiffSharp.Tests/bin/Debug/net5.0"
4+
#r "DiffSharp.Core.dll"
5+
#r "DiffSharp.Data.dll"
6+
#r "DiffSharp.Backends.Torch.dll"
7+
8+
// Libtorch binaries
9+
// Option A: you can use a platform-specific nuget package
10+
#r "nuget: TorchSharp-cpu, 0.93.5"
11+
// #r "nuget: TorchSharp-cuda-linux, 0.93.5"
12+
// #r "nuget: TorchSharp-cuda-windows, 0.93.5"
13+
// Option B: you can use a local libtorch installation
14+
// System.Runtime.InteropServices.NativeLibrary.Load("/home/gunes/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch.so")
15+
16+
17+
open DiffSharp
18+
open DiffSharp.Compose
19+
open DiffSharp.Model
20+
open DiffSharp.Data
21+
open DiffSharp.Optim
22+
open DiffSharp.Util
23+
open DiffSharp.Distributions
24+
25+
open System.IO
26+
27+
dsharp.config(backend=Backend.Torch, device=Device.CPU)
28+
dsharp.seed(1)
29+
30+
type Model<'In, 'Out> with
31+
member m.run = m.forward
32+
type DiffProg<'In, 'Out> = Model<'In, 'Out>
33+
34+
35+
let diffprog parameters (f:'In->'Out) : DiffProg<'In, 'Out>=
36+
DiffProg<'In, 'Out>.create [] parameters [] f
37+
38+
let param (x:Tensor) = Parameter(x)
39+
40+
// Learn a differentiable program given an objective
41+
// DiffProg<'a,'b> -> (DiffProg<'a,'b> -> Tensor) -> DiffProg<'a,'b>
42+
let learn (diffprog:DiffProg<_,_>) loss =
43+
let lr = 0.001
44+
for i=0 to 10 do
45+
diffprog.reverseDiff()
46+
let l:Tensor = loss diffprog
47+
l.reverse()
48+
let p = diffprog.parametersVector
49+
diffprog.parametersVector <- p.primal - lr * p.derivative
50+
printfn "iteration %A, loss %A" i (float l)
51+
diffprog
52+
53+
// A linear model as a differentiable program
54+
// DiffProg<Tensor,Tensor>
55+
let dp =
56+
let w = param (dsharp.randn([5; 1]))
57+
diffprog [w]
58+
(fun (x:Tensor) -> x.matmul(w.value))
59+
60+
// Data
61+
let x = dsharp.randn([1024; 5])
62+
let y = dsharp.randn([1024; 1])
63+
64+
// let a = diffprog.run x
65+
// printfn "%A %A %A " a.shape y.shape (dsharp.mseLoss(a, y))
66+
67+
// Objective
68+
// DiffProg<Tensor,Tensor> -> Tensor
69+
let loss (diffprog:DiffProg<Tensor, Tensor>) = dsharp.mseLoss(diffprog.run x, y)
70+
71+
// Learned diferentiable program
72+
// DiffProg<Tensor,Tensor>
73+
let dpLearned = learn dp loss
74+
75+
// Function that runs the differentiable program with new data
76+
// Tensor -> Tensor
77+
dpLearned.run

examples/gan.fsx

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88
// Libtorch binaries
99
// Option A: you can use a platform-specific nuget package
10-
// #r "nuget: libtorch-cuda-11.1-win-x64, 1.8.0.7"
11-
// #r "nuget: libtorch-cuda-11.1-linux-x64, 1.8.0.7"
10+
#r "nuget: TorchSharp-cpu, 0.93.5"
11+
// #r "nuget: TorchSharp-cuda-linux, 0.93.5"
12+
// #r "nuget: TorchSharp-cuda-windows, 0.93.5"
1213
// Option B: you can use a local libtorch installation
13-
System.Runtime.InteropServices.NativeLibrary.Load("/home/gunes/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch.so")
14+
// System.Runtime.InteropServices.NativeLibrary.Load("/home/gunes/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch.so")
1415

1516

1617
open DiffSharp
@@ -94,11 +95,9 @@ let discriminator =
9495
--> Linear(256, 1)
9596
--> dsharp.sigmoid
9697

97-
print "Generator"
98-
print generator
98+
printfn "Generator\n%s" (generator.summary())
9999

100-
print "Discriminator"
101-
print discriminator
100+
printfn "Discriminator\n%s" (discriminator.summary())
102101

103102
let epochs = 10
104103
let batchSize = 16

examples/rnn.fsx

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88
// Libtorch binaries
99
// Option A: you can use a platform-specific nuget package
10-
// #r "nuget: libtorch-cuda-11.1-win-x64, 1.8.0.7"
11-
// #r "nuget: libtorch-cuda-11.1-linux-x64, 1.8.0.7"
10+
#r "nuget: TorchSharp-cpu, 0.93.5"
11+
// #r "nuget: TorchSharp-cuda-linux, 0.93.5"
12+
// #r "nuget: TorchSharp-cuda-windows, 0.93.5"
1213
// Option B: you can use a local libtorch installation
13-
System.Runtime.InteropServices.NativeLibrary.Load("/home/gunes/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch.so")
14+
// System.Runtime.InteropServices.NativeLibrary.Load("/home/gunes/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch.so")
1415

1516

1617
open DiffSharp
@@ -43,7 +44,7 @@ let languageModel =
4344
--> dsharp.view([-1; 512])
4445
--> Linear(512, dataset.numChars)
4546

46-
print languageModel
47+
printfn "%s" (languageModel.summary())
4748

4849
let modelFileName = "rnn_language_model.params"
4950
if File.Exists(modelFileName) then

examples/vae.fsx

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88
// Libtorch binaries
99
// Option A: you can use a platform-specific nuget package
10-
// #r "nuget: libtorch-cuda-11.1-win-x64, 1.8.0.7"
11-
// #r "nuget: libtorch-cuda-11.1-linux-x64, 1.9.0.10"
10+
#r "nuget: TorchSharp-cpu, 0.93.5"
11+
// #r "nuget: TorchSharp-cuda-linux, 0.93.5"
12+
// #r "nuget: TorchSharp-cuda-windows, 0.93.5"
1213
// Option B: you can use a local libtorch installation
13-
System.Runtime.InteropServices.NativeLibrary.Load("/home/gunes/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch.so")
14+
// System.forwardtime.InteropServices.NativeLibrary.Load("/home/gunes/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch.so")
1415

1516

1617
open DiffSharp
@@ -100,7 +101,7 @@ let validSet = MNIST("../data", urls=urls, train=false, transform=id)
100101
let validLoader = validSet.loader(batchSize=batchSize, shuffle=false)
101102

102103
let model = VAE(28*28, 20, [400])
103-
printfn "Model: %A" model
104+
printfn "Model\n%s" (model.summary())
104105

105106
let optimizer = Adam(model, lr=dsharp.tensor(0.001))
106107

examples/vae_cnn.fsx

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88
// Libtorch binaries
99
// Option A: you can use a platform-specific nuget package
10-
// #r "nuget: libtorch-cuda-11.1-win-x64, 1.8.0.7"
11-
// #r "nuget: libtorch-cuda-11.1-linux-x64, 1.9.0.10"
10+
#r "nuget: TorchSharp-cpu, 0.93.5"
11+
// #r "nuget: TorchSharp-cuda-linux, 0.93.5"
12+
// #r "nuget: TorchSharp-cuda-windows, 0.93.5"
1213
// Option B: you can use a local libtorch installation
13-
System.Runtime.InteropServices.NativeLibrary.Load("/home/gunes/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch.so")
14+
// System.Runtime.InteropServices.NativeLibrary.Load("/home/gunes/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch.so")
1415

1516

1617
open DiffSharp
@@ -57,7 +58,7 @@ let decoder =
5758

5859
let model = VAE([1;28;28], 64, encoder, decoder)
5960

60-
printfn "Model: %A" model
61+
printfn "Model\n%s" (model.summary())
6162

6263
let optimizer = Adam(model, lr=dsharp.tensor(0.001))
6364

src/DiffSharp.Core/DiffSharp.fs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,15 +1392,15 @@ type dsharp with
13921392
static member noDiff (tensor:Tensor) = tensor.noDiff()
13931393

13941394
/// <summary>Produce a new tensor suitable for calculating the forward-mode derivative at the given level tag.</summary>
1395-
/// <param name="tag">The level tag.</param>
1395+
/// <param name="nestingTag">The level tag.</param>
13961396
/// <param name="derivative">The derivative of the input.</param>
13971397
/// <param name="tensor">The input.</param>
1398-
static member forwardDiff (tag:uint32) (derivative:Tensor) (tensor:Tensor) = tensor.forwardDiff(derivative, tag)
1398+
static member forwardDiff (nestingTag:uint32) (derivative:Tensor) (tensor:Tensor) = tensor.forwardDiff(derivative, nestingTag)
13991399

14001400
/// <summary>Produce a new tensor suitable for calculating the reverse-mode derivative at the given level tag.</summary>
1401-
/// <param name="tag">The level tag.</param>
1401+
/// <param name="nestingTag">The level tag.</param>
14021402
/// <param name="tensor">The output tensor.</param>
1403-
static member reverseDiff (tag:uint32) (tensor:Tensor) = tensor.reverseDiff(tag)
1403+
static member reverseDiff (nestingTag:uint32) (tensor:Tensor) = tensor.reverseDiff(nestingTag=nestingTag)
14041404

14051405
/// <summary>Reset the reverse mode computation associated with the given output tensor.</summary>
14061406
/// <param name="tensor">The output tensor.</param>

src/DiffSharp.Core/Model.Dropout.fs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
namespace DiffSharp.Model
77

8-
open DiffSharp
98

109
/// <summary>A model which during training, randomly zeroes some of the elements of the input tensor with probability p using samples from a Bernoulli distribution.</summary>
1110
type Dropout(?p:double) =

src/DiffSharp.Core/Model.VAE.fs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
namespace DiffSharp.Model
77

88
open DiffSharp
9-
open DiffSharp.Compose
109

1110
/// <summary>Variational auto-encoder base</summary>
1211
[<AbstractClass>]

0 commit comments

Comments
 (0)
0