@@ -47,20 +47,28 @@ type ParameterDict() =
47
47
let dd = d.copy()
48
48
dd.unflatten( tensors)
49
49
dd
50
+ override d.ToString () =
51
+ let sb = System.Text.StringBuilder()
52
+ for KeyValue( n, p) in d.values do sb.AppendLine( sprintf " %A , %A " n p) |> ignore
53
+ sb.ToString()
54
+
50
55
51
56
[<AbstractClass>]
52
57
type Model () =
53
58
member val Parameters : ParameterDict = ParameterDict()
54
59
member val SubModels : Dictionary < string , Model > = Dictionary()
55
- member inline m.add ( parameters : list < string * 'a >) =
56
- for n, p in parameters do
60
+ member m.add ( parameters : seq < obj >, ? names : seq < string >) =
61
+ let parameters = parameters |> Seq.toArray
62
+ let names = defaultArg names ( Seq.init ( parameters.Length) ( fun i -> sprintf " p__%d " i)) |> Seq.toArray
63
+ if parameters.Length <> names.Length then failwithf " Expecting parameters.Length (%A ) and names.Length (%A ) to be same" parameters.Length names.Length
64
+ for p, n in Array.zip parameters names do
57
65
match ( box p) with
58
66
| :? Parameter as p ->
59
67
m.Parameters.add( n, p)
60
68
| :? Model as mm ->
61
69
m.SubModels.Add( n, mm)
62
70
m.Parameters.add( mm.Parameters.map( fun ( nn , pp : Parameter ) -> ( n + " __" + nn, pp)))
63
- | _ -> failwithf " Unsupported type. Expecting a list<string * 'a> where 'a is Parameter or Model"
71
+ | _ -> failwithf " Unsupported type. Expecting a Parameter or Model"
64
72
member m.forwardDiff ( derivatives : ParameterDict ) = m.Parameters.forwarddiff( derivatives)
65
73
member m.reverseDiff () = m.Parameters.reverseDiff()
66
74
member m.noDiff () = m.Parameters.noDiff()
@@ -76,6 +84,12 @@ type Model() =
76
84
member m.forwardLoss ( f : Tensor -> Tensor -> Tensor ) ( input : Tensor ) ( target : Tensor ) ( parameters : Tensor ) =
77
85
m.forwardCompose ( f target) input parameters
78
86
abstract member forward: Tensor -> Tensor
87
+ static member create ps f =
88
+ let model = { new Model() with override __.forward ( x ) = f x}
89
+ model.add( ps)
90
+ model
91
+ static member compose ( model1 : Model ) ( model2 : Model ) =
92
+ Model.create [ model1; model2] ( model1.forward >> model2.forward)
79
93
80
94
81
95
type Weight () =
@@ -95,7 +109,7 @@ type Linear(inFeatures, outFeatures, ?bias:bool) =
95
109
let w = Parameter( Weight.kaiming( inFeatures, outFeatures))
96
110
let k = 1. / sqrt ( float outFeatures)
97
111
let b = Parameter( if bias then Weight.standard([| outFeatures|], k) else dsharp.zero())
98
- do base .add([ " weight " , w ; " bias " , b ])
112
+ do base .add([ w ; b ],[ " Linear__weight " ; " Linear__bias " ])
99
113
override l.forward ( value ) =
100
114
let f = dsharp.matmul( value, w.value)
101
115
if bias then f + b.value else f
@@ -107,7 +121,7 @@ type Conv1d(inChannels:int, outChannels:int, kernelSize:int, ?stride:int, ?paddi
107
121
let k = 1. / sqrt ( float ( inChannels* kernelSize))
108
122
let w = Parameter <| Weight.standard([| outChannels; inChannels; kernelSize|], k)
109
123
let b = Parameter <| if bias then Weight.standard([| outChannels|], k) else dsharp.zero()
110
- do base .add([ " weight " , w ; " bias " , b ])
124
+ do base .add([ w ; b ],[ " Conv1d__weight " ; " Conv1d__bias " ])
111
125
override c.forward ( value ) =
112
126
let f = dsharp.conv1d( value, w.value, ?stride= stride, ?padding= padding, ?dilation= dilation)
113
127
if bias then f + b.value.expand([ value.shape.[ 0 ]; outChannels]) .view([ value.shape.[ 0 ]; outChannels; 1 ]) else f
@@ -119,13 +133,31 @@ type Conv2d(inChannels:int, outChannels:int, ?kernelSize:int, ?stride:int, ?padd
119
133
match kernelSize, kernelSizes with
120
134
| Some _ , Some _ -> failwithf " Expecting only one of kernelSize, kernelSizes"
121
135
| Some k, None -> [| k; k|]
122
- | None, Some k -> k |> Array.ofSeq
136
+ | None, Some k -> let k = k |> Array.ofSeq in if k.Length <> 2 then failwithf " Expecting kernelSizes to have length two " else k
123
137
| _ -> [| 1 ; 1 |]
124
138
let bias = defaultArg bias true
125
139
let k = 1. / sqrt ( float ( inChannels* kernelSizes.[ 0 ]* kernelSizes.[ 1 ]))
126
140
let w = Parameter <| Weight.standard([| outChannels; inChannels; kernelSizes.[ 0 ]; kernelSizes.[ 1 ]|], k)
127
141
let b = Parameter <| if bias then Weight.standard([| outChannels|], k) else dsharp.zero()
128
- do base .add([ " weight " , w ; " bias " , b ])
142
+ do base .add([ w ; b ],[ " Conv2d__weight " ; " Conv2d__bias " ])
129
143
override c.forward ( value ) =
130
144
let f = dsharp.conv2d( value, w.value, ?stride= stride, ?strides= strides, ?padding= padding, ?paddings= paddings, ?dilation= dilation, ?dilations= dilations)
131
- if bias then f + b.value.expand([ value.shape.[ 0 ]; outChannels]) .view([ value.shape.[ 0 ]; outChannels; 1 ; 1 ]) else f
145
+ if bias then f + b.value.expand([ value.shape.[ 0 ]; outChannels]) .view([ value.shape.[ 0 ]; outChannels; 1 ; 1 ]) else f
146
+
147
+
148
+ type Conv3d ( inChannels : int , outChannels : int , ? kernelSize : int , ? stride : int , ? padding : int , ? dilation : int , ? kernelSizes :seq < int >, ? strides :seq < int >, ? paddings :seq < int >, ? dilations :seq < int >, ? bias : bool ) =
149
+ inherit Model()
150
+ let kernelSizes =
151
+ match kernelSize, kernelSizes with
152
+ | Some _ , Some _ -> failwithf " Expecting only one of kernelSize, kernelSizes"
153
+ | Some k, None -> [| k; k; k|]
154
+ | None, Some k -> let k = k |> Array.ofSeq in if k.Length <> 3 then failwithf " Expecting kernelSizes to have length three" else k
155
+ | _ -> [| 1 ; 1 ; 1 |]
156
+ let bias = defaultArg bias true
157
+ let k = 1. / sqrt ( float ( inChannels* kernelSizes.[ 0 ]* kernelSizes.[ 1 ]* kernelSizes.[ 2 ]))
158
+ let w = Parameter <| Weight.standard([| outChannels; inChannels; kernelSizes.[ 0 ]; kernelSizes.[ 1 ]; kernelSizes.[ 2 ]|], k)
159
+ let b = Parameter <| if bias then Weight.standard([| outChannels|], k) else dsharp.zero()
160
+ do base .add([ w; b],[ " Conv3d__weight" ; " Conv3d__bias" ])
161
+ override c.forward ( value ) =
162
+ let f = dsharp.conv3d( value, w.value, ?stride= stride, ?strides= strides, ?padding= padding, ?paddings= paddings, ?dilation= dilation, ?dilations= dilations)
163
+ if bias then f + b.value.expand([ value.shape.[ 0 ]; outChannels]) .view([ value.shape.[ 0 ]; outChannels; 1 ; 1 ; 1 ]) else f
0 commit comments