Skip to content

Commit f4367b7

Browse files
committed
refactor: call model forward directly instead of base model backbone
1 parent 11f8cf9 commit f4367b7

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

trl/experimental/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def get_reward(
372372
input_ids=input_ids,
373373
attention_mask=attention_mask,
374374
position_ids=position_ids
375-
).last_hidden_state
375+
).logits
376376
sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length
377377
# https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454
378378
return (

0 commit comments

Comments
 (0)