2
2
3
3
#include < ATen/ATen.h>
4
4
5
- #include < iostream>
5
+ #include < iostream>
6
6
using namespace std ;
7
7
using namespace at ;
8
8
@@ -33,6 +33,32 @@ TEST(atest, operators) {
33
33
ASSERT_TRUE (tensor ({a ^ b}).equal (a_tensor ^ b_tensor));
34
34
}
35
35
36
+ template <class T >
37
+ void run_logical_op_test (const Tensor& exp, T func) {
38
+ auto x_tensor = tensor ({1 , 1 , 0 , 1 , 0 });
39
+ auto y_tensor = tensor ({0 , 1 , 0 , 1 , 1 });
40
+ // Test op over integer tensors
41
+ auto out_tensor = empty ({5 }, kInt );
42
+ func (out_tensor, x_tensor, y_tensor);
43
+ ASSERT_EQ (out_tensor.dtype (), kInt );
44
+ ASSERT_TRUE (exp.equal (out_tensor));
45
+ // Test op over boolean tensors
46
+ out_tensor = empty ({5 }, kBool );
47
+ func (out_tensor, x_tensor.to (kBool ), y_tensor.to (kBool ));
48
+ ASSERT_EQ (out_tensor.dtype (), kBool );
49
+ ASSERT_TRUE (out_tensor.equal (exp.to (kBool )));
50
+ }
51
+
52
+ TEST (atest, logical_and_operators) {
53
+ run_logical_op_test (tensor ({0 , 1 , 0 , 1 , 0 }), logical_and_out);
54
+ }
55
+ TEST (atest, logical_or_operators) {
56
+ run_logical_op_test (tensor ({1 , 1 , 0 , 1 , 1 }), logical_or_out);
57
+ }
58
+ TEST (atest, logical_xor_operators) {
59
+ run_logical_op_test (tensor ({1 , 0 , 0 , 0 , 1 }), logical_xor_out);
60
+ }
61
+
36
62
// TEST_CASE( "atest", "[]" ) {
37
63
TEST (atest, atest) {
38
64
manual_seed (123 );
@@ -84,17 +110,15 @@ TEST(atest, atest) {
84
110
{
85
111
int isgone = 0 ;
86
112
{
87
- auto f2 =
88
- from_blob (data, {1 , 2 , 3 }, [&](void *) { isgone++; });
113
+ auto f2 = from_blob (data, {1 , 2 , 3 }, [&](void *) { isgone++; });
89
114
}
90
115
ASSERT_EQ (isgone, 1 );
91
116
}
92
117
{
93
118
int isgone = 0 ;
94
119
Tensor a_view;
95
120
{
96
- auto f2 =
97
- from_blob (data, {1 , 2 , 3 }, [&](void *) { isgone++; });
121
+ auto f2 = from_blob (data, {1 , 2 , 3 }, [&](void *) { isgone++; });
98
122
a_view = f2.view ({3 , 2 , 1 });
99
123
}
100
124
ASSERT_EQ (isgone, 0 );
@@ -105,17 +129,17 @@ TEST(atest, atest) {
105
129
if (at::hasCUDA ()) {
106
130
int isgone = 0 ;
107
131
{
108
- auto base = at::empty ({1 ,2 , 3 }, TensorOptions (kCUDA ));
132
+ auto base = at::empty ({1 , 2 , 3 }, TensorOptions (kCUDA ));
109
133
auto f2 = from_blob (base.data_ptr (), {1 , 2 , 3 }, [&](void *) { isgone++; });
110
134
}
111
135
ASSERT_EQ (isgone, 1 );
112
136
113
137
// Attempt to specify the wrong device in from_blob
114
- auto t = at::empty ({1 ,2 , 3 }, TensorOptions (kCUDA , 0 ));
115
- EXPECT_ANY_THROW (from_blob (t.data_ptr (), {1 ,2 , 3 }, at::
6632
Device (kCUDA , 1 )));
138
+ auto t = at::empty ({1 , 2 , 3 }, TensorOptions (kCUDA , 0 ));
139
+ EXPECT_ANY_THROW (from_blob (t.data_ptr (), {1 , 2 , 3 }, at::Device (kCUDA , 1 )));
116
140
117
141
// Infers the correct device
118
- auto t_ = from_blob (t.data_ptr (), {1 ,2 , 3 }, kCUDA );
142
+ auto t_ = from_blob (t.data_ptr (), {1 , 2 , 3 }, kCUDA );
119
143
ASSERT_EQ (t_.device (), at::Device (kCUDA , 0 ));
120
144
}
121
145
}
0 commit comments