Skip to content

Commit cbea187

Browse files
authored
[ML] Do not create inference endpoint if ID is used in existing mappings (#137587)
When creating an inference endpoint, if the inference ID is used in incompatible semantic_text mappings, prevent the endpoint from being created. Closes #124272 - Check if existing semantic text fields have compatible model settings - Update and expand test coverage for the new behaviour - Improve existing test InferenceServiceExtension implementations - Move SemanticTextInfoExtractor from xpack.core.ml.utils to xpack.inference.common (cherry picked from commit 635fda1)
1 parent 397b5f9 commit cbea187

File tree

19 files changed

+663
-178
lines changed

19 files changed

+663
-178
lines changed

docs/changelog/137055.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 137055
2+
summary: Do not create inference endpoint if ID is used in existing mappings
3+
area: Machine Learning
4+
type: bug
5+
issues:
6+
- 124272

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/InferenceProcessorInfoExtractor.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,12 @@ public static Map<String, Set<String>> pipelineIdsByResource(ClusterState state,
116116
}
117117

118118
/**
119-
* @param state Current {@link ClusterState}
119+
* @param metadata Current cluster state {@link Metadata}
120120
* @return a map from Model or Deployment IDs or Aliases to each pipeline referencing them.
121121
*/
122-
public static Set<String> pipelineIdsForResource(ClusterState state, Set<String> ids) {
122+
public static Set<String> pipelineIdsForResource(Metadata metadata, Set<String> ids) {
123123
assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures");
124124
Set<String> pipelineIds = new HashSet<>();
125-
Metadata metadata = state.metadata();
126125
if (metadata == null) {
127126
return pipelineIds;
128127
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java

Lines changed: 0 additions & 46 deletions
This file was deleted.

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,22 @@ static String mockDenseServiceModelConfig() {
171171
""";
172172
}
173173

174+
static String mockDenseServiceModelConfig(int dimensions) {
175+
return Strings.format("""
176+
{
177+
"task_type": "text_embedding",
178+
"service": "text_embedding_test_service",
179+
"service_settings": {
180+
"model": "my_dense_vector_model",
181+
"api_key": "abc64",
182+
"dimensions": %s
183+
},
184+
"task_settings": {
185+
}
186+
}
187+
""", dimensions);
188+
}
189+
174190
static String mockRerankServiceModelConfig() {
175191
return """
176192
{

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.xpack.inference;
1111

1212
import org.apache.http.util.EntityUtils;
13+
import org.elasticsearch.client.Request;
1314
import org.elasticsearch.client.Response;
1415
import org.elasticsearch.client.ResponseException;
1516
import org.elasticsearch.common.Strings;
@@ -211,7 +212,7 @@ public void testDeleteEndpointWhileReferencedBySemanticText() throws IOException
211212
final String endpointId = "endpoint_referenced_by_semantic_text";
212213
final String searchEndpointId = "search_endpoint_referenced_by_semantic_text";
213214
final String indexName = randomAlphaOfLength(10).toLowerCase();
214-
final Function<String, String> buildErrorString = endpointName -> " Inference endpoint "
215+
final Function<String, String> buildErrorString = endpointName -> "Inference endpoint "
215216
+ endpointName
216217
+ " is being used in the mapping for indexes: "
217218
+ Set.of(indexName)
@@ -303,6 +304,74 @@ public void testDeleteEndpointWhileReferencedBySemanticTextAndPipeline() throws
303304
deleteIndex(indexName);
304305
}
305306

307+
public void testCreateEndpoint_withInferenceIdReferencedBySemanticText() throws IOException {
308+
final String endpointId = "endpoint_referenced_by_semantic_text";
309+
final String otherEndpointId = "other_endpoint_referenced_by_semantic_text";
310+
final String indexName1 = randomAlphaOfLength(10).toLowerCase();
311+
final String indexName2 = randomValueOtherThan(indexName1, () -> randomAlphaOfLength(10).toLowerCase());
312+
313+
putModel(endpointId, mockDenseServiceModelConfig(128), TaskType.TEXT_EMBEDDING);
314+
putModel(otherEndpointId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
315+
// Create two indices, one where the inference ID of the endpoint we'll be deleting and
316+
// recreating is used for inference_id and one where it's used for search_inference_id
317+
putSemanticText(endpointId, otherEndpointId, indexName1);
318+
putSemanticText(otherEndpointId, endpointId, indexName2);
319+
320+
// Confirm that we can create the endpoint with different settings if there
321+
// are documents in the indices which do not use the semantic text field
322+
var request = new Request("PUT", indexName1 + "/_create/1");
323+
request.setJsonEntity("{\"non_inference_field\": \"value\"}");
324+
assertStatusOkOrCreated(client().performRequest(request));
325+
326+
request = new Request("PUT", indexName2 + "/_create/1");
327+
request.setJsonEntity("{\"non_inference_field\": \"value\"}");
328+
assertStatusOkOrCreated(client().performRequest(request));
329+
330+
assertStatusOkOrCreated(client().performRequest(new Request("GET", "_refresh")));
331+
332+
deleteModel(endpointId, "force=true");
333+
putModel(endpointId, mockDenseServiceModelConfig(64), TaskType.TEXT_EMBEDDING);
334+
335+
// Index a document with the semantic text field into each index
336+
request = new Request("PUT", indexName1 + "/_create/2");
337+
request.setJsonEntity("{\"inference_field\": \"value\"}");
338+
assertStatusOkOrCreated(client().performRequest(request));
339+
340+
request = new Request("PUT", indexName2 + "/_create/2");
341+
request.setJsonEntity("{\"inference_field\": \"value\"}");
342+
assertStatusOkOrCreated(client().performRequest(request));
343+
344+
assertStatusOkOrCreated(client().performRequest(new Request("GET", "_refresh")));
345+
346+
deleteModel(endpointId, "force=true");
347+
348+
// Try to create an inference endpoint with the same ID but different dimensions
349+
// from when the document with the semantic text field was indexed
350+
ResponseException responseException = assertThrows(
351+
ResponseException.class,
352+
() -> putModel(endpointId, mockDenseServiceModelConfig(128), TaskType.TEXT_EMBEDDING)
353+
);
354+
assertThat(
355+
responseException.getMessage(),
356+
containsString(
357+
"Inference endpoint ["
358+
+ endpointId
359+
+ "] could not be created because the inference_id is being used in mappings with incompatible settings for indices: ["
360+
)
361+
);
362+
assertThat(responseException.getMessage(), containsString(indexName1));
363+
assertThat(responseException.getMessage(), containsString(indexName2));
364+
assertThat(
365+
responseException.getMessage(),
366+
containsString("Please either use a different inference_id or update the index mappings to refer to a different inference_id.")
367+
);
368+
369+
deleteIndex(indexName1);
370+
deleteIndex(indexName2);
371+
372+
deleteModel(otherEndpointId, "force=true");
373+
}
374+
306375
public void testUnsupportedStream() throws Exception {
307376
String modelId = "streaming";
308377
putModel(modelId, mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING, "streaming_completion_test_service"));

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public List<Factory> getInferenceServiceFactories() {
4949
}
5050

5151
public static class TestInferenceService extends AbstractTestInferenceService {
52-
private static final String NAME = "completion_test_service";
52+
public static final String NAME = "completion_test_service";
5353
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of(TaskType.COMPLETION);
5454

5555
public TestInferenceService(InferenceServiceFactoryContext context) {}

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,12 @@ public void chunkedInfer(
170170
private TextEmbeddingFloatResults makeResults(List<String> input, ServiceSettings serviceSettings) {
171171
List<TextEmbeddingFloatResults.Embedding> embeddings = new ArrayList<>();
172172
for (String inputString : input) {
173-
List<Float> floatEmbeddings = generateEmbedding(inputString, serviceSettings.dimensions(), serviceSettings.elementType());
173+
List<Float> floatEmbeddings = generateEmbedding(
174+
inputString,
175+
serviceSettings.dimensions(),
176+
serviceSettings.elementType(),
177+
serviceSettings.similarity()
178+
);
174179
embeddings.add(TextEmbeddingFloatResults.Embedding.of(floatEmbeddings));
175180
}
176181
return new TextEmbeddingFloatResults(embeddings);
@@ -206,7 +211,7 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
206211
* <ul>
207212
* <li>Unique to the input</li>
208213
* <li>Reproducible (i.e given the same input, the same embedding should be generated)</li>
209-
* <li>Valid for the provided element type</li>
214+
* <li>Valid for the provided element type and similarity measure</li>
210215
* </ul>
211216
* <p>
212217
* The embedding is generated by:
@@ -216,6 +221,7 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
216221
* <li>converting the hash code value to a string</li>
217222
* <li>converting the string to a UTF-8 encoded byte array</li>
218223
* <li>repeatedly appending the byte array to the embedding until the desired number of dimensions are populated</li>
224+
* <li>converting the embedding to a unit vector if the similarity measure requires that</li>
219225
* </ul>
220226
* <p>
221227
* Since the hash code value, when interpreted as a string, is guaranteed to only contain digits and the "-" character, the UTF-8
@@ -226,11 +232,17 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
226232
* embedding byte.
227233
* </p>
228234
*
229-
* @param input The input string
230-
* @param dimensions The embedding dimension count
235+
* @param input The input string
236+
* @param dimensions The embedding dimension count
237+
* @param similarityMeasure The similarity measure
231238
* @return An embedding
232239
*/
233-
private static List<Float> generateEmbedding(String input, int dimensions, DenseVectorFieldMapper.ElementType elementType) {
240+
private static List<Float> generateEmbedding(
241+
String input,
242+
int dimensions,
243+
DenseVectorFieldMapper.ElementType elementType,
244+
SimilarityMeasure similarityMeasure
245+
) {
234246
int embeddingLength = getEmbeddingLength(elementType, dimensions);
235247
List<Float> embedding = new ArrayList<>(embeddingLength);
236248

@@ -248,6 +260,9 @@ private static List<Float> generateEmbedding(String input, int dimensions, Dense
248260
if (remainingLength > 0) {
249261
embedding.addAll(embeddingValues.subList(0, remainingLength));
250262
}
263+
if (similarityMeasure == SimilarityMeasure.DOT_PRODUCT) {
264+
embedding = toUnitVector(embedding);
265+
}
251266

252267
return embedding;
253268
}
@@ -263,6 +278,11 @@ private static int getEmbeddingLength(DenseVectorFieldMapper.ElementType element
263278
};
264279
}
265280

281+
private static List<Float> toUnitVector(List<Float> embedding) {
282+
var magnitude = (float) Math.sqrt(embedding.stream().reduce(0f, (a, b) -> a + (b * b)));
283+
return embedding.stream().map(v -> v / magnitude).toList();
284+
}
285+
266286
public static class Configuration {
267287
public static InferenceServiceConfiguration get() {
268288
return configuration.getOrCompute();
@@ -303,9 +323,13 @@ public record TestServiceSettings(
303323
public static TestServiceSettings fromMap(Map<String, Object> map) {
304324
ValidationException validationException = new ValidationException();
305325

306-
String model = (String) map.remove("model");
326+
String model = (String) map.remove("model_id");
327+
307328
if (model == null) {
308-
validationException.addValidationError("missing model");
329+
model = (String) map.remove("model");
330+
if (model == null) {
331+
validationException.addValidationError("missing model");
332+
}
309333
}
310334

311335
Integer dimensions = (Integer) map.remove("dimensions");

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,10 @@ public static TestServiceSettings fromMap(Map<String, Object> map) {
300300
String model = (String) map.remove("model_id");
301301

302302
if (model == null) {
303-
validationException.addValidationError("missing model");
303+
model = (String) map.remove("model");
304+
if (model == null) {
305+
validationException.addValidationError("missing model");
306+
}
304307
}
305308

306309
if (validationException.validationErrors().isEmpty() == false) {

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,13 @@ public record TestServiceSettings(String model, String hiddenField, boolean shou
241241
public static TestServiceSettings fromMap(Map<String, Object> map) {
242242
ValidationException validationException = new ValidationException();
243243

244-
String model = (String) map.remove("model");
244+
String model = (String) map.remove("model_id");
245245

246246
if (model == null) {
247-
validationException.addValidationError("missing model");
247+
model = (String) map.remove("model");
248+
if (model == null) {
249+
validationException.addValidationError("missing model");
250+
}
248251
}
249252

250253
String hiddenField = (String) map.remove("hidden_field");

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public List<Factory> getInferenceServiceFactories() {
5959
}
6060

6161
public static class TestInferenceService extends AbstractTestInferenceService {
62-
private static final String NAME = "streaming_completion_test_service";
62+
public static final String NAME = "streaming_completion_test_service";
6363
private static final String ALIAS = "streaming_completion_test_service_alias";
6464
private static final Set<TaskType> supportedStreamingTasks = Set.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
6565

@@ -342,12 +342,15 @@ public TestServiceSettings(StreamInput in) throws IOException {
342342
}
343343

344344
public static TestServiceSettings fromMap(Map<String, Object> map) {
345-
var modelId = map.remove("model").toString();
345+
String modelId = (String) map.remove("model_id");
346346

347347
if (modelId == null) {
348-
ValidationException validationException = new ValidationException();
349-
validationException.addValidationError("missing model id");
350-
throw validationException;
348+
modelId = (String) map.remove("model");
349+
if (modelId == null) {
350+
ValidationException validationException = new ValidationException();
351+
validationException.addValidationError("missing model id");
352+
throw validationException;
353+
}
351354
}
352355

353356
return new TestServiceSettings(modelId);

0 commit comments

Comments
 (0)