From c84a4e510b98a17d8883b928871f315cb30b4cd7 Mon Sep 17 00:00:00 2001 From: iamshnoo Date: Tue, 11 Aug 2020 01:01:45 +0530 Subject: [PATCH 1/6] Pixel Shuffle layer. Commit 1. --- src/mlpack/methods/ann/layer/CMakeLists.txt | 2 + src/mlpack/methods/ann/layer/layer.hpp | 1 + src/mlpack/methods/ann/layer/layer_types.hpp | 2 + .../methods/ann/layer/pixel_shuffle.hpp | 179 ++++++++++++++++++ .../methods/ann/layer/pixel_shuffle_impl.hpp | 144 ++++++++++++++ src/mlpack/tests/ann_layer_test.cpp | 74 ++++++++ 6 files changed, 402 insertions(+) create mode 100644 src/mlpack/methods/ann/layer/pixel_shuffle.hpp create mode 100644 src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp diff --git a/src/mlpack/methods/ann/layer/CMakeLists.txt b/src/mlpack/methods/ann/layer/CMakeLists.txt index c3ae086c870..b4034f580c3 100644 --- a/src/mlpack/methods/ann/layer/CMakeLists.txt +++ b/src/mlpack/methods/ann/layer/CMakeLists.txt @@ -79,6 +79,8 @@ set(SOURCES noisylinear_impl.hpp parametric_relu.hpp parametric_relu_impl.hpp + pixel_shuffle.hpp + pixel_shuffle_impl.hpp recurrent.hpp recurrent_impl.hpp recurrent_attention.hpp diff --git a/src/mlpack/methods/ann/layer/layer.hpp b/src/mlpack/methods/ann/layer/layer.hpp index 8e8e00691eb..66734116165 100644 --- a/src/mlpack/methods/ann/layer/layer.hpp +++ b/src/mlpack/methods/ann/layer/layer.hpp @@ -55,6 +55,7 @@ #include "noisylinear.hpp" #include "padding.hpp" #include "parametric_relu.hpp" +#include "pixel_shuffle.hpp" #include "recurrent_attention.hpp" #include "recurrent.hpp" #include "reinforce_normal.hpp" diff --git a/src/mlpack/methods/ann/layer/layer_types.hpp b/src/mlpack/methods/ann/layer/layer_types.hpp index 0f52a24df79..9e8dd0ce360 100644 --- a/src/mlpack/methods/ann/layer/layer_types.hpp +++ b/src/mlpack/methods/ann/layer/layer_types.hpp @@ -40,6 +40,7 @@ #include #include #include +#include #include #include #include @@ -272,6 +273,7 @@ using LayerTypes = boost::variant< NoisyLinear*, Padding*, PReLU*, + PixelShuffle*, Softmax*, TransposedConvolution, NaiveConvolution, diff --git a/src/mlpack/methods/ann/layer/pixel_shuffle.hpp b/src/mlpack/methods/ann/layer/pixel_shuffle.hpp new file mode 100644 index 00000000000..27dea9ba381 --- /dev/null +++ b/src/mlpack/methods/ann/layer/pixel_shuffle.hpp @@ -0,0 +1,179 @@ +/** + * @file methods/ann/layer/pixel_shuffle.hpp + * @author Anjishnu Mukherjee + * + * Definition of the PixelShuffle class. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef MLPACK_METHODS_ANN_LAYER_PIXEL_SHUFFLE_HPP +#define MLPACK_METHODS_ANN_LAYER_PIXEL_SHUFFLE_HPP + +#include + +namespace mlpack { +namespace ann /** Artificial Neural Network. */ { + +/** + * Implementation of the PixelShuffle layer. + * + * For more information, refer to the following paper, + * + * @code + * @article{Shi16, + * author = {Wenzhe Shi, Jose Caballero,Ferenc Huszár, Johannes Totz, + * Andrew P. Aitken, Rob Bishop, Daniel Rueckert, Zehan Wang}, + * title = {Real-Time Single Image and Video Super-Resolution Using an + * Efficient Sub-Pixel Convolutional Neural Network}, + * journal = {CoRR}, + * volume = {abs/1609.05158}, + * year = {2016}, + * url = {https://arxiv.org/abs/1609.05158}, + * eprint = {1609.05158}, + * } + * @endcode + * + * @tparam InputDataType Type of the input data (arma::colvec, arma::mat, + * arma::sp_mat or arma::cube). + * @tparam OutputDataType Type of the output data (arma::colvec, arma::mat, + * arma::sp_mat or arma::cube). + */ +template < + typename InputDataType = arma::mat, + typename OutputDataType = arma::mat +> +class PixelShuffle +{ + public: + //! Create the PixelShuffle object. + PixelShuffle(); + /** + * Create the PixelShuffle object using the specified parameters. + * The number of input channels should be an integral multiple of the square + * of the upscale factor. + * + * @param upscaleFactor The scaling factor for Pixel Shuffle. + * @param height The height of each input image. + * @param width The width of each input image. + * @param size The number of channels of each input image. + */ + PixelShuffle( size_t upscaleFactor, + size_t height, + size_t width, + size_t size); + + /** + * Ordinary feed forward pass of the PixelShuffle layer. + * + * @param input Input data used for evaluating the specified function. + * @param output Resulting output activation. + */ + template + void Forward(const arma::Mat& input, arma::Mat& output); + + /** + * Ordinary feed backward pass of the PixelShuffle layer. + * + * @param * (input) The propagated input activation. + * @param gy The backpropagated error. + * @param g The calculated gradient. + */ + template + void Backward(const arma::Mat& input, + const arma::Mat& gy, + arma::Mat& g); + + //! Get the output parameter. + OutputDataType const& OutputParameter() const { return outputParameter; } + //! Modify the output parameter. + OutputDataType& OutputParameter() { return outputParameter; } + + //! Get the delta. + OutputDataType const& Delta() const { return delta; } + //! Modify the delta. + OutputDataType& Delta() { return delta; } + + //! Get the upscale factor. + size_t UpscaleFactor() const { return upscaleFactor; } + + //! Modify the upscale factor. + size_t& UpscaleFactor() { return upscaleFactor; } + + //! Get the input image height. + size_t InputHeight() const { return height; } + + //! Modify the input image height. + size_t& InputHeight() { return height; } + + //! Get the input image width. + size_t InputWidth() const { return width; } + + //! Modify the input image width. + size_t& InputWidth() { return width; } + + //! Get the number of input channels. + size_t InputChannels() const { return size; } + + //! Modify the number of input channels. + size_t& InputChannels() { return size; } + + //! Get the output image height. + size_t OutputHeight() const { return outputHeight; } + + //! Get the output image width. + size_t OutputWidth() const { return outputWidth; } + + //! Get the number of output channels. + size_t OutputChannels() const { return sizeOut; } + + /** + * Serialize the layer. + */ + template + void serialize(Archive& ar, const unsigned int /* version */); + + private: + //! Locally-stored delta object. + OutputDataType delta; + + //! Locally-stored output parameter object. + OutputDataType outputParameter; + + //! The scaling factor for Pixel Shuffle. + size_t upscaleFactor; + + //! The height of each input image. + size_t height; + + //! The width of each input image. + size_t width; + + //! The number of channels of each input image. + size_t size; + + //! The number of images in the batch. + size_t batchSize; + + //! The height of each output image. + size_t outputHeight; + + //! The width of each output image. + size_t outputWidth; + + //! The number of channels of each output image. + size_t sizeOut; + + //! A boolean used to do some internal calculations once initially. + bool reset; +}; // class PixelShuffle + +} // namespace ann +} // namespace mlpack + +// Include implementation. +#include "pixel_shuffle_impl.hpp" + +#endif diff --git a/src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp b/src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp new file mode 100644 index 00000000000..cfbd4287b9e --- /dev/null +++ b/src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp @@ -0,0 +1,144 @@ +/** + * @file methods/ann/layer/pixel_shuffle_impl.hpp + * @author Anjishnu Mukherjee + * + * Implementation of the PixelShuffle class. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef MLPACK_METHODS_ANN_LAYER_PIXEL_SHUFFLE_IMPL_HPP +#define MLPACK_METHODS_ANN_LAYER_PIXEL_SHUFFLE_IMPL_HPP + +// In case it hasn't yet been included. +#include "pixel_shuffle.hpp" + +namespace mlpack { +namespace ann /** Artificial Neural Network. */ { + +template +PixelShuffle::PixelShuffle() : + upscaleFactor(0), + height(0), + width(0), + size(0), + reset(false) +{ + // Nothing to do here. +} + +template +PixelShuffle::PixelShuffle( + size_t upscaleFactor, + size_t height, + size_t width, + size_t size) : + upscaleFactor(upscaleFactor), + height(height), + width(width), + size(size), + reset(false) +{ + // Nothing to do here. +} + +template +template +void PixelShuffle::Forward( + const arma::Mat& input, arma::Mat& output) +{ + if(!reset) + { + batchSize = input.n_cols; + sizeOut = size / std::pow(upscaleFactor, 2); + outputHeight = height * upscaleFactor; + outputWidth = width * upscaleFactor; + reset = true; + } + output.zeros(outputHeight * outputWidth * sizeOut, batchSize); + for(size_t n = 0; n < batchSize; n++) + { + arma::mat inputImage = input.col(n); + arma::mat outputImage = output.col(n); + arma::cube inputTemp(const_cast(inputImage).memptr(), height, + width, size, false, false); + arma::cube outputTemp(const_cast(outputImage).memptr(), + outputHeight, outputWidth, sizeOut, false, false); + + for (size_t c = 0; c < sizeOut ; c++) + { + for (size_t h = 0; h < outputHeight; h++) + { + for (size_t w = 0; w < outputWidth; w++) + { + size_t height_index = h / upscaleFactor; + size_t width_index = w / upscaleFactor; + size_t channel_index = (upscaleFactor * (h % upscaleFactor)) + + (w % upscaleFactor) + (c * std::pow(upscaleFactor, 2)); + outputTemp(w, h, c) = inputTemp(width_index, height_index, + channel_index); + } + } + } + output.col(n) = outputImage; + } +} + +template +template +void PixelShuffle::Backward( + const arma::Mat& input, const arma::Mat& gy, arma::Mat& g) +{ + g.zeros(arma::size(input)); + for(size_t n = 0; n < batchSize; n++) + { + arma::mat gyImage = gy.col(n); + arma::mat gImage = g.col(n); + arma::cube gyTemp(const_cast(gyImage).memptr(), outputHeight, + outputWidth, sizeOut, false, false); + arma::cube gTemp(const_cast(gImage).memptr(), height, width, + size, false, false); + + for (size_t c = 0; c < sizeOut ; c++) + { + for (size_t h = 0; h < outputHeight; h++) + { + for (size_t w = 0; w < outputWidth; w++) + { + size_t height_index = h / upscaleFactor; + size_t width_index = w / upscaleFactor; + size_t channel_index = (upscaleFactor * (h % upscaleFactor)) + + (w % upscaleFactor) + (c * std::pow(upscaleFactor, 2)); + gTemp(width_index, height_index, channel_index) = gyTemp(w, h, c); + } + } + } + + g.col(n) = gImage; + } +} + +template +template +void PixelShuffle::serialize( + Archive& ar, + const unsigned int /* version */) +{ + ar & BOOST_SERIALIZATION_NVP(delta); + ar & BOOST_SERIALIZATION_NVP(outputParameter); + ar & BOOST_SERIALIZATION_NVP(upscaleFactor); + ar & BOOST_SERIALIZATION_NVP(height); + ar & BOOST_SERIALIZATION_NVP(width); + ar & BOOST_SERIALIZATION_NVP(size); + ar & BOOST_SERIALIZATION_NVP(batchSize); + ar & BOOST_SERIALIZATION_NVP(outputHeight); + ar & BOOST_SERIALIZATION_NVP(outputWidth); + ar & BOOST_SERIALIZATION_NVP(sizeOut); +} + +} // namespace ann +} // namespace mlpack + +#endif diff --git a/src/mlpack/tests/ann_layer_test.cpp b/src/mlpack/tests/ann_layer_test.cpp index 29e16273b82..38fec13c059 100644 --- a/src/mlpack/tests/ann_layer_test.cpp +++ b/src/mlpack/tests/ann_layer_test.cpp @@ -4186,3 +4186,77 @@ TEST_CASE("BatchNormDeterministicTest", "[ANNLayerTest]") // The model should switch to training mode for predicting. REQUIRE(boost::get*>(module.Model()[0])->Deterministic() == 0); } + +/** + * Simple Test for PixelShuffle layer. + */ +TEST_CASE("PixelShuffleLayerTest", "[ANNLayerTest]") +{ + arma::mat input, output, gy, g, outputExpected, gExpected; + PixelShuffle<> module(2, 2, 2, 4); + + // Input is a batch of 2 images, each of size (2,2) and having 4 channels. + input << 1 << 3 << 2 << 4 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 + << 0 << 0 << arma::endr << 5 << 7 << 6 << 8 << 0 << 0 << 0 << 0 << 0 << 0 + << 0 << 0 << 0 << 0 << 0 << 0 << arma::endr; + + gy << 1 << 5 << 9 << 13 << 2 << 6 << 10 << 14 << 3 << 7 << 11 << 15 << 4 << 8 + << 12 << 16 << arma::endr << 17 << 21 << 25 << 29 << 18 << 22 << 26 << 30 + << 19 << 23 << 27 << 31 << 20 << 24 << 28 << 32 << arma::endr; + + // Calculated using torch.nn.PixelShuffle(). + outputExpected << 1 << 0 << 3 << 0 << 0 << 0 << 0 << 0 << 2 << 0 << 4 << 0 + << 0 << 0 << 0 << 0 << arma::endr << 5 << 0 << 7 << 0 << 0 << 0 << 0 << 0 + << 6 << 0 << 8 << 0 << 0 << 0 << 0 << 0 << arma::endr; + gExpected << 1 << 9 << 3 << 11 << 5 << 13 << 7 << 15 << 2 << 10 << 4 << 12 + << 6 << 14 << 8 << 16 << arma::endr << 17 << 25 << 19 << 27 << 21 << 29 + << 23 << 31 << 18 << 26 << 20 << 28 << 22 << 30 << 24 << 32 << arma::endr; + + input = input.t(); + outputExpected = outputExpected.t(); + gy = gy.t(); + gExpected = gExpected.t(); + + // Check the Forward pass of the layer. + module.Forward(input, output); + CheckMatrices(output, outputExpected); + + // Check the Backward pass of the layer. + module.Backward(input, gy, g); + CheckMatrices(g, gExpected); +} + +/** + * Test that the function that can access the parameters of the + * PixelShuffle layer works. + */ +TEST_CASE("PixelShuffleLayerParametersTest", "[ANNLayerTest]") +{ + // Create the layer using the empty constructor. + PixelShuffle<> layer; + + // Set the different input parameters of the layer. + layer.UpscaleFactor() = 2; + layer.InputHeight() = 2; + layer.InputWidth() = 2; + layer.InputChannels() = 4; + + // Make sure we can get the parameters successfully. + REQUIRE(layer.UpscaleFactor() == 2); + REQUIRE(layer.InputHeight() == 2); + REQUIRE(layer.InputWidth() == 2); + REQUIRE(layer.InputChannels() == 4); + + arma::mat input, output; + // Input is a batch of 2 images, each of size (2,2) and having 4 channels. + input << 1 << 3 << 2 << 4 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 + << 0 << 0 << arma::endr << 5 << 7 << 6 << 8 << 0 << 0 << 0 << 0 << 0 << 0 + << 0 << 0 << 0 << 0 << 0 << 0 << arma::endr; + input = input.t(); + layer.Forward(input, output); + + // Check whether output parameters are returned correctly. + REQUIRE(layer.OutputHeight() == 4 ); + REQUIRE(layer.OutputWidth() == 4); + REQUIRE(layer.OutputChannels() == 1); +} From 64dd26145a1b1c02593c105a10e2622d1aa33634 Mon Sep 17 00:00:00 2001 From: iamshnoo Date: Fri, 14 Aug 2020 19:08:50 +0530 Subject: [PATCH 2/6] Fix style issues. --- src/mlpack/methods/ann/layer/pixel_shuffle.hpp | 10 +++++----- src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/mlpack/methods/ann/layer/pixel_shuffle.hpp b/src/mlpack/methods/ann/layer/pixel_shuffle.hpp index 27dea9ba381..2fda7cf5f4c 100644 --- a/src/mlpack/methods/ann/layer/pixel_shuffle.hpp +++ b/src/mlpack/methods/ann/layer/pixel_shuffle.hpp @@ -60,10 +60,10 @@ class PixelShuffle * @param width The width of each input image. * @param size The number of channels of each input image. */ - PixelShuffle( size_t upscaleFactor, - size_t height, - size_t width, - size_t size); + PixelShuffle(size_t upscaleFactor, + size_t height, + size_t width, + size_t size); /** * Ordinary feed forward pass of the PixelShuffle layer. @@ -77,7 +77,7 @@ class PixelShuffle /** * Ordinary feed backward pass of the PixelShuffle layer. * - * @param * (input) The propagated input activation. + * @param input The propagated input activation. * @param gy The backpropagated error. * @param g The calculated gradient. */ diff --git a/src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp b/src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp index cfbd4287b9e..3ca2852074a 100644 --- a/src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp +++ b/src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp @@ -49,7 +49,7 @@ template void PixelShuffle::Forward( const arma::Mat& input, arma::Mat& output) { - if(!reset) + if (!reset) { batchSize = input.n_cols; sizeOut = size / std::pow(upscaleFactor, 2); @@ -58,7 +58,7 @@ void PixelShuffle::Forward( reset = true; } output.zeros(outputHeight * outputWidth * sizeOut, batchSize); - for(size_t n = 0; n < batchSize; n++) + for (size_t n = 0; n < batchSize; n++) { arma::mat inputImage = input.col(n); arma::mat outputImage = output.col(n); @@ -92,7 +92,7 @@ void PixelShuffle::Backward( const arma::Mat& input, const arma::Mat& gy, arma::Mat& g) { g.zeros(arma::size(input)); - for(size_t n = 0; n < batchSize; n++) + for (size_t n = 0; n < batchSize; n++) { arma::mat gyImage = gy.col(n); arma::mat gImage = g.col(n); From 5da525cde653ee3ed2bd5c4c979bdacdbd9b7959 Mon Sep 17 00:00:00 2001 From: iamshnoo Date: Fri, 14 Aug 2020 23:49:48 +0530 Subject: [PATCH 3/6] Fix consistency issue for constructor format. --- src/mlpack/methods/ann/layer/pixel_shuffle.hpp | 8 ++++---- src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/mlpack/methods/ann/layer/pixel_shuffle.hpp b/src/mlpack/methods/ann/layer/pixel_shuffle.hpp index 2fda7cf5f4c..c62b0f0ecbf 100644 --- a/src/mlpack/methods/ann/layer/pixel_shuffle.hpp +++ b/src/mlpack/methods/ann/layer/pixel_shuffle.hpp @@ -60,10 +60,10 @@ class PixelShuffle * @param width The width of each input image. * @param size The number of channels of each input image. */ - PixelShuffle(size_t upscaleFactor, - size_t height, - size_t width, - size_t size); + PixelShuffle(const size_t upscaleFactor, + const size_t height, + const size_t width, + const size_t size); /** * Ordinary feed forward pass of the PixelShuffle layer. diff --git a/src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp b/src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp index 3ca2852074a..d65a0c456d0 100644 --- a/src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp +++ b/src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp @@ -31,10 +31,10 @@ PixelShuffle::PixelShuffle() : template PixelShuffle::PixelShuffle( - size_t upscaleFactor, - size_t height, - size_t width, - size_t size) : + const size_t upscaleFactor, + const size_t height, + const size_t width, + const size_t size) : upscaleFactor(upscaleFactor), height(height), width(width), From 438cc009898dad31326636af9f3d47d908ced4af Mon Sep 17 00:00:00 2001 From: iamshnoo Date: Tue, 25 Aug 2020 11:47:28 +0530 Subject: [PATCH 4/6] Use suggestions from code review. --- .../methods/ann/layer/pixel_shuffle_impl.hpp | 1 - src/mlpack/tests/ann_layer_test.cpp | 56 ++++++++++++++----- 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp b/src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp index d65a0c456d0..31d137eb595 100644 --- a/src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp +++ b/src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp @@ -115,7 +115,6 @@ void PixelShuffle::Backward( } } } - g.col(n) = gImage; } } diff --git a/src/mlpack/tests/ann_layer_test.cpp b/src/mlpack/tests/ann_layer_test.cpp index c3c0b11edc6..a1ed47db97a 100644 --- a/src/mlpack/tests/ann_layer_test.cpp +++ b/src/mlpack/tests/ann_layer_test.cpp @@ -4302,38 +4302,64 @@ TEST_CASE("TransposedConvolutionWeightInitializationTest", "[ANNLayerTest]") */ TEST_CASE("PixelShuffleLayerTest", "[ANNLayerTest]") { - arma::mat input, output, gy, g, outputExpected, gExpected; - PixelShuffle<> module(2, 2, 2, 4); + arma::mat input1, output1, gy1, g1, outputExpected1, gExpected1; + arma::mat input2, output2, gy2, g2, outputExpected2, gExpected2; + PixelShuffle<> module1(2, 2, 2, 4); + PixelShuffle<> module2(2, 2, 2, 4); + + // Input is a single image, of size (2,2) and having 4 channels. + input1 << 1 << 3 << 2 << 4 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 + << 0 << 0 << arma::endr; + gy1 << 1 << 5 << 9 << 13 << 2 << 6 << 10 << 14 << 3 << 7 << 11 << 15 << 4 << 8 + << 12 << 16 << arma::endr; + + // Calculated using torch.nn.PixelShuffle(). + outputExpected1 << 1 << 0 << 3 << 0 << 0 << 0 << 0 << 0 << 2 << 0 << 4 << 0 + << 0 << 0 << 0 << 0 << arma::endr; + gExpected1 << 1 << 9 << 3 << 11 << 5 << 13 << 7 << 15 << 2 << 10 << 4 << 12 + << 6 << 14 << 8 << 16 << arma::endr; + + input1 = input1.t(); + outputExpected1 = outputExpected1.t(); + gy1 = gy1.t(); + gExpected1 = gExpected1.t(); + + // Check the Forward pass of the layer. + module1.Forward(input1, output1); + CheckMatrices(output1, outputExpected1); + + // Check the Backward pass of the layer. + module1.Backward(input1, gy1, g1); + CheckMatrices(g1, gExpected1); // Input is a batch of 2 images, each of size (2,2) and having 4 channels. - input << 1 << 3 << 2 << 4 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 + input2 << 1 << 3 << 2 << 4 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << arma::endr << 5 << 7 << 6 << 8 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << arma::endr; - - gy << 1 << 5 << 9 << 13 << 2 << 6 << 10 << 14 << 3 << 7 << 11 << 15 << 4 << 8 + gy2 << 1 << 5 << 9 << 13 << 2 << 6 << 10 << 14 << 3 << 7 << 11 << 15 << 4 << 8 << 12 << 16 << arma::endr << 17 << 21 << 25 << 29 << 18 << 22 << 26 << 30 << 19 << 23 << 27 << 31 << 20 << 24 << 28 << 32 << arma::endr; // Calculated using torch.nn.PixelShuffle(). - outputExpected << 1 << 0 << 3 << 0 << 0 << 0 << 0 << 0 << 2 << 0 << 4 << 0 + outputExpected2 << 1 << 0 << 3 << 0 << 0 << 0 << 0 << 0 << 2 << 0 << 4 << 0 << 0 << 0 << 0 << 0 << arma::endr << 5 << 0 << 7 << 0 << 0 << 0 << 0 << 0 << 6 << 0 << 8 << 0 << 0 << 0 << 0 << 0 << arma::endr; - gExpected << 1 << 9 << 3 << 11 << 5 << 13 << 7 << 15 << 2 << 10 << 4 << 12 + gExpected2 << 1 << 9 << 3 << 11 << 5 << 13 << 7 << 15 << 2 << 10 << 4 << 12 << 6 << 14 << 8 << 16 << arma::endr << 17 << 25 << 19 << 27 << 21 << 29 << 23 << 31 << 18 << 26 << 20 << 28 << 22 << 30 << 24 << 32 << arma::endr; - input = input.t(); - outputExpected = outputExpected.t(); - gy = gy.t(); - gExpected = gExpected.t(); + input2 = input2.t(); + outputExpected2 = outputExpected2.t(); + gy2 = gy2.t(); + gExpected2 = gExpected2.t(); // Check the Forward pass of the layer. - module.Forward(input, output); - CheckMatrices(output, outputExpected); + module2.Forward(input2, output2); + CheckMatrices(output2, outputExpected2); // Check the Backward pass of the layer. - module.Backward(input, gy, g); - CheckMatrices(g, gExpected); + module2.Backward(input2, gy2, g2); + CheckMatrices(g2, gExpected2); } /** From c6cd6b860c3e56c76c2df9abe0bfe17c3bf7d1e1 Mon Sep 17 00:00:00 2001 From: iamshnoo Date: Tue, 25 Aug 2020 16:25:52 +0530 Subject: [PATCH 5/6] FIx static analysis issue. --- src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp b/src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp index 31d137eb595..796ee54071d 100644 --- a/src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp +++ b/src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp @@ -24,6 +24,10 @@ PixelShuffle::PixelShuffle() : height(0), width(0), size(0), + batchSize(0), + outputHeight(0), + outputWidth(0), + sizeOut(0), reset(false) { // Nothing to do here. @@ -39,6 +43,10 @@ PixelShuffle::PixelShuffle( height(height), width(width), size(size), + batchSize(0), + outputHeight(0), + outputWidth(0), + sizeOut(0), reset(false) { // Nothing to do here. From cbb151000528700e9c57933030fc71e4c0803b70 Mon Sep 17 00:00:00 2001 From: iamshnoo Date: Wed, 26 Aug 2020 09:46:37 +0530 Subject: [PATCH 6/6] Update HISTORY.md for Pixel Shuffle layer. --- HISTORY.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index c04b2aeca17..ba5bb0b1944 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,7 @@ ### mlpack ?.?.? ###### ????-??-?? + * Added Pixel Shuffle layer (#2563). + * Force CMake to show error when it didn't find Python/modules (#2568). * Refactor `ProgramInfo()` to separate out all the different