@@ -125,7 +125,7 @@ def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
125
125
'radd' , 'rsub' , 'rmul' , 'rtruediv' , 'rfloordiv' , 'rpow' , # reverse arithmetic
126
126
'and' , 'or' , 'xor' , # logic
127
127
'iadd' , 'iand' , 'idiv' , 'ilshift' , 'imul' ,
128
- 'ior' , 'irshift' , 'isub' , 'ixor' , # inplace ops
128
+ 'ior' , 'irshift' , 'isub' , 'ixor' , 'ifloordiv' , 'imod' , # inplace ops
129
129
)
130
130
comparison_ops = ('eq' , 'ne' , 'ge' , 'gt' , 'lt' , 'le' )
131
131
unary_ops = ('neg' , 'abs' , 'invert' )
@@ -324,6 +324,32 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
324
324
'is_grad_enabled' : ['def is_grad_enabled() -> _bool: ...' ],
325
325
'nonzero' : ['def nonzero(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...' ,
326
326
'def nonzero(input: Tensor, *, as_tuple: bool=...) -> Tensor: ...' ],
327
+ 'binary_cross_entropy_with_logits' : ['def binary_cross_entropy_with_logits(input: Tensor, target: Tensor, '
328
+ 'weight: Optional[Tensor] = None, size_average: Optional[bool] = None, '
329
+ 'reduce: Optional[bool] = None, reduction: str = ..., '
330
+ 'pos_weight: Optional[Tensor] = None) -> Tensor: ...' ],
331
+ 'cosine_embedding_loss' : ['def cosine_embedding_loss(input1: Tensor, input2: Tensor, '
332
+ 'target: Tensor, margin: float = ..., size_average: Optional[bool] = ..., '
333
+ 'reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...' ],
334
+ 'ctc_loss' : ['def ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor,'
335
+ ' blank: int = ..., reduction: str = ..., zero_infinity: bool = ...) -> Tensor: ...' ],
336
+ 'hinge_embedding_loss' : ['def hinge_embedding_loss(input: Tensor, target: Tensor, margin: float = ...,'
337
+ ' size_average: Optional[bool] = ..., reduce: Optional[bool] = ..., '
338
+ 'reduction: str = ...) -> Tensor: ...' ],
339
+ 'kl_div' : ['def kl_div(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., '
340
+ 'reduce: Optional[bool] = ..., reduction: str = ..., log_target: bool = ...) -> Tensor: ...' ],
341
+ 'margin_ranking_loss' : ['def margin_ranking_loss(input1: Tensor, input2: Tensor, target: Tensor,'
342
+ ' margin: float = ..., size_average: Optional[bool] = ..., '
343
+ ' reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...' ],
344
+ 'triplet_margin_loss' : ['def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, '
345
+ 'margin: float = ..., p: float = ..., eps: float = ..., swap: bool = ..., '
346
+ 'size_average: Optional[bool] = ..., '
347
+ 'reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...' ],
348
+ 'dsmm' : ['def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ...' ],
349
+ 'hsmm' : ['def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ...' ],
350
+ 'saddmm' : ['def saddmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Number=1, '
351
+ 'alpha: Number=1, out: Optional[Tensor]=None) -> Tensor: ...' ],
352
+ 'spmm' : ['def spmm(input: Tensor, mat2: Tensor) -> Tensor: ...' ],
327
353
})
328
354
for binop in ['mul' , 'div' , 'true_divide' , 'floor_divide' ]:
329
355
unsorted_function_hints [binop ].append (
@@ -382,10 +408,12 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
382
408
'def __init__(self, size: _size, *, {}) -> None: ...' .format (DEVICE_PARAM ),
383
409
],
384
410
'as_subclass' : ["def as_subclass(self, cls: Tensor) -> Tensor: ..." ],
411
+ '_make_subclass' : ["def _make_subclass(cls, data: Tensor, require_grad: _bool = False) -> Tensor: ..." ],
385
412
# clamp has no default values in the Declarations
386
413
'clamp' : ["def clamp(self, min: _float=-inf, max: _float=inf,"
387
414
" *, out: Optional[Tensor]=None) -> Tensor: ..." ],
388
415
'clamp_' : ["def clamp_(self, min: _float=-inf, max: _float=inf) -> Tensor: ..." ],
416
+ '__get__' : ["def __get__(self, instance, owner=None) -> Tensor: ..." ],
389
417
'__getitem__' : ["def __getitem__(self, {}) -> Tensor: ..." .format (INDICES )],
390
418
'__setitem__' : ["def __setitem__(self, {}, val: Union[Tensor, Number])"
391
419
" -> None: ..." .format (INDICES )],
@@ -402,13 +430,17 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
402
430
'numpy' : ['def numpy(self) -> Any: ...' ],
403
431
'apply_' : ['def apply_(self, callable: Callable) -> Tensor: ...' ],
404
432
'map_' : ['def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ...' ],
433
+ 'map2_' : ['def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ...' ],
405
434
'storage' : ['def storage(self) -> Storage: ...' ],
435
+ 'storage_type' : ['def storage_type(self) -> Storage: ...' ],
406
436
'type' : ['def type(self, dtype: None=None, non_blocking: _bool=False) -> str: ...' ,
407
437
'def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False) -> Tensor: ...' ,
408
438
],
409
439
'get_device' : ['def get_device(self) -> _int: ...' ],
410
440
'contiguous' : ['def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ...' ],
441
+ 'has_names' : ['def has_names(self) -> _bool: ...' ],
411
442
'is_contiguous' : ['def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ...' ],
443
+ '_is_view' : ['def _is_view(self) -> _bool: ...' ],
412
444
'is_cuda' : ['is_cuda: _bool' ],
413
445
'is_leaf' : ['is_leaf: _bool' ],
414
446
'is_sparse' : ['is_sparse: _bool' ],
0 commit comments