Skip to content

fix: per-sequence token counts in batch embedding averaging#3792

Open
alvinttang wants to merge 1 commit intolm-sys:mainfrom
alvinttang:fix/batch-embedding-averaging
Open

fix: per-sequence token counts in batch embedding averaging#3792
alvinttang wants to merge 1 commit intolm-sys:mainfrom
alvinttang:fix/batch-embedding-averaging

Conversation

@alvinttang
Copy link

Summary

  • Bug: __process_embed_chunk computed token_num = torch.sum(attention_mask).item(), which sums tokens across the entire batch into a single scalar. When batch_size > 1, every sequence's mean-pooled embedding was divided by this aggregate count instead of its own token count, silently producing incorrect embeddings.
  • Fix: Replace with attention_mask.sum(dim=1, keepdim=True) to get per-sequence token counts (shape (batch, 1)), so each embedding is normalized correctly via broadcasting.
  • The ret["token_num"] metadata field (total tokens for billing/logging) is preserved as a scalar via .sum().item().

Details

The bug affects both code paths in get_embeddings:

  1. Truncate path (line 225): embedding / token_num now broadcasts correctly per sequence.
  2. Chunked path (lines 267-273): chunk_embeddings * token_num and the final / all_token_num both broadcast correctly with the (batch, 1) shape.

When batch_size == 1, the old and new behavior produce identical results, which is why this went undetected.

Test plan

  • Verify with batch_size=1: results unchanged
  • Verify with batch_size>1 and variable-length sequences: each embedding should now differ from the old (incorrect) output
  • Confirm ret["token_num"] remains a scalar integer

🤖 Generated with Claude Code

`__process_embed_chunk` previously computed `token_num` as a single
scalar summing tokens across the entire batch. When batch_size > 1,
each sequence's mean-pooled embedding was divided by the aggregate
token count instead of its own, silently producing incorrect results.

Replace the scalar sum with per-sequence counts via
`attention_mask.sum(dim=1, keepdim=True)` so each embedding is
normalized by its own token count.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant