From 95a5dbd6cdc1be102efac911320ce2bfade55baf Mon Sep 17 00:00:00 2001 From: Donal Evans Date: Mon, 3 Nov 2025 13:35:06 -0800 Subject: [PATCH] [ML] Do not create inference endpoint if ID is used in existing mappings (#137055) When creating an inference endpoint, if the inference ID is used in incompatible semantic_text mappings, prevent the endpoint from being created. Closes #124272 - Check if existing semantic text fields have compatible model settings - Update and expand test coverage for the new behaviour - Improve existing test InferenceServiceExtension implementations - Move SemanticTextInfoExtractor from xpack.core.ml.utils to xpack.inference.common (cherry picked from commit 635fda1212c386e7d21a1bcf39e97820178087c2) # Conflicts: # x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java --- docs/changelog/137055.yaml | 6 + .../InferenceProcessorInfoExtractor.java | 5 +- .../ml/utils/SemanticTextInfoExtractor.java | 46 --- .../inference/InferenceBaseRestTest.java | 16 + .../xpack/inference/InferenceCrudIT.java | 71 +++- .../mock/TestCompletionServiceExtension.java | 2 +- .../TestDenseInferenceServiceExtension.java | 38 +- .../mock/TestRerankingServiceExtension.java | 5 +- .../TestSparseInferenceServiceExtension.java | 7 +- ...stStreamingCompletionServiceExtension.java | 13 +- .../CreateInferenceEndpointIT.java | 355 ++++++++++++++++++ ...ransportDeleteInferenceEndpointAction.java | 18 +- .../TransportPutInferenceModelAction.java | 75 +++- .../common/SemanticTextInfoExtractor.java | 75 ++++ .../inference/mapper/SemanticTextField.java | 6 +- .../mapper/SemanticTextFieldMapper.java | 2 +- .../settings/DefaultSecretSettings.java | 2 +- .../inference/LocalStateInferencePlugin.java | 8 +- ..._text_query_inference_endpoint_changes.yml | 95 +---- 19 files changed, 670 insertions(+), 175 deletions(-) create mode 100644 docs/changelog/137055.yaml delete mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java create mode 100644 x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CreateInferenceEndpointIT.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/SemanticTextInfoExtractor.java diff --git a/docs/changelog/137055.yaml b/docs/changelog/137055.yaml new file mode 100644 index 0000000000000..e2e0581a5f5ed --- /dev/null +++ b/docs/changelog/137055.yaml @@ -0,0 +1,6 @@ +pr: 137055 +summary: Do not create inference endpoint if ID is used in existing mappings +area: Machine Learning +type: bug +issues: + - 124272 diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/InferenceProcessorInfoExtractor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/InferenceProcessorInfoExtractor.java index 70c9ecf872e97..951d3f02d2cf5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/InferenceProcessorInfoExtractor.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/InferenceProcessorInfoExtractor.java @@ -116,13 +116,12 @@ public static Map> pipelineIdsByResource(ClusterState state, } /** - * @param state Current {@link ClusterState} + * @param metadata Current cluster state {@link Metadata} * @return a map from Model or Deployment IDs or Aliases to each pipeline referencing them. */ - public static Set pipelineIdsForResource(ClusterState state, Set ids) { + public static Set pipelineIdsForResource(Metadata metadata, Set ids) { assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures"); Set pipelineIds = new HashSet<>(); - Metadata metadata = state.metadata(); if (metadata == null) { return pipelineIds; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java deleted file mode 100644 index d65e0117027a9..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - * - * this file was contributed to by a Generative AI - */ - -package org.elasticsearch.xpack.core.ml.utils; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; -import org.elasticsearch.cluster.metadata.Metadata; -import org.elasticsearch.transport.Transports; - -import java.util.HashSet; -import java.util.Map; -import java.util.Set; - -public class SemanticTextInfoExtractor { - private static final Logger logger = LogManager.getLogger(SemanticTextInfoExtractor.class); - - public static Set extractIndexesReferencingInferenceEndpoints(Metadata metadata, Set endpointIds) { - assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures"); - assert endpointIds.isEmpty() == false; - assert metadata != null; - - Set referenceIndices = new HashSet<>(); - - Map indices = metadata.getProject().indices(); - - indices.forEach((indexName, indexMetadata) -> { - Map inferenceFields = indexMetadata.getInferenceFields(); - if (inferenceFields.values() - .stream() - .anyMatch(im -> endpointIds.contains(im.getInferenceId()) || endpointIds.contains(im.getSearchInferenceId()))) { - referenceIndices.add(indexName); - } - }); - - return referenceIndices; - } -} diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 69256d49fe1d2..2c833186df0f0 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -171,6 +171,22 @@ static String mockDenseServiceModelConfig() { """; } + static String mockDenseServiceModelConfig(int dimensions) { + return Strings.format(""" + { + "task_type": "text_embedding", + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_dense_vector_model", + "api_key": "abc64", + "dimensions": %s + }, + "task_settings": { + } + } + """, dimensions); + } + static String mockRerankServiceModelConfig() { return """ { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 0a98787514010..bf5c10233119d 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference; import org.apache.http.util.EntityUtils; +import org.elasticsearch.client.Request; import org.elasticsearch.client.Response; import org.elasticsearch.client.ResponseException; import org.elasticsearch.common.Strings; @@ -211,7 +212,7 @@ public void testDeleteEndpointWhileReferencedBySemanticText() throws IOException final String endpointId = "endpoint_referenced_by_semantic_text"; final String searchEndpointId = "search_endpoint_referenced_by_semantic_text"; final String indexName = randomAlphaOfLength(10).toLowerCase(); - final Function buildErrorString = endpointName -> " Inference endpoint " + final Function buildErrorString = endpointName -> "Inference endpoint " + endpointName + " is being used in the mapping for indexes: " + Set.of(indexName) @@ -303,6 +304,74 @@ public void testDeleteEndpointWhileReferencedBySemanticTextAndPipeline() throws deleteIndex(indexName); } + public void testCreateEndpoint_withInferenceIdReferencedBySemanticText() throws IOException { + final String endpointId = "endpoint_referenced_by_semantic_text"; + final String otherEndpointId = "other_endpoint_referenced_by_semantic_text"; + final String indexName1 = randomAlphaOfLength(10).toLowerCase(); + final String indexName2 = randomValueOtherThan(indexName1, () -> randomAlphaOfLength(10).toLowerCase()); + + putModel(endpointId, mockDenseServiceModelConfig(128), TaskType.TEXT_EMBEDDING); + putModel(otherEndpointId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING); + // Create two indices, one where the inference ID of the endpoint we'll be deleting and + // recreating is used for inference_id and one where it's used for search_inference_id + putSemanticText(endpointId, otherEndpointId, indexName1); + putSemanticText(otherEndpointId, endpointId, indexName2); + + // Confirm that we can create the endpoint with different settings if there + // are documents in the indices which do not use the semantic text field + var request = new Request("PUT", indexName1 + "/_create/1"); + request.setJsonEntity("{\"non_inference_field\": \"value\"}"); + assertStatusOkOrCreated(client().performRequest(request)); + + request = new Request("PUT", indexName2 + "/_create/1"); + request.setJsonEntity("{\"non_inference_field\": \"value\"}"); + assertStatusOkOrCreated(client().performRequest(request)); + + assertStatusOkOrCreated(client().performRequest(new Request("GET", "_refresh"))); + + deleteModel(endpointId, "force=true"); + putModel(endpointId, mockDenseServiceModelConfig(64), TaskType.TEXT_EMBEDDING); + + // Index a document with the semantic text field into each index + request = new Request("PUT", indexName1 + "/_create/2"); + request.setJsonEntity("{\"inference_field\": \"value\"}"); + assertStatusOkOrCreated(client().performRequest(request)); + + request = new Request("PUT", indexName2 + "/_create/2"); + request.setJsonEntity("{\"inference_field\": \"value\"}"); + assertStatusOkOrCreated(client().performRequest(request)); + + assertStatusOkOrCreated(client().performRequest(new Request("GET", "_refresh"))); + + deleteModel(endpointId, "force=true"); + + // Try to create an inference endpoint with the same ID but different dimensions + // from when the document with the semantic text field was indexed + ResponseException responseException = assertThrows( + ResponseException.class, + () -> putModel(endpointId, mockDenseServiceModelConfig(128), TaskType.TEXT_EMBEDDING) + ); + assertThat( + responseException.getMessage(), + containsString( + "Inference endpoint [" + + endpointId + + "] could not be created because the inference_id is being used in mappings with incompatible settings for indices: [" + ) + ); + assertThat(responseException.getMessage(), containsString(indexName1)); + assertThat(responseException.getMessage(), containsString(indexName2)); + assertThat( + responseException.getMessage(), + containsString("Please either use a different inference_id or update the index mappings to refer to a different inference_id.") + ); + + deleteIndex(indexName1); + deleteIndex(indexName2); + + deleteModel(otherEndpointId, "force=true"); + } + public void testUnsupportedStream() throws Exception { String modelId = "streaming"; putModel(modelId, mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING, "streaming_completion_test_service")); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java index 9c15ac77cc13f..58ccdd9fbfe30 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java @@ -49,7 +49,7 @@ public List getInferenceServiceFactories() { } public static class TestInferenceService extends AbstractTestInferenceService { - private static final String NAME = "completion_test_service"; + public static final String NAME = "completion_test_service"; private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of(TaskType.COMPLETION); public TestInferenceService(InferenceServiceFactoryContext context) {} diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 044af0ab1d37d..efde0f201b604 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -170,7 +170,12 @@ public void chunkedInfer( private TextEmbeddingFloatResults makeResults(List input, ServiceSettings serviceSettings) { List embeddings = new ArrayList<>(); for (String inputString : input) { - List floatEmbeddings = generateEmbedding(inputString, serviceSettings.dimensions(), serviceSettings.elementType()); + List floatEmbeddings = generateEmbedding( + inputString, + serviceSettings.dimensions(), + serviceSettings.elementType(), + serviceSettings.similarity() + ); embeddings.add(TextEmbeddingFloatResults.Embedding.of(floatEmbeddings)); } return new TextEmbeddingFloatResults(embeddings); @@ -206,7 +211,7 @@ protected ServiceSettings getServiceSettingsFromMap(Map serviceS *
    *
  • Unique to the input
  • *
  • Reproducible (i.e given the same input, the same embedding should be generated)
  • - *
  • Valid for the provided element type
  • + *
  • Valid for the provided element type and similarity measure
  • *
*

* The embedding is generated by: @@ -216,6 +221,7 @@ protected ServiceSettings getServiceSettingsFromMap(Map serviceS *

  • converting the hash code value to a string
  • *
  • converting the string to a UTF-8 encoded byte array
  • *
  • repeatedly appending the byte array to the embedding until the desired number of dimensions are populated
  • + *
  • converting the embedding to a unit vector if the similarity measure requires that
  • * *

    * Since the hash code value, when interpreted as a string, is guaranteed to only contain digits and the "-" character, the UTF-8 @@ -226,11 +232,17 @@ protected ServiceSettings getServiceSettingsFromMap(Map serviceS * embedding byte. *

    * - * @param input The input string - * @param dimensions The embedding dimension count + * @param input The input string + * @param dimensions The embedding dimension count + * @param similarityMeasure The similarity measure * @return An embedding */ - private static List generateEmbedding(String input, int dimensions, DenseVectorFieldMapper.ElementType elementType) { + private static List generateEmbedding( + String input, + int dimensions, + DenseVectorFieldMapper.ElementType elementType, + SimilarityMeasure similarityMeasure + ) { int embeddingLength = getEmbeddingLength(elementType, dimensions); List embedding = new ArrayList<>(embeddingLength); @@ -248,6 +260,9 @@ private static List generateEmbedding(String input, int dimensions, Dense if (remainingLength > 0) { embedding.addAll(embeddingValues.subList(0, remainingLength)); } + if (similarityMeasure == SimilarityMeasure.DOT_PRODUCT) { + embedding = toUnitVector(embedding); + } return embedding; } @@ -263,6 +278,11 @@ private static int getEmbeddingLength(DenseVectorFieldMapper.ElementType element }; } + private static List toUnitVector(List embedding) { + var magnitude = (float) Math.sqrt(embedding.stream().reduce(0f, (a, b) -> a + (b * b))); + return embedding.stream().map(v -> v / magnitude).toList(); + } + public static class Configuration { public static InferenceServiceConfiguration get() { return configuration.getOrCompute(); @@ -303,9 +323,13 @@ public record TestServiceSettings( public static TestServiceSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); - String model = (String) map.remove("model"); + String model = (String) map.remove("model_id"); + if (model == null) { - validationException.addValidationError("missing model"); + model = (String) map.remove("model"); + if (model == null) { + validationException.addValidationError("missing model"); + } } Integer dimensions = (Integer) map.remove("dimensions"); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index b496ea783c002..d9a3c5c044623 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -300,7 +300,10 @@ public static TestServiceSettings fromMap(Map map) { String model = (String) map.remove("model_id"); if (model == null) { - validationException.addValidationError("missing model"); + model = (String) map.remove("model"); + if (model == null) { + validationException.addValidationError("missing model"); + } } if (validationException.validationErrors().isEmpty() == false) { diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index d8dd8b4e0e35b..74800df36493f 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -241,10 +241,13 @@ public record TestServiceSettings(String model, String hiddenField, boolean shou public static TestServiceSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); - String model = (String) map.remove("model"); + String model = (String) map.remove("model_id"); if (model == null) { - validationException.addValidationError("missing model"); + model = (String) map.remove("model"); + if (model == null) { + validationException.addValidationError("missing model"); + } } String hiddenField = (String) map.remove("hidden_field"); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index f70e0884879ea..65505f943a1eb 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -59,7 +59,7 @@ public List getInferenceServiceFactories() { } public static class TestInferenceService extends AbstractTestInferenceService { - private static final String NAME = "streaming_completion_test_service"; + public static final String NAME = "streaming_completion_test_service"; private static final String ALIAS = "streaming_completion_test_service_alias"; private static final Set supportedStreamingTasks = Set.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); @@ -342,12 +342,15 @@ public TestServiceSettings(StreamInput in) throws IOException { } public static TestServiceSettings fromMap(Map map) { - var modelId = map.remove("model").toString(); + String modelId = (String) map.remove("model_id"); if (modelId == null) { - ValidationException validationException = new ValidationException(); - validationException.addValidationError("missing model id"); - throw validationException; + modelId = (String) map.remove("model"); + if (modelId == null) { + ValidationException validationException = new ValidationException(); + validationException.addValidationError("missing model id"); + throw validationException; + } } return new TestServiceSettings(modelId); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CreateInferenceEndpointIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CreateInferenceEndpointIT.java new file mode 100644 index 0000000000000..860831ab968d6 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CreateInferenceEndpointIT.java @@ -0,0 +1,355 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.license.LicenseSettings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.reindex.ReindexPlugin; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction; +import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; +import org.elasticsearch.xpack.inference.mock.TestRerankingServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestStreamingCompletionServiceExtension; + +import java.io.IOException; +import java.util.Collection; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.inference.ModelConfigurations.SERVICE; +import static org.elasticsearch.inference.ModelConfigurations.SERVICE_SETTINGS; +import static org.elasticsearch.inference.SimilarityMeasure.COSINE; +import static org.elasticsearch.inference.SimilarityMeasure.L2_NORM; +import static org.elasticsearch.inference.TaskType.ANY; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.ELEMENT_TYPE; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings.API_KEY; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.not; + +@ESTestCase.WithoutEntitlements // due to dependency issue ES-12435 +public class CreateInferenceEndpointIT extends ESIntegTestCase { + + public static final String INFERENCE_ID = "inference-id"; + public static final String NOT_MODIFIED_INFERENCE_ID = "not-modified-inference-id"; + public static final String SEMANTIC_TEXT_FIELD_NAME = "semantic-text-field"; + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); + } + + @Override + protected Collection> nodePlugins() { + return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class, ReindexPlugin.class); + } + + public void testCreateInferenceEndpoint_fails_whenSemanticTextFieldUsingTheInferenceIdExists_andTaskTypeIsIncompatible() + throws IOException { + modifyEndpointAndAssertFailure(true, null); + } + + public void testCreateInferenceEndpoint_fails_whenSemanticTextFieldUsingTheInferenceIdExists_andDimensionsAreIncompatible() + throws IOException { + modifyEndpointAndAssertFailure(false, DIMENSIONS); + } + + public void testCreateInferenceEndpoint_fails_whenSemanticTextFieldUsingTheInferenceIdExists_andElementTypeIsIncompatible() + throws IOException { + modifyEndpointAndAssertFailure(false, ELEMENT_TYPE); + } + + public void testCreateInferenceEndpoint_fails_whenSemanticTextFieldUsingTheInferenceIdExists_andSimilarityIsIncompatible() + throws IOException { + modifyEndpointAndAssertFailure(false, SIMILARITY); + } + + public void testCreateInferenceEndpoint_succeeds_whenSemanticTextFieldUsingTheInferenceIdExists_andAllSettingsAreTheSame() + throws IOException { + modifyEndpointAndAssertSuccess(null, true); + } + + public void testCreateInferenceEndpoint_succeeds_whenSemanticTextFieldUsingTheInferenceIdExists_andModelIdIsDifferent() + throws IOException { + modifyEndpointAndAssertSuccess(MODEL_ID, true); + } + + public void testCreateInferenceEndpoint_succeeds_whenSemanticTextFieldUsingTheInferenceIdExists_andApiKeyIsDifferent() + throws IOException { + modifyEndpointAndAssertSuccess(API_KEY, true); + } + + public void testCreateInferenceEndpoint_succeeds_whenNoDocumentsUsingSemanticTextHaveBeenIndexed() throws IOException { + String fieldToModify = randomFrom(DIMENSIONS, ELEMENT_TYPE, SIMILARITY); + modifyEndpointAndAssertSuccess(fieldToModify, false); + } + + public void testCreateInferenceEndpoint_succeeds_whenIndexIsCreatedBeforeInferenceEndpoint() throws IOException { + String inferenceId = NOT_MODIFIED_INFERENCE_ID; + String indexName = createIndexWithSemanticTextMapping(inferenceId); + + assertEndpointCreationSuccessful(randomTaskType(), getRandomServiceSettings(), inferenceId); + + deleteIndex(client(), indexName); + } + + private void modifyEndpointAndAssertFailure(boolean modifyTaskType, String settingsFieldToModify) throws IOException { + TaskType taskType = TEXT_EMBEDDING; + Map serviceSettings = getRandomServiceSettings(); + Set indicesUsingInferenceId = new HashSet<>(); + + String indexNotUsingInferenceId = indexDocumentsAndDeleteEndpoint(taskType, serviceSettings, indicesUsingInferenceId, true); + + ElasticsearchStatusException statusException = expectThrows( + ElasticsearchStatusException.class, + () -> createEndpointWithModifiedSettings(modifyTaskType, settingsFieldToModify, taskType, serviceSettings) + ); + + assertThat(statusException.status(), is(RestStatus.BAD_REQUEST)); + assertThat( + statusException.getMessage(), + containsString( + "Inference endpoint [" + + INFERENCE_ID + + "] could not be created because the inference_id is being used in mappings with incompatible settings for indices: [" + ) + ); + + // Make sure we only report the indices that were using the inference ID + indicesUsingInferenceId.forEach(index -> assertThat(statusException.getMessage(), containsString(index))); + assertThat(statusException.getMessage(), not(containsString(indexNotUsingInferenceId))); + + indicesUsingInferenceId.forEach(index -> deleteIndex(client(), index)); + deleteIndex(client(), indexNotUsingInferenceId); + } + + private void modifyEndpointAndAssertSuccess(String fieldToModify, boolean documentHasSemanticText) throws IOException { + TaskType taskType = TEXT_EMBEDDING; + Map serviceSettings = getRandomServiceSettings(); + HashSet indicesUsingInferenceId = new HashSet<>(); + + String indexNotUsingInferenceId = indexDocumentsAndDeleteEndpoint( + taskType, + serviceSettings, + indicesUsingInferenceId, + documentHasSemanticText + ); + + PutInferenceModelAction.Response response = createEndpointWithModifiedSettings(false, fieldToModify, taskType, serviceSettings); + assertThat(response.getModel().getInferenceEntityId(), equalTo(INFERENCE_ID)); + + indicesUsingInferenceId.forEach(index -> deleteIndex(client(), index)); + deleteIndex(client(), indexNotUsingInferenceId); + } + + private String indexDocumentsAndDeleteEndpoint( + TaskType taskType, + Map serviceSettings, + Set indicesUsingInferenceId, + boolean documentHasSemanticText + ) throws IOException { + assertEndpointCreationSuccessful(taskType, serviceSettings, INFERENCE_ID); + assertEndpointCreationSuccessful(taskType, serviceSettings, NOT_MODIFIED_INFERENCE_ID); + + // Create several indices to confirm that we can identify them all in the error message + for (int i = 0; i < 5; ++i) { + String indexUsingInferenceId = createIndexWithSemanticTextMapping(INFERENCE_ID, indicesUsingInferenceId); + indexDocument(indexUsingInferenceId, documentHasSemanticText); + indicesUsingInferenceId.add(indexUsingInferenceId); + } + + // Also create a second endpoint which will not be deleted and recreated, and an index which is using it + String indexNotUsingInferenceId = createIndexWithSemanticTextMapping(NOT_MODIFIED_INFERENCE_ID, indicesUsingInferenceId); + indexDocument(indexNotUsingInferenceId, documentHasSemanticText); + + forceDeleteInferenceEndpoint(INFERENCE_ID, taskType); + return indexNotUsingInferenceId; + } + + private PutInferenceModelAction.Response createEndpointWithModifiedSettings( + boolean modifyTaskType, + String fieldToModify, + TaskType taskType, + Map serviceSettings + ) { + TaskType newTaskType = modifyTaskType ? randomValueOtherThan(taskType, CreateInferenceEndpointIT::randomTaskType) : taskType; + Map newSettings = fieldToModify != null ? modifyServiceSettings(serviceSettings, fieldToModify) : serviceSettings; + return createEndpoint(newTaskType, newSettings, INFERENCE_ID).actionGet(TEST_REQUEST_TIMEOUT); + } + + private void assertEndpointCreationSuccessful(TaskType taskType, Map serviceSettings, String inferenceId) { + assertThat( + createEndpoint(taskType, serviceSettings, inferenceId).actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), + equalTo(inferenceId) + ); + } + + private ActionFuture createEndpoint( + TaskType taskType, + Map serviceSettings, + String inferenceId + ) { + final BytesReference content; + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.startObject(); + builder.field(SERVICE, getServiceForTaskType(taskType)); + builder.field(SERVICE_SETTINGS, serviceSettings); + builder.endObject(); + content = BytesReference.bytes(builder); + } catch (IOException ex) { + throw new AssertionError(ex); + } + + var request = new PutInferenceModelAction.Request(taskType, inferenceId, content, XContentType.JSON, TEST_REQUEST_TIMEOUT); + return client().execute(PutInferenceModelAction.INSTANCE, request); + } + + private String getServiceForTaskType(TaskType taskType) { + return switch (taskType) { + case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME; + case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME; + case RERANK -> TestRerankingServiceExtension.TestInferenceService.NAME; + case COMPLETION, CHAT_COMPLETION -> TestStreamingCompletionServiceExtension.TestInferenceService.NAME; + default -> throw new IllegalStateException("Unexpected value: " + taskType); + }; + } + + private static TaskType randomTaskType() { + EnumSet taskTypes = EnumSet.allOf(TaskType.class); + taskTypes.remove(ANY); + return randomFrom(taskTypes); + } + + private String createIndexWithSemanticTextMapping(String inferenceId) throws IOException { + return createIndexWithSemanticTextMapping(inferenceId, Set.of()); + } + + private String createIndexWithSemanticTextMapping(String inferenceId, Set existingIndexNames) throws IOException { + // Ensure that all index names are unique + String indexName = randomValueOtherThanMany( + existingIndexNames::contains, + () -> ESTestCase.randomAlphaOfLength(10).toLowerCase(Locale.ROOT) + ); + XContentBuilder mapping = XContentFactory.jsonBuilder().startObject().startObject(ElasticsearchMappings.PROPERTIES); + mapping.startObject(SEMANTIC_TEXT_FIELD_NAME); + mapping.field(ElasticsearchMappings.TYPE, SemanticTextFieldMapper.CONTENT_TYPE); + mapping.field(SemanticTextField.INFERENCE_ID_FIELD, inferenceId); + mapping.endObject().endObject().endObject(); + + assertAcked(prepareCreate(indexName).setMapping(mapping)); + return indexName; + } + + private static void indexDocument(String indexName, boolean withSemanticText) { + var source = new HashMap(); + source.put("field", "value"); + if (withSemanticText) { + source.put(SEMANTIC_TEXT_FIELD_NAME, randomAlphaOfLength(10)); + } + DocWriteResponse response = client().prepareIndex(indexName).setSource(source).get(TEST_REQUEST_TIMEOUT); + assertThat(response.getResult(), is(DocWriteResponse.Result.CREATED)); + client().admin().indices().prepareRefresh(indexName).get(); + } + + private static Map getRandomServiceSettings() { + Map settings = new HashMap<>(); + settings.put(MODEL_ID, randomIdentifier()); + settings.put(API_KEY, randomIdentifier()); + // Always use a dimension that's a multiple of 8 because the BIT element type requires that + settings.put(DIMENSIONS, randomIntBetween(1, 32) * 8); + ElementType elementType = randomFrom(ElementType.values()); + settings.put(ELEMENT_TYPE, elementType.toString()); + if (elementType == ElementType.BIT) { + // The only supported similarity measure for BIT vectors is L2_NORM + settings.put(SIMILARITY, L2_NORM.toString()); + } else if (elementType == ElementType.BYTE) { + // DOT_PRODUCT similarity does not work with BYTE due to how TestDenseInferenceServiceExtension creates embeddings + settings.put(SIMILARITY, randomFrom(L2_NORM, COSINE).toString()); + } else { + settings.put(SIMILARITY, randomFrom(SimilarityMeasure.values()).toString()); + } + return settings; + } + + private static Map modifyServiceSettings(Map serviceSettings, String fieldToModify) { + var newServiceSettings = new HashMap<>(serviceSettings); + switch (fieldToModify) { + case MODEL_ID, API_KEY -> newServiceSettings.compute( + fieldToModify, + (k, value) -> randomValueOtherThan(value, ESTestCase::randomIdentifier) + ); + case DIMENSIONS -> newServiceSettings.compute( + DIMENSIONS, + (k, dimensions) -> randomValueOtherThan(dimensions, () -> randomIntBetween(8, 128) * 8) + ); + case ELEMENT_TYPE -> newServiceSettings.compute( + ELEMENT_TYPE, + (k, elementType) -> randomValueOtherThan(elementType, () -> randomFrom(ElementType.values()).toString()) + ); + case SIMILARITY -> newServiceSettings.compute( + SIMILARITY, + (k, similarity) -> randomValueOtherThan(similarity, () -> randomFrom(SimilarityMeasure.values()).toString()) + ); + default -> throw new AssertionError("Invalid service settings field " + fieldToModify); + } + return newServiceSettings; + } + + private void forceDeleteInferenceEndpoint(String inferenceId, TaskType taskType) { + var request = new DeleteInferenceEndpointAction.Request(inferenceId, taskType, true, false); + var responseFuture = client().execute(DeleteInferenceEndpointAction.INSTANCE, request); + responseFuture.actionGet(TEST_REQUEST_TIMEOUT); + } + + private static void deleteIndex(Client client, String indexName) { + assertAcked( + safeGet( + client.admin() + .indices() + .prepareDelete(indexName) + .setIndicesOptions( + IndicesOptions.builder().concreteTargetOptions(new IndicesOptions.ConcreteTargetOptions(true)).build() + ) + .execute() + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java index d213111d82d9f..1147b261e4b22 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java @@ -30,15 +30,15 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction; -import org.elasticsearch.xpack.core.ml.utils.InferenceProcessorInfoExtractor; import org.elasticsearch.xpack.inference.common.InferenceExceptions; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.Set; import java.util.concurrent.Executor; -import static org.elasticsearch.xpack.core.ml.utils.SemanticTextInfoExtractor.extractIndexesReferencingInferenceEndpoints; +import static org.elasticsearch.xpack.core.ml.utils.InferenceProcessorInfoExtractor.pipelineIdsForResource; import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; +import static org.elasticsearch.xpack.inference.common.SemanticTextInfoExtractor.extractIndexesReferencingInferenceEndpoints; public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeAction< DeleteInferenceEndpointAction.Request, @@ -166,12 +166,9 @@ private static void handleDryRun( ClusterState state, ActionListener masterListener ) { - Set pipelines = InferenceProcessorInfoExtractor.pipelineIdsForResource(state, Set.of(request.getInferenceEndpointId())); + Set pipelines = endpointIsReferencedInPipelines(state, request.getInferenceEndpointId()); - Set indexesReferencedBySemanticText = extractIndexesReferencingInferenceEndpoints( - state.getMetadata(), - Set.of(request.getInferenceEndpointId()) - ); + Set indexesReferencedBySemanticText = endpointIsReferencedInIndex(state, request.getInferenceEndpointId()); masterListener.onResponse( new DeleteInferenceEndpointAction.Response( @@ -212,7 +209,10 @@ private static String buildErrorString(String inferenceEndpointId, Set p } if (indexes.isEmpty() == false) { - errorString.append(" Inference endpoint ") + if (errorString.isEmpty() == false) { + errorString.append(" "); + } + errorString.append("Inference endpoint ") .append(inferenceEndpointId) .append(" is being used in the mapping for indexes: ") .append(indexes) @@ -229,7 +229,7 @@ private static Set endpointIsReferencedInIndex(final ClusterState state, } private static Set endpointIsReferencedInPipelines(final ClusterState state, final String inferenceEndpointId) { - return InferenceProcessorInfoExtractor.pipelineIdsForResource(state, Set.of(inferenceEndpointId)); + return pipelineIdsForResource(state.metadata(), Set.of(inferenceEndpointId)); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 80d57f888ef6e..f7a563d9bfed9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -13,18 +13,23 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.master.TransportMasterNodeAction; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.project.ProjectResolver; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.mapper.FieldMapper; import org.elasticsearch.index.mapper.StrictDynamicMappingException; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; @@ -49,11 +54,17 @@ import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.io.IOException; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; +import static org.elasticsearch.xpack.inference.common.SemanticTextInfoExtractor.getModelSettingsForIndicesReferencingInferenceEndpoints; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.canMergeModelSettings; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME; public class TransportPutInferenceModelAction extends TransportMasterNodeAction< @@ -65,6 +76,7 @@ public class TransportPutInferenceModelAction extends TransportMasterNodeAction< private final XPackLicenseState licenseState; private final ModelRegistry modelRegistry; private final InferenceServiceRegistry serviceRegistry; + private final OriginSettingClient client; private volatile boolean skipValidationAndStart; private final ProjectResolver projectResolver; @@ -78,7 +90,8 @@ public TransportPutInferenceModelAction( ModelRegistry modelRegistry, InferenceServiceRegistry serviceRegistry, Settings settings, - ProjectResolver projectResolver + ProjectResolver projectResolver, + Client client ) { super( PutInferenceModelAction.NAME, @@ -97,6 +110,7 @@ public TransportPutInferenceModelAction( clusterService.getClusterSettings() .addSettingsUpdateConsumer(InferencePlugin.SKIP_VALIDATE_AND_START, this::setSkipValidationAndStart); this.projectResolver = projectResolver; + this.client = new OriginSettingClient(client, INFERENCE_ORIGIN); } @Override @@ -181,7 +195,15 @@ protected void masterOperation( return; } - parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.getTimeout(), listener); + parseAndStoreModel( + service.get(), + request.getInferenceEntityId(), + resolvedTaskType, + requestAsMap, + request.getTimeout(), + state.metadata(), + listener + ); } private void parseAndStoreModel( @@ -190,6 +212,7 @@ private void parseAndStoreModel( TaskType taskType, Map config, TimeValue timeout, + Metadata metadata, ActionListener listener ) { ActionListener storeModelListener = listener.delegateFailureAndWrap( @@ -212,7 +235,7 @@ private void parseAndStoreModel( ) ); - ActionListener parsedModelListener = listener.delegateFailureAndWrap((delegate, model) -> { + ActionListener modelValidatingListener = listener.delegateFailureAndWrap((delegate, model) -> { if (skipValidationAndStart) { storeModelListener.onResponse(model); } else { @@ -221,7 +244,51 @@ private void parseAndStoreModel( } }); - service.parseRequestConfig(inferenceEntityId, taskType, config, parsedModelListener); + ActionListener existingUsesListener = listener.delegateFailureAndWrap((delegate, model) -> { + // Execute in another thread because checking for existing uses requires reading from indices + threadPool.executor(UTILITY_THREAD_POOL_NAME) + .execute(() -> checkForExistingUsesOfInferenceId(metadata, model, modelValidatingListener)); + }); + + service.parseRequestConfig(inferenceEntityId, taskType, config, existingUsesListener); + } + + private void checkForExistingUsesOfInferenceId(Metadata metadata, Model model, ActionListener modelValidatingListener) { + Set inferenceEntityIdSet = Set.of(model.getInferenceEntityId()); + Set indicesWithIncompatibleMappings = findIndicesWithIncompatibleMappings(model, metadata, inferenceEntityIdSet); + + if (indicesWithIncompatibleMappings.isEmpty()) { + modelValidatingListener.onResponse(model); + } else { + modelValidatingListener.onFailure( + new ElasticsearchStatusException( + buildErrorString(model.getInferenceEntityId(), indicesWithIncompatibleMappings), + RestStatus.BAD_REQUEST + ) + ); + } + } + + private Set findIndicesWithIncompatibleMappings(Model model, Metadata metadata, Set inferenceEntityIdSet) { + var serviceSettingsMap = getModelSettingsForIndicesReferencingInferenceEndpoints(metadata, inferenceEntityIdSet); + var incompatibleIndices = new HashSet(); + if (serviceSettingsMap.isEmpty() == false) { + MinimalServiceSettings newSettings = new MinimalServiceSettings(model); + serviceSettingsMap.forEach((indexName, existingSettings) -> { + if (canMergeModelSettings(existingSettings, newSettings, new FieldMapper.Conflicts("")) == false) { + incompatibleIndices.add(indexName); + } + }); + } + return incompatibleIndices; + } + + private static String buildErrorString(String inferenceId, Set indicesWithIncompatibleMappings) { + return "Inference endpoint [" + + inferenceId + + "] could not be created because the inference_id is being used in mappings with incompatible settings for indices: " + + indicesWithIncompatibleMappings + + ". Please either use a different inference_id or update the index mappings to refer to a different inference_id."; } private void startInferenceEndpoint( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/SemanticTextInfoExtractor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/SemanticTextInfoExtractor.java new file mode 100644 index 0000000000000..e4ca1c7dbfa93 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/SemanticTextInfoExtractor.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common; + +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.cluster.metadata.MappingMetadata; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.inference.MinimalServiceSettings; +import org.elasticsearch.transport.Transports; +import org.elasticsearch.xcontent.ObjectPath; +import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings; +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +public class SemanticTextInfoExtractor { + public static Set extractIndexesReferencingInferenceEndpoints(Metadata metadata, Set endpointIds) { + assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures"); + assert endpointIds.isEmpty() == false; + assert metadata != null; + + Set referenceIndices = new HashSet<>(); + + Map indices = metadata.getProject().indices(); + + indices.forEach((indexName, indexMetadata) -> { + Map inferenceFields = indexMetadata.getInferenceFields(); + if (inferenceFields.values() + .stream() + .anyMatch(im -> endpointIds.contains(im.getInferenceId()) || endpointIds.contains(im.getSearchInferenceId()))) { + referenceIndices.add(indexName); + } + }); + + return referenceIndices; + } + + public static Map getModelSettingsForIndicesReferencingInferenceEndpoints( + Metadata metadata, + Set endpointIds + ) { + assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures"); + assert endpointIds.isEmpty() == false; + assert metadata != null; + + Map serviceSettingsMap = new HashMap<>(); + + metadata.getProject().indices().forEach((indexName, indexMetadata) -> { + indexMetadata.getInferenceFields() + .values() + .stream() + .filter(field -> endpointIds.contains(field.getInferenceId()) || endpointIds.contains(field.getSearchInferenceId())) + .findFirst() // Assume that the model settings are the same for all fields using the inference endpoint + .ifPresent(field -> { + MappingMetadata mapping = indexMetadata.mapping(); + if (mapping != null) { + String[] pathArray = { ElasticsearchMappings.PROPERTIES, field.getName(), SemanticTextField.MODEL_SETTINGS_FIELD }; + Object modelSettings = ObjectPath.eval(pathArray, mapping.sourceAsMap()); + serviceSettingsMap.put(indexName, SemanticTextField.parseModelSettingsFromMap(modelSettings)); + } + }); + }); + + return serviceSettingsMap; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index b6652e499b9fc..b160667e95c31 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -62,7 +62,7 @@ public record SemanticTextField( static final String TEXT_FIELD = "text"; static final String INFERENCE_FIELD = "inference"; - static final String INFERENCE_ID_FIELD = "inference_id"; + public static final String INFERENCE_ID_FIELD = "inference_id"; static final String SEARCH_INFERENCE_ID_FIELD = "search_inference_id"; static final String CHUNKS_FIELD = "chunks"; static final String CHUNKED_EMBEDDINGS_FIELD = "embeddings"; @@ -70,7 +70,7 @@ public record SemanticTextField( static final String CHUNKED_OFFSET_FIELD = "offset"; static final String CHUNKED_START_OFFSET_FIELD = "start_offset"; static final String CHUNKED_END_OFFSET_FIELD = "end_offset"; - static final String MODEL_SETTINGS_FIELD = "model_settings"; + public static final String MODEL_SETTINGS_FIELD = "model_settings"; static final String CHUNKING_SETTINGS_FIELD = "chunking_settings"; public record InferenceResult( @@ -110,7 +110,7 @@ static SemanticTextField parse(XContentParser parser, ParserContext context) thr return SEMANTIC_TEXT_FIELD_PARSER.parse(parser, context); } - static MinimalServiceSettings parseModelSettingsFromMap(Object node) { + public static MinimalServiceSettings parseModelSettingsFromMap(Object node) { if (node == null) { return null; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index b718493a37790..8be02061fb371 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -1268,7 +1268,7 @@ static SemanticTextIndexOptions defaultBbqHnswSemanticTextIndexOptions() { ); } - private static boolean canMergeModelSettings(MinimalServiceSettings previous, MinimalServiceSettings current, Conflicts conflicts) { + public static boolean canMergeModelSettings(MinimalServiceSettings previous, MinimalServiceSettings current, Conflicts conflicts) { if (previous != null && current != null && previous.canMergeWith(current)) { return true; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/DefaultSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/DefaultSecretSettings.java index d076c946889ed..745d6f585a137 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/DefaultSecretSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/DefaultSecretSettings.java @@ -36,7 +36,7 @@ public record DefaultSecretSettings(SecureString apiKey) implements SecretSettings, ApiKeySecrets { public static final String NAME = "default_secret_settings"; - static final String API_KEY = "api_key"; + public static final String API_KEY = "api_key"; public static DefaultSecretSettings fromMap(@Nullable Map map) { if (map == null) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java index 5aa42520d74bd..3a14cf6a851ec 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java @@ -16,8 +16,11 @@ import org.elasticsearch.search.fetch.subphase.highlight.Highlighter; import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin; import org.elasticsearch.xpack.core.ssl.SSLService; +import org.elasticsearch.xpack.inference.mock.TestCompletionServiceExtension; import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestRerankingServiceExtension; import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestStreamingCompletionServiceExtension; import java.nio.file.Path; import java.util.Collection; @@ -47,7 +50,10 @@ protected XPackLicenseState getLicenseState() { public List getInferenceServiceFactories() { return List.of( TestSparseInferenceServiceExtension.TestInferenceService::new, - TestDenseInferenceServiceExtension.TestInferenceService::new + TestDenseInferenceServiceExtension.TestInferenceService::new, + TestRerankingServiceExtension.TestInferenceService::new, + TestCompletionServiceExtension.TestInferenceService::new, + TestStreamingCompletionServiceExtension.TestInferenceService::new ); } }; diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml index 51595d40737a3..01c91012beff7 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml @@ -77,89 +77,14 @@ setup: non_inference_field: "non inference test" refresh: true --- -"sparse_embedding changed to text_embedding": - - do: - inference.delete: - inference_id: sparse-inference-id - force: true - - - do: - inference.put: - task_type: text_embedding - inference_id: sparse-inference-id - body: > - { - "service": "text_embedding_test_service", - "service_settings": { - "model": "my_model", - "dimensions": 10, - "api_key": "abc64", - "similarity": "COSINE" - }, - "task_settings": { - } - } - - - do: - catch: bad_request - search: - index: test-sparse-index - body: - query: - semantic: - field: "inference_field" - query: "inference test" - - - match: { error.caused_by.type: "illegal_argument_exception" } - - match: { error.caused_by.reason: "Field [inference_field] expected query inference results to be of type - [text_expansion_result], got [text_embedding_result]. Has the configuration for - inference endpoint [sparse-inference-id] changed?" } - ---- -"text_embedding changed to sparse_embedding": +"create endpoint fails when the inference_id is used by a semantic text field and is incompatible": - do: inference.delete: inference_id: dense-inference-id force: true - - do: - inference.put: - task_type: sparse_embedding - inference_id: dense-inference-id - body: > - { - "service": "test_service", - "service_settings": { - "model": "my_model", - "api_key": "abc64" - }, - "task_settings": { - } - } - - do: catch: bad_request - search: - index: test-dense-index - body: - query: - semantic: - field: "inference_field" - query: "inference test" - - - match: { error.caused_by.type: "illegal_argument_exception" } - - match: { error.caused_by.reason: "Field [inference_field] expected query inference results to be of type - [text_embedding_result], got [text_expansion_result]. Has the configuration for - inference endpoint [dense-inference-id] changed?" } - ---- -"text_embedding dimension count changed": - - do: - inference.delete: - inference_id: dense-inference-id - force: true - - - do: inference.put: task_type: text_embedding inference_id: dense-inference-id @@ -176,17 +101,7 @@ setup: } } - - do: - catch: bad_request - search: - index: test-dense-index - body: - query: - semantic: - field: "inference_field" - query: "inference test" - - - match: { error.caused_by.type: "illegal_argument_exception" } - - match: { error.caused_by.reason: "Field [inference_field] expected query inference results with 10 dimensions, got - 20 dimensions. Has the configuration for inference endpoint [dense-inference-id] - changed?" } + - match: { error.reason: "Inference endpoint [dense-inference-id] could not be created because the inference_id + is being used in mappings with incompatible settings for indices: [test-dense-index]. + Please either use a different inference_id or update the index mappings to refer to a + different inference_id." }