Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/137055.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 137055
summary: Do not create inference endpoint if ID is used in existing mappings
area: Machine Learning
type: bug
issues:
- 124272
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,12 @@ public static Map<String, Set<String>> pipelineIdsByResource(ClusterState state,
}

/**
* @param state Current {@link ClusterState}
* @param metadata Current cluster state {@link Metadata}
* @return a map from Model or Deployment IDs or Aliases to each pipeline referencing them.
*/
public static Set<String> pipelineIdsForResource(ClusterState state, Set<String> ids) {
public static Set<String> pipelineIdsForResource(Metadata metadata, Set<String> ids) {
assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures");
Set<String> pipelineIds = new HashSet<>();
Metadata metadata = state.metadata();
if (metadata == null) {
return pipelineIds;
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,22 @@ static String mockDenseServiceModelConfig() {
""";
}

static String mockDenseServiceModelConfig(int dimensions) {
return Strings.format("""
{
"task_type": "text_embedding",
"service": "text_embedding_test_service",
"service_settings": {
"model": "my_dense_vector_model",
"api_key": "abc64",
"dimensions": %s
},
"task_settings": {
}
}
""", dimensions);
}

static String mockRerankServiceModelConfig() {
return """
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package org.elasticsearch.xpack.inference;

import org.apache.http.util.EntityUtils;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.Strings;
Expand Down Expand Up @@ -211,7 +212,7 @@ public void testDeleteEndpointWhileReferencedBySemanticText() throws IOException
final String endpointId = "endpoint_referenced_by_semantic_text";
final String searchEndpointId = "search_endpoint_referenced_by_semantic_text";
final String indexName = randomAlphaOfLength(10).toLowerCase();
final Function<String, String> buildErrorString = endpointName -> " Inference endpoint "
final Function<String, String> buildErrorString = endpointName -> "Inference endpoint "
+ endpointName
+ " is being used in the mapping for indexes: "
+ Set.of(indexName)
Expand Down Expand Up @@ -303,6 +304,74 @@ public void testDeleteEndpointWhileReferencedBySemanticTextAndPipeline() throws
deleteIndex(indexName);
}

public void testCreateEndpoint_withInferenceIdReferencedBySemanticText() throws IOException {
final String endpointId = "endpoint_referenced_by_semantic_text";
final String otherEndpointId = "other_endpoint_referenced_by_semantic_text";
final String indexName1 = randomAlphaOfLength(10).toLowerCase();
final String indexName2 = randomValueOtherThan(indexName1, () -> randomAlphaOfLength(10).toLowerCase());

putModel(endpointId, mockDenseServiceModelConfig(128), TaskType.TEXT_EMBEDDING);
putModel(otherEndpointId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
// Create two indices, one where the inference ID of the endpoint we'll be deleting and
// recreating is used for inference_id and one where it's used for search_inference_id
putSemanticText(endpointId, otherEndpointId, indexName1);
putSemanticText(otherEndpointId, endpointId, indexName2);

// Confirm that we can create the endpoint with different settings if there
// are documents in the indices which do not use the semantic text field
var request = new Request("PUT", indexName1 + "/_create/1");
request.setJsonEntity("{\"non_inference_field\": \"value\"}");
assertStatusOkOrCreated(client().performRequest(request));

request = new Request("PUT", indexName2 + "/_create/1");
request.setJsonEntity("{\"non_inference_field\": \"value\"}");
assertStatusOkOrCreated(client().performRequest(request));

assertStatusOkOrCreated(client().performRequest(new Request("GET", "_refresh")));

deleteModel(endpointId, "force=true");
putModel(endpointId, mockDenseServiceModelConfig(64), TaskType.TEXT_EMBEDDING);

// Index a document with the semantic text field into each index
request = new Request("PUT", indexName1 + "/_create/2");
request.setJsonEntity("{\"inference_field\": \"value\"}");
assertStatusOkOrCreated(client().performRequest(request));

request = new Request("PUT", indexName2 + "/_create/2");
request.setJsonEntity("{\"inference_field\": \"value\"}");
assertStatusOkOrCreated(client().performRequest(request));

assertStatusOkOrCreated(client().performRequest(new Request("GET", "_refresh")));

deleteModel(endpointId, "force=true");

// Try to create an inference endpoint with the same ID but different dimensions
// from when the document with the semantic text field was indexed
ResponseException responseException = assertThrows(
ResponseException.class,
() -> putModel(endpointId, mockDenseServiceModelConfig(128), TaskType.TEXT_EMBEDDING)
);
assertThat(
responseException.getMessage(),
containsString(
"Inference endpoint ["
+ endpointId
+ "] could not be created because the inference_id is being used in mappings with incompatible settings for indices: ["
)
);
assertThat(responseException.getMessage(), containsString(indexName1));
assertThat(responseException.getMessage(), containsString(indexName2));
assertThat(
responseException.getMessage(),
containsString("Please either use a different inference_id or update the index mappings to refer to a different inference_id.")
);

deleteIndex(indexName1);
deleteIndex(indexName2);

deleteModel(otherEndpointId, "force=true");
}

public void testUnsupportedStream() throws Exception {
String modelId = "streaming";
putModel(modelId, mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING, "streaming_completion_test_service"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public List<Factory> getInferenceServiceFactories() {
}

public static class TestInferenceService extends AbstractTestInferenceService {
private static final String NAME = "completion_test_service";
public static final String NAME = "completion_test_service";
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of(TaskType.COMPLETION);

public TestInferenceService(InferenceServiceFactoryContext context) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,12 @@ public void chunkedInfer(
private TextEmbeddingFloatResults makeResults(List<String> input, ServiceSettings serviceSettings) {
List<TextEmbeddingFloatResults.Embedding> embeddings = new ArrayList<>();
for (String inputString : input) {
List<Float> floatEmbeddings = generateEmbedding(inputString, serviceSettings.dimensions(), serviceSettings.elementType());
List<Float> floatEmbeddings = generateEmbedding(
inputString,
serviceSettings.dimensions(),
serviceSettings.elementType(),
serviceSettings.similarity()
);
embeddings.add(TextEmbeddingFloatResults.Embedding.of(floatEmbeddings));
}
return new TextEmbeddingFloatResults(embeddings);
Expand Down Expand Up @@ -206,7 +211,7 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
* <ul>
* <li>Unique to the input</li>
* <li>Reproducible (i.e given the same input, the same embedding should be generated)</li>
* <li>Valid for the provided element type</li>
* <li>Valid for the provided element type and similarity measure</li>
* </ul>
* <p>
* The embedding is generated by:
Expand All @@ -216,6 +221,7 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
* <li>converting the hash code value to a string</li>
* <li>converting the string to a UTF-8 encoded byte array</li>
* <li>repeatedly appending the byte array to the embedding until the desired number of dimensions are populated</li>
* <li>converting the embedding to a unit vector if the similarity measure requires that</li>
* </ul>
* <p>
* Since the hash code value, when interpreted as a string, is guaranteed to only contain digits and the "-" character, the UTF-8
Expand All @@ -226,11 +232,17 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
* embedding byte.
* </p>
*
* @param input The input string
* @param dimensions The embedding dimension count
* @param input The input string
* @param dimensions The embedding dimension count
* @param similarityMeasure The similarity measure
* @return An embedding
*/
private static List<Float> generateEmbedding(String input, int dimensions, DenseVectorFieldMapper.ElementType elementType) {
private static List<Float> generateEmbedding(
String input,
int dimensions,
DenseVectorFieldMapper.ElementType elementType,
SimilarityMeasure similarityMeasure
) {
int embeddingLength = getEmbeddingLength(elementType, dimensions);
List<Float> embedding = new ArrayList<>(embeddingLength);

Expand All @@ -248,6 +260,9 @@ private static List<Float> generateEmbedding(String input, int dimensions, Dense
if (remainingLength > 0) {
embedding.addAll(embeddingValues.subList(0, remainingLength));
}
if (similarityMeasure == SimilarityMeasure.DOT_PRODUCT) {
embedding = toUnitVector(embedding);
}

return embedding;
}
Expand All @@ -263,6 +278,11 @@ private static int getEmbeddingLength(DenseVectorFieldMapper.ElementType element
};
}

private static List<Float> toUnitVector(List<Float> embedding) {
var magnitude = (float) Math.sqrt(embedding.stream().reduce(0f, (a, b) -> a + (b * b)));
return embedding.stream().map(v -> v / magnitude).toList();
}

public static class Configuration {
public static InferenceServiceConfiguration get() {
return configuration.getOrCompute();
Expand Down Expand Up @@ -303,9 +323,13 @@ public record TestServiceSettings(
public static TestServiceSettings fromMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();

String model = (String) map.remove("model");
String model = (String) map.remove("model_id");

if (model == null) {
validationException.addValidationError("missing model");
model = (String) map.remove("model");
if (model == null) {
validationException.addValidationError("missing model");
}
}

Integer dimensions = (Integer) map.remove("dimensions");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,10 @@ public static TestServiceSettings fromMap(Map<String, Object> map) {
String model = (String) map.remove("model_id");

if (model == null) {
validationException.addValidationError("missing model");
model = (String) map.remove("model");
if (model == null) {
validationException.addValidationError("missing model");
}
}

if (validationException.validationErrors().isEmpty() == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,13 @@ public record TestServiceSettings(String model, String hiddenField, boolean shou
public static TestServiceSettings fromMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();

String model = (String) map.remove("model");
String model = (String) map.remove("model_id");

if (model == null) {
validationException.addValidationError("missing model");
model = (String) map.remove("model");
if (model == null) {
validationException.addValidationError("missing model");
}
}

String hiddenField = (String) map.remove("hidden_field");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public List<Factory> getInferenceServiceFactories() {
}

public static class TestInferenceService extends AbstractTestInferenceService {
private static final String NAME = "streaming_completion_test_service";
public static final String NAME = "streaming_completion_test_service";
private static final String ALIAS = "streaming_completion_test_service_alias";
private static final Set<TaskType> supportedStreamingTasks = Set.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);

Expand Down Expand Up @@ -342,12 +342,15 @@ public TestServiceSettings(StreamInput in) throws IOException {
}

public static TestServiceSettings fromMap(Map<String, Object> map) {
var modelId = map.remove("model").toString();
String modelId = (String) map.remove("model_id");

if (modelId == null) {
ValidationException validationException = new ValidationException();
validationException.addValidationError("missing model id");
throw validationException;
modelId = (String) map.remove("model");
if (modelId == null) {
ValidationException validationException = new ValidationException();
validationException.addValidationError("missing model id");
throw validationException;
}
}

return new TestServiceSettings(modelId);
Expand Down
Loading