Skip to content

Commit 296b870

Browse files
committed
Update commentary around model document encoding
1 parent fc0bdea commit 296b870

File tree

6 files changed

+62
-85
lines changed

6 files changed

+62
-85
lines changed

embeddings/cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,10 @@ def create_embeddings(
230230
)
231231

232232
# create an iterator of EmbeddingInputs applying all requested strategies
233-
input_records = create_embedding_inputs(timdex_records, list(strategy))
233+
embedding_inputs = create_embedding_inputs(timdex_records, list(strategy))
234234

235235
# create embeddings via the embedding model
236-
embeddings = model.create_embeddings(input_records)
236+
embeddings = model.create_embeddings(embedding_inputs)
237237

238238
# if requested, write embeddings to a local JSONLines file
239239
if output_jsonl:

embeddings/models/base.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,20 @@ def load(self) -> None:
5151
"""Load model from self.model_path."""
5252

5353
@abstractmethod
54-
def create_embedding(self, input_record: EmbeddingInput) -> Embedding:
54+
def create_embedding(self, embedding_input: EmbeddingInput) -> Embedding:
5555
"""Create an Embedding for an EmbeddingInput.
5656
5757
Args:
58-
input_record: EmbeddingInput instance
58+
embedding_input: EmbeddingInput instance
5959
"""
6060

6161
def create_embeddings(
62-
self, input_records: Iterator[EmbeddingInput]
62+
self, embedding_inputs: Iterator[EmbeddingInput]
6363
) -> Iterator[Embedding]:
64-
"""Yield Embeddings for an iterator of InputRecords.
64+
"""Yield Embeddings for a batch of EmbeddingInputs.
6565
6666
Args:
67-
input_records: iterator of InputRecords
67+
embedding_inputs: iterator of EmbeddingInputs
6868
"""
69-
for input_text in input_records:
70-
yield self.create_embedding(input_text)
69+
for embedding_input in embedding_inputs:
70+
yield self.create_embedding(embedding_input)

embeddings/models/os_neural_sparse_doc_v3_gte.py

Lines changed: 40 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -173,32 +173,24 @@ def load(self) -> None:
173173
f"{time.perf_counter() - start_time:.2f}s"
174174
)
175175

176-
def create_embedding(self, input_record: EmbeddingInput) -> Embedding:
177-
"""Create sparse embeddings for the input text (document encoding).
178-
179-
This method generates sparse document embeddings.
180-
181-
Process follows the model card exactly:
182-
1. Tokenize the document
183-
2. Pass through the masked language model to get logits
184-
3. Convert logits to sparse vector
185-
6. Return both raw sparse vector and decoded token-weight pairs
176+
def create_embedding(self, embedding_input: EmbeddingInput) -> Embedding:
177+
"""Create sparse vector and decoded token weight embeddings for an input text.
186178
187179
Args:
188-
input_record: The input containing text to embed
180+
embedding_input: EmbeddingInput object with a .text attribute
189181
"""
190182
# generate the sparse embeddings
191-
sparse_vector, decoded_tokens = self._encode_documents([input_record.text])[0]
183+
sparse_vector, decoded_tokens = self._encode_documents([embedding_input.text])[0]
192184

193185
# coerce sparse vector tensor into list[float]
194186
sparse_vector_list = sparse_vector.cpu().numpy().tolist()
195187

196188
return Embedding(
197-
timdex_record_id=input_record.timdex_record_id,
198-
run_id=input_record.run_id,
199-
run_record_offset=input_record.run_record_offset,
189+
timdex_record_id=embedding_input.timdex_record_id,
190+
run_id=embedding_input.run_id,
191+
run_record_offset=embedding_input.run_record_offset,
200192
model_uri=self.model_uri,
201-
embedding_strategy=input_record.embedding_strategy,
193+
embedding_strategy=embedding_input.embedding_strategy,
202194
embedding_vector=sparse_vector_list,
203195
embedding_token_weights=decoded_tokens,
204196
)
@@ -212,53 +204,32 @@ def _encode_documents(
212204
This follows the pattern outlined on the HuggingFace model card for document
213205
encoding.
214206
215-
This method will accommodate a list of text inputs, and return a list of
216-
embeddings, but the calling base method create_embeddings() is a singular input +
207+
This method will accommodate MULTIPLE text inputs, and return a list of
208+
embeddings, but the calling context of create_embedding() is a SINGULAR input +
217209
output. This method keeps the ability to handle multiple inputs + outputs, in the
218-
event we want something like a create_multiple_embeddings() method in the future.
219-
220-
The following is a rough approximation of receiving logits back from the model
221-
and converting this to a sparse vector which can then be decoded to token:weights:
210+
event we want something like a create_multiple_embeddings() method in the future,
211+
but only returns a single result.
222212
223-
----------------------------------------------------------------------------------
224-
Imagine your vocabulary is just 5 words: ["cat", "dog", "bird", "fish", "tree"]
225-
Vocabulary indices: [ 0, 1, 2, 3, 4]
213+
At a very high level, the following is performed:
226214
227-
1. MODEL RETURNS LOGITS
228-
Let's say you input the text: "cat and dog"
229-
After tokenization, you have 3 tokens at 3 sequence positions
230-
The model outputs logits - a score for EVERY vocab word at EVERY position:
215+
1. We tokenize the input text into "features" using the model's tokenizer.
231216
232-
logits = [
233-
# Position 0 (word "cat"): scores for each vocab word at this position
234-
[9.2, 1.1, 0.3, 0.5, 0.2], # "cat" gets high score (9.2)
235-
236-
# Position 1 (word "and" - not in our toy vocab, but tokenized somehow):
237-
[2.1, 1.8, 0.4, 0.3, 0.9], # moderate scores everywhere
238-
239-
# Position 2 (word "dog"):
240-
[0.8, 8.7, 0.2, 0.4, 0.1], # "dog" gets high score (8.7)
241-
]
242-
Shape: (3 positions, 5 vocab words)
217+
2. The features are fed to the model returning model output logits. These logits
218+
are "dense" in the sense there are few zeros, but they are not "dense vectors"
219+
(embeddings) in the sense that they meaningfully represent the input document in
220+
geometric space; two logit tensors cannot be compared with something like cosine
221+
similarity.
243222
223+
3. The logits are then converted into a sparse vector, which is a numeric
224+
array of floats with the same number of values as the model's vocabulary. Each
225+
value's position in the sparse array corresponds to the token id in the
226+
vocabulary, and the value itself is the "weight" of this token in the input text.
244227
245-
2. PRODUCE SPARSE VECTORS FROM LOGITS
246-
We collapse the sequence positions by taking the MAX score for each vocab word:
247-
248-
sparse_vector = [
249-
max(9.2, 2.1, 0.8), # "cat": take max across all 3 positions = 9.2
250-
max(1.1, 1.8, 8.7), # "dog": take max = 8.7
251-
max(0.3, 0.4, 0.2), # "bird": take max = 0.4
252-
max(0.5, 0.3, 0.4), # "fish": take max = 0.5
253-
max(0.2, 0.9, 0.1), # "tree": take max = 0.9
254-
]
255-
256-
Apply transformations (ReLU, double-log) to make it sparser:
257-
sparse_vector = [5.1, 4.8, 0.0, 0.0, 0.0] # smaller values become 0
258-
259-
Final result:
260-
{"cat": 5.1, "dog": 4.8} # Only the relevant words have non-zero weights
261-
----------------------------------------------------------------------------------
228+
4. Lastly, we convert this sparse vector into a {token:weight} dictionary of the
229+
actual token strings and their numerical weight. This dictionary may contain
230+
tokens not present in the original text, but will be considerably shorter than
231+
the model vocabulary length given all zero and low scoring tokens are dropped.
232+
This is the final form that we will ultimately index into OpenSearch.
262233
263234
Args:
264235
texts: list of strings to create embeddings for
@@ -278,14 +249,14 @@ def _encode_documents(
278249
# move to CPU or GPU device, depending on what's available
279250
features = {k: v.to(self._device) for k, v in features.items()}
280251

281-
# get model logits output
252+
# pass features to the model and receive model output logits as a tensor
282253
with torch.no_grad():
283254
output = self._model(**features)[0]
284255

285-
# generate sparse vectors from model logits
256+
# generate sparse vectors from model logits tensor
286257
sparse_vectors = self._get_sparse_vectors(features, output)
287258

288-
# decode to token-weight dictionaries
259+
# decode sparse vectors to token-weight dictionaries
289260
decoded = self._decode_sparse_vectors(sparse_vectors)
290261

291262
# return list of tuple(vector, decoded token weights) embedding results
@@ -304,20 +275,26 @@ def _get_sparse_vectors(
304275
2. log(1 + log(1 + relu())) transformation
305276
3. Zero out special tokens
306277
278+
The end resul is a sparse vector with a length of the model vocabulary, with each
279+
position representing a token in the model vocabulary and each value representing
280+
that token's weight relative to the input text.
281+
307282
Args:
308283
features: Tokenizer output with attention_mask
309284
output: Model logits of shape (batch_size, seq_len, vocab_size)
310285
311286
Returns:
312287
Sparse vectors of shape (batch_size, vocab_size)
313288
"""
314-
# max pooling with attention mask
289+
# collapse sequence positions: take max logit for each vocab token across all
290+
# positions (also masks out padding tokens)
315291
values, _ = torch.max(output * features["attention_mask"].unsqueeze(-1), dim=1)
316292

317-
# apply the v3 model activation
293+
# compress values to create sparsity: ReLU removes negatives,
294+
# double-log shrinks large values
318295
values = torch.log(1 + torch.log(1 + torch.relu(values)))
319296

320-
# zero out special tokens
297+
# remove special tokens like [CLS], [SEP], [PAD]
321298
values[:, self._special_token_ids] = 0
322299

323300
return values

tests/conftest.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ def download(self) -> Path:
4545
def load(self) -> None:
4646
logger.info("Model loaded successfully, 1.5s")
4747

48-
def create_embedding(self, input_record: EmbeddingInput) -> Embedding:
48+
def create_embedding(self, embedding_input: EmbeddingInput) -> Embedding:
4949
return Embedding(
50-
timdex_record_id=input_record.timdex_record_id,
51-
run_id=input_record.run_id,
52-
run_record_offset=input_record.run_record_offset,
53-
embedding_strategy=input_record.embedding_strategy,
50+
timdex_record_id=embedding_input.timdex_record_id,
51+
run_id=embedding_input.run_id,
52+
run_record_offset=embedding_input.run_record_offset,
53+
embedding_strategy=embedding_input.embedding_strategy,
5454
model_uri=self.model_uri,
5555
embedding_vector=[0.1, 0.2, 0.3],
5656
embedding_token_weights={"coffee": 0.9, "seattle": 0.5},

tests/test_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ def test_mock_model_load(caplog, mock_model):
3535

3636

3737
def test_mock_model_create_embedding(mock_model):
38-
input_record = EmbeddingInput(
38+
embedding_input = EmbeddingInput(
3939
timdex_record_id="test-id",
4040
run_id="test-run",
4141
run_record_offset=42,
4242
embedding_strategy="full_record",
4343
text="test text",
4444
)
45-
embedding = mock_model.create_embedding(input_record)
45+
embedding = mock_model.create_embedding(embedding_input)
4646

4747
assert embedding.timdex_record_id == "test-id"
4848
assert embedding.run_id == "test-run"
@@ -87,7 +87,7 @@ class InvalidModel(BaseEmbeddingModel):
8787

8888

8989
def test_base_model_create_embeddings_calls_create_embedding(mock_model):
90-
input_records = [
90+
embedding_inputs = [
9191
EmbeddingInput(
9292
timdex_record_id="id-1",
9393
run_id="run-1",
@@ -105,7 +105,7 @@ def test_base_model_create_embeddings_calls_create_embedding(mock_model):
105105
]
106106

107107
# create_embeddings should iterate and call create_embedding
108-
embeddings = list(mock_model.create_embeddings(iter(input_records)))
108+
embeddings = list(mock_model.create_embeddings(iter(embedding_inputs)))
109109

110110
assert len(embeddings) == 2 # two input records
111111
assert embeddings[0].timdex_record_id == "id-1"

tests/test_os_neural_sparse_doc_v3_gte.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def test_load_sets_up_special_token_ids(
293293
def test_create_embedding_raises_error_if_model_not_loaded(tmp_path):
294294
"""Test create_embedding raises RuntimeError if model not loaded."""
295295
model = OSNeuralSparseDocV3GTE(tmp_path / "model")
296-
input_record = EmbeddingInput(
296+
embedding_input = EmbeddingInput(
297297
timdex_record_id="test:123",
298298
run_id="run-456",
299299
run_record_offset=0,
@@ -302,7 +302,7 @@ def test_create_embedding_raises_error_if_model_not_loaded(tmp_path):
302302
)
303303

304304
with pytest.raises(RuntimeError, match="Model not loaded"):
305-
model.create_embedding(input_record)
305+
model.create_embedding(embedding_input)
306306

307307

308308
def test_create_embedding_returns_embedding_object(tmp_path, monkeypatch):
@@ -317,15 +317,15 @@ def mock_encode_documents(texts):
317317

318318
monkeypatch.setattr(model, "_encode_documents", mock_encode_documents)
319319

320-
input_record = EmbeddingInput(
320+
embedding_input = EmbeddingInput(
321321
timdex_record_id="test:123",
322322
run_id="run-456",
323323
run_record_offset=42,
324324
embedding_strategy="title_only",
325325
text="test document",
326326
)
327327

328-
embedding = model.create_embedding(input_record)
328+
embedding = model.create_embedding(embedding_input)
329329

330330
assert embedding.timdex_record_id == "test:123"
331331
assert embedding.run_id == "run-456"

0 commit comments

Comments
 (0)