8000 [3/N] fix clang-tidy warnings in torch/csrc (#108024) · pytorch/pytorch@054f3f1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 054f3f1

Browse files
cyyeverpytorchmergebot
authored andcommitted
[3/N] fix clang-tidy warnings in torch/csrc (#108024)
Apply fixes to some found issues by clang-tidy in torch/csrc. Pull Request resolved: #108024 Approved by: https://github.com/Skylion007, https://github.com/albanD, https://github.com/malfet
1 parent 356b8f6 commit 054f3f1

File tree

17 files changed

+105
-104
lines changed

17 files changed

+105
-104
lines changed

aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ void _sparse_binary_op_intersection_kernel_impl(
133133
Tensor& res,
134134
const Tensor& x_,
135135
const Tensor& y_,
136-
const std::vector<int64_t> broadcasted_shape,
136+
const std::vector<int64_t>& broadcasted_shape,
137137
const c10::optional<Tensor>& x_hash_opt_ = c10::nullopt,
138138
const c10::optional<Tensor>& y_hash_opt_ = c10::nullopt,
139139
const bool accumulate_matches = true,
@@ -445,7 +445,7 @@ void _sparse_binary_op_intersection_kernel_out(
445445
return;
446446
}
447447

448-
const auto t_hash = *t_hash_opt;
448+
const auto &t_hash = *t_hash_opt;
449449
TORCH_INTERNAL_ASSERT(
450450
t_hash.dim() == 1 && t_hash.scalar_type() == kLong && t_hash.size(-1) == t._indices().size(-1),
451451
NAME, "(): explicit hash values need to be a 1-dim Long tensor with the ",

torch/csrc/autograd/variable.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,7 @@ unsigned VariableHooks::_register_hook(
724724
auto& list = torch::autograd::impl::get_autograd_meta(self)->cpp_hooks_list_;
725725
if (!list) {
726726
torch::autograd::impl::create_cpp_hook(
727-
self, /*is_retains_grad_hook=*/false);
727+
self, /*is_retains_grad_hooks=*/false);
728728
}
729729
unsigned idx = list->size();
730730
list->push_back(hook);

torch/csrc/dynamo/cpython_defs.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ THP_PyFrame_OpAlreadyRan(_PyInterpreterFrame *frame, int opcode, int oparg)
6969
int
7070
THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame) {
7171
/* Merge fast locals into f->f_locals */
72-
PyObject *locals;
73-
PyObject **fast;
74-
PyCodeObject *co;
72+
PyObject *locals = NULL;
73+
PyObject **fast = NULL;
74+
PyCodeObject *co = NULL;
7575
locals = frame->f_locals;
7676
if (locals == NULL) {
7777
locals = frame->f_locals = PyDict_New();
@@ -232,7 +232,7 @@ PyFrameObject *
232232
THP_PyFrame_MakeAndSetFrameObject(_PyInterpreterFrame *frame)
233233
{
234234
CHECK(frame->frame_obj == NULL);
235-
PyObject *error_type, *error_value, *error_traceback;
235+
PyObject *error_type = NULL, *error_value = NULL, *error_traceback = NULL;
236236
PyErr_Fetch(&error_type, &error_value, &error_traceback);
237237

238238
PyFrameObject *f = THP_PyFrame_New_NoTrack(frame->f_code);

torch/csrc/dynamo/guards.cpp

Lines changed: 49 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ class TensorCheck {
3535
device_index_(v.device().index()),
3636
requires_grad_(v.requires_grad()),
3737
sizes_(std::move(dynamic_dims_sizes)),
38-
strides_(std::move(dynamic_dims_strides)) {
38+
strides_(std::move(dynamic_dims_strides)),
39+
dim_(static_cast<int64_t>(sizes_.size())) {
3940
// TODO(voz): In cases where sizes_ and strides_ are fully dynamic, should
4041
// we just treat this as optional?
41-
dim_ = sizes_.size();
4242
}
4343

4444
// See note in guards.py [Note - On Export Tensor Guards]
@@ -157,9 +157,9 @@ typedef struct {
157157
} TensorGuards;
158158

159159
static void TensorGuards_dealloc(TensorGuards* self) {
160-
if (self->checks != NULL) {
160+
if (self->checks != nullptr) {
161161
delete self->checks;
162-
self->checks = NULL;
162+
self->checks = nullptr;
163163
}
164164
Py_TYPE(self)->tp_free((PyObject*)self);
165165
}
@@ -169,7 +169,7 @@ static PyObject* TensorGuards_new(
169169
PyObject* args,
170170
PyObject* kwds) {
171171
TensorGuards* self = (TensorGuards*)type->tp_alloc(type, 0);
172-
if (self != NULL) {
172+
if (self != nullptr) {
173173
self->checks = new ChecksList();
174174
}
175175
return (PyObject*)self;
@@ -191,7 +191,7 @@ static std::vector<std::optional<int64_t>> pyListToVecOptInt(PyObject* pyList) {
191191
for (Py_ssize_t i = 0; i < size; i++) {
192192
PyObject* item = PyList_GetItem(pyList, i);
193193
if (item == Py_None) {
194-
vec.push_back(std::nullopt);
194+
vec.emplace_back(std::nullopt);
195195
} else {
196196
int64_t value = PyLong_AsLongLong(item);
197197
if (value == -1 && PyErr_Occurred()) {
@@ -200,7 +200,7 @@ static std::vector<std::optional<int64_t>> pyListToVecOptInt(PyObject* pyList) {
200200
"Size or stride list item is not a valid integer.");
201201
TORCH_CHECK(false, "Size or stride list item is not a valid integer.");
202202
}
203-
vec.push_back(value);
203+
vec.emplace_back(value);
204204
}
205205
}
206206
return vec;
@@ -231,13 +231,13 @@ static int TensorGuards_init(
231231
// Top level structure is List[List[Union[int, None]]]
232232
PyObject* dynamic_dims_sizes_py =
233233
PyDict_GetItemString(kwds, "dynamic_dims_sizes");
234-
if (dynamic_dims_sizes_py == NULL) {
234+
if (dynamic_dims_sizes_py == nullptr) {
235235
PyErr_SetString(PyExc_TypeError, "missing dynamic_dims_sizes=...");
236236
return -1;
237237
}
238238
PyObject* dynamic_dims_strides_py =
239239
PyDict_GetItemString(kwds, "dynamic_dims_strides");
240-
if (dynamic_dims_strides_py == NULL) {
240+
if (dynamic_dims_strides_py == nullptr) {
241241
PyErr_SetString(PyExc_TypeError, "missing dynamic_dims_strides=...");
242242
return -1;
243243
}
@@ -263,11 +263,11 @@ static int TensorGuards_init(
263263
}
264264
auto tensor = THPVariable_Unpack(item);
265265
std::vector<std::optional<int64_t>> tensor_dims_size =
266-
per_tensor_dynamic_dims_sizes.size() == 0
266+
per_tensor_dynamic_dims_sizes.empty()
267267
? wrapIntegersInOptional(tensor.sizes())
268268
: per_tensor_dynamic_dims_sizes[i];
269269
std::vector<std::optional<int64_t>> tensor_dims_stride =
270-
per_tensor_dynamic_dims_strides.size() == 0
270+
per_tensor_dynamic_dims_strides.empty()
271271
? wrapIntegersInOptional(tensor.strides())
272272
: per_tensor_dynamic_dims_strides[i];
273273
checks.emplace_back(
@@ -286,7 +286,7 @@ PyObject* TensorGuards_check(
286286
PyObject* kwargs) {
287287
if (!PyTuple_CheckExact(args)) {
288288
PyErr_SetString(PyExc_TypeError, "expected tuple()");
289-
return NULL;
289+
return nullptr;
290290
}
291291
auto& checks = *self->checks;
292292
auto len = PyTuple_GET_SIZE(args);
@@ -295,7 +295,7 @@ PyObject* TensorGuards_check(
295295

296296
if (static_cast<decltype(len)>(checks.size()) != len) {
297297
PyErr_SetString(PyExc_TypeError, "wrong length");
298-
return NULL;
298+
return nullptr;
299299
}
300300

301301
LocalState state;
@@ -330,34 +330,34 @@ PyObject* TensorGuards_check_verbose(
330330
PyObject* kwargs) {
331331
if (!PyTuple_CheckExact(args)) {
332332
PyErr_SetString(PyExc_TypeError, "expected tuple()");
333-
return NULL;
333+
return nullptr;
334334
}
335335
auto& checks = *self->checks;
336336
auto len = PyTuple_GET_SIZE(args);
337337

338338
if (static_cast<decltype(len)>(checks.size()) != len) {
339339
PyErr_SetString(PyExc_TypeError, "wrong length");
340-
return NULL;
340+
return nullptr;
341341
}
342342

343343
PyObject* tensor_check_names_py =
344344
PyDict_GetItemString(kwargs, "tensor_check_names");
345-
if (tensor_check_names_py == NULL) {
345+
if (tensor_check_names_py == nullptr) {
346346
PyErr_SetString(PyExc_TypeError, "missing tensor_check_names kwarg");
347-
return NULL;
347+
return nullptr;
348348
}
349349

350350
if (!PyList_Check(tensor_check_names_py)) {
351351
PyErr_SetString(PyExc_TypeError, "tensor_check_names kwarg must be a list");
352-
return NULL;
352+
return nullptr;
353353
}
354354

355355
auto names_size = PyList_Size(tensor_check_names_py);
356356
if (names_size != static_cast<decltype(names_size)>(checks.size())) {
357357
PyErr_SetString(
358358
PyExc_TypeError,
359359
"tensor_check_names should be the same size as # tensors");
360-
return NULL;
360+
return nullptr;
361361
}
362362

363363
std::vector<std::string> tensor_check_names;
@@ -367,7 +367,7 @@ PyObject* TensorGuards_check_verbose(
367367
if (!PyUnicode_Check(value)) {
368368
PyErr_SetString(
369369
PyExc_TypeError, "tensor_check_names must only contain strings");
370-
return NULL;
370+
return nullptr;
371371
}
372372
tensor_check_names.emplace_back(PyUnicode_AsUTF8(value));
373373
}
@@ -407,6 +407,7 @@ PyObject* TensorGuards_check_verbose(
407407
Py_RETURN_TRUE;
408408
}
409409

410+
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
410411
static PyMethodDef TensorGuards_methods[] = {
411412
{"check",
412413
(PyCFunction)(void*)TensorGuards_check,
@@ -416,20 +417,19 @@ static PyMethodDef TensorGuards_methods[] = {
416417
(PyCFunction)(void*)TensorGuards_check_verbose,
417418
METH_VARARGS | METH_KEYWORDS,
418419
"verbose fail reasons for failed checks"},
419-
{NULL} /* Sentinel */
420+
{nullptr} /* Sentinel */
420421
};
421422

422-
static PyTypeObject TensorGuardsType = {
423-
// NOLINTNEXTLINE
424-
PyVarObject_HEAD_INIT(NULL, 0)};
423+
static PyTypeObject TensorGuardsType = {PyVarObject_HEAD_INIT(nullptr, 0)};
425424

426425
static PyObject* check_type_id(PyObject* dummy, PyObject* args) {
427426
// faster `lambda obj, expected: id(type(obj)) == expected`
428-
PyObject* obj;
429-
unsigned long long expected;
427+
PyObject* obj = nullptr;
428+
unsigned long long expected = 0;
430429
if (!PyArg_ParseTuple(args, "OK", &obj, &expected)) {
431-
return NULL;
430+
return nullptr;
432431
}
432+
// NOLINTNEXTLINE(performance-no-int-to-ptr)
433433
if (Py_TYPE(obj) == (void*)expected) {
434434
Py_RETURN_TRUE;
435435
} else {
@@ -439,11 +439,12 @@ static PyObject* check_type_id(PyObject* dummy, PyObject* args) {
439439

440440
static PyObject* check_obj_id(PyObject* dummy, PyObject* args) {
441441
// faster `lambda obj, expected: id(obj) == expected`
442-
PyObject* obj;
443-
unsigned long long expected;
442+
PyObject* obj = nullptr;
443+
unsigned long long expected = 0;
444444
if (!PyArg_ParseTuple(args, "OK", &obj, &expected)) {
445-
return NULL;
445+
return nullptr;
446446
}
447+
// NOLINTNEXTLINE(performance-no-int-to-ptr)
447448
if (obj == (void*)expected) {
448449
Py_RETURN_TRUE;
449450
} else {
@@ -456,25 +457,25 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
456457
Assert that a given tensor has a given size/stride, but ignore strides
457458
of size==1 dimensions. Implemented in C++ as this is on the hot path.
458459
*/
459-
PyObject* item;
460-
PyObject* size;
461-
PyObject* stride;
460+
PyObject* item = nullptr;
461+
PyObject* size = nullptr;
462+
PyObject* stride = nullptr;
462463
if (!PyArg_ParseTuple(args, "OOO", &item, &size, &stride)) {
463-
return NULL;
464+
return nullptr;
464465
}
465466
if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
466467
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
467-
return NULL;
468+
return nullptr;
468469
}
469470
if (!PyTuple_CheckExact(size) || !PyTuple_CheckExact(stride)) {
470471
PyErr_SetString(PyExc_TypeError, "expected tuple()");
471-
return NULL;
472+
return nullptr;
472473
}
473474
at::Tensor tensor = THPVariable_Unpack(item);
474475
int64_t ndim = tensor.ndimension();
475476
if (PyTuple_GET_SIZE(size) != ndim || PyTuple_GET_SIZE(stride) != ndim) {
476477
PyErr_SetString(PyExc_AssertionError, "wrong number of dimensions");
477-
return NULL;
478+
return nullptr;
478479
}
479480
for (auto i : c10::irange(ndim)) {
480481
int64_t want_size = THPUtils_unpackLong(PyTuple_GET_ITEM(size, i));
@@ -488,17 +489,18 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
488489
msg << "expected size " << actual_size << "==" << want_size << ", stride "
489490
<< actual_stride << "==" << want_stride << " at dim=" << i;
490491
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());
491-
return NULL;
492+
return nullptr;
492493
}
493494
}
494495
Py_RETURN_TRUE;
495496
}
496497

498+
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
497499
static PyMethodDef _methods[] = {
498-
{"check_type_id", check_type_id, METH_VARARGS, NULL},
499-
{"check_obj_id", check_obj_id, METH_VARARGS, NULL},
500-
{"assert_size_stride", assert_size_stride, METH_VARARGS, NULL},
501-
{NULL, NULL, 0, NULL}};
500+
{"check_type_id", check_type_id, METH_VARARGS, nullptr},
501+
{"check_obj_id", check_obj_id, METH_VARARGS, nullptr},
502+
{"assert_size_stride", assert_size_stride, METH_VARARGS, nullptr},
503+
{nullptr, nullptr, 0, nullptr}};
502504

503505
static struct PyModuleDef _module = {
504506
PyModuleDef_HEAD_INIT,
@@ -521,19 +523,18 @@ PyObject* torch_c_dynamo_guards_init() {
521523
TensorGuardsType.tp_init = (initproc)TensorGuards_init;
522524
TensorGuardsType.tp_new = TensorGuards_new;
523525

524-
PyObject* m;
525526
if (PyType_Ready(&TensorGuardsType) < 0)
526-
return NULL;
527+
return nullptr;
527528

528-
m = PyModule_Create(&_module);
529-
if (m == NULL)
530-
return NULL;
529+
auto m = PyModule_Create(&_module);
530+
if (m == nullptr)
531+
return nullptr;
531532

532533
Py_INCREF(&TensorGuardsType);
533534
if (PyModule_AddObject(m, "TensorGuards", (PyObject*)&TensorGuardsType) < 0) {
534535
Py_DECREF(&TensorGuardsType);
535536
Py_DECREF(m);
536-
return NULL;
537+
return nullptr;
537538
}
538539

539540
return m;

torch/csrc/dynamo/init.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,26 @@
66
#include <torch/csrc/dynamo/python_compiled_autograd.h>
77

88
static struct PyModuleDef _module =
9-
{PyModuleDef_HEAD_INIT, "torch._C._dynamo", "", -1, NULL};
9+
{PyModuleDef_HEAD_INIT, "torch._C._dynamo", "", -1, nullptr};
1010

1111
namespace torch {
1212
namespace dynamo {
1313
using torch::dynamo::autograd::torch_c_dynamo_compiled_autograd_init;
1414

1515
void initDynamoBindings(PyObject* torch) {
1616
PyObject* dynamo = PyModule_Create(&_module);
17-
if (dynamo == NULL || PyModule_AddObject(torch, "_dynamo", dynamo) != 0) {
17+
if (dynamo == nullptr || PyModule_AddObject(torch, "_dynamo", dynamo) != 0) {
1818
throw python_error();
1919
}
2020

2121
PyObject* eval_frame = torch_c_dynamo_eval_frame_init();
22-
if (eval_frame == NULL ||
22+
if (eval_frame == nullptr ||
2323
PyModule_AddObject(dynamo, "eval_frame", eval_frame) != 0) {
2424
throw python_error();
2525
}
2626

2727
PyObject* guards = torch_c_dynamo_guards_init();
28-
if (guards == NULL || PyModule_AddObject(dynamo, "guards", guards) != 0) {
28+
if (guards == nullptr || PyModule_AddObject(dynamo, "guards", guards) != 0) {
2929
throw python_error();
3030
}
3131

0 commit comments

Comments
 (0)
0