Raw logits or softmax probability outputs in nn.CrossEntropyLoss()? #252
-
In section 4, we have code for multiclass classification. I was experimenting with the code and tried to pass both the raw logits as well as probabilities (after passing raw logits through I was wondering, which one is recommended, raw logits or softmax probs to crossentropy? CASE 1 - training loop where I am passing logits to crossentropy
CASE 2 - training loop where I am passing prediction probabilities (after passing logits to softmax()) to crossentropy
In both the cases, my model reaches 97% test accuracy |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
This is an interesting observation! If you changed the loss_fun parameters like below: The reason is that y_pred doesn't have any grad_fn meaning that loss cannot do the backward due to lack of grad_fn! Could you please restart the kernel of the notebook and re-run it again? Also, if possible please provide more information and pieces of code. I am curious about this topic that you raised. |
Beta Was this translation helpful? Give feedback.
-
As I mentioned, it is indeed an interesting observation. So, thank you for sharing this. But back to your question. According to PyTorch documentation: But it is beyond that. Have you noticed that the input is in the shape of Another interesting aspect is that unlike Now let me answer your possible doubt about the fact that you achieve more than 90 per cent of accuracy no matter whether |
Beta Was this translation helpful? Give feedback.
As I mentioned, it is indeed an interesting observation. So, thank you for sharing this.
But back to your question. According to PyTorch documentation:
The input is expected to contain the unnormalized logits for each class (which do not need to be positive or sum to 1, in general).
Therefore it is recommended to set the raw logit instead of probabilities.
But it is beyond that. Have you noticed that the input is in the shape of
logits
andy_train
orprobs
and 'y_train'? If you have a closer look, you will notice that thelogits
and 'probs' are in the shape of [number_of_sample, number_of_classes], buty_train
is in [number_of_sample]! So, how does cross-entropy loss calculate the loss?T…