8000 Add tests to logical operation in BinaryOpsKernel.cpp (#41515) · pytorch/pytorch@e324ea8 · GitHub
[go: up one dir, main page]

Skip to content

Commit e324ea8

Browse files
scintillerfacebook-github-bot
authored andcommitted
Add tests to logical operation in BinaryOpsKernel.cpp (#41515)
Summary: Pull Request resolved: #41515 add test in atest.cpp to cover logical_and_kernel, logical_or_kernel and logical_nor_kernel in Aten/native/cpu/BinaryOpsKernel.cpp https://pxl.cl/1drmV Test Plan: CI Reviewed By: malfet Differential Revision: D22565235 fbshipit-source-id: 7ad9fd8420d7fdd23fd9a703c75da212f72bde2c
1 parent f49d97a commit e324ea8

File tree

1 file changed

+33
-9
lines changed

1 file changed

+33
-9
lines changed

aten/src/ATen/test/atest.cpp

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#include <ATen/ATen.h>
44

5-
#include<iostream>
5+
#include <iostream>
66
using namespace std;
77
using namespace at;
88

@@ -33,6 +33,32 @@ TEST(atest, operators) {
3333
ASSERT_TRUE(tensor({a ^ b}).equal(a_tensor ^ b_tensor));
3434
}
3535

36+
template <class T>
37+
void run_logical_op_test(const Tensor& exp, T func) {
38+
auto x_tensor = tensor({1, 1, 0, 1, 0});
39+
auto y_tensor = tensor({0, 1, 0, 1, 1});
40+
// Test op over integer tensors
41+
auto out_tensor = empty({5}, kInt);
42+
func(out_tensor, x_tensor, y_tensor);
43+
ASSERT_EQ(out_tensor.dtype(), kInt);
44+
ASSERT_TRUE(exp.equal(out_tensor));
45+
// Test op over boolean tensors
46+
out_tensor = empty({5}, kBool);
47+
func(out_tensor, x_tensor.to(kBool), y_tensor.to(kBool));
48+
ASSERT_EQ(out_tensor.dtype(), kBool);
49+
ASSERT_TRUE(out_tensor.equal(exp.to(kBool)));
50+
}
51+
52+
TEST(atest, logical_and_operators) {
53+
run_logical_op_test(tensor({0, 1, 0, 1, 0}), logical_and_out);
54+
}
55+
TEST(atest, logical_or_operators) {
56+
run_logical_op_test(tensor({1, 1, 0, 1, 1}), logical_or_out);
57+
}
58+
TEST(atest, logical_xor_operators) {
59+
run_logical_op_test(tensor({1, 0, 0, 0, 1}), logical_xor_out);
60+
}
61+
3662
// TEST_CASE( "atest", "[]" ) {
3763
TEST(atest, atest) {
3864
manual_seed(123);
@@ -84,17 +110,15 @@ TEST(atest, atest) {
84110
{
85111
int isgone = 0;
86112
{
87-
auto f2 =
88-
from_blob(data, {1, 2, 3}, [&](void*) { isgone++; });
113+
auto f2 = from_blob(data, {1, 2, 3}, [&](void*) { isgone++; });
89114
}
90115
ASSERT_EQ(isgone, 1);
91116
}
92117
{
93118
int isgone = 0;
94119
Tensor a_view;
95120
{
96-
auto f2 =
97-
from_blob(data, {1, 2, 3}, [&](void*) { isgone++; });
121+
auto f2 = from_blob(data, {1, 2, 3}, [&](void*) { isgone++; });
98122
a_view = f2.view({3, 2, 1});
99123
}
100124
ASSERT_EQ(isgone, 0);
@@ -105,17 +129,17 @@ TEST(atest, atest) {
105129
if (at::hasCUDA()) {
106130
int isgone = 0;
107131
{
108-
auto base = at::empty({1,2,3}, TensorOptions(kCUDA));
132+
auto base = at::empty({1, 2, 3}, TensorOptions(kCUDA));
109133
auto f2 = from_blob(base.data_ptr(), {1, 2, 3}, [&](void*) { isgone++; });
110134
}
111135
ASSERT_EQ(isgone, 1);
112136

113137
// Attempt to specify the wrong device in from_blob
114-
auto t = at::empty({1,2,3}, TensorOptions(kCUDA, 0));
115-
EXPECT_ANY_THROW(from_blob(t.data_ptr(), {1,2,3}, at:: 6632 Device(kCUDA, 1)));
138+
auto t = at::empty({1, 2, 3}, TensorOptions(kCUDA, 0));
139+
EXPECT_ANY_THROW(from_blob(t.data_ptr(), {1, 2, 3}, at::Device(kCUDA, 1)));
116140

117141
// Infers the correct device
118-
auto t_ = from_blob(t.data_ptr(), {1,2,3}, kCUDA);
142+
auto t_ = from_blob(t.data_ptr(), {1, 2, 3}, kCUDA);
119143
ASSERT_EQ(t_.device(), at::Device(kCUDA, 0));
120144
}
121145
}

0 commit comments

Comments
 (0)
0