@@ -225,7 +225,7 @@ type Categorical(?probs:Tensor, ?logits:Tensor) =
225
225
226
226
227
227
/// <summary>Represents an Empirical distribution.</summary>
228
- type Empirical < 'T when 'T:equality >( values :seq < 'T >, ? weights : Tensor , ? logWeights : Tensor , ? combineDuplicates : bool ) =
228
+ type Empirical < 'T when 'T:equality >( values :seq < 'T >, ? weights : Tensor , ? logWeights : Tensor , ? combineDuplicates : bool , ? device : Device , ? dtype : Dtype , ? backend : Backend ) =
229
229
inherit Distribution< 'T>()
230
230
let _categorical , _weighted =
231
231
match weights, logWeights with
@@ -239,7 +239,7 @@ type Empirical<'T when 'T:equality>(values:seq<'T>, ?weights:Tensor, ?logWeights
239
239
let _valuesTensor =
240
240
lazy ( try _ values |> Array.map ( fun v -> box v :?> Tensor) |> dsharp.stack
241
241
with | _ ->
242
- try _ values |> Array.map ( dsharp.tensor) |> dsharp.stack
242
+ try _ values |> Array.map ( dsharp.tensor( device = defaultArg device Device.Default , backend = defaultArg backend Backend.Default , dtype = defaultArg dtype Dtype.Default ) ) |> dsharp.stack
243
243
with | _ -> failwith " Not supported because Empirical does not hold values that are Tensors or can be converted to Tensors" )
244
244
do
245
245
let combineDuplicates = defaultArg combineDuplicates false
@@ -256,7 +256,8 @@ type Empirical<'T when 'T:equality>(values:seq<'T>, ?weights:Tensor, ?logWeights
256
256
Dictionary.copyKeys uniques, dsharp.stack( Dictionary.copyValues uniques) .view(- 1 )
257
257
else
258
258
let vals , counts = _ values |> Array.getUniqueCounts false
259
- vals, dsharp.tensor( counts)
259
+ let c = dsharp.tensor( counts, device= defaultArg device Device.Default, backend= defaultArg backend Backend.Default, dtype= defaultArg dtype Dtype.Default)
260
+ vals, probsToLogits ( c/ c.sum()) false
260
261
_ values <- newValues
261
262
_ categorical <- Categorical( logits= newLogWeights)
262
263
_ weighted <- not ( Seq.allEqual (_ categorical.probs.unstack()))
0 commit comments