8000 [Docs] Add clarification for target types in CrossEntropyLoss doc (#1… · pytorch/pytorch@f97307f · GitHub
[go: up one dir, main page]

Skip to content

Commit f97307f

Browse files
spzalapytorchmergebot
authored andcommitted
[Docs] Add clarification for target types in CrossEntropyLoss doc (#145444)
CrossEntropyLoss function requires that target for class indices are provided as a long and class probabilities are provided as a float datatype. The CrossEntropyLoss function distinguish the two scenarios (indices and probabilities) by comparing the shapes. When input and target shapes are the same it’s a case for probabilities otherwise it will be used as a class index as already covered in the doc. The related code is here, https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossNLL.cpp#L624 I think the current documentation is great but seems like it can confuse users about types as reported in the issues so this PR adds a bit more clarification. Fixes #137188 Pull Request resolved: #145444 Approved by: https://github.com/mikaylagawarecki
1 parent 5ed5793 commit f97307f

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

torch/nn/modules/loss.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,8 +1245,10 @@ class probabilities only when a single class label per minibatch item is too res
12451245
- Input: Shape :math:`(C)`, :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
12461246
in the case of `K`-dimensional loss.
12471247
- Target: If containing class indices, shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with
1248-
:math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`.
1249-
If containing class probabilities, same shape as the input and each value should be between :math:`[0, 1]`.
1248+
:math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`. The
1249+
target data type is required to be long when using class indices. If containing class probabilities, the
1250+
target must be the same shape input, and each value should be between :math:`[0, 1]`. This means the target
1251+
data type is required to be float when using class probabilities.
12501252
- Output: If reduction is 'none', shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
12511253
in the case of K-dimensional loss, depending on the shape of the input. Otherwise, scalar.
12521254

0 commit comments

Comments
 (0)
0