Skip to content

How to train using hard negative only? #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
zcakzhuu opened this issue Jan 22, 2025 · 4 comments
Open

How to train using hard negative only? #71

zcakzhuu opened this issue Jan 22, 2025 · 4 comments

Comments

@zcakzhuu
Copy link

I understand that the GritLM fine-tuning uses both in-batch negative and hard negatives for contrastive learning. We can use in-batch negatives only by setting train group size to 1.

However, in my case, I can only use hard negatives not in-batch negatives. Is there a way to disable in-batch negatives? If not, could you kindly advise which part of code I should modify to implement the changes myself please?

@Muennighoff
Copy link
Collaborator

I think you'd need to modify

p_reps = self._dist_gather_tensor(p_reps)
to only gather hard negatives and only use them for the loss. Would be great if you could share your code changes!

@zcakzhuu
Copy link
Author

zcakzhuu commented Jan 25, 2025

I think you'd need to modify

gritlm/gritlm/training/model.py

Line 41 in 724df95

p_reps = self._dist_gather_tensor(p_reps)
to only gather hard negatives and only use them for the loss. Would be great if you could share your code changes!

Thanks for the instructions! I’ve made the proposed code changes, which involve avoiding data gathering from other GPUs and reshaping p_reps and q_reps to ensure that the dot products are computed between each query and its corresponding passages. Let me know your thoughts on these changes.

Additionally, I was wondering about the function compute_similarity. Under what circumstances would len(p_reps.size()) not equal 2?

class DistributedContrastiveLoss:
    def __init__(self, temperature: float, negatives_cross_device: bool, hard_negatives_only: bool):
        self.cross_entropy = torch.nn.CrossEntropyLoss(reduction='mean')
        self.temperature = temperature
        self.hard_negatives_only = hard_negatives_only
        # Do not gather other GPU's batches if use hard negatives only
        self.negatives_cross_device = False if self.hard_negatives_only else negatives_cross_device
        
        if self.negatives_cross_device:
            if not dist.is_initialized():
                raise ValueError('Cannot do negatives_cross_device without distributed training')
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()

    def __call__(self, q_reps, p_reps):
        """
        q_reps: [batch_size, hidden_size]  # Query embeddings
        p_reps: [batch_size * (num_negatives + 1), hidden_size] # Passage embeddings. num_negatives + 1 = train_group_size
        """
        
        if self.negatives_cross_device:
            # This gathers both negatives and positives.
            # It could likely be optimized by only gathering negatives.
            q_reps = self._dist_gather_tensor(q_reps)
            p_reps = self._dist_gather_tensor(p_reps)
        
        if self.hard_negatives_only:
            # Reshape `p_reps` to group passages for each query
            p_reps = p_reps.view(q_reps.size(0), (p_reps.size(0) // q_reps.size(0)), -1)  # [batch_size, num_negatives + 1, hidden_size]
            
            
        scores = self.compute_similarity(q_reps, p_reps) / self.temperature
        scores = scores.view(q_reps.size(0), -1)

        if self.hard_negatives_only:
            # Target is always 0 since the first passage in each group is positive
            target = torch.zeros(scores.size(0), dtype=torch.long, device=scores.device)
        else:
            target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
            target *= (p_reps.size(0) // q_reps.size(0))
        return self.cross_entropy(scores, target)

    def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
        if t is None: return None
        t = t.contiguous()

        all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
        # All tensors have the same shape, as pooling already applied to them
        dist.all_gather(all_tensors, t)

        all_tensors[self.rank] = t
        all_tensors = torch.cat(all_tensors, dim=0)

        return all_tensors

    def compute_similarity(self, q_reps, p_reps):
        if self.hard_negatives_only: 
            # Query embedding: [batch_size, hidden_size] -unsqueeze-> [batch_size, 1, hidden_size]
            # Passage embedding: [batch_size, num_negatives + 1, hidden_size] -transpose-> [batch_size, hidden_size, num_negatives + 1]
            # Resulting shape: [batch_size, 1, num_negatives + 1] -squeeze-> [batch_size, num_negatives + 1]
            return torch.matmul(q_reps.unsqueeze(1), p_reps.transpose(-2, -1)).squeeze(1) 
        if len(p_reps.size()) == 2: return torch.matmul(q_reps, p_reps.transpose(0, 1))
        return torch.matmul(q_reps, p_reps.transpose(-2, -1))

@zcakzhuu
Copy link
Author

zcakzhuu commented Feb 4, 2025

Hi @Muennighoff,

Just checking in on my previous message regarding the code changes and the compute_similarity function. Would appreciate your thoughts when you get a chance.

Let me know if you need any clarifications!

@Muennighoff
Copy link
Collaborator

Maybe try training and check if it works?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants