8000 functional model creation and composition · DiffSharp/DiffSharp@091447c · GitHub
[go: up one dir, main page]

Skip to content

Commit 091447c

Browse files
committed
functional model creation and composition
1 parent 30b9357 commit 091447c

File tree

5 files changed

+96
-40
lines changed

5 files changed

+96
-40
lines changed

src/DiffSharp.Core/DiffSharp.fs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ type DiffSharp =
157157
static member conv1d(b:Tensor, ?stride:int, ?padding:int, ?dilation:int) = fun (a:Tensor) -> a.conv1d(b, ?stride=stride, ?padding=padding, ?dilation=dilation)
158158
static member conv2d(a:Tensor, b:Tensor, ?stride:int, ?strides:seq<int>, ?padding:int, ?paddings:seq<int>, ?dilation:int, ?dilations:seq<int>) = a.conv2d(b, ?stride=stride, ?strides=strides, ?padding=padding, ?paddings=paddings, ?dilation=dilation, ?dilations=dilations)
159159
static member conv2d(b:Tensor, ?stride:int, ?strides:seq<int>, ?padding:int, ?paddings:seq<int>, ?dilation:int, ?dilations:seq<int>) = fun (a:Tensor) -> a.conv2d(b, ?stride=stride, ?strides=strides, ?padding=padding, ?paddings=paddings, ?dilation=dilation, ?dilations=dilations)
160+
static member conv3d(a:Tensor, b:Tensor, ?stride:int, ?strides:seq<int>, ?padding:int, ?paddings:seq<int>, ?dilation:int, ?dilations:seq<int>) = a.conv3d(b, ?stride=stride, ?strides=strides, ?padding=padding, ?paddings=paddings, ?dilation=dilation, ?dilations=dilations)
161+
static member conv3d(b:Tensor, ?stride:int, ?strides:seq<int>, ?padding:int, ?paddings:seq<int>, ?dilation:int, ?dilations:seq<int>) = fun (a:Tensor) -> a.conv3d(b, ?stride=stride, ?strides=strides, ?padding=padding, ?paddings=paddings, ?dilation=dilation, ?dilations=dilations)
160162

161163
// Methods mirroring F# array modules
162164
// TODO: update to support non-float types once we have backing DTypes implemented

src/DiffSharp.Core/Model.fs

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,28 @@ type ParameterDict() =
4747
let dd = d.copy()
4848
dd.unflatten(tensors)
4949
dd
50+
override d.ToString() =
51+
let sb = System.Text.StringBuilder()
52+
for KeyValue(n, p) in d.values do sb.AppendLine(sprintf "%A, %A" n p) |> ignore
53+
sb.ToString()
54+
5055

5156
[<AbstractClass>]
5257
type Model() =
5358
member val Parameters:ParameterDict = ParameterDict()
5459
member val SubModels:Dictionary<string, Model> = Dictionary()
55-
member inline m.add(parameters:list<string * 'a>) =
56-
for n, p in parameters do
60+
member m.add(parameters:seq<obj>, ?names:seq<string>) =
61+
let parameters = parameters |> Seq.toArray
62+
let names = defaultArg names (Seq.init (parameters.Length) (fun i -> sprintf "p__%d" i)) |> Seq.toArray
63+
if parameters.Length <> names.Length then failwithf "Expecting parameters.Length (%A) and names.Length (%A) to be same" parameters.Length names.Length
64+
for p, n in Array.zip parameters names do
5765
match (box p) with
5866
| :? Parameter as p ->
5967
m.Parameters.add(n, p)
6068
| :? Model as mm ->
6169
m.SubModels.Add(n, mm)
6270
m.Parameters.add(mm.Parameters.map(fun (nn, pp:Parameter) -> (n + "__" + nn, pp)))
63-
| _ -> failwithf "Unsupported type. Expecting a list<string * 'a> where 'a is Parameter or Model"
71+
| _ -> failwithf "Unsupported type. Expecting a Parameter or Model"
6472
member m.forwardDiff(derivatives:ParameterDict) = m.Parameters.forwarddiff(derivatives)
6573
member m.reverseDiff() = m.Parameters.reverseDiff()
6674
member m.noDiff() = m.Parameters.noDiff()
@@ -76,6 +84,12 @@ type Model() =
7684
member m.forwardLoss (f:Tensor->Tensor->Tensor) (input:Tensor) (target:Tensor) (parameters:Tensor) =
7785
m.forwardCompose (f target) input parameters
7886
abstract member forward: Tensor -> Tensor
87+
static member create ps f =
88+
let model = { new Model() with override __.forward(x) = f x}
89+
model.add(ps)
90+
model
91+
static member compose (model1:Model) (model2:Model) =
92+
Model.create [model1; model2] (model1.forward >> model2.forward)
7993

8094

8195
type Weight() =
@@ -95,7 +109,7 @@ type Linear(inFeatures, outFeatures, ?bias:bool) =
95109
let w = Parameter(Weight.kaiming(inFeatures, outFeatures))
96110
let k = 1./sqrt (float outFeatures)
97111
let b = Parameter(if bias then Weight.standard([|outFeatures|], k) else dsharp.zero())
98-
do base.add(["weight", w; "bias", b])
112+
do base.add([w;b],["Linear__weight";"Linear__bias"])
99113
override l.forward(value) =
100114
let f = dsharp.matmul(value, w.value)
101115
if bias then f + b.value else f
@@ -107,7 +121,7 @@ type Conv1d(inChannels:int, outChannels:int, kernelSize:int, ?stride:int, ?paddi
107121
let k = 1./ sqrt (float (inChannels*kernelSize))
108122
let w = Parameter <| Weight.standard([|outChannels; inChannels; kernelSize|], k)
109123
let b = Parameter <| if bias then Weight.standard([|outChannels|], k) else dsharp.zero()
110-
do base.add(["weight", w; "bias", b])
124+
do base.add([w;b],["Conv1d__weight";"Conv1d__bias"])
111125
override c.forward(value) =
112126
let f = dsharp.conv1d(value, w.value, ?stride=stride, ?padding=padding, ?dilation=dilation)
113127
if bias then f + b.value.expand([value.shape.[0]; outChannels]).view([value.shape.[0]; outChannels; 1]) else f
@@ -119,13 +133,31 @@ type Conv2d(inChannels:int, outChannels:int, ?kernelSize:int, ?stride:int, ?padd
119133
match kernelSize, kernelSizes with
120134
| Some _ , Some _ -> failwithf "Expecting only one of kernelSize, kernelSizes"
121135
| Some k, None -> [|k; k|]
122-
| None, Some k -> k |> Array.ofSeq
136+
| None, Some k -> let k = k |> Array.ofSeq in if k.Length <> 2 then failwithf "Expecting kernelSizes to have length two" else k
123137
| _ -> [|1; 1|]
124138
let bias = defaultArg bias true
125139
let k = 1./ sqrt (float (inChannels*kernelSizes.[0]*kernelSizes.[1]))
126140
let w = Parameter <| Weight.standard([|outChannels; inChannels; kernelSizes.[0]; kernelSizes.[1]|], k)
127141
let b = Parameter <| if bias then Weight.standard([|outChannels|], k) else dsharp.zero()
128-
do base.add(["weight", w; "bias", b])
142+
do base.add([w;b],["Conv2d__weight";"Conv2d__bias"])
129143
override c.forward(value) =
130144
let f = dsharp.conv2d(value, w.value, ?stride=stride, ?strides=strides, ?padding=padding, ?paddings=paddings, ?dilation=dilation, ?dilations=dilations)
131-
if bias then f + b.value.expand([value.shape.[0]; outChannels]).view([value.shape.[0]; outChannels; 1; 1]) else f
145+
if bias then f + b.value.expand([value.shape.[0]; outChannels]).view([value.shape.[0]; outChannels; 1; 1]) else f
146+
147+
148+
type Conv3d(inChannels:int, outChannels:int, ?kernelSize:int, ?stride:int, ?padding:int, ?dilation:int, ?kernelSizes:seq<int>, ?strides:seq<int>, ?paddings:seq<int>, ?dilations:seq<int>, ?bias:bool) =
149+
inherit Model()
150+
let kernelSizes =
151+
match kernelSize, kernelSizes with
152+
| Some _ , Some _ -> failwithf "Expecting only one of kernelSize, kernelSizes"
153+
| Some k, None -> [|k; k; k|]
154+
| None, Some k -> let k = k |> Array.ofSeq in if k.Length <> 3 then failwithf "Expecting kernelSizes to have length three" else k
155+
| _ -> [|1; 1; 1|]
156+
let bias = defaultArg bias true
157+
let k = 1./ sqrt (float (inChannels*kernelSizes.[0]*kernelSizes.[1]*kernelSizes.[2]))
158+
let w = Parameter <| Weight.standard([|outChannels; inChannels; kernelSizes.[0]; kernelSizes.[1]; kernelSizes.[2]|], k)
159+
let b = Parameter <| if bias then Weight.standard([|outChannels|], k) else dsharp.zero()
160+
do base.add([w;b],["Conv3d__weight";"Conv3d__bias"])
161+
override c.forward(value) =
162+
let f = dsharp.conv3d(value, w.value, ?stride=stride, ?strides=strides, ?padding=padding, ?paddings=paddings, ?dilation=dilation, ?dilations=dilations)
163+
if bias then f + b.value.expand([value.shape.[0]; outChannels]).view([value.shape.[0]; outChannels; 1; 1; 1]) else f

src/DiffSharp.Core/Tensor.fs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,19 +1130,19 @@ type Tensor =
11301130
match stride, strides with
11311131
| Some _ , Some _ -> failwithf "Expecting only one of stride, strides"
11321132
| Some s, None -> [|s; s|]
1133-
| None, Some s -> s |> Array.ofSeq
1133+
| None, Some s -> let s = s |> Array.ofSeq in if s.Length <> 2 then failwithf "Expecting strides to have length two" else s
11341134
| _ -> [|1; 1|]
11351135
let paddings =
11361136
match padding, paddings with
11371137
| Some _ , Some _ -> failwithf "Expecting only one of padding, paddings"
11381138
| Some p, None -> [|p; p|]
1139-
| None, Some p -> p |> Array.ofSeq
1139+
| None, Some p -> let p = p |> Array.ofSeq in if p.Length <> 2 then failwithf "Expecting paddings to have length two" else p
11401140
| _ -> [|0; 0|]
11411141
let dilations =
11421142
match dilation, dilations with
11431143
| Some _ , Some _ -> failwithf "Expecting only one of dilation, dilations"
11441144
| Some d, None -> [|d; d|]
1145-
| None, Some d -> d |> Array.ofSeq
1145+
| None, Some d -> let d = d |> Array.ofSeq in if d.Length <> 2 then failwithf "Expecting dilations to have length two" else d
11461146
| _ -> [|1; 1|]
11471147
checkCanConv2d a.shape b.shape strides paddings dilations
11481148
let mutable b = b
@@ -1217,19 +1217,19 @@ type Tensor =
12171217
match stride, strides with
12181218
| Some _ , Some _ -> failwithf "Expecting only one of stride, strides"
12191219
| Some s, None -> [|s; s; s|]
1220-
| None, Some s -> s |> Array.ofSeq
1220+
| None, Some s -> let s = s |> Array.ofSeq in if s.Length <> 3 then failwithf "Expecting strides to have length three" else s
12211221
| _ -> [|1; 1; 1|]
12221222
let paddings =
12231223
match padding, paddings with
12241224
| Some _ , Some _ -> failwithf "Expecting only one of padding, paddings"
12251225
| Some p, None -> [|p; p; p|]
1226-
| None, Some p -> p |> Array.ofSeq
1226+
| None, Some p -> let p = p |> Array.ofSeq in if p.Length <> 3 then failwithf "Expecting paddings to have length three" else p
12271227
| _ -> [|0; 0; 0|]
12281228
let dilations =
12291229
match dilation, dilations with
12301230
| Some _ , Some _ -> failwithf "Expecting only one of dilation, dilations"
12311231
| Some d, None -> [|d; d; d|]
1232-
| None, Some d -> d |> Array.ofSeq
1232+
| None, Some d -> let d = d |> Array.ofSeq in if d.Length <> 3 then failwithf "Expecting dilations to have length three" else d
12331233
| _ -> [|1; 1; 1|]
12341234
checkCanConv3d a.shape b.shape strides paddings dilations
12351235
let mutable b = b

src/DiffSharp.Tests/TestModel.fs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,8 @@ type TestModel () =
3434
let d3flat = d3.flatten()
3535
Assert.AreEqual(d1flatCorrect, d3flat)
3636

37+
// [<Test>]
38+
// member this.TestLinear () =
39+
// let n, dIn, h, dOut = 64, 1000, 100, 10
40+
// let x = dsharp.randn(n, dIn)
41+
// let y = dsharp.randn(n, dOut)

src/Test/Program.fs

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,12 @@ open DiffSharp.Backend.None
2929

3030
type Net() =
3131
inherit Model()
32-
let conv1 = Conv2d(1, 32, 3)
33-
let conv2 = Conv2d(32, 64, 3)
32+
let conv1 = Conv2d(1, 2, 3)
33+
let conv2 = Conv2d(2, 4, 3)
3434
let k = dsharp.randn([1;1;28;28]) |> conv1.forward |> conv2.forward |> dsharp.nelement
3535
let fc1 = Linear(k, 128)
3636
let fc2 = Linear(128, 10)
37-
do
38-
base.add(["conv1", conv1; "conv2", conv2])
39-
base.add(["fc1", fc1; "fc2", fc2])
37+
do base.add([conv1; conv2; fc1; fc2])
4038
override __.forward(x) =
4139
x
4240
// |> dsharp.view [-1; 28*28]
@@ -50,7 +48,6 @@ type Net() =
5048
|> fc2.forward
5149

5250

53-
5451
[<EntryPoint>]
5552
let main _argv =
5653
printfn "Hello World from F#!"
@@ -60,30 +57,50 @@ let main _argv =
6057
let dataset = MNIST("./data", train=true)
6158
let dataloader = dataset.loader(8, shuffle=true, numBatches=50)
6259

63-
let net = Net()
60+
// let net = Net()
61+
62+
let cnn () =
63+
let conv1 = Conv2d(1, 2, 3)
64+
let conv2 = Conv2d(2, 4, 3)
65+
let k = dsharp.randn([1;1;28;28]) |> conv1.forward |> conv2.forward |> dsharp.nelement
66+
let fc1 = Linear(k, 128)
67+
let fc2 = Linear(128, 10)
68+
Model.create [conv1; conv2; fc1; fc2]
69+
(conv1.forward
70+
>> dsharp.relu
71+
>> conv2.forward
72+
>> dsharp.relu
73+
>> dsharp.flatten 1
74+
>> fc1.forward
75+
>> dsharp.relu
76+
>> fc2.forward)
77+
let net = cnn()
78+
6479
printfn "params: %A" (net.nparameters())
80+
// printfn "params: %A" (net.Parameters)
6581

66-
let optimizer = SGD(net, learningRate=dsharp.tensor(0.01), momentum=dsharp.tensor(0.9), nesterov=true)
67-
let mutable epoch = -1
68-
let mutable stop = false
69-
while not stop do
70-
epoch <- epoch + 1
71-
for i, data, targets in dataloader.epoch() do
72-
net.reverseDiff()
73-
let o = net.forward(data)
74-
let loss = dsharp.crossEntropyLoss(o, targets)
75-
loss.reverse()
76-
optimizer.step()
82+
// let optimizer = SGD(net, learningRate=dsharp.tensor(0.01), momentum=dsharp.tensor(0.9), nesterov=true)
83+
// let mutable epoch = -1
84+
// let mutable stop = false
85+
// while not stop do
86+
// epoch <- epoch + 1
87+
// for i, data, targets in dataloader.epoch() do
88+
// net.reverseDiff()
89+
// let o = net.forward(data)
90+
// let loss = dsharp.crossEntropyLoss(o, targets)
91+
// loss.reverse()
92+
// optimizer.step()
7793

78-
let loss = loss.toScalar() :?> float32
79-
printfn "epoch %A, minibatch %A, loss %A\r" epoch i loss
94+
// let loss = loss.toScalar() :?> float32
95+
// printfn "epoch %A, minibatch %A, loss %A\r" epoch i loss
8096

8197
// let loss data target p = net.forwardCompose (dsharp.crossEntropyLoss(target=target)) data p
82-
// let loss = net.forwardLoss dsharp.crossEntropyLoss
83-
// let mutable p = net.getParameters()
84-
// for i, data, target in dataloader.epoch() do
85-
// let loss, g = dsharp.pgrad (loss data target) p
86-
// p <- p - 0.1 * g
87-
// printfn "%A %A" i loss
98+
let loss = net.forwardLoss dsharp.crossEntropyLoss
99+
let mutable p = net.getParameters()
100+
for i, data, target in dataloader.epoch() do
101+
let loss, g = dsharp.pgrad (loss data target) p
102+
p <- p - 0.1 * g
103+
printfn "%A %A" i loss
104+
88105

89106
0 // return an integer exit code

0 commit comments

Comments
 (0)
0