diff --git a/trl/experimental/utils.py b/trl/experimental/utils.py index fc53a34577..a1d1428e0d 100644 --- a/trl/experimental/utils.py +++ b/trl/experimental/utils.py @@ -374,10 +374,10 @@ def get_reward( attention_mask=attention_mask, position_ids=position_ids, return_dict=True, - output_hidden_states=True, + output_hidden_states=False, use_cache=False, # otherwise mistral-based RM would error out ) - reward_logits = model.score(output.hidden_states[-1]) + reward_logits = model.score(output.last_hidden_state) sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 return (