@@ -607,6 +607,223 @@ type TestDerivatives () =
607
607
Assert.True( revz.allclose( revzCorrect, 0.01 ))
608
608
Assert.True( revxd.allclose( revxdCorrect, 0.01 ))
609
609
610
+ [<Test>]
611
+ member this.TestDerivativeMaxUnpool2D () =
612
+ let indices = dsharp.tensor([[[[ 10 , 21 ],
613
+ [ 41 , 45 ]],
614
+
615
+ [[ 8 , 5 ],
616
+ [ 40 , 28 ]]],
617
+
618
+
619
+ [[[ 8 , 21 ],
620
+ [ 32 , 36 ]],
621
+
622
+ [[ 9 , 13 ],
623
+ [ 25 , 27 ]]]], dtype= DType.Int32)
624
+ let fwdx = dsharp.tensor([[[[ 1.8489 , 1.1338 ],
625
+ [ 0.6819 , 1.6331 ]],
626
+
627
+ [[ 1.0867 , 2.1048 ],
628
+ [ 2.7646 , 1.0156 ]]],
629
+
630
+
631
+ [[[ 2.1120 , 0.8666 ],
632
+ [ 0.9141 , 1.7133 ]],
633
+
634
+ [[ 1.4250 , 1.8228 ],
635
+ [ 1.2607 , 0.5448 ]]]])
636
+ let fwdx = fwdx.forwardDiff( dsharp.tensor([[[[ 1.3110 , 1.5369 ],
637
+ [- 0.4640 , 0.1933 ]],
638
+
639
+ [[ 0.2313 , - 0.4964 ],
640
+ [- 0.1616 , - 1.2032 ]]],
641
+
642
+
643
+ [[[ 0.0377 , 2.1561 ],
644
+ [- 0.3110 , - 1.6315 ]],
645
+
646
+ [[ 0.4036 , 0.7063 ],
647
+ [ 0.0583 , 1.8215 ]]]]))
648
+ let fwdz = dsharp.maxunpool2d( fwdx, indices, 3 , outputSize=[ 2 ; 2 ; 8 ; 8 ])
649
+ let fwdzCorrect = dsharp.tensor([[[[ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
650
+ [ 0.0000 , 0.0000 , 1.8489 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
651
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 1.1338 , 0.0000 , 0.0000 ],
652
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
653
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
654
+ [ 0.0000 , 0.6819 , 0.0000 , 0.0000 , 0.0000 , 1.6331 , 0.0000 , 0.0000 ],
655
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
656
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ]],
657
+
658
+ [[ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 2.1048 , 0.0000 , 0.0000 ],
659
+ [ 1.0867 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
660
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
661
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 1.0156 , 0.0000 , 0.0000 , 0.0000 ],
662
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
663
+ [ 2.7646 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
664
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
665
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ]]],
666
+
667
+
668
+ [[[ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
669
+ [ 2.1120 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
670
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.8666 , 0.0000 , 0.0000 ],
671
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
672
+ [ 0.9141 , 0.0000 , 0.0000 , 0.0000 , 1.7133 , 0.0000 , 0.0000 , 0.0000 ],
673
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
674
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
675
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ]],
676
+
677
+ [[ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
678
+ [ 0.0000 , 1.4250 , 0.0000 , 0.0000 , 0.0000 , 1.8228 , 0.0000 , 0.0000 ],
679
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
680
+ [ 0.0000 , 1.2607 , 0.0000 , 0.5448 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
681
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
682
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
683
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
684
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ]]]])
685
+ let fwdzd = fwdz.derivative
686
+ let fwdzdCorrect = dsharp.tensor([[[[ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
687
+ [ 0.0000 , 0.0000 , 1.3110 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
688
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 1.5369 , 0.0000 , 0.0000 ],
689
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
690
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
691
+ [ 0.0000 , - 0.4640 , 0.0000 , 0.0000 , 0.0000 , 0.1933 , 0.0000 , 0.0000 ],
692
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
693
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ]],
694
+
695
+ [[ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , - 0.4964 , 0.0000 , 0.0000 ],
696
+ [ 0.2313 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
697
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
698
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , - 1.2032 , 0.0000 , 0.0000 , 0.0000 ],
699
+ [ 0.0000 , 0.0000 , 0.0000
E377
span>, 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
700
+ [- 0.1616 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
701
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
702
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ]]],
703
+
704
+
705
+ [[[ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
706
+ [ 0.0377 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
707
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 2.1561 , 0.0000 , 0.0000 ],
708
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
709
+ [- 0.3110 , 0.0000 , 0.0000 , 0.0000 , - 1.6315 , 0.0000 , 0.0000 , 0.0000 ],
710
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
711
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
712
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ]],
713
+
714
+ [[ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
715
+ [ 0.0000 , 0.4036 , 0.0000 , 0.0000 , 0.0000 , 0.7063 , 0.0000 , 0.0000 ],
716
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
717
+ [ 0.0000 , 0.0583 , 0.0000 , 1.8215 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
718
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
719
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
720
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
721
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ]]]])
722
+
723
+ let revx = dsharp.tensor([[[[ 1.8489 , 1.1338 ],
724
+ [ 0.6819 , 1.6331 ]],
725
+
726
+ [[ 1.0867 , 2.1048 ],
727
+ [ 2.7646 , 1.0156 ]]],
728
+
729
+
730
+ [[[ 2.1120 , 0.8666 ],
731
+ [ 0.9141 , 1.7133 ]],
732
+
733
+ [[ 1.4250 , 1.8228 ],
734
+ [ 1.2607 , 0.5448 ]]]]) .reverseDiff()
735
+ let revz = dsharp.maxunpool2d( revx, indices, 3 , outputSize=[ 2 ; 2 ; 8 ; 8 ])
736
+ let revzCorrect = dsharp.tensor([[[[ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
737
+ [ 0.0000 , 0.0000 , 1.8489 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
738
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 1.1338 , 0.0000 , 0.0000 ],
739
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
740
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
741
+ [ 0.0000 , 0.6819 , 0.0000 , 0.0000 , 0.0000 , 1.6331 , 0.0000 , 0.0000 ],
742
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
743
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ]],
744
+
745
+ [[ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 2.1048 , 0.0000 , 0.0000 ],
746
+ [ 1.0867 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
747
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
748
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 1.0156 , 0.0000 , 0.0000 , 0.0000 ],
749
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
750
+ [ 2.7646 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
751
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
752
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ]]],
753
+
754
+
755
+ [[[ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
756
+ [ 2.1120 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
757
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.8666 , 0.0000 , 0.0000 ],
758
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
759
+ [ 0.9141 , 0.0000 , 0.0000 , 0.0000 , 1.7133 , 0.0000 , 0.0000 , 0.0000 ],
760
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
761
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
762
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ]],
763
+
764
+ [[ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
765
+ [ 0.0000 , 1.4250 , 0.0000 , 0.0000 , 0.0000 , 1.8228 , 0.0000 , 0.0000 ],
766
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
767
+ [ 0.0000 , 1.2607 , 0.0000 , 0.5448 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
768
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
769
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
770
+ [ 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ],
771
+ [ 0.0000 <
10000
span class="pl-k">, 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 , 0.0000 ]]]])
772
+ revz.reverse( dsharp.tensor([[[[ 0.8984 , - 0.1793 , - 0.6322 , - 0.0690 , - 0.9264 , 0.0128 , - 0.7905 , - 1.1720 ],
773
+ [- 0.9698 , - 0.5440 , 0.4486 , 1.2180 , 0.2225 , 0.9609 , - 0.9034 , - 0.0780 ],
774
+ [- 0.4951 , 0.8007 , 0.9854 , 0.0834 , - 0.6088 , - 0.0391 , - 1.2598 , 0.9716 ],
775
+ [ 0.3452 , - 0.8826 , - 0.6269 , 0.2676 , - 0.7513 , - 1.2730 , 1.3117 , - 1.0414 ],
776
+ [- 0.6497 , 0.3582 , - 1.9917 , - 0.4683 , - 0.7881 , - 1.3295 , - 0.2698 , 0.5392 ],
777
+ [- 0.5929 , 0.5991 , - 0.9721 , 3.0464 , - 0.4441 , - 2.0565 , - 1.5350 , - 2.4916 ],
778
+ [- 0.6468 , 0.4241 , 0.1965 , - 0.9907 , 0.1452 , 0.4475 , - 0.1735 , 0.1073 ],
779
+ [ 2.5096 , 1.4079 , - 0.6148 , - 0.4607 , - 0.4818 , 0.0415 , - 1.3375 , 0.9602 ]],
780
+
781
+ [[ 0.2130 , 1.5446 , - 0.5831 , 0.1359 , - 1.0135 , - 0.6529 , 2.1866 , 1.2187 ],
782
+ [ 1.1122 , - 0.1649 , 0.0473 , - 0.6117 , 2.1489 , - 0.4845 , 0.3153 , - 1.8326 ],
783
+ [- 1.9014 , - 0.3670 , 0.8990 , - 0.4523 , 0.3366 , - 1.0262 , 1.0180 , - 1.6572 ],
784
+ [- 0.0980 , - 0.7111 , 1.0891 , 0.2800 , - 1.7344 , 1.7927 , 0.1482 , 0.7804 ],
785
+ [- 0.2373 , 0.1023 , - 0.4915 , - 0.7444 , 1.1870 , - 0.2154 , 2.6652 , - 0.5908 ],
786
+ [ 0.1938 , 0.7860 , - 0.2982 , 0.3848 , 0.6933 , 0.0560 , 0.3348 , - 1.7360 ],
787
+ [ 1.6364 , 0.0241 , 0.5503 , 1.0353 , 0.2991 , - 0.0626 , - 0.1946 , - 0.8325 ],
788
+ [ 0.2725 , - 1.0100 , - 1.0443 , - 0.1072 , - 0.7004 , - 1.5560 , 1.2907 , - 0.4360 ]]],
789
+
790
+
791
+ [[[- 0.2089 , 0.7702 , 0.0477 , 0.4665 , 0.9363 , - 0.1614 , 2.6186 , - 0.7527 ],
792
+ [- 1.1260 , - 1.0446 , - 1.1232 , - 1.4205 , - 0.5242 , - 0.6396 , 1.2601 , 0.6628 ],
793
+ [- 0.1578 , - 0.9871 , - 0.0246 , 0.9263 , - 0.8434 , 2.0567 , 0.1197 , - 1.4947 ],
794
+ [ 0.6996 , - 1.6738 , 1.4169 , 1.6045 , 1.7552 , - 1.2681 , - 0.6705 , 0.2275 ],
795
+ [- 0.8477 , 0.2387 , 2.1242 , - 1.0176 , - 0.6977 , - 1.1640 , - 0.6400 , - 0.5105 ],
796
+ [- 1.2159 , 0.0419 , - 0.5746 , 0.1896 , - 1.2158 , - 0.7044 , 0.7461 , - 1.2746 ],
797
+ [ 0.0053 , 1.7733 , - 1.2196 , 1.2569 , 0.8448 , - 0.3330 , 0.5204 , 1.0973 ],
798
+ [ 0.5777 , 0.6732 , 0.1366 , - 0.8237 , 0.2497 , - 1.0159 , - 2.3620 , - 0.2002 ]],
799
+
800
+ [[- 0.3120 , - 0.2487 , - 0.3603 , 0.2290 , - 1.3754 , 0.1596 , - 0.1769 , - 1.2327 ],
801
+ [- 0.3505 , 0.4122 , - 0.9472 , - 0.4892 , - 0.4146 , - 1.5960 , - 0.5563 , 0.3567 ],
802
+ [ 0.7608 , 0.6693 , - 1.0732 , - 1.6005 , - 0.2449 , - 1.2722 , 0.3509 , - 1.3285 ],
803
+ [- 1.1414 , 0.0691 , 1.0393 , 1.2117 , - 0.1610 , 0.7500 , 0.3646 , 0.4578 ],
804
+ [ 1.2932 , 0.4704 , 1.3387 , 0.8193 , 1.8205 , - 0.0931 , 0.0629 , - 0.1365 ],
805
+ [ 0.5605 , 0.6983 , - 1.1321 , - 1.4662 , 2.1607 , 0.1176 , 0.0903 , 0.7739 ],
806
+ [ 0.3887 , 0.3413 , 0.4066 , 0.8743 , 3.5218 , - 0.4829 , - 0.8280 , - 1.2032 ],
807
+ [ 0.8603 , - 0.4883 , 0.0139 , 0.4995 , 1.0476 , - 2.1789 , 0.1493 , 1.1592 ]]]]))
808
+ let revxd = revx.derivative
809
+ let revxdCorrect = dsharp.tensor([[[[ 0.4486 , - 0.0391 ],
810
+ [ 0.5991 , - 2.0565 ]],
811
+
812
+ [[ 1.1122 , - 0.6529 ],
813
+ [ 0.1938 , - 1.7344 ]]],
814
+
815
+
816
+ [[[- 1.1260 , 2.0567 ],
817
+ [- 0.8477 , - 0.6977 ]],
818
+
819
+ [[ 0.4122 , - 1.5960 ],
820
+ [ 0.0691 , 1.2117 ]]]])
821
+
822
+ Assert.True( fwdz.allclose( fwdzCorrect, 0.01 ))
823
+ Assert.True( fwdzd.allclose( fwdzdCorrect, 0.01 ))
824
+ Assert.True( revz.allclose( revzCorrect, 0.01 ))
825
+ Assert.True( revxd.allclose( revxdCorrect, 0.01 ))
826
+
610
827
[<Test>]
611
828
member this.TestDerivativeConv1D () =
612
829
let fwdx = dsharp.tensor([[[ 0.1264 ; 5.3183 ; 6.6905 ; - 10.6416 ];
0 commit comments