|
1 | 1 | from typing import Callable, Optional, Union
|
2 | 2 |
|
3 | 3 | import torch
|
4 |
| - |
5 | 4 | import torch._prims as prims
|
6 | 5 | import torch._prims_common as utils
|
7 | 6 | import torch._refs as refs
|
|
25 | 24 | _make_elementwise_unary_reference,
|
26 | 25 | )
|
27 | 26 |
|
| 27 | +from torch._subclasses.fake_tensor import FakeTensor |
| 28 | + |
28 | 29 | __all__ = [
|
29 | 30 | "celu",
|
30 | 31 | "dropout",
|
|
36 | 37 | "l1_loss",
|
37 | 38 | "margin_ranking_loss",
|
38 | 39 | "mish",
|
| 40 | + "nll_loss", |
39 | 41 | "mse_loss",
|
40 | 42 | "poisson_nll_loss",
|
41 | 43 | "prelu",
|
@@ -435,6 +437,159 @@ def hinge_embedding_loss(
|
435 | 437 | return _apply_loss_reduction(loss, reduction)
|
436 | 438 |
|
437 | 439 |
|
| 440 | +def _nll_loss_nd( |
| 441 | + input: TensorLikeType, |
| 442 | + target: TensorLikeType, |
| 443 | + weight: Optional[TensorLikeType], |
| 444 | + reduction: str, |
| 445 | + ignore_index: int, |
| 446 | +) -> TensorLikeType: |
| 447 | + utils.check( |
| 448 | + input.ndim > 0 and input.ndim <= 3, |
| 449 | + lambda: f"Expected input dimension to be either [1, 2, 3] but recieved {input.ndim}.", |
| 450 | + ) |
| 451 | + |
| 452 | + utils.check( |
| 453 | + (input.ndim == 1) or (input.shape[0] == target.shape[0]), |
| 454 | + lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.", |
| 455 | + ) |
| 456 | + |
| 457 | + _check_reduction_value(reduction) |
| 458 | + |
| 459 | + flat_target = torch.flatten(target) |
| 460 | + ignore_classes_mask = torch.eq(flat_target, ignore_index) |
| 461 | + |
| 462 | + # TODO: Enable data-dependent checks with debug mode |
| 463 | + # TODO: This check does not work with FakeTensor inputs; See Issue #85834 |
| 464 | + # Explicit cast for class_check to bool; See Issue #78071 |
| 465 | + """ |
| 466 | + num_classes = input.shape[1] if input.ndim > 1 else input.shape[0] |
| 467 | + valid_classes_mask = torch.logical_and( |
| 468 | + (flat_target >= 0), (flat_target < num_classes) |
| 469 | + ) |
| 470 | + class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask)) |
| 471 | + utils.check( |
| 472 | + isinstance(target, FakeTensor) or bool(class_check.item()), |
| 473 | + lambda: "A target class is out-of-bounds and not the ignore index.", |
| 474 | + ) |
| 475 | + """ |
| 476 | + |
| 477 | + ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device) |
| 478 | + class_weight = ( |
| 479 | + torch.scalar_tensor(1, dtype=input.dtype, device=input.device) |
| 480 | + if weight is None |
| 481 | + else weight[flat_target] |
| 482 | + ) |
| 483 | + current_weight = torch.where( |
| 484 | + ignore_classes_mask, |
| 485 | + ignore_class_weight, |
| 486 | + class_weight, |
| 487 | + ) |
| 488 | + |
| 489 | + if input.ndim == 1: |
| 490 | + # implicit batch size = 1 |
| 491 | + # input (1 batch size, C classes) |
| 492 | + loss = -input[target] * current_weight |
| 493 | + elif input.ndim == 2: |
| 494 | + # input (N batch size, C classes) |
| 495 | + batch_size = input.shape[0] |
| 496 | + loss = -input[torch.arange(batch_size), target] * current_weight |
| 497 | + else: |
| 498 | + # 3D case (N batch size, C classe, K dimensions) |
| 499 | + # input (N batch size, C classes, K) |
| 500 | + batch_size = input.shape[0] |
| 501 | + extent = input.shape[2] |
| 502 | + numel = batch_size * extent |
| 503 | + indices = torch.arange(numel) |
| 504 | + bdx = indices // extent |
| 505 | + kdx = indices % extent |
| 506 | + loss = -input[bdx, flat_target, kdx] * current_weight |
| 507 | + loss = torch.reshape(loss, target.shape) |
| 508 | + |
| 509 | + if reduction == "none": |
| 510 | + return loss |
| 511 | + elif reduction == "sum": |
| 512 | + return torch.sum(loss) |
| 513 | + else: |
| 514 | + # calculate weighted mean of the loss function |
| 515 | + return torch.sum(loss) / torch.sum(current_weight) |
| 516 | + |
| 517 | + |
| 518 | +@register_decomposition(torch.ops.aten.nll_loss) |
| 519 | +@elementwise_type_promotion_wrapper( |
| 520 | + type_promoting_args=("input",), |
| 521 | + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
| 522 | +) |
| 523 | +@out_wrapper() |
| 524 | +def nll_loss( |
| 525 | + input: TensorLikeType, |
| 526 | + target: TensorLikeType, |
| 527 | + weight: Optional[TensorLikeType] = None, |
| 528 | + size_average: Optional[bool] = None, |
| 529 | + ignore_index: int = -100, |
| 530 | + reduce: Optional[bool] = None, |
| 531 | + reduction: str = "mean", |
| 532 | +) -> TensorLikeType: |
| 533 | + """ |
| 534 | + Reference implementation of torch.nn.functional.nll_loss |
| 535 | + """ |
| 536 | + utils.check( |
| 537 | + input.ndim > 0, |
| 538 | + lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})", |
| 539 | + ) |
| 540 | + |
| 541 | + # TODO: raise exception instead of converting value |
| 542 | + # msg = "size_average and reduce args are deprecated, please use reduction argument." |
| 543 | + # Convert these options for consistency with the eager mode |
| 544 | + if size_average is not None or reduce is not None: |
| 545 | + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) |
| 546 | + |
| 547 | + # The expected behavior when the target and input have zero elements: |
| 548 | + # reduction = 'none' --- tensor([]) |
| 549 | + # reduction = 'sum' --- tensor(0.) |
| 550 | + # reduction = 'mean' --- tensor(nan) |
| 551 | + # Mean reduction on empty tensors produces NaN. See the discussion in |
| 552 | + # https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162 |
| 553 | + if input.numel() == 0 and target.numel() == 0: |
| 554 | + if reduction == "none": |
| 555 | + return torch.zeros_like(target) |
| 556 | + elif reduction == "sum": |
| 557 | + return torch.empty_like(target) |
| 558 | + else: |
| 559 | + return torch.full_like(target, float("nan")) |
| 560 | + |
| 561 | + # The _nll_loss_nd helper function handles the most common cases. |
| 562 | + # ndim == 1 (Single Example) |
| 563 | + # => Batch Size: 1, Input: (C), Target: () |
| 564 | + # ndim == 2 (k = 1) |
| 565 | + # => Batch Size: N, Input: (N, C), Target: (N) |
| 566 | + # ndim == 3 (k > 1) |
| 567 | + # => Batch Size: N, Input: (N, C, K), Target: (N, K) |
| 568 | + if input.ndim <= 3: |
| 569 | + return _nll_loss_nd(input, target, weight, reduction, ignore_index) |
| 570 | + |
| 571 | + # For ndim > 3, we reshape the input and target to 3-D case. |
| 572 | + # Input (N batch-size, C classes, k-dimensions) |
| 573 | + # Target (N batch-size, k-dimensions) |
| 574 | + utils.check( |
| 575 | + input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:], |
| 576 | + lambda: f"Expected target shape {out_size} but got {target.shape}", |
| 577 | + ) |
| 578 | + |
| 579 | + batch_size = input.shape[0] |
| 580 | + num_classes = input.shape[1] |
| 581 | + out_size = [batch_size] + list(target.shape[1:]) |
| 582 | + |
| 583 | + input = torch.reshape(input, [batch_size, num_classes, -1]) |
| 584 | + target = torch.reshape(target, [batch_size, -1]) |
| 585 | + if reduction != "none": |
| 586 | + return _nll_loss_nd(input, target, weight, reduction, ignore_index) |
| 587 | + else: |
| 588 | + result = _nll_loss_nd(input, target, weight, reduction, ignore_index) |
| 589 | + # reshape flattened inner-dim to original k-dimensions |
| 590 | + return torch.reshape(result, out_size) |
| 591 | + |
| 592 | + |
438 | 593 | # TODO: This ref supports int reduction and out kwarg to be compatible with ATen:
|
439 | 594 | # https://github.com/pytorch/pytorch/issues/83931
|
440 | 595 | # TODO: Could be rewritten to support complex:
|
|
0 commit comments