8000 maxunpool2d complete · DiffSharp/DiffSharp@601a146 · GitHub
[go: up one dir, main page]

Skip to content

Commit 601a146

Browse files
committed
maxunpool2d complete
1 parent b848836 commit 601a146

File tree

2 files changed

+218
-1
lines changed

2 files changed

+218
-1
lines changed

src/DiffSharp.Core/Tensor.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1650,7 +1650,7 @@ type Tensor =
16501650
| MaxPool1DT(a, indices, kernelSize) -> push ((t.derivative.maxunpool1d(indices, kernelSize=kernelSize, outputSize=a.shape), a) :: tt)
16511651
| MaxPool2DT(a, indices, kernelSizes) -> push ((t.derivative.maxunpool2d(indices, kernelSizes=kernelSizes, outputSize=a.shape), a) :: tt)
16521652
| MaxUnpool1DT(a, indices) -> push ((t.derivative.gather(dim=2, indices=indices), a) :: tt)
1653-
| MaxUnpool2DT(a, indices) -> failwith "Not implemented" // push ((t.derivative.gather(dim=2, indices=indices), a) :: tt)
1653+
| MaxUnpool2DT(a, indices) -> push ((t.derivative.flatten(startDim=2).gather(dim=2, indices=indices.flatten(startDim=2)).viewAs(a), a) :: tt)
16541654
| Conv1DTT(a,b,stride,padding) ->
16551655
let aderivative, bderivative = t.conv1dReverseDiff(a, b, false, false, stride, padding)
16561656
push ((aderivative, a) :: (bderivative, b) :: tt)

src/DiffSharp.Tests/TestDerivatives.fs

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,223 @@ type TestDerivatives () =
607607
Assert.True(revz.allclose(revzCorrect, 0.01))
608608
Assert.True(revxd.allclose(revxdCorrect, 0.01))
609609

610+
[<Test>]
611+
member this.TestDerivativeMaxUnpool2D () =
612+
let indices = dsharp.tensor([[[[10, 21],
613+
[41, 45]],
614+
615+
[[ 8, 5],
616+
[40, 28]]],
617+
618+
619+
[[[ 8, 21],
620+
[32, 36]],
621+
622+
[[ 9, 13],
623+
[25, 27]]]], dtype=DType.Int32)
624+
let fwdx = dsharp.tensor([[[[1.8489, 1.1338],
625+
[0.6819, 1.6331]],
626+
627+
[[1.0867, 2.1048],
628+
[2.7646, 1.0156]]],
629+
630+
631+
[[[2.1120, 0.8666],
632+
[0.9141, 1.7133]],
633+
634+
[[1.4250, 1.8228],
635+
[1.2607, 0.5448]]]])
636+
let fwdx = fwdx.forwardDiff(dsharp.tensor([[[[ 1.3110, 1.5369],
637+
[-0.4640, 0.1933]],
638+
639+
[[ 0.2313, -0.4964],
640+
[-0.1616, -1.2032]]],
641+
642+
643+
[[[ 0.0377, 2.1561],
644+
[-0.3110, -1.6315]],
645+
646+
[[ 0.4036, 0.7063],
647+
[ 0.0583, 1.8215]]]]))
648+
let fwdz = dsharp.maxunpool2d(fwdx, indices, 3, outputSize=[2;2;8;8])
649+
let fwdzCorrect = dsharp.tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
650+
[0.0000, 0.0000, 1.8489, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
651+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.1338, 0.0000, 0.0000],
652+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
653+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
654+
[0.0000, 0.6819, 0.0000, 0.0000, 0.0000, 1.6331, 0.0000, 0.0000],
655+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
656+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
657+
658+
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.1048, 0.0000, 0.0000],
659+
[1.0867, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
660+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
661+
[0.0000, 0.0000, 0.0000, 0.0000, 1.0156, 0.0000, 0.0000, 0.0000],
662+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
663+
[2.7646, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
664+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
665+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
666+
667+
668+
[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
669+
[2.1120, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
670+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8666, 0.0000, 0.0000],
671+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
672+
[0.9141, 0.0000, 0.0000, 0.0000, 1.7133, 0.0000, 0.0000, 0.0000],
673+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
674+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
675+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
676+
677+
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
678+
[0.0000, 1.4250, 0.0000, 0.0000, 0.0000, 1.8228, 0.0000, 0.0000],
679+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
680+
[0.0000, 1.2607, 0.0000, 0.5448, 0.0000, 0.0000, 0.0000, 0.0000],
681+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
682+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
683+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
684+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
685+
let fwdzd = fwdz.derivative
686+
let fwdzdCorrect = dsharp.tensor([[[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
687+
[ 0.0000, 0.0000, 1.3110, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
688+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.5369, 0.0000, 0.0000],
689+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
690+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
691+
[ 0.0000, -0.4640, 0.0000, 0.0000, 0.0000, 0.1933, 0.0000, 0.0000],
692+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
693+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
694+
695+
[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -0.4964, 0.0000, 0.0000],
696+
[ 0.2313, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
697+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
698+
[ 0.0000, 0.0000, 0.0000, 0.0000, -1.2032, 0.0000, 0.0000, 0.0000],
699+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
700+
[-0.1616, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
701+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
702+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
703+
704+
705+
[[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
706+
[ 0.0377, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
707+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.1561, 0.0000, 0.0000],
708+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
709+
[-0.3110, 0.0000, 0.0000, 0.0000, -1.6315, 0.0000, 0.0000, 0.0000],
710+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
711+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
712+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
713+
714+
[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
715+
[ 0.0000, 0.4036, 0.0000, 0.0000, 0.0000, 0.7063, 0.0000, 0.0000],
716+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
717+
[ 0.0000, 0.0583, 0.0000, 1.8215, 0.0000, 0.0000, 0.0000, 0.0000],
718+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
719+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
720+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
721+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
722+
723+
let revx = dsharp.tensor([[[[1.8489, 1.1338],
724+
[0.6819, 1.6331]],
725+
726+
[[1.0867, 2.1048],
727+
[2.7646, 1.0156]]],
728+
729+
730+
[[[2.1120, 0.8666],
731+
[0.9141, 1.7133]],
732+
733+
[[1.4250, 1.8228],
734+
[1.2607, 0.5448]]]]).reverseDiff()
735+
let revz = dsharp.maxunpool2d(revx, indices, 3, outputSize=[2;2;8;8])
736+
let revzCorrect = dsharp.tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
737+
[0.0000, 0.0000, 1.8489, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
738+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.1338, 0.0000, 0.0000],
739+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
740+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
741+
[0.0000, 0.6819, 0.0000, 0.0000, 0.0000, 1.6331, 0.0000, 0.0000],
742+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
743+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
744+
745+
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.1048, 0.0000, 0.0000],
746+
[1.0867, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
747+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
748+
[0.0000, 0.0000, 0.0000, 0.0000, 1.0156, 0.0000, 0.0000, 0.0000],
749+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
750+
[2.7646, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
751+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
752+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
753+
754+
755+
[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
756+
[2.1120, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
757+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8666, 0.0000, 0.0000],
758+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
759+
[0.9141, 0.0000, 0.0000, 0.0000, 1.7133, 0.0000, 0.0000, 0.0000],
760+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
761+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
762+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
763+
764+
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
765+
[0.0000, 1.4250, 0.0000, 0.0000, 0.0000, 1.8228, 0.0000, 0.0000],
766+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
767+
[0.0000, 1.2607, 0.0000, 0.5448, 0.0000, 0.0000, 0.0000, 0.0000],
768+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
769+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
770+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
771+
[0.0000< 10000 span class="pl-k">, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
772+
revz.reverse(dsharp.tensor([[[[ 0.8984, -0.1793, -0.6322, -0.0690, -0.9264, 0.0128, -0.7905, -1.1720],
773+
[-0.9698, -0.5440, 0.4486, 1.2180, 0.2225, 0.9609, -0.9034, -0.0780],
774+
[-0.4951, 0.8007, 0.9854, 0.0834, -0.6088, -0.0391, -1.2598, 0.9716],
775+
[ 0.3452, -0.8826, -0.6269, 0.2676, -0.7513, -1.2730, 1.3117, -1.0414],
776+
[-0.6497, 0.3582, -1.9917, -0.4683, -0.7881, -1.3295, -0.2698, 0.5392],
777+
[-0.5929, 0.5991, -0.9721, 3.0464, -0.4441, -2.0565, -1.5350, -2.4916],
778+
[-0.6468, 0.4241, 0.1965, -0.9907, 0.1452, 0.4475, -0.1735, 0.1073],
779+
[ 2.5096, 1.4079, -0.6148, -0.4607, -0.4818, 0.0415, -1.3375, 0.9602]],
780+
781+
[[ 0.2130, 1.5446, -0.5831, 0.1359, -1.0135, -0.6529, 2.1866, 1.2187],
782+
[ 1.1122, -0.1649, 0.0473, -0.6117, 2.1489, -0.4845, 0.3153, -1.8326],
783+
[-1.9014, -0.3670, 0.8990, -0.4523, 0.3366, -1.0262, 1.0180, -1.6572],
784+
[-0.0980, -0.7111, 1.0891, 0.2800, -1.7344, 1.7927, 0.1482, 0.7804],
785+
[-0.2373, 0.1023, -0.4915, -0.7444, 1.1870, -0.2154, 2.6652, -0.5908],
786+
[ 0.1938, 0.7860, -0.2982, 0.3848, 0.6933, 0.0560, 0.3348, -1.7360],
787+
[ 1.6364, 0.0241, 0.5503, 1.0353, 0.2991, -0.0626, -0.1946, -0.8325],
788+
[ 0.2725, -1.0100, -1.0443, -0.1072, -0.7004, -1.5560, 1.2907, -0.4360]]],
789+
790+
791+
[[[-0.2089, 0.7702, 0.0477, 0.4665, 0.9363, -0.1614, 2.6186, -0.7527],
792+
[-1.1260, -1.0446, -1.1232, -1.4205, -0.5242, -0.6396, 1.2601, 0.6628],
793+
[-0.1578, -0.9871, -0.0246, 0.9263, -0.8434, 2.0567, 0.1197, -1.4947],
794+
[ 0.6996, -1.6738, 1.4169, 1.6045, 1.7552, -1.2681, -0.6705, 0.2275],
795+
[-0.8477, 0.2387, 2.1242, -1.0176, -0.6977, -1.1640, -0.6400, -0.5105],
796+
[-1.2159, 0.0419, -0.5746, 0.1896, -1.2158, -0.7044, 0.7461, -1.2746],
797+
[ 0.0053, 1.7733, -1.2196, 1.2569, 0.8448, -0.3330, 0.5204, 1.0973],
798+
[ 0.5777, 0.6732, 0.1366, -0.8237, 0.2497, -1.0159, -2.3620, -0.2002]],
799+
800+
[[-0.3120, -0.2487, -0.3603, 0.2290, -1.3754, 0.1596, -0.1769, -1.2327],
801+
[-0.3505, 0.4122, -0.9472, -0.4892, -0.4146, -1.5960, -0.5563, 0.3567],
802+
[ 0.7608, 0.6693, -1.0732, -1.6005, -0.2449, -1.2722, 0.3509, -1.3285],
803+
[-1.1414, 0.0691, 1.0393, 1.2117, -0.1610, 0.7500, 0.3646, 0.4578],
804+
[ 1.2932, 0.4704, 1.3387, 0.8193, 1.8205, -0.0931, 0.0629, -0.1365],
805+
[ 0.5605, 0.6983, -1.1321, -1.4662, 2.1607, 0.1176, 0.0903, 0.7739],
806+
[ 0.3887, 0.3413, 0.4066, 0.8743, 3.5218, -0.4829, -0.8280, -1.2032],
807+
[ 0.8603, -0.4883, 0.0139, 0.4995, 1.0476, -2.1789, 0.1493, 1.1592]]]]))
808+
let revxd = revx.derivative
809+
let revxdCorrect = dsharp.tensor([[[[ 0.4486, -0.0391],
810+
[ 0.5991, -2.0565]],
811+
812+
[[ 1.1122, -0.6529],
813+
[ 0.1938, -1.7344]]],
814+
815+
816+
[[[-1.1260, 2.0567],
817+
[-0.8477, -0.6977]],
818+
819+
[[ 0.4122, -1.5960],
820+
[ 0.0691, 1.2117]]]])
821+
822+
Assert.True(fwdz.allclose(fwdzCorrect, 0.01))
823+
Assert.True(fwdzd.allclose(fwdzdCorrect, 0.01))
824+
Assert.True(revz.allclose(revzCorrect, 0.01))
825+
Assert.True(revxd.allclose(revxdCorrect, 0.01))
826+
610827
[<Test>]
611828
member this.TestDerivativeConv1D () =
612829
let fwdx = dsharp.tensor([[[ 0.1264; 5.3183; 6.6905; -10.6416];

0 commit comments

Comments
 (0)
0