-
Notifications
You must be signed in to change notification settings - Fork 45
How to train using hard negative only? #71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I think you'd need to modify gritlm/gritlm/training/model.py Line 41 in 724df95
|
Thanks for the instructions! I’ve made the proposed code changes, which involve avoiding data gathering from other GPUs and reshaping p_reps and q_reps to ensure that the dot products are computed between each query and its corresponding passages. Let me know your thoughts on these changes. Additionally, I was wondering about the function compute_similarity. Under what circumstances would len(p_reps.size()) not equal 2? class DistributedContrastiveLoss:
def __init__(self, temperature: float, negatives_cross_device: bool, hard_negatives_only: bool):
self.cross_entropy = torch.nn.CrossEntropyLoss(reduction='mean')
self.temperature = temperature
self.hard_negatives_only = hard_negatives_only
# Do not gather other GPU's batches if use hard negatives only
self.negatives_cross_device = False if self.hard_negatives_only else negatives_cross_device
if self.negatives_cross_device:
if not dist.is_initialized():
raise ValueError('Cannot do negatives_cross_device without distributed training')
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
def __call__(self, q_reps, p_reps):
"""
q_reps: [batch_size, hidden_size] # Query embeddings
p_reps: [batch_size * (num_negatives + 1), hidden_size] # Passage embeddings. num_negatives + 1 = train_group_size
"""
if self.negatives_cross_device:
# This gathers both negatives and positives.
# It could likely be optimized by only gathering negatives.
q_reps = self._dist_gather_tensor(q_reps)
p_reps = self._dist_gather_tensor(p_reps)
if self.hard_negatives_only:
# Reshape `p_reps` to group passages for each query
p_reps = p_reps.view(q_reps.size(0), (p_reps.size(0) // q_reps.size(0)), -1) # [batch_size, num_negatives + 1, hidden_size]
scores = self.compute_similarity(q_reps, p_reps) / self.temperature
scores = scores.view(q_reps.size(0), -1)
if self.hard_negatives_only:
# Target is always 0 since the first passage in each group is positive
target = torch.zeros(scores.size(0), dtype=torch.long, device=scores.device)
else:
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
target *= (p_reps.size(0) // q_reps.size(0))
return self.cross_entropy(scores, target)
def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
if t is None: return None
t = t.contiguous()
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
# All tensors have the same shape, as pooling already applied to them
dist.all_gather(all_tensors, t)
all_tensors[self.rank] = t
all_tensors = torch.cat(all_tensors, dim=0)
return all_tensors
def compute_similarity(self, q_reps, p_reps):
if self.hard_negatives_only:
# Query embedding: [batch_size, hidden_size] -unsqueeze-> [batch_size, 1, hidden_size]
# Passage embedding: [batch_size, num_negatives + 1, hidden_size] -transpose-> [batch_size, hidden_size, num_negatives + 1]
# Resulting shape: [batch_size, 1, num_negatives + 1] -squeeze-> [batch_size, num_negatives + 1]
return torch.matmul(q_reps.unsqueeze(1), p_reps.transpose(-2, -1)).squeeze(1)
if len(p_reps.size()) == 2: return torch.matmul(q_reps, p_reps.transpose(0, 1))
return torch.matmul(q_reps, p_reps.transpose(-2, -1)) |
Hi @Muennighoff, Just checking in on my previous message regarding the code changes and the compute_similarity function. Would appreciate your thoughts when you get a chance. Let me know if you need any clarifications! |
Maybe try training and check if it works? |
I understand that the GritLM fine-tuning uses both in-batch negative and hard negatives for contrastive learning. We can use in-batch negatives only by setting train group size to 1.
However, in my case, I can only use hard negatives not in-batch negatives. Is there a way to disable in-batch negatives? If not, could you kindly advise which part of code I should modify to implement the changes myself please?
The text was updated successfully, but these errors were encountered: