Skip to content

Commit 7216c9e

Browse files
committed
Allow JSONLines files as input for CLI create-embeddings
Why these changes are being introduced: Initially, the CLI command create-embeddings only supported reading input records from the TIMDEX dataset via TDA. While this is likely the way we'll get input records, supporting a JSONLines file as input is helpful for testing. How this addresses that need: * Adds a new --input-jsonl argument that reads a JSONLines file and uses those rows as input for creating embeddings. * Args --dataset-location and --run-id are required when --input-jsonl is not set. Side effects of this change: * None Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/USE-137
1 parent fc0bdea commit 7216c9e

File tree

6 files changed

+227
-44
lines changed

6 files changed

+227
-44
lines changed

embeddings/cli.py

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import click
1111
import jsonlines
12+
import smart_open
1213
from timdex_dataset_api import TIMDEXDataset
1314

1415
from embeddings.config import configure_logger, configure_sentry
@@ -156,32 +157,41 @@ def test_model_load(ctx: click.Context) -> None:
156157
@click.pass_context
157158
@model_required
158159
@click.option(
159-
"-d",
160160
"--dataset-location",
161-
required=True,
161+
required=False,
162162
type=click.Path(),
163163
help="TIMDEX dataset location, e.g. 's3://timdex/dataset', to read records from.",
164164
)
165165
@click.option(
166166
"--run-id",
167-
required=True,
167+
required=False,
168168
type=str,
169169
help="TIMDEX ETL run id.",
170170
)
171171
@click.option(
172172
"--run-record-offset",
173-
required=True,
173+
required=False,
174174
type=int,
175175
default=0,
176176
help="TIMDEX ETL run record offset to start from, default = 0.",
177177
)
178178
@click.option(
179179
"--record-limit",
180-
required=True,
180+
required=False,
181181
type=int,
182182
default=None,
183183
help="Limit number of records after --run-record-offset, default = None (unlimited).",
184184
)
185+
@click.option(
186+
"--input-jsonl",
187+
required=False,
188+
type=str,
189+
default=None,
190+
help=(
191+
"Optional filepath to JSONLines file containing "
192+
"TIMDEX records to create embeddings from."
193+
),
194+
)
185195
@click.option(
186196
"--strategy",
187197
type=click.Choice(list(STRATEGY_REGISTRY.keys())),
@@ -205,50 +215,65 @@ def create_embeddings(
205215
run_id: str,
206216
run_record_offset: int,
207217
record_limit: int,
218+
input_jsonl: str,
208219
strategy: list[str],
209220
output_jsonl: str,
210221
) -> None:
211222
"""Create embeddings for TIMDEX records."""
212223
model: BaseEmbeddingModel = ctx.obj["model"]
213224
model.load()
214225

215-
# init TIMDEXDataset
216-
timdex_dataset = TIMDEXDataset(dataset_location)
217-
218-
# query TIMDEX dataset for an iterator of records
219-
timdex_records = timdex_dataset.read_dicts_iter(
220-
columns=[
221-
"timdex_record_id",
222-
"run_id",
223-
"run_record_offset",
224-
"transformed_record",
225-
],
226-
run_id=run_id,
227-
where=f"""run_record_offset >= {run_record_offset}""",
228-
limit=record_limit,
229-
action="index",
230-
)
226+
# read input records from TIMDEX dataset (default)
227+
# or a JSONLines file
228+
if input_jsonl:
229+
with (
230+
smart_open.open(input_jsonl, "r") as file_obj, # type: ignore[no-untyped-call]
231+
jsonlines.Reader(file_obj) as reader,
232+
):
233+
timdex_records = iter(list(reader))
234+
235+
else:
236+
if not dataset_location or not run_id:
237+
raise click.UsageError(
238+
"Both '--dataset-location' and '--run-id' are required arguments "
239+
"when reading input records from the TIMDEX dataset."
240+
)
241+
242+
# init TIMDEXDataset
243+
timdex_dataset = TIMDEXDataset(dataset_location)
244+
245+
# query TIMDEX dataset for an iterator of records
246+
timdex_records = timdex_dataset.read_dicts_iter(
247+
columns=[
248+
"timdex_record_id",
249+
"run_id",
250+
"run_record_offset",
251+
"transformed_record",
252+
],
253+
run_id=run_id,
254+
where=f"""run_record_offset >= {run_record_offset}""",
255+
limit=record_limit,
256+
action="index",
257+
)
231258

232259
# create an iterator of EmbeddingInputs applying all requested strategies
233260
input_records = create_embedding_inputs(timdex_records, list(strategy))
234261

235262
# create embeddings via the embedding model
236263
embeddings = model.create_embeddings(input_records)
237264

238-
# if requested, write embeddings to a local JSONLines file
265+
# write embeddings to TIMDEX dataset (default)
266+
# or to a JSONLines file
239267
if output_jsonl:
240-
with jsonlines.open(
241-
output_jsonl,
242-
mode="w",
243-
dumps=lambda obj: json.dumps(
244-
obj,
245-
default=str,
246-
),
247-
) as writer:
268+
with (
269+
smart_open.open(output_jsonl, "w") as s3_file, # type: ignore[no-untyped-call]
270+
jsonlines.Writer(
271+
s3_file,
272+
dumps=lambda obj: json.dumps(obj, default=str),
273+
) as writer,
274+
):
248275
for embedding in embeddings:
249276
writer.write(embedding.to_dict())
250-
251-
# else, default writing embeddings back to TIMDEX dataset
252277
else:
253278
# WIP NOTE: write via anticipated timdex_dataset.embeddings.write(...)
254279
# NOTE: will likely use an imported TIMDEXEmbedding class from TDA, which the

embeddings/strategies/processor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@ def create_embedding_inputs(
3333
for timdex_dataset_record in timdex_dataset_records:
3434

3535
# decode and parse the TIMDEX JSON record once for all requested strategies
36-
timdex_record = json.loads(timdex_dataset_record["transformed_record"].decode())
36+
transformed_record_raw = timdex_dataset_record["transformed_record"]
37+
timdex_record = json.loads(
38+
transformed_record_raw.decode()
39+
if isinstance(transformed_record_raw, bytes)
40+
else transformed_record_raw
41+
)
3742

3843
for transformer in transformers:
3944
# prepare text for embedding from transformer strategy

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ dependencies = [
1212
"huggingface-hub>=0.26.0",
1313
"jsonlines>=4.0.0",
1414
"sentry-sdk>=2.34.1",
15+
"smart-open[s3]>=7.4.4",
1516
"timdex-dataset-api",
1617
"torch>=2.9.0",
1718
"transformers>=4.57.1",
@@ -41,7 +42,10 @@ exclude = [
4142
]
4243

4344
[[tool.mypy.overrides]]
44-
module = ["timdex_dataset_api.*"]
45+
module = [
46+
"timdex_dataset_api.*",
47+
"smart_open.*",
48+
]
4549
follow_untyped_imports = true
4650

4751

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{"timdex_record_id": "record:1", "run_id": "abc123", "run_record_offset": 0, "transformed_record": "{\"title\":\"Record 1\",\"description\":\"This is a record about coffee in the mountains.\"}"}
2+
{"timdex_record_id": "record:2", "run_id": "abc123", "run_record_offset": 1, "transformed_record": "{\"title\":\"Record 2\",\"description\":\"Sometimes poetry is made accidentally by the fabrication of metadata.\"}"}
3+
{"timdex_record_id": "record:3", "run_id": "abc123", "run_record_offset": 2, "transformed_record": "{\"title\":\"Record 3\",\"description\":\"This is an oddball record, meant to evoke the peculiar nature of mathematics.\"}"}

tests/test_cli.py

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,38 @@ def test_model_required_decorator_works_across_commands(
133133
assert "OK" in result.output
134134

135135

136+
def test_create_embeddings_requires_strategy(register_mock_model, runner):
137+
result = runner.invoke(
138+
main,
139+
[
140+
"create-embeddings",
141+
"--model-uri",
142+
"test/mock-model",
143+
"--dataset-location",
144+
"s3://test",
145+
"--run-id",
146+
"run-1",
147+
],
148+
)
149+
assert result.exit_code != 0
150+
assert "Missing option '--strategy'" in result.output
151+
152+
136153
def test_create_embeddings_requires_dataset_location(register_mock_model, runner):
137-
result = runner.invoke(main, ["create-embeddings", "--model-uri", "test/mock-model"])
154+
result = runner.invoke(
155+
main,
156+
[
157+
"create-embeddings",
158+
"--model-uri",
159+
"test/mock-model",
160+
"--run-id",
161+
"run-1",
162+
"--strategy",
163+
"full_record",
164+
],
165+
)
138166
assert result.exit_code != 0
139-
assert "--dataset-location" in result.output
167+
assert "Both '--dataset-location' and '--run-id' are required" in result.output
140168

141169

142170
def test_create_embeddings_requires_run_id(register_mock_model, runner):
@@ -148,24 +176,54 @@ def test_create_embeddings_requires_run_id(register_mock_model, runner):
148176
"test/mock-model",
149177
"--dataset-location",
150178
"s3://test",
179+
"--strategy",
180+
"full_record",
151181
],
152182
)
153183
assert result.exit_code != 0
154-
assert "Missing option '--run-id'" in result.output
184+
assert "Both '--dataset-location' and '--run-id' are required" in result.output
155185

156186

157-
def test_create_embeddings_requires_strategy(register_mock_model, runner):
187+
def test_create_embeddings_optional_input_jsonl(register_mock_model, runner, tmp_path):
188+
input_file = "tests/fixtures/cli_inputs/test-3-records.jsonl"
189+
output_file = tmp_path / "output.jsonl"
190+
158191
result = runner.invoke(
159192
main,
160193
[
161194
"create-embeddings",
162195
"--model-uri",
163196
"test/mock-model",
164-
"--dataset-location",
165-
"s3://test",
166-
"--run-id",
167-
"run-1",
197+
"--input-jsonl",
198+
input_file,
199+
"--strategy",
200+
"full_record",
201+
"--output-jsonl",
202+
str(output_file),
168203
],
169204
)
170-
assert result.exit_code != 0
171-
assert "Missing option '--strategy'" in result.output
205+
assert result.exit_code == 0
206+
assert output_file.exists()
207+
208+
209+
def test_create_embeddings_optional_input_jsonl_does_not_require_dataset_params(
210+
register_mock_model, runner, tmp_path
211+
):
212+
input_file = "tests/fixtures/cli_inputs/test-3-records.jsonl"
213+
output_file = tmp_path / "output.jsonl"
214+
215+
result = runner.invoke(
216+
main,
217+
[
218+
"create-embeddings",
219+
"--model-uri",
220+
"test/mock-model",
221+
"--input-jsonl",
222+
input_file,
223+
"--strategy",
224+
"full_record",
225+
"--output-jsonl",
226+
str(output_file),
227+
],
228+
)
229+
assert result.exit_code == 0

0 commit comments

Comments
 (0)