@@ -173,32 +173,24 @@ def load(self) -> None:
173173 f"{ time .perf_counter () - start_time :.2f} s"
174174 )
175175
176- def create_embedding (self , input_record : EmbeddingInput ) -> Embedding :
177- """Create sparse embeddings for the input text (document encoding).
178-
179- This method generates sparse document embeddings.
180-
181- Process follows the model card exactly:
182- 1. Tokenize the document
183- 2. Pass through the masked language model to get logits
184- 3. Convert logits to sparse vector
185- 6. Return both raw sparse vector and decoded token-weight pairs
176+ def create_embedding (self , embedding_input : EmbeddingInput ) -> Embedding :
177+ """Create sparse vector and decoded token weight embeddings for an input text.
186178
187179 Args:
188- input_record: The input containing text to embed
180+ embedding_input: EmbeddingInput object with a . text attribute
189181 """
190182 # generate the sparse embeddings
191- sparse_vector , decoded_tokens = self ._encode_documents ([input_record .text ])[0 ]
183+ sparse_vector , decoded_tokens = self ._encode_documents ([embedding_input .text ])[0 ]
192184
193185 # coerce sparse vector tensor into list[float]
194186 sparse_vector_list = sparse_vector .cpu ().numpy ().tolist ()
195187
196188 return Embedding (
197- timdex_record_id = input_record .timdex_record_id ,
198- run_id = input_record .run_id ,
199- run_record_offset = input_record .run_record_offset ,
189+ timdex_record_id = embedding_input .timdex_record_id ,
190+ run_id = embedding_input .run_id ,
191+ run_record_offset = embedding_input .run_record_offset ,
200192 model_uri = self .model_uri ,
201- embedding_strategy = input_record .embedding_strategy ,
193+ embedding_strategy = embedding_input .embedding_strategy ,
202194 embedding_vector = sparse_vector_list ,
203195 embedding_token_weights = decoded_tokens ,
204196 )
@@ -212,53 +204,32 @@ def _encode_documents(
212204 This follows the pattern outlined on the HuggingFace model card for document
213205 encoding.
214206
215- This method will accommodate a list of text inputs, and return a list of
216- embeddings, but the calling base method create_embeddings () is a singular input +
207+ This method will accommodate MULTIPLE text inputs, and return a list of
208+ embeddings, but the calling context of create_embedding () is a SINGULAR input +
217209 output. This method keeps the ability to handle multiple inputs + outputs, in the
218- event we want something like a create_multiple_embeddings() method in the future.
219-
220- The following is a rough approximation of receiving logits back from the model
221- and converting this to a sparse vector which can then be decoded to token:weights:
210+ event we want something like a create_multiple_embeddings() method in the future,
211+ but only returns a single result.
222212
223- ----------------------------------------------------------------------------------
224- Imagine your vocabulary is just 5 words: ["cat", "dog", "bird", "fish", "tree"]
225- Vocabulary indices: [ 0, 1, 2, 3, 4]
213+ At a very high level, the following is performed:
226214
227- 1. MODEL RETURNS LOGITS
228- Let's say you input the text: "cat and dog"
229- After tokenization, you have 3 tokens at 3 sequence positions
230- The model outputs logits - a score for EVERY vocab word at EVERY position:
215+ 1. We tokenize the input text into "features" using the model's tokenizer.
231216
232- logits = [
233- # Position 0 (word "cat"): scores for each vocab word at this position
234- [9.2, 1.1, 0.3, 0.5, 0.2], # "cat" gets high score (9.2)
235-
236- # Position 1 (word "and" - not in our toy vocab, but tokenized somehow):
237- [2.1, 1.8, 0.4, 0.3, 0.9], # moderate scores everywhere
238-
239- # Position 2 (word "dog"):
240- [0.8, 8.7, 0.2, 0.4, 0.1], # "dog" gets high score (8.7)
241- ]
242- Shape: (3 positions, 5 vocab words)
217+ 2. The features are fed to the model returning model output logits. These logits
218+ are "dense" in the sense there are few zeros, but they are not "dense vectors"
219+ (embeddings) in the sense that they meaningfully represent the input document in
220+ geometric space; two logit tensors cannot be compared with something like cosine
221+ similarity.
243222
223+ 3. The logits are then converted into a sparse vector, which is a numeric
224+ array of floats with the same number of values as the model's vocabulary. Each
225+ value's position in the sparse array corresponds to the token id in the
226+ vocabulary, and the value itself is the "weight" of this token in the input text.
244227
245- 2. PRODUCE SPARSE VECTORS FROM LOGITS
246- We collapse the sequence positions by taking the MAX score for each vocab word:
247-
248- sparse_vector = [
249- max(9.2, 2.1, 0.8), # "cat": take max across all 3 positions = 9.2
250- max(1.1, 1.8, 8.7), # "dog": take max = 8.7
251- max(0.3, 0.4, 0.2), # "bird": take max = 0.4
252- max(0.5, 0.3, 0.4), # "fish": take max = 0.5
253- max(0.2, 0.9, 0.1), # "tree": take max = 0.9
254- ]
255-
256- Apply transformations (ReLU, double-log) to make it sparser:
257- sparse_vector = [5.1, 4.8, 0.0, 0.0, 0.0] # smaller values become 0
258-
259- Final result:
260- {"cat": 5.1, "dog": 4.8} # Only the relevant words have non-zero weights
261- ----------------------------------------------------------------------------------
228+ 4. Lastly, we convert this sparse vector into a {token:weight} dictionary of the
229+ actual token strings and their numerical weight. This dictionary may contain
230+ tokens not present in the original text, but will be considerably shorter than
231+ the model vocabulary length given all zero and low scoring tokens are dropped.
232+ This is the final form that we will ultimately index into OpenSearch.
262233
263234 Args:
264235 texts: list of strings to create embeddings for
@@ -278,14 +249,14 @@ def _encode_documents(
278249 # move to CPU or GPU device, depending on what's available
279250 features = {k : v .to (self ._device ) for k , v in features .items ()}
280251
281- # get model logits output
252+ # pass features to the model and receive model output logits as a tensor
282253 with torch .no_grad ():
283254 output = self ._model (** features )[0 ]
284255
285- # generate sparse vectors from model logits
256+ # generate sparse vectors from model logits tensor
286257 sparse_vectors = self ._get_sparse_vectors (features , output )
287258
288- # decode to token-weight dictionaries
259+ # decode sparse vectors to token-weight dictionaries
289260 decoded = self ._decode_sparse_vectors (sparse_vectors )
290261
291262 # return list of tuple(vector, decoded token weights) embedding results
@@ -304,20 +275,26 @@ def _get_sparse_vectors(
304275 2. log(1 + log(1 + relu())) transformation
305276 3. Zero out special tokens
306277
278+ The end resul is a sparse vector with a length of the model vocabulary, with each
279+ position representing a token in the model vocabulary and each value representing
280+ that token's weight relative to the input text.
281+
307282 Args:
308283 features: Tokenizer output with attention_mask
309284 output: Model logits of shape (batch_size, seq_len, vocab_size)
310285
311286 Returns:
312287 Sparse vectors of shape (batch_size, vocab_size)
313288 """
314- # max pooling with attention mask
289+ # collapse sequence positions: take max logit for each vocab token across all
290+ # positions (also masks out padding tokens)
315291 values , _ = torch .max (output * features ["attention_mask" ].unsqueeze (- 1 ), dim = 1 )
316292
317- # apply the v3 model activation
293+ # compress values to create sparsity: ReLU removes negatives,
294+ # double-log shrinks large values
318295 values = torch .log (1 + torch .log (1 + torch .relu (values )))
319296
320- # zero out special tokens
297+ # remove special tokens like [CLS], [SEP], [PAD]
321298 values [:, self ._special_token_ids ] = 0
322299
323300 return values
0 commit comments