|
33 | 33 | "hinge_embedding_loss",
|
34 | 34 | "margin_ranking_loss",
|
35 | 35 | "mish",
|
| 36 | + "nll_loss", |
36 | 37 | "relu",
|
37 | 38 | "selu",
|
38 | 39 | "softplus",
|
@@ -349,6 +350,173 @@ def hinge_embedding_loss(
|
349 | 350 | return _apply_loss_reduction(loss, reduction)
|
350 | 351 |
|
351 | 352 |
|
| 353 | +def _get_string_reduction_arg( |
| 354 | + size_average: Optional[bool], reduce: Optional[bool] |
| 355 | +) -> str: |
| 356 | + if size_average is None: |
| 357 | + size_average = True |
| 358 | + if reduce is None: |
| 359 | + reduce = True |
| 360 | + if size_average and reduce: |
| 361 | + ret = "mean" |
| 362 | + elif reduce: |
| 363 | + ret = "sum" |
| 364 | + else: |
| 365 | + ret = "none" |
| 366 | + return ret |
| 367 | + |
| 368 | + |
| 369 | +def _nll_loss_2d( |
| 370 | + input: TensorLikeType, |
| 371 | + target: TensorLikeType, |
| 372 | + weight: Optional[TensorLikeType], |
| 373 | + reduction: str, |
| 374 | + ignore_index: int, |
| 375 | +) -> TensorLikeType: |
| 376 | + if input.ndim != 4: |
| 377 | + msg = "Expected 4 dimensions for input but got {}." |
| 378 | + raise ValueError(msg.format(input.ndim)) |
| 379 | + |
| 380 | + if target.ndim != 3: |
| 381 | + msg = "Expected 3 dimensions for input but got {}." |
| 382 | + raise ValueError(msg.format(input.ndim)) |
| 383 | + |
| 384 | + _check_reduction_value(reduction) |
| 385 | + |
| 386 | + current_target = torch.reshape(target, [-1]) |
| 387 | + ignore_classes_mask = torch.eq(current_target, ignore_index) |
| 388 | + if weight is None: |
| 389 | + current_weight = torch.where( |
| 390 | + ignore_classes_mask, |
| 391 | + torch.scalar_tensor(0, dtype=input.dtype, device=input.device), |
| 392 | + torch.scalar_tensor(1, dtype=input.dtype, device=input.device), |
| 393 | + ) |
| 394 | + else: |
| 395 | + ignore_class_weight = torch.scalar_tensor( |
| 396 | + 0, dtype=input.dtype, device=input.device |
| 397 | + ).expand_as(current_target) |
| 398 | + current_weight = torch.where( |
| 399 | + ignore_classes_mask, ignore_class_weight, weight[current_target] |
| 400 | + ) |
| 401 | + |
| 402 | + batch_size = input.shape[0] |
| 403 | + height = input.shape[2] |
| 404 | + width = input.shape[3] |
| 405 | + extent = height * width |
| 406 | + numel = batch_size * extent |
| 407 | + |
| 408 | + bdx = torch.arange(numel) // extent |
| 409 | + hdx = (torch.arange(numel) - (bdx * extent)) // width |
| 410 | + wdx = torch.arange(numel) % width |
| 411 | + |
| 412 | + loss = -input[bdx, current_target, hdx, wdx] * current_weight |
| 413 | + loss = torch.reshape(loss, target.shape) |
| 414 | + |
| 415 | + if reduction == "none": |
| 416 | + return loss |
| 417 | + elif reduction == "sum": |
| 418 | + return torch.sum(loss) |
| 419 | + else: |
| 420 | + return torch.sum(loss) / torch.sum(current_weight) |
| 421 | + |
| 422 | + |
| 423 | +def _nll_loss_1d( |
| 424 | + input: TensorLikeType, |
| 425 | + target: TensorLikeType, |
| 426 | + weight: Optional[TensorLikeType], |
| 427 | + reduction: str, |
| 428 | + ignore_index: int, |
| 429 | +) -> TensorLikeType: |
| 430 | + if input.ndim < 1: |
| 431 | + msg = "Expected 1 or more dimension for input but got {}." |
| 432 | + raise ValueError(msg.format(input.ndim)) |
| 433 | + |
| 434 | + if input.ndim != 1 and input.shape[0] != target.shape[0]: |
| 435 | + msg = "Expected input batch size ({}) to match target batch size ({})." |
| 436 | + raise ValueError(msg.format(input.shape[0], target.shape[0])) |
| 437 | + |
| 438 | + _check_reduction_value(reduction) |
| 439 | + |
| 440 | + if target.ndim <= 1: |
| 441 | + current_target = <
B41A
span class=pl-s1>target |
| 442 | + else: |
| 443 | + current_target = target[:, 0] |
| 444 | + |
| 445 | + ignore_classes_mask = torch.eq(current_target, ignore_index) |
| 446 | + if weight is None: |
| 447 | + current_weight = torch.where( |
| 448 | + ignore_classes_mask, |
| 449 | + torch.scalar_tensor(0, dtype=input.dtype, device=input.device), |
| 450 | + torch.scalar_tensor(1, dtype=input.dtype, device=input.device), |
| 451 | + ) |
| 452 | + else: |
| 453 | + ignore_class_weight = torch.scalar_tensor( |
| 454 | + 0, dtype=input.dtype, device=input.device |
| 455 | + ).expand_as(current_target) |
| 456 | + current_weight = torch.where( |
| 457 | + ignore_classes_mask, ignore_class_weight, weight[current_target] |
| 458 | + ) |
| 459 | + |
| 460 | + batch_size = input.shape[0] |
| 461 | + if input.ndim == 1: |
| 462 | + loss = -input[current_target] * current_weight |
| 463 | + else: |
| 464 | + loss = -input[torch.arange(batch_size), current_target] * current_weight |
| 465 | + |
| 466 | + if reduction == "none": |
| 467 | + return loss |
| 468 | + elif reduction == "sum": |
| 469 | + return torch.sum(loss) |
| 470 | + else: |
| 471 | + return torch.sum(loss) / torch.sum(current_weight) |
| 472 | + |
| 473 | + |
| 474 | +def nll_loss( |
| 475 | + input: TensorLikeType, |
| 476 | + target: TensorLikeType, |
| 477 | + weight: Optional[TensorLikeType] = None, |
| 478 | + size_average: Optional[bool] = None, |
| 479 | + ignore_index: int = -100, |
| 480 | + reduce: Optional[bool] = None, |
| 481 | + reduction: str = "mean", |
| 482 | +) -> TensorLikeType: |
| 483 | + if size_average is not None or reduce is not None: |
| 484 | + # TODO raise exception instead of converting value |
| 485 | + # msg = "size_average and reduce args are deprecated, please use reduction argument." |
| 486 | + reduction = _get_string_reduction_arg(size_average, reduce) |
| 487 | + |
| 488 | + if input.ndim == 1 or input.ndim == 2: |
| 489 | + return _nll_loss_1d(input, target, weight, reduction, ignore_index) |
| 490 | + elif input.ndim == 4: |
| 491 | + return _nll_loss_2d(input, target, weight, reduction, ignore_index) |
| 492 | + else: |
| 493 | + # input ndim is == 3 or > 4 |
| 494 | + batch_size = input.shape[0] |
| 495 | + num_classes = input.shape[1] |
| 496 | + out_size = [batch_size] + list(input.shape[2:]) |
| 497 | + |
| 498 | + if target.shape[1:] != input.shape[2:]: |
| 499 | + msg = "Expected target size {} but got {}" |
| 500 | + raise ValueError(msg.format(out_size, target.shape)) |
| 501 | + |
| 502 | + # support empty batches, see #15870 |
| 503 | + if input.numel() > 0: |
| 504 | + input = torch.reshape(input, [batch_size, num_classes, 1, -1]) |
| 505 | + else: |
| 506 | + input = torch.reshape(input, [batch_size, num_classes, 0, 0]) |
| 507 | + |
| 508 | + if target.numel() > 0: |
| 509 | + target = torch.reshape(target, [batch_size, 1, -1]) |
| 510 | + else: |
| 511 | + target = torch.reshape(target, [batch_size, 0, 0]) |
| 512 | + |
| 513 | + if reduction == "none": |
| 514 | + return _nll_loss_2d(input, target, weight, reduction, ignore_index) |
| 515 | + else: |
| 516 | + result = _nll_loss_2d(input, target, weight, reduction, ignore_index) |
| 517 | + return torch.reshape(result, out_size) |
| 518 | + |
| 519 | + |
352 | 520 | # tanhshrink does not use _make_elementwise_unary_reference because it does not support out
|
353 | 521 | @elementwise_unary_scalar_wrapper
|
354 | 522 | @elementwise_type_promotion_wrapper(
|
|
0 commit comments