8000 Fix fallback FBGEMM implementation for Big Endian systems. (#96422) · pytorch/pytorch@60bb02a · GitHub
[go: up one dir, main page]

Skip to content

Commit 60bb02a

Browse files
Fix fallback FBGEMM implementation for Big Endian systems. (#96422)
This change fixes multiple tests in test/test_quantization.py::TestQuantizedEmbeddingOps. Pull Request resolved: #96422 Approved by: https://github.com/huydhn
1 parent 49e964c commit 60bb02a

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,33 +108,69 @@ at::Tensor& embedding_lookup_fallback_impl(
108108
const uint8_t* scale_bias =
109109
weight_data + (idx + 1) * weight_size - 2 * sizeof(float);
110110
uint32_t scale_val_int32 = 0;
111+
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
111112
scale_val_int32 = scale_val_int32 |
112113
(scale_bias[0]) |
113114
(scale_bias[1] << 8) |
114115
(scale_bias[2] << 16) |
115116
(scale_bias[3] << 24);
117+
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
118+
scale_val_int32 = scale_val_int32 |
119+
(scale_bias[3]) |
120+
(scale_bias[2] << 8) |
121+
(scale_bias[1] << 16) |
122+
(scale_bias[0] << 24);
123+
#else
124+
#error Unexpected or undefined __BYTE_ORDER__
125+
#endif
116126
float scale_val = (reinterpret_cast<float*>(&scale_val_int32))[0];
117127
uint32_t bias_val_int32 = 0;
128+
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
118129
bias_val_int32 = bias_val_int32 |
119130
(scale_bias[4]) |
120131
(scale_bias[5] << 8) |
121132
(scale_bias[6] << 16) |
122133
(scale_bias[7] << 24);
134+
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
135+
bias_val_int32 = bias_val_int32 |
136+
(scale_bias[7]) |
137+
(scale_bias[6] << 8) |
138+
(scale_bias[5] << 16) |
139+
(scale_bias[4] << 24);
140+
#else
141+
#error Unexpected or undefined __BYTE_ORDER__
142+
#endif
123143
float bias_val = (reinterpret_cast<float*>(&bias_val_int32))[0];
124144
scale = weight_val * scale_val;
125145
bias = weight_val * bias_val;
126146
} else {
127147
const uint8_t* scale_bias =
128148
weight_data + (idx + 1) * weight_size - 2 * sizeof(at::Half);
129149
uint16_t scale_val_int16 = 0;
150+
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
130151
scale_val_int16 = scale_val_int16 |
131152
(scale_bias[0]) |
132153
(scale_bias[1] << 8);
154+
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
155+
scale_val_int16 = scale_val_int16 |
156+
(scale_bias[1]) |
157+
(scale_bias[0] << 8);
158+
#else
159+
#error Unexpected or undefined __BYTE_ORDER__
160+
#endif
133161
at::Half scale_val = (reinterpret_cast<at::Half*>(&scale_val_int16))[0];
134162
uint16_t bias_val_int16 = 0;
163+
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
135164
bias_val_int16 = bias_val_int16 |
136165
(scale_bias[2]) |
137166
(scale_bias[3] << 8);
167+
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
168+
bias_val_int16 = bias_val_int16 |
169+
(scale_bias[3]) |
170+
(scale_bias[2] << 8);
171+
#else
172+
#error Unexpected or undefined __BYTE_ORDER__
173+
#endif
138174
at::Half bias_val = (reinterpret_cast<at::Half*>(&bias_val_int16))[0];
139175
scale = weight_val * scale_val;
140176
bias = weight_val * bias_val;

0 commit comments

Comments
 (0)
0