-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
8000
when i change the example code loss from DiceLoss to CrossEntropyLoss:
loss = smp.utils.losses.DiceLoss()
loss = smp.utils.losses.CrossEntropyLoss()
i got this error:
Traceback (most recent call last):
File "train.py", line 131, in
train_logs = train_epoch.run(train_loader)
File "/home/qianjinhao/anaconda3cn10/envs/open-mmlab/lib/python3.7/site-packages/segmentation_models_pytorch/utils/train.py", line 47, in run
loss, y_pred = self.batch_update(x, y)
File "/home/qianjinhao/anaconda3cn10/envs/open-mmlab/lib/python3.7/site-packages/segmentation_models_pytorch/utils/train.py", line 88, in batch_update
loss = self.los
57B0
s(prediction, y)
File "/home/qianjinhao/anaconda3cn10/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in call
result = self.forward(*input, **kwargs)
File "/home/qianjinhao/anaconda3cn10/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/modules/loss.py", line 916, in forward
ignore_index=self.ignore_index, reduction=self.reduction)
File "/home/qianjinhao/anaconda3cn10/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/functional.py", line 2021, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File "/home/qianjinhao/anaconda3cn10/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/functional.py", line 1840, in nll_loss
ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss2d_forward