We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 11f8cf9 commit f4367b7Copy full SHA for f4367b7
trl/experimental/utils.py
@@ -372,7 +372,7 @@ def get_reward(
372
input_ids=input_ids,
373
attention_mask=attention_mask,
374
position_ids=position_ids
375
- ).last_hidden_state
+ ).logits
376
sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length
377
# https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454
378
return (
0 commit comments