10000 fix unweighted combination · DiffSharp/DiffSharp@3edbce6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3edbce6

Browse files
committed
fix unweighted combination
1 parent 6d49247 commit 3edbce6

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

src/DiffSharp.Core/Distributions.fs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ type Categorical(?probs:Tensor, ?logits:Tensor) =
225225

226226

227227
/// <summary>Represents an Empirical distribution.</summary>
228-
type Empirical<'T when 'T:equality>(values:seq<'T>, ?weights:Tensor, ?logWeights:Tensor, ?combineDuplicates:bool) =
228+
type Empirical<'T when 'T:equality>(values:seq<'T>, ?weights:Tensor, ?logWeights:Tensor, ?combineDuplicates:bool, ?device:Device, ?dtype:Dtype, ?backend:Backend) =
229229
inherit Distribution<'T>()
230230
let _categorical, _weighted =
231231
match weights, logWeights with
@@ -239,7 +239,7 @@ type Empirical<'T when 'T:equality>(values:seq<'T>, ?weights:Tensor, ?logWeights
239239
let _valuesTensor =
240240
lazy(try _values |> Array.map (fun v -> box v :?> Tensor) |> dsharp.stack
241241
with | _ ->
242-
try _values |> Array.map (dsharp.tensor) |> dsharp.stack
242+
try _values |> Array.map (dsharp.tensor(device=defaultArg device Device.Default, backend=defaultArg backend Backend.Default, dtype=defaultArg dtype Dtype.Default)) |> dsharp.stack
243243
with | _ -> failwith "Not supported because Empirical does not hold values that are Tensors or can be converted to Tensors")
244244
do
245245
let combineDuplicates = defaultArg combineDuplicates false
@@ -256,7 +256,8 @@ type Empirical<'T when 'T:equality>(values:seq<'T>, ?weights:Tensor, ?logWeights
256256
Dictionary.copyKeys uniques, dsharp.stack(Dictionary.copyValues uniques).view(-1)
257257
else
258258
let vals, counts = _values |> Array.getUniqueCounts false
259-
vals, dsharp.tensor(counts)
259+
let c = dsharp.tensor(counts, device=defaultArg device Device.Default, backend=defaultArg backend Backend.Default, dtype=defaultArg dtype Dtype.Default)
260+
vals, probsToLogits (c/c.sum()) false
260261
_values <- newValues
261262
_categorical <- Categorical(logits=newLogWeights)
262263
_weighted <- not (Seq.allEqual (_categorical.probs.unstack()))

tests/DiffSharp.Tests/TestDistributions.fs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,23 @@ type TestDistributions () =
369369
Assert.CheckEqual(distUnweightedCombinedModeCorrect, distUnweightedCombinedMode)
370370
Assert.CheckEqual(distUnweightedCombinedLengthCorrect, distUnweightedCombinedLength)
371371

372+
[<Test>]
373+
member _.TestDistributionsEmpiricalCombineDuplicatesUnweighted () =
374+
for combo in Combos.AllDevicesAndBackendsFloat32 do
375+
let values = combo.tensor([0,1,1,1,2,2])
376+
377+
let dist = Empirical(values.unstack(), combineDuplicates=true, device=combo.device, backend=combo.backend, dtype=combo.dtype)
378+
let values = dist.valuesTensor
379+
let weights = dist.weights
380+
let valuesCorrect = combo.tensor([0,1,2])
381+
let weightsCorrect = combo.tensor([1./6., 3./6., 2./6.])
382+
printfn "%A" values.dtype
383+
printfn "%A" weights.dtype
384+
printfn "%A" valuesCorrect.dtype
385+
printfn "%A" weightsCorrect.dtype
386+
Assert.True(valuesCorrect.allclose(values, 0.1))
387+
Assert.True(weightsCorrect.allclose(weights, 0.1))
388+
372389
[<Test>]
373390
member _.TestDistributionsEmpiricalResampleFilter () =
374391
for combo in Combos.AllDevicesAndBackendsFloat32 do

0 commit comments

Comments
 (0)
0