8000 Pixel Shuffle layer by iamshnoo · Pull Request #2563 · mlpack/mlpack · GitHub
[go: up one dir, main page]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pixel Shuffle layer #2563

Merged
merged 13 commits into from
Feb 26, 2021
Prev Previous commit
Next Next commit
Use suggestions from code review.
  • Loading branch information
iamshnoo committed Aug 25, 2020
commit 438cc009898dad31326636af9f3d47d908ced4af
1 change: 0 additions & 1 deletion src/mlpack/methods/ann/layer/pixel_shuffle_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ void PixelShuffle<InputDataType, OutputDataType>::Backward(
}
}
}

g.col(n) = gImage;
}
}
Expand Down
56 changes: 41 additions & 15 deletions src/mlpack/tests/ann_layer_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4302,38 +4302,64 @@ TEST_CASE("TransposedConvolutionWeightInitializationTest", "[ANNLayerTest]")
*/
TEST_CASE("PixelShuffleLayerTest", "[ANNLayerTest]")
{
arma::mat input, output, gy, g, outputExpected, gExpected;
PixelShuffle<> module(2, 2, 2, 4);
arma::mat input1, output1, gy1, g1, outputExpected1, gExpected1;
arma::mat input2, output2, gy2, g2, outputExpected2, gExpected2;
PixelShuffle<> module1(2, 2, 2, 4);
PixelShuffle<> module2(2, 2, 2, 4);

// Input is a single image, of size (2,2) and having 4 channels.
input1 << 1 << 3 << 2 << 4 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0
<< 0 << 0 << arma::endr;
gy1 << 1 << 5 << 9 << 13 << 2 << 6 << 10 << 14 << 3 << 7 << 11 << 15 << 4 << 8
<< 12 << 16 << arma::endr;

// Calculated using torch.nn.PixelShuffle().
outputExpected1 << 1 << 0 << 3 << 0 << 0 << 0 << 0 << 0 << 2 << 0 << 4 << 0
<< 0 << 0 << 0 << 0 << arma::endr;
gExpected1 << 1 << 9 << 3 << 11 << 5 << 13 << 7 << 15 << 2 << 10 << 4 << 12
<< 6 << 14 << 8 << 16 << arma::endr;

input1 = input1.t();
outputExpected1 = outputExpected1.t();
gy1 = gy1.t();
gExpected1 = gExpected1.t();

// Check the Forward pass of the layer.
module1.Forward(input1, output1);
CheckMatrices(output1, outputExpected1);

// Check the Backward pass of the layer.
module1.Backward(input1, gy1, g1);
CheckMatrices(g1, gExpected1);

// Input is a batch of 2 images, each of size (2,2) and having 4 channels.
input << 1 << 3 << 2 << 4 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0
input2 << 1 << 3 << 2 << 4 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0 << 0
<< 0 << 0 << arma::endr << 5 << 7 << 6 << 8 << 0 << 0 << 0 << 0 << 0 << 0
<< 0 << 0 << 0 << 0 << 0 << 0 << arma::endr;

gy << 1 << 5 << 9 << 13 << 2 << 6 << 10 << 14 << 3 << 7 << 11 << 15 << 4 << 8
gy2 << 1 << 5 << 9 << 13 << 2 << 6 << 10 << 14 << 3 << 7 << 11 << 15 << 4 << 8
<< 12 << 16 << arma::endr << 17 << 21 << 25 << 29 << 18 << 22 << 26 << 30
<< 19 << 23 << 27 << 31 << 20 << 24 << 28 << 32 << arma::endr;

// Calculated using torch.nn.PixelShuffle().
outputExpected << 1 << 0 << 3 << 0 << 0 << 0 << 0 << 0 << 2 << 0 << 4 << 0
outputExpected2 << 1 << 0 << 3 << 0 << 0 << 0 << 0 << 0 << 2 << 0 << 4 << 0
<< 0 << 0 << 0 << 0 << arma::endr << 5 << 0 << 7 << 0 << 0 << 0 << 0 << 0
<< 6 << 0 << 8 << 0 << 0 << 0 << 0 << 0 << arma::endr;
gExpected << 1 << 9 << 3 << 11 << 5 << 13 << 7 << 15 << 2 << 10 << 4 << 12
gExpected2 << 1 << 9 << 3 << 11 << 5 << 13 << 7 << 15 << 2 << 10 << 4 << 12
<< 6 << 14 << 8 << 16 << arma::endr << 17 << 25 << 19 << 27 << 21 << 29
<< 23 << 31 << 18 << 26 << 20 << 28 << 22 << 30 << 24 << 32 << arma::endr;

input = input.t();
outputExpected = outputExpected.t();
gy = gy.t();
gExpected = gExpected.t();
input2 = input2.t();
outputExpected2 = outputExpected2.t();
gy2 = gy2.t();
gExpected2 = gExpected2.t();

// Check the Forward pass of the layer.
module.Forward(input, output);
CheckMatrices(output, outputExpected);
module2.Forward(input2, output2);
CheckMatrices(output2, outputExpected2);

// Check the Backward pass of the layer.
module.Backward(input, gy, g);
CheckMatrices(g, gExpected);
module2.Backward(input2, gy2, g2);
CheckMatrices(g2, gExpected2);
}

/**
Expand Down
0