Skip to content

fix: strip trailing newlines from collator templates before tokenizing#2247

Draft
lefft wants to merge 1 commit intomainfrom
tim/fix-collator-bpe-boundary-mismatch
Draft

fix: strip trailing newlines from collator templates before tokenizing#2247
lefft wants to merge 1 commit intomainfrom
tim/fix-collator-bpe-boundary-mismatch

Conversation

@lefft
Copy link
Contributor

@lefft lefft commented Mar 6, 2026

NB: This PR was generated by Claude Code

Problem

A customer training job (job 630-7257, Qwen3-4B-Instruct) completed with loss=0.0 for the entire run — 62 steps across 2 epochs, zero gradient updates, saved checkpoint identical to the base model.

The DataCollatorForCompletionOnlyLM skipped every single example (973/973) because it could not find the instruction_template token IDs in the tokenized training data.

Root Cause

BPE token boundary mismatch between the collator's standalone-tokenized template and the same template as it appears in the full chat-template-tokenized sequence.

Step by step

  1. Collator init tokenizes "<|im_start|>user\n" standalone → [151644, 872, 198] (where 198 = \n)

  2. Chat template renders messages as <|im_start|>{role}\n{content}<|im_end|>. The customer's dataset had message content starting with \n (e.g. "\nPlease classify..."), so the rendered text became:

    <|im_start|>user\n\nPlease classify...
                     ^^
                     two consecutive newlines
    
  3. BPE merge: The Qwen3 tokenizer has a merge rule for double-newline: \n\n → token 271. So the full sequence contains [151644, 872, 271, ...]not [..., 198].

  4. Subsequence search fails: The collator searches for [151644, 872, 198] in the token IDs, never finds it, sets all labels to ignore_index=-100, and loss becomes 0.

Verified experimentally

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")

tokenizer.encode("user\n",   add_special_tokens=False)  # → [872, 198]
tokenizer.encode("user\n\n", add_special_tokens=False)  # → [872, 271]  ← mismatch!

tokenizer.encode("\n",   add_special_tokens=False)  # → [198]
tokenizer.encode("\n\n", add_special_tokens=False)  # → [271]  ← BPE merge

This is not Qwen-specific — most modern BPE tokenizers (Llama, GPT, etc.) have similar merge rules for common character sequences like \n\n. Any dataset with leading \n in message content will trigger this.

Fix

Strip trailing \n from template strings before tokenizing them in the collator's __init__. This produces shorter, unambiguous token sequences (e.g. [151644, 872] for <|im_start|>user) that are found regardless of what follows in the content.

Before: tokenizer.encode("<|im_start|>user\n")[151644, 872, 198] — fragile, breaks if content starts with \n

After: tokenizer.encode("<|im_start|>user")[151644, 872] — robust, <|im_start|> is a special token (always 151644) and role names are unambiguous after it

No false positives

<|im_start|> (token 151644) only ever appears at the start of a chat turn, and is always followed by a role name token. The shorter two-token sequence [151644, 872] (<|im_start|>user) cannot match anywhere else in a well-formed chat-template-tokenized sequence. Verified against the actual problematic dataset — finds exactly the correct positions with no false matches.

Existing tests pass

All 14 tests in test_text_completions_collator_with_padding.py and test_collators.py continue to pass.

Changed file

src/oumi/core/collators/trl_data_collator_for_completion_only_lm.py — 2 lines changed (.rstrip("\n") added to both instruction_template and response_template before tokenizer.encode)

Fixes a silent training failure where loss=0 for the entire run when
dataset message content starts with \n.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@gitar-bot
Copy link

gitar-bot bot commented Mar 6, 2026

Important

Upgrade your plan to unlock code review, CI analysis, custom rules, and more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant