From 092d2a2e89a81c25978c4e4b24629444d7a4948b Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Mon, 3 Nov 2025 20:21:44 +0100 Subject: [PATCH 1/8] Implement single pass, similar to ResolveUnionTypes --- .../PushDownVectorSimilarityFunctions.java | 43 ++++++++++++------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushDownVectorSimilarityFunctions.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushDownVectorSimilarityFunctions.java index a5dc214b442c9..3c810f216cc5e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushDownVectorSimilarityFunctions.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushDownVectorSimilarityFunctions.java @@ -18,13 +18,14 @@ import org.elasticsearch.xpack.esql.core.type.FunctionEsField; import org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction; import org.elasticsearch.xpack.esql.optimizer.LocalLogicalOptimizerContext; -import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject; +import org.elasticsearch.xpack.esql.rule.ParameterizedRule; +import org.elasticsearch.xpack.esql.stats.SearchStats; import java.util.ArrayList; import java.util.HashMap; @@ -38,30 +39,44 @@ * the similarity function during value loading, when one side of the function is a literal. * It also adds the new field function attribute to the EsRelation output, and adds a projection after it to remove it from the output. */ -public class PushDownVectorSimilarityFunctions extends OptimizerRules.ParameterizedOptimizerRule< +public class PushDownVectorSimilarityFunctions extends ParameterizedRule< + LogicalPlan, LogicalPlan, LocalLogicalOptimizerContext> { - public PushDownVectorSimilarityFunctions() { - super(OptimizerRules.TransformDirection.DOWN); - } @Override - protected LogicalPlan rule(LogicalPlan plan, LocalLogicalOptimizerContext context) { + public LogicalPlan apply(LogicalPlan plan, LocalLogicalOptimizerContext context) { + Map addedAttrs = new HashMap<>(); + return plan.transformUp(LogicalPlan.class, p -> doRule(p, context.searchStats(), addedAttrs)); + } + + + private LogicalPlan doRule(LogicalPlan plan, SearchStats searchStats, Map addedAttrs) { + // Collect field attributes from previous runs + int originalAddedAttrsSize = addedAttrs.size(); + if (plan instanceof EsRelation rel) { + addedAttrs.clear(); + for (Attribute attr : rel.output()) { + if (attr instanceof FieldAttribute fa && fa.field() instanceof FunctionEsField) { + addedAttrs.put(fa.ignoreId(), fa); + } + } + } + if (plan instanceof Eval || plan instanceof Filter || plan instanceof Aggregate) { - Map addedAttrs = new HashMap<>(); LogicalPlan transformedPlan = plan.transformExpressionsOnly( VectorSimilarityFunction.class, - similarityFunction -> replaceFieldsForFieldTransformations(similarityFunction, addedAttrs, context) + similarityFunction -> replaceFieldsForFieldTransformations(similarityFunction, addedAttrs, searchStats) ); - if (addedAttrs.isEmpty()) { + // No fields were added, return the original plan + if (addedAttrs.size() == originalAddedAttrsSize) { return plan; } List previousAttrs = transformedPlan.output(); - // Transforms EsRelation to extract the new attribute - + // Transforms EsRelation to extract the new attributes List addedAttrsList = addedAttrs.values().stream().toList(); transformedPlan = transformedPlan.transformDown(EsRelation.class, esRelation -> { AttributeSet updatedOutput = esRelation.outputSet().combine(AttributeSet.of(addedAttrsList)); @@ -82,8 +97,7 @@ protected LogicalPlan rule(LogicalPlan plan, LocalLogicalOptimizerContext contex private static Expression replaceFieldsForFieldTransformations( VectorSimilarityFunction similarityFunction, - Map addedAttrs, - LocalLogicalOptimizerContext context + Map addedAttrs, SearchStats searchStats ) { // Only replace if exactly one side is a literal and the other a field attribute if ((similarityFunction.left() instanceof Literal ^ similarityFunction.right() instanceof Literal) == false) { @@ -98,7 +112,7 @@ private static Expression replaceFieldsForFieldTransformations( fieldAttr = fa; } // We can push down also for doc values, requires handling that case on the field mapper - if (fieldAttr == null || context.searchStats().isIndexed(fieldAttr.fieldName()) == false) { + if (fieldAttr == null || searchStats.isIndexed(fieldAttr.fieldName()) == false) { return similarityFunction; } @@ -131,7 +145,6 @@ private static Expression replaceFieldsForFieldTransformations( ); Attribute.IdIgnoringWrapper key = newFunctionAttr.ignoreId(); if (addedAttrs.containsKey(key)) { - ; return addedAttrs.get(key); } From b5b1176d3606cebd39ce598c77089cc0d5f6791e Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 4 Nov 2025 08:23:30 +0100 Subject: [PATCH 2/8] Add test for replacing duplicates in multiple commands --- .../LocalLogicalPlanOptimizerTests.java | 54 +++++++++++-------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java index 7925fe5ef9242..795d9bd7ecc7f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java @@ -1421,20 +1421,20 @@ public void testVectorFunctionsWithDuplicateFunctions() { String query = """ from test_all | eval s1 = v_dot_product(dense_vector, [1.0, 2.0, 3.0]), s2 = v_dot_product(dense_vector, [1.0, 2.0, 3.0]) * 2 / 3 - | eval s3 = v_dot_product(dense_vector, [1.0, 2.0, 3.0]) + 5, r1 = v_dot_product(dense_vector, [4.0, 5.0, 6.0]) + | where v_dot_product(dense_vector, [1.0, 2.0, 3.0]) + 5 + v_dot_product(dense_vector, [4.0, 5.0, 6.0]) > 0 | eval r2 = v_dot_product(dense_vector, [4.0, 5.0, 6.0]) + v_cosine(dense_vector, [4.0, 5.0, 6.0]) - | keep s1, s2, r1, r2 + | keep s1, s2, r2 """; LogicalPlan plan = localPlan(plan(query, allTypesAnalyzer), TEST_SEARCH_STATS); - // EsqlProject[[s1{r}#5, s2{r}#8, r1{r}#14, r2{r}#18]] + // EsqlProject[[s1{r}#5, s2{r}#8, r2{r}#14]] var project = as(plan, EsqlProject.class); - assertThat(Expressions.names(project.projections()), contains("s1", "s2", "r1", "r2")); + assertThat(Expressions.names(project.projections()), contains("s1", "s2", "r2")); - // Eval with s1, s2, r1, r2 + // Eval with s1, s2, r2 var eval = as(project.child(), Eval.class); - assertThat(eval.fields(), hasSize(4)); + assertThat(eval.fields(), hasSize(3)); // Check s1 = $$dense_vector$DotProduct$... var s1Alias = as(eval.fields().getFirst(), Alias.class); @@ -1455,25 +1455,19 @@ public void testVectorFunctionsWithDuplicateFunctions() { var s2FieldAttr = as(s2Mul.left(), FieldAttribute.class); assertThat(s1FieldAttr, is(s2FieldAttr)); - // Check r1 = $$dense_vector$DotProduct$882900992 (vector [4.0, 5.0, 6.0]) - var r1Alias = as(eval.fields().get(2), Alias.class); - assertThat(r1Alias.name(), equalTo("r1")); - var r1FieldAttr = as(r1Alias.child(), FieldAttribute.class); - assertThat(r1FieldAttr.fieldName().string(), equalTo("dense_vector")); - assertThat(r1FieldAttr.name(), startsWith("$$dense_vector$DotProduct")); - var r1Field = as(r1FieldAttr.field(), FunctionEsField.class); - var r1Config = as(r1Field.functionConfig(), DenseVectorFieldMapper.VectorSimilarityFunctionConfig.class); - assertThat(r1Config.similarityFunction(), is(DotProduct.SIMILARITY_FUNCTION)); - assertThat(r1Config.vector(), equalTo(new float[] { 4.0f, 5.0f, 6.0f })); - // Check r2 = $$dense_vector$DotProduct$882900992 + $$dense_vector$CosineSimilarity$882900992 - var r2Alias = as(eval.fields().get(3), Alias.class); + var r2Alias = as(eval.fields().get(2), Alias.class); assertThat(r2Alias.name(), equalTo("r2")); var r2Add = as(r2Alias.child(), Add.class); - // Left side: DotProduct field (same as r1) + // Left side: DotProduct field with vector [4.0, 5.0, 6.0] var r2DotProductFieldAttr = as(r2Add.left(), FieldAttribute.class); - assertThat(r2DotProductFieldAttr, is(r1FieldAttr)); + assertThat(r2DotProductFieldAttr.fieldName().string(), equalTo("dense_vector")); + assertThat(r2DotProductFieldAttr.name(), startsWith("$$dense_vector$DotProduct")); + var r2DotProductField = as(r2DotProductFieldAttr.field(), FunctionEsField.class); + var r2DotProductConfig = as(r2DotProductField.functionConfig(), DenseVectorFieldMapper.VectorSimilarityFunctionConfig.class); + assertThat(r2DotProductConfig.similarityFunction(), is(DotProduct.SIMILARITY_FUNCTION)); + assertThat(r2DotProductConfig.vector(), equalTo(new float[] { 4.0f, 5.0f, 6.0f })); // Right side: CosineSimilarity field var r2CosineFieldAttr = as(r2Add.right(), FieldAttribute.class); @@ -1487,10 +1481,24 @@ public void testVectorFunctionsWithDuplicateFunctions() { // Limit[1000[INTEGER],false,false] var limit = as(eval.child(), Limit.class); - // EsRelation[test_all][!alias_integer, boolean{f}#24, byte{f}#25, constant..] - var esRelation = as(limit.child(), EsRelation.class); + // Filter[$$dense_vector$DotProduct$1606418432 + 5 + $$dense_vector$DotProduct$882900992 > 0] + var filter = as(limit.child(), Filter.class); + var greaterThan = as(filter.condition(), GreaterThan.class); + var filterAdd1 = as(greaterThan.left(), Add.class); + var filterAdd2 = as(filterAdd1.left(), Add.class); + + // Check filter uses s1 field (DotProduct with [1.0, 2.0, 3.0]) + var filterS1FieldAttr = as(filterAdd2.left(), FieldAttribute.class); + assertThat(filterS1FieldAttr, is(s1FieldAttr)); + + // Check filter uses r2's DotProduct field (DotProduct with [4.0, 5.0, 6.0]) + var filterR2FieldAttr = as(filterAdd1.right(), FieldAttribute.class); + assertThat(filterR2FieldAttr, is(r2DotProductFieldAttr)); + + // EsRelation[test_all][!alias_integer, boolean{f}#19, byte{f}#20, constant..] + var esRelation = as(filter.child(), EsRelation.class); assertTrue(esRelation.output().contains(s1FieldAttr)); - assertTrue(esRelation.output().contains(r1FieldAttr)); + assertTrue(esRelation.output().contains(r2DotProductFieldAttr)); assertTrue(esRelation.output().contains(r2CosineFieldAttr)); } From 516b4f97b0cb4aa1c0f5c5f89dbb521e0684a804 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 4 Nov 2025 09:18:55 +0100 Subject: [PATCH 3/8] Calculate name as part of the BlockLoaderFunctionConfig, so it can ensure no duplicate names --- .../index/mapper/MappedFieldType.java | 4 +++- .../mapper/vectors/DenseVectorFieldMapper.java | 11 +++++++++-- .../function/vector/CosineSimilarity.java | 5 +++++ .../expression/function/vector/DotProduct.java | 5 +++++ .../expression/function/vector/Hamming.java | 10 ++++++++++ .../expression/function/vector/L1Norm.java | 5 +++++ .../expression/function/vector/L2Norm.java | 5 +++++ .../PushDownVectorSimilarityFunctions.java | 18 ++++-------------- 8 files changed, 46 insertions(+), 17 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MappedFieldType.java b/server/src/main/java/org/elasticsearch/index/mapper/MappedFieldType.java index 1c5d0cc2dfa15..c35288ab427fd 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MappedFieldType.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MappedFieldType.java @@ -715,6 +715,8 @@ default BlockLoaderFunctionConfig blockLoaderFunctionConfig() { * Is retrievable from the {@link BlockLoaderContext}. The {@link MappedFieldType} can use this configuration to choose the appropriate * implementation for transforming loaded values into blocks. */ - public interface BlockLoaderFunctionConfig {} + public interface BlockLoaderFunctionConfig { + String name(); + } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index a67ffd77e97f2..01d1874b4eb3c 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -3167,6 +3167,8 @@ public interface IntBooleanConsumer { } public interface SimilarityFunction { + String name(); + float calculateSimilarity(float[] leftVector, float[] rightVector); float calculateSimilarity(byte[] leftVector, byte[] rightVector); @@ -3183,9 +3185,10 @@ public static class VectorSimilarityFunctionConfig implements MappedFieldType.Bl private byte[] vectorAsBytes; public VectorSimilarityFunctionConfig(SimilarityFunction similarityFunction, float[] vector) { + Objects.requireNonNull(vector); + assert vector.length > 0 : "vector length must be > 0"; this.similarityFunction = similarityFunction; this.vector = vector; - } /** @@ -3212,6 +3215,10 @@ public SimilarityFunction similarityFunction() { return similarityFunction; } + public String name() { + return similarityFunction.name() + "$" + Arrays.hashCode(vector); + } + @Override public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; @@ -3223,7 +3230,7 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(similarityFunction, Arrays.hashCode(vector), Arrays.hashCode(vectorAsBytes)); + return Objects.hash(similarityFunction, Arrays.hashCode(vector)); } } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarity.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarity.java index d74aefdd0d360..71d33cb70878a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarity.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarity.java @@ -31,6 +31,11 @@ public class CosineSimilarity extends VectorSimilarityFunction { CosineSimilarity::new ); public static final DenseVectorFieldMapper.SimilarityFunction SIMILARITY_FUNCTION = new DenseVectorFieldMapper.SimilarityFunction() { + @Override + public String name() { + return "CosineSimilarity"; + } + @Override public float calculateSimilarity(byte[] leftVector, byte[] rightVector) { return VectorUtil.cosine(leftVector, rightVector); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/DotProduct.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/DotProduct.java index e319b235e5eaa..d077ad777821b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/DotProduct.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/DotProduct.java @@ -32,6 +32,11 @@ public class DotProduct extends VectorSimilarityFunction { ); public static final DenseVectorFieldMapper.SimilarityFunction SIMILARITY_FUNCTION = new DenseVectorFieldMapper.SimilarityFunction() { + @Override + public String name() { + return "DotProduct"; + } + @Override public float calculateSimilarity(byte[] leftVector, byte[] rightVector) { return VectorUtil.dotProduct(leftVector, rightVector); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Hamming.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Hamming.java index 8ebd771221c48..219ddcff0b75f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Hamming.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Hamming.java @@ -27,6 +27,11 @@ public class Hamming extends VectorSimilarityFunction { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Hamming", Hamming::new); public static final DenseVectorFieldMapper.SimilarityFunction SIMILARITY_FUNCTION = new DenseVectorFieldMapper.SimilarityFunction() { + @Override + public String name() { + return "Hamming"; + } + @Override public float calculateSimilarity(byte[] leftVector, byte[] rightVector) { return Hamming.calculateSimilarity(leftVector, rightVector); @@ -39,6 +44,11 @@ public float calculateSimilarity(float[] leftVector, float[] rightVector) { }; public static final DenseVectorFieldMapper.SimilarityFunction EVALUATOR_SIMILARITY_FUNCTION = new DenseVectorFieldMapper.SimilarityFunction() { + @Override + public String name() { + return "Hamming_evaluator"; + } + @Override public float calculateSimilarity(byte[] leftVector, byte[] rightVector) { return Hamming.calculateSimilarity(leftVector, rightVector); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/L1Norm.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/L1Norm.java index 09bf2681f8559..2fddbaa4ec4b4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/L1Norm.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/L1Norm.java @@ -26,6 +26,11 @@ public class L1Norm extends VectorSimilarityFunction { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "L1Norm", L1Norm::new); public static final DenseVectorFieldMapper.SimilarityFunction SIMILARITY_FUNCTION = new DenseVectorFieldMapper.SimilarityFunction() { + @Override + public String name() { + return "L1Norm"; + } + @Override public float calculateSimilarity(byte[] leftVector, byte[] rightVector) { if (leftVector.length != rightVector.length) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/L2Norm.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/L2Norm.java index a6350fc98535c..afbd578d0234b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/L2Norm.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/L2Norm.java @@ -28,6 +28,11 @@ public class L2Norm extends VectorSimilarityFunction { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "L2Norm", L2Norm::new); public static final DenseVectorFieldMapper.SimilarityFunction SIMILARITY_FUNCTION = new DenseVectorFieldMapper.SimilarityFunction() { + @Override + public String name() { + return "L2Norm"; + } + @Override public float calculateSimilarity(byte[] leftVector, byte[] rightVector) { return (float) Math.sqrt(VectorUtil.squareDistance(leftVector, rightVector)); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushDownVectorSimilarityFunctions.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushDownVectorSimilarityFunctions.java index 3c810f216cc5e..7edc56b2798fd 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushDownVectorSimilarityFunctions.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushDownVectorSimilarityFunctions.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.optimizer.rules.logical.local; +import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeSet; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -44,7 +45,6 @@ public class PushDownVectorSimilarityFunctions extends ParameterizedRule< LogicalPlan, LocalLogicalOptimizerContext> { - @Override public LogicalPlan apply(LogicalPlan plan, LocalLogicalOptimizerContext context) { Map addedAttrs = new HashMap<>(); @@ -104,7 +104,6 @@ private static Expression replaceFieldsForFieldTransformations( return similarityFunction; } - Literal literal = (Literal) (similarityFunction.left() instanceof Literal ? similarityFunction.left() : similarityFunction.right()); FieldAttribute fieldAttr = null; if (similarityFunction.left() instanceof FieldAttribute fa) { fieldAttr = fa; @@ -116,23 +115,14 @@ private static Expression replaceFieldsForFieldTransformations( return similarityFunction; } - @SuppressWarnings("unchecked") - List vectorList = (List) literal.value(); - float[] vectorArray = new float[vectorList.size()]; - int arrayHashCode = 0; - for (int i = 0; i < vectorList.size(); i++) { - vectorArray[i] = vectorList.get(i).floatValue(); - arrayHashCode = 31 * arrayHashCode + Float.floatToIntBits(vectorArray[i]); - } - // Change the similarity function to a reference of a transformation on the field + MappedFieldType.BlockLoaderFunctionConfig blockLoaderFunctionConfig = similarityFunction.getBlockLoaderFunctionConfig(); FunctionEsField functionEsField = new FunctionEsField( fieldAttr.field(), similarityFunction.dataType(), - similarityFunction.getBlockLoaderFunctionConfig() + blockLoaderFunctionConfig ); - var name = rawTemporaryName(fieldAttr.name(), similarityFunction.nodeName(), String.valueOf(arrayHashCode)); - // TODO: Check if exists before adding, retrieve the previous one + var name = rawTemporaryName(fieldAttr.name(), blockLoaderFunctionConfig.name()); var newFunctionAttr = new FieldAttribute( fieldAttr.source(), fieldAttr.parentName(), From 9466e317530314eaaca38ea666a9a89bc8b16e66 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 4 Nov 2025 09:34:25 +0100 Subject: [PATCH 4/8] Add javadoc --- .../java/org/elasticsearch/index/mapper/MappedFieldType.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MappedFieldType.java b/server/src/main/java/org/elasticsearch/index/mapper/MappedFieldType.java index c35288ab427fd..815de0157ebed 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MappedFieldType.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MappedFieldType.java @@ -716,6 +716,10 @@ default BlockLoaderFunctionConfig blockLoaderFunctionConfig() { * implementation for transforming loaded values into blocks. */ public interface BlockLoaderFunctionConfig { + /** + * Returns a representable name for this configuration that can be used as part of a pushed down attribute name. + * Configurations that are equal() must return the same name. + */ String name(); } From 207075a4783aacd452e61dd73b2fa3597c47be1e Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 4 Nov 2025 10:06:58 +0100 Subject: [PATCH 5/8] Implement canonicalize() and CanonicalizeVectorSimilarityFunctions --- .../vector/VectorSimilarityFunction.java | 12 ++++ .../esql/optimizer/LogicalPlanOptimizer.java | 2 + ...CanonicalizeVectorSimilarityFunctions.java | 27 +++++++ .../PushDownVectorSimilarityFunctions.java | 70 +++++++++---------- 4 files changed, 74 insertions(+), 37 deletions(-) create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CanonicalizeVectorSimilarityFunctions.java diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java index 94c3c54ec29fb..c4222079b5000 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java @@ -220,6 +220,18 @@ public MappedFieldType.BlockLoaderFunctionConfig getBlockLoaderFunctionConfig() return new DenseVectorFieldMapper.VectorSimilarityFunctionConfig(getSimilarityFunction(), vector); } + @Override + protected Expression canonicalize() { + VectorSimilarityFunction canonical = (VectorSimilarityFunction) super.canonicalize(); + + // Set literals to the right + if (canonical.left() instanceof Literal && canonical.right() instanceof Literal == false) { + return canonical.replaceChildren(right(), left()); + } + + return canonical; + } + interface VectorValueProvider extends Releasable { void eval(Page page); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java index 2ff42cbe8e4ef..cec17f3f4f4b7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java @@ -13,6 +13,7 @@ import org.elasticsearch.xpack.esql.optimizer.rules.PruneInlineJoinOnEmptyRightSide; import org.elasticsearch.xpack.esql.optimizer.rules.logical.BooleanFunctionEqualsElimination; import org.elasticsearch.xpack.esql.optimizer.rules.logical.BooleanSimplification; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.CanonicalizeVectorSimilarityFunctions; import org.elasticsearch.xpack.esql.optimizer.rules.logical.CombineBinaryComparisons; import org.elasticsearch.xpack.esql.optimizer.rules.logical.CombineDisjunctions; import org.elasticsearch.xpack.esql.optimizer.rules.logical.CombineEvals; @@ -182,6 +183,7 @@ protected static Batch operators() { // boolean new BooleanSimplification(), new LiteralsOnTheRight(), + new CanonicalizeVectorSimilarityFunctions(), // needs to occur before BinaryComparison combinations (see class) new PropagateEquals(), new PropagateNullable(), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CanonicalizeVectorSimilarityFunctions.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CanonicalizeVectorSimilarityFunctions.java new file mode 100644 index 0000000000000..647de72d626ac --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CanonicalizeVectorSimilarityFunctions.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules.logical; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; + +/** + * Ensures that vector similarity functions are in their canonical form, with literals to the right. + */ +public class CanonicalizeVectorSimilarityFunctions extends OptimizerRules.OptimizerExpressionRule { + + public CanonicalizeVectorSimilarityFunctions() { + super(OptimizerRules.TransformDirection.UP); + } + + @Override + protected Expression rule(VectorSimilarityFunction vectorSimilarityFunction, LogicalOptimizerContext ctx) { + return vectorSimilarityFunction.canonical(); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushDownVectorSimilarityFunctions.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushDownVectorSimilarityFunctions.java index 7edc56b2798fd..276a6db49f247 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushDownVectorSimilarityFunctions.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushDownVectorSimilarityFunctions.java @@ -99,46 +99,42 @@ private static Expression replaceFieldsForFieldTransformations( VectorSimilarityFunction similarityFunction, Map addedAttrs, SearchStats searchStats ) { - // Only replace if exactly one side is a literal and the other a field attribute - if ((similarityFunction.left() instanceof Literal ^ similarityFunction.right() instanceof Literal) == false) { - return similarityFunction; - } + // Only replace if it consists of a literal and the other a field attribute. + // CanonicalizeVectorSimilarityFunctions ensures that if there is a literal, it will be on the right side. + if (similarityFunction.left() instanceof FieldAttribute fieldAttr && similarityFunction.right() instanceof Literal) { - FieldAttribute fieldAttr = null; - if (similarityFunction.left() instanceof FieldAttribute fa) { - fieldAttr = fa; - } else if (similarityFunction.right() instanceof FieldAttribute fa) { - fieldAttr = fa; - } - // We can push down also for doc values, requires handling that case on the field mapper - if (fieldAttr == null || searchStats.isIndexed(fieldAttr.fieldName()) == false) { - return similarityFunction; - } + // We can push down also for doc values, requires handling that case on the field mapper + if (searchStats.isIndexed(fieldAttr.fieldName()) == false) { + return similarityFunction; + } + + // Change the similarity function to a reference of a transformation on the field + MappedFieldType.BlockLoaderFunctionConfig blockLoaderFunctionConfig = similarityFunction.getBlockLoaderFunctionConfig(); + FunctionEsField functionEsField = new FunctionEsField( + fieldAttr.field(), + similarityFunction.dataType(), + blockLoaderFunctionConfig + ); + var name = rawTemporaryName(fieldAttr.name(), blockLoaderFunctionConfig.name()); + var newFunctionAttr = new FieldAttribute( + fieldAttr.source(), + fieldAttr.parentName(), + fieldAttr.qualifier(), + name, + functionEsField, + fieldAttr.nullable(), + new NameId(), + true + ); + Attribute.IdIgnoringWrapper key = newFunctionAttr.ignoreId(); + if (addedAttrs.containsKey(key)) { + return addedAttrs.get(key); + } - // Change the similarity function to a reference of a transformation on the field - MappedFieldType.BlockLoaderFunctionConfig blockLoaderFunctionConfig = similarityFunction.getBlockLoaderFunctionConfig(); - FunctionEsField functionEsField = new FunctionEsField( - fieldAttr.field(), - similarityFunction.dataType(), - blockLoaderFunctionConfig - ); - var name = rawTemporaryName(fieldAttr.name(), blockLoaderFunctionConfig.name()); - var newFunctionAttr = new FieldAttribute( - fieldAttr.source(), - fieldAttr.parentName(), - fieldAttr.qualifier(), - name, - functionEsField, - fieldAttr.nullable(), - new NameId(), - true - ); - Attribute.IdIgnoringWrapper key = newFunctionAttr.ignoreId(); - if (addedAttrs.containsKey(key)) { - return addedAttrs.get(key); + addedAttrs.put(key, newFunctionAttr); + return newFunctionAttr; } - addedAttrs.put(key, newFunctionAttr); - return newFunctionAttr; + return similarityFunction; } } From d15493131808eb19dde884ad6cdb04987a6aa26d Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 4 Nov 2025 11:26:06 +0100 Subject: [PATCH 6/8] Add randomized testing --- .../LocalLogicalPlanOptimizerTests.java | 255 +++++++++++------- 1 file changed, 165 insertions(+), 90 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java index 795d9bd7ecc7f..d77010af0bae5 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java @@ -36,6 +36,7 @@ import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.expression.Order; import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; import org.elasticsearch.xpack.esql.expression.function.aggregate.Min; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; @@ -44,8 +45,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.RLikeList; import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLike; import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLikeList; -import org.elasticsearch.xpack.esql.expression.function.vector.CosineSimilarity; -import org.elasticsearch.xpack.esql.expression.function.vector.DotProduct; +import org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction; import org.elasticsearch.xpack.esql.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; @@ -76,10 +76,10 @@ import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; import org.elasticsearch.xpack.esql.rule.RuleExecutor; import org.elasticsearch.xpack.esql.stats.SearchStats; -import org.hamcrest.Matchers; import org.junit.BeforeClass; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Locale; import java.util.Map; @@ -115,6 +115,7 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; @@ -223,7 +224,7 @@ public void testMissingFieldInProject() { var alias = as(eval.fields().get(0), Alias.class); var literal = as(alias.child(), Literal.class); assertThat(literal.value(), is(nullValue())); - assertThat(literal.dataType(), is(DataType.KEYWORD)); + assertThat(literal.dataType(), is(KEYWORD)); var limit = as(eval.child(), Limit.class); var source = as(limit.child(), EsRelation.class); @@ -256,7 +257,7 @@ public void testReassignedMissingFieldInProject() { var alias = as(eval.fields().get(0), Alias.class); var literal = as(alias.child(), Literal.class); assertThat(literal.value(), is(new BytesRef("foo"))); - assertThat(literal.dataType(), is(DataType.KEYWORD)); + assertThat(literal.dataType(), is(KEYWORD)); var limit = as(eval.child(), Limit.class); var source = as(limit.child(), EsRelation.class); @@ -322,7 +323,7 @@ public void testMissingFieldInMvExpand() { assertEquals(eval.fields().size(), 1); var lastName = eval.fields().get(0); assertEquals(lastName.name(), "last_name"); - assertEquals(lastName.child(), new Literal(EMPTY, null, DataType.KEYWORD)); + assertEquals(lastName.child(), new Literal(EMPTY, null, KEYWORD)); var limit2 = asLimit(eval.child(), 1000, false); var relation = as(limit2.child(), EsRelation.class); assertThat(Expressions.names(relation.output()), not(contains("last_name"))); @@ -380,7 +381,7 @@ public void testMissingFieldInNewCommand() { new FieldAttribute( EMPTY, "last_name", - new EsField("last_name", DataType.KEYWORD, Map.of(), true, EsField.TimeSeriesFieldType.NONE) + new EsField("last_name", KEYWORD, Map.of(), true, EsField.TimeSeriesFieldType.NONE) ) ), testStats @@ -414,7 +415,7 @@ public void testMissingFieldInNewCommand() { assertThat(Expressions.names(eval.fields()), contains("last_name")); var literal = as(eval.fields().get(0), Alias.class); - assertEquals(literal.child(), new Literal(EMPTY, null, DataType.KEYWORD)); + assertEquals(literal.child(), new Literal(EMPTY, null, KEYWORD)); assertThat(Expressions.names(relation.output()), not(contains("last_name"))); assertEquals(Expressions.names(initialRelation.output()), Expressions.names(project.output())); @@ -445,7 +446,7 @@ public void testMissingFieldInEval() { var alias = as(eval.fields().get(0), Alias.class); var literal = as(alias.child(), Literal.class); assertThat(literal.value(), is(nullValue())); - assertThat(literal.dataType(), is(DataType.INTEGER)); + assertThat(literal.dataType(), is(INTEGER)); var limit = as(eval.child(), Limit.class); var source = as(limit.child(), EsRelation.class); @@ -535,7 +536,7 @@ public void testSparseDocument() throws Exception { Map large = Maps.newLinkedHashMapWithExpectedSize(size); for (int i = 0; i < size; i++) { var name = String.format(Locale.ROOT, "field%03d", i); - large.put(name, new EsField(name, DataType.INTEGER, emptyMap(), true, false, EsField.TimeSeriesFieldType.NONE)); + large.put(name, new EsField(name, INTEGER, emptyMap(), true, false, EsField.TimeSeriesFieldType.NONE)); } SearchStats searchStats = statsForExistingField("field000", "field001", "field002", "field003", "field004"); @@ -568,7 +569,7 @@ public void testSparseDocument() throws Exception { var eval = as(project.child(), Eval.class); var field = eval.fields().get(0); assertThat(Expressions.name(field), is("field005")); - assertThat(Alias.unwrap(field).fold(FoldContext.small()), Matchers.nullValue()); + assertThat(Alias.unwrap(field).fold(FoldContext.small()), nullValue()); } // InferIsNotNull @@ -912,11 +913,11 @@ public void testGroupingByMissingFields() { Alias eval1 = eval.fields().get(0); Literal literal1 = as(eval1.child(), Literal.class); assertNull(literal1.value()); - assertThat(literal1.dataType(), is(DataType.KEYWORD)); + assertThat(literal1.dataType(), is(KEYWORD)); Alias eval2 = eval.fields().get(1); Literal literal2 = as(eval2.child(), Literal.class); assertNull(literal2.value()); - assertThat(literal2.dataType(), is(DataType.KEYWORD)); + assertThat(literal2.dataType(), is(KEYWORD)); assertThat(grouping1.id(), equalTo(eval1.id())); assertThat(grouping2.id(), equalTo(eval2.id())); as(eval.child(), EsRelation.class); @@ -987,8 +988,8 @@ protected LogicalPlan rule(Aggregate plan, LocalLogicalOptimizerContext context) // We only want to apply it once, so we use a static counter if (appliedCount.get() == 0) { appliedCount.set(appliedCount.get() + 1); - Literal additionalLiteral = new Literal(Source.EMPTY, "additional literal", INTEGER); - return new Eval(plan.source(), plan, List.of(new Alias(Source.EMPTY, "additionalAttribute", additionalLiteral))); + Literal additionalLiteral = new Literal(EMPTY, "additional literal", INTEGER); + return new Eval(plan.source(), plan, List.of(new Alias(EMPTY, "additionalAttribute", additionalLiteral))); } return plan; } @@ -1108,17 +1109,18 @@ public void testPruneLeftJoinOnNullMatchingFieldAndShadowingAttributes() { */ public void testVectorFunctionsReplaced() { assumeTrue("requires similarity functions", EsqlCapabilities.Cap.VECTOR_SIMILARITY_FUNCTIONS_PUSHDOWN.isEnabled()); - String query = """ + SimilarityFunctionTestCase testCase = SimilarityFunctionTestCase.random("dense_vector"); + String query = String.format(Locale.ROOT, """ from test_all - | eval s = v_dot_product(dense_vector, [1.0, 2.0, 3.0]) - """; + | eval s = %s + """, testCase.toQuery()); LogicalPlan plan = localPlan(plan(query, allTypesAnalyzer), TEST_SEARCH_STATS); // EsqlProject[[!alias_integer, boolean{f}#7, byte{f}#8, ... s{r}#5]] var project = as(plan, EsqlProject.class); // Does not contain the extracted field - assertFalse(Expressions.names(project.projections()).stream().anyMatch(s -> s.startsWith("$$dense_vector$DotProduct"))); + assertFalse(Expressions.names(project.projections()).stream().anyMatch(s -> s.startsWith(testCase.toFieldAttrName()))); // Eval[[$$dense_vector$DOTPRODUCT$27{f}#27 AS s#5]] var eval = as(project.child(), Eval.class); @@ -1129,11 +1131,11 @@ public void testVectorFunctionsReplaced() { // Check replaced field attribute FieldAttribute fieldAttr = (FieldAttribute) alias.child(); assertThat(fieldAttr.fieldName().string(), equalTo("dense_vector")); - assertThat(fieldAttr.name(), startsWith("$$dense_vector$DotProduct")); + assertThat(fieldAttr.name(), startsWith(testCase.toFieldAttrName())); var field = as(fieldAttr.field(), FunctionEsField.class); var blockLoaderFunctionConfig = as(field.functionConfig(), DenseVectorFieldMapper.VectorSimilarityFunctionConfig.class); - assertThat(blockLoaderFunctionConfig.similarityFunction(), is(DotProduct.SIMILARITY_FUNCTION)); - assertThat(blockLoaderFunctionConfig.vector(), equalTo(new float[] { 1.0f, 2.0f, 3.0f })); + assertThat(blockLoaderFunctionConfig.similarityFunction(), instanceOf(DenseVectorFieldMapper.SimilarityFunction.class)); + assertThat(blockLoaderFunctionConfig.vector(), equalTo(testCase.vector())); // Limit[1000[INTEGER],false,false] var limit = as(eval.child(), Limit.class); @@ -1153,13 +1155,14 @@ public void testVectorFunctionsReplaced() { */ public void testVectorFunctionsReplacedWithTopN() { assumeTrue("requires similarity functions", EsqlCapabilities.Cap.VECTOR_SIMILARITY_FUNCTIONS_PUSHDOWN.isEnabled()); - String query = """ + SimilarityFunctionTestCase testCase = SimilarityFunctionTestCase.random("dense_vector"); + String query = String.format(Locale.ROOT, """ from test_all - | eval s = v_dot_product(dense_vector, [1.0, 2.0, 3.0]) + | eval s = %s | sort s desc | limit 1 | keep s - """; + """, testCase.toQuery()); LogicalPlan plan = localPlan(plan(query, allTypesAnalyzer), TEST_SEARCH_STATS); @@ -1185,11 +1188,11 @@ public void testVectorFunctionsReplacedWithTopN() { // Check replaced field attribute FieldAttribute fieldAttr = (FieldAttribute) alias.child(); assertThat(fieldAttr.fieldName().string(), equalTo("dense_vector")); - assertThat(fieldAttr.name(), startsWith("$$dense_vector$DotProduct")); + assertThat(fieldAttr.name(), startsWith(testCase.toFieldAttrName())); var field = as(fieldAttr.field(), FunctionEsField.class); var blockLoaderFunctionConfig = as(field.functionConfig(), DenseVectorFieldMapper.VectorSimilarityFunctionConfig.class); - assertThat(blockLoaderFunctionConfig.similarityFunction(), is(DotProduct.SIMILARITY_FUNCTION)); - assertThat(blockLoaderFunctionConfig.vector(), equalTo(new float[] { 1.0f, 2.0f, 3.0f })); + assertThat(blockLoaderFunctionConfig.similarityFunction(), instanceOf(DenseVectorFieldMapper.SimilarityFunction.class)); + assertThat(blockLoaderFunctionConfig.vector(), equalTo(testCase.vector())); // EsRelation[types] var esRelation = as(eval.child(), EsRelation.class); @@ -1198,13 +1201,14 @@ public void testVectorFunctionsReplacedWithTopN() { public void testVectorFunctionsNotPushedDownWhenNotIndexed() { assumeTrue("requires similarity functions", EsqlCapabilities.Cap.VECTOR_SIMILARITY_FUNCTIONS_PUSHDOWN.isEnabled()); - String query = """ + SimilarityFunctionTestCase testCase = SimilarityFunctionTestCase.random("dense_vector"); + String query = String.format(Locale.ROOT, """ from test_all - | eval s = v_dot_product(dense_vector, [1.0, 2.0, 3.0]) + | eval s = %s | sort s desc | limit 1 | keep s - """; + """, testCase.toQuery()); LogicalPlan plan = localPlan(plan(query, allTypesAnalyzer), new EsqlTestUtils.TestSearchStats() { @Override @@ -1225,8 +1229,8 @@ public boolean isIndexed(FieldAttribute.FieldName field) { var alias = as(eval.fields().getFirst(), Alias.class); assertThat(alias.name(), equalTo("s")); - // Check similarly function field attribute is NOT replaced - as(alias.child(), DotProduct.class); + // Check similarity function field attribute is NOT replaced + as(alias.child(), VectorSimilarityFunction.class); // EsRelation does not contain a FunctionEsField var esRelation = as(eval.child(), EsRelation.class); @@ -1239,13 +1243,14 @@ public boolean isIndexed(FieldAttribute.FieldName field) { public void testVectorFunctionsWhenFieldMissing() { assumeTrue("requires similarity functions", EsqlCapabilities.Cap.VECTOR_SIMILARITY_FUNCTIONS_PUSHDOWN.isEnabled()); - String query = """ + SimilarityFunctionTestCase testCase = SimilarityFunctionTestCase.random("dense_vector"); + String query = String.format(Locale.ROOT, """ from test_all - | eval s = v_dot_product(dense_vector, [1.0, 2.0, 3.0]) + | eval s = %s | sort s desc | limit 1 | keep s - """; + """, testCase.toQuery()); LogicalPlan plan = localPlan(plan(query, allTypesAnalyzer), new EsqlTestUtils.TestSearchStats() { @Override @@ -1281,11 +1286,12 @@ public boolean exists(FieldAttribute.FieldName field) { public void testVectorFunctionsInWhere() { assumeTrue("requires similarity functions", EsqlCapabilities.Cap.VECTOR_SIMILARITY_FUNCTIONS_PUSHDOWN.isEnabled()); - String query = """ + SimilarityFunctionTestCase testCase = SimilarityFunctionTestCase.random("dense_vector"); + String query = String.format(Locale.ROOT, """ from test_all - | where v_dot_product(dense_vector, [1.0, 2.0, 3.0]) > 0.5 + | where %s > 0.5 | keep dense_vector - """; + """, testCase.toQuery()); LogicalPlan plan = localPlan(plan(query, allTypesAnalyzer), TEST_SEARCH_STATS); @@ -1300,11 +1306,11 @@ public void testVectorFunctionsInWhere() { // Check left side is the replaced field attribute var fieldAttr = as(greaterThan.left(), FieldAttribute.class); assertThat(fieldAttr.fieldName().string(), equalTo("dense_vector")); - assertThat(fieldAttr.name(), startsWith("$$dense_vector$DotProduct")); + assertThat(fieldAttr.name(), startsWith(testCase.toFieldAttrName())); var field = as(fieldAttr.field(), FunctionEsField.class); var blockLoaderFunctionConfig = as(field.functionConfig(), DenseVectorFieldMapper.VectorSimilarityFunctionConfig.class); - assertThat(blockLoaderFunctionConfig.similarityFunction(), is(DotProduct.SIMILARITY_FUNCTION)); - assertThat(blockLoaderFunctionConfig.vector(), equalTo(new float[] { 1.0f, 2.0f, 3.0f })); + assertThat(blockLoaderFunctionConfig.similarityFunction(), instanceOf(DenseVectorFieldMapper.SimilarityFunction.class)); + assertThat(blockLoaderFunctionConfig.vector(), equalTo(testCase.vector())); // Check right side is 0.5 var literal = as(greaterThan.right(), Literal.class); @@ -1319,10 +1325,11 @@ public void testVectorFunctionsInWhere() { public void testVectorFunctionsInStats() { assumeTrue("requires similarity functions", EsqlCapabilities.Cap.VECTOR_SIMILARITY_FUNCTIONS_PUSHDOWN.isEnabled()); - String query = """ + SimilarityFunctionTestCase testCase = SimilarityFunctionTestCase.random("dense_vector"); + String query = String.format(Locale.ROOT, """ from test_all - | stats count(*) where v_dot_product(dense_vector, [1.0, 2.0, 3.0]) > 0.5 - """; + | stats count(*) where %s > 0.5 + """, testCase.toQuery()); LogicalPlan plan = localPlan(plan(query, allTypesAnalyzer), TEST_SEARCH_STATS); @@ -1337,7 +1344,7 @@ public void testVectorFunctionsInStats() { // Check the Count aggregate with filter var countAlias = as(aggregate.aggregates().getFirst(), Alias.class); - var count = as(countAlias.child(), org.elasticsearch.xpack.esql.expression.function.aggregate.Count.class); + var count = as(countAlias.child(), Count.class); // Check the filter on the Count aggregate assertThat(count.filter(), equalTo(Literal.TRUE)); @@ -1349,11 +1356,11 @@ public void testVectorFunctionsInStats() { // Check left side is the replaced field attribute var fieldAttr = as(filterCondition.left(), FieldAttribute.class); assertThat(fieldAttr.fieldName().string(), equalTo("dense_vector")); - assertThat(fieldAttr.name(), startsWith("$$dense_vector$DotProduct")); + assertThat(fieldAttr.name(), startsWith(testCase.toFieldAttrName())); var field = as(fieldAttr.field(), FunctionEsField.class); var blockLoaderFunctionConfig = as(field.functionConfig(), DenseVectorFieldMapper.VectorSimilarityFunctionConfig.class); - assertThat(blockLoaderFunctionConfig.similarityFunction(), is(DotProduct.SIMILARITY_FUNCTION)); - assertThat(blockLoaderFunctionConfig.vector(), equalTo(new float[] { 1.0f, 2.0f, 3.0f })); + assertThat(blockLoaderFunctionConfig.similarityFunction(), instanceOf(DenseVectorFieldMapper.SimilarityFunction.class)); + assertThat(blockLoaderFunctionConfig.vector(), equalTo(testCase.vector())); // Verify the filter condition matches the aggregate filter var filterFieldAttr = as(filterCondition.left(), FieldAttribute.class); @@ -1366,14 +1373,15 @@ public void testVectorFunctionsInStats() { public void testVectorFunctionsUpdateIntermediateProjections() { assumeTrue("requires similarity functions", EsqlCapabilities.Cap.VECTOR_SIMILARITY_FUNCTIONS_PUSHDOWN.isEnabled()); - String query = """ + SimilarityFunctionTestCase testCase = SimilarityFunctionTestCase.random("dense_vector"); + String query = String.format(Locale.ROOT, """ from test_all | keep * | mv_expand keyword - | eval similarity = v_cosine(dense_vector, [0, 255, 255]) + | eval similarity = %s | sort similarity desc, keyword asc | limit 1 - """; + """, testCase.toQuery()); LogicalPlan plan = localPlan(plan(query, allTypesAnalyzer), TEST_SEARCH_STATS); @@ -1392,11 +1400,11 @@ public void testVectorFunctionsUpdateIntermediateProjections() { // Check replaced field attribute var fieldAttr = as(alias.child(), FieldAttribute.class); assertThat(fieldAttr.fieldName().string(), equalTo("dense_vector")); - assertThat(fieldAttr.name(), startsWith("$$dense_vector$CosineSimilarity")); + assertThat(fieldAttr.name(), startsWith(testCase.toFieldAttrName())); var field = as(fieldAttr.field(), FunctionEsField.class); var blockLoaderFunctionConfig = as(field.functionConfig(), DenseVectorFieldMapper.VectorSimilarityFunctionConfig.class); - assertThat(blockLoaderFunctionConfig.similarityFunction(), is(CosineSimilarity.SIMILARITY_FUNCTION)); - assertThat(blockLoaderFunctionConfig.vector(), equalTo(new float[] { 0.0f, 255.0f, 255.0f })); + assertThat(blockLoaderFunctionConfig.similarityFunction(), instanceOf(DenseVectorFieldMapper.SimilarityFunction.class)); + assertThat(blockLoaderFunctionConfig.vector(), equalTo(testCase.vector())); // MvExpand[keyword{f}#23,keyword{r}#32] var mvExpand = as(eval.child(), MvExpand.class); @@ -1408,7 +1416,7 @@ public void testVectorFunctionsUpdateIntermediateProjections() { assertTrue( innerProject.projections() .stream() - .anyMatch(p -> (p instanceof FieldAttribute fa) && fa.name().startsWith("$$dense_vector$CosineSimilarity")) + .anyMatch(p -> (p instanceof FieldAttribute fa) && fa.name().startsWith(testCase.toFieldAttrName())) ); // EsRelation[test_all][$$dense_vector$CosineSimilarity$33{f}#33, !alias_in..] @@ -1418,13 +1426,30 @@ public void testVectorFunctionsUpdateIntermediateProjections() { public void testVectorFunctionsWithDuplicateFunctions() { assumeTrue("requires similarity functions", EsqlCapabilities.Cap.VECTOR_SIMILARITY_FUNCTIONS_PUSHDOWN.isEnabled()); - String query = """ + // Generate two random test cases - one for duplicate usage, one for the second set + SimilarityFunctionTestCase testCase1 = SimilarityFunctionTestCase.random("dense_vector"); + SimilarityFunctionTestCase testCase2 = randomValueOtherThan(testCase1, () -> SimilarityFunctionTestCase.random("dense_vector")); + SimilarityFunctionTestCase testCase3 = randomValueOtherThanMany( + tc -> (tc.equals(testCase1) || tc.equals(testCase2)), + () -> SimilarityFunctionTestCase.random("dense_vector") + ); + + String query = String.format( + Locale.ROOT, + """ from test_all - | eval s1 = v_dot_product(dense_vector, [1.0, 2.0, 3.0]), s2 = v_dot_product(dense_vector, [1.0, 2.0, 3.0]) * 2 / 3 - | where v_dot_product(dense_vector, [1.0, 2.0, 3.0]) + 5 + v_dot_product(dense_vector, [4.0, 5.0, 6.0]) > 0 - | eval r2 = v_dot_product(dense_vector, [4.0, 5.0, 6.0]) + v_cosine(dense_vector, [4.0, 5.0, 6.0]) + | eval s1 = %s, s2 = %s * 2 / 3 + | where %s + 5 + %s > 0 + | eval r2 = %s + %s | keep s1, s2, r2 - """; + """, + testCase1.toQuery(), + testCase1.toQuery(), + testCase1.toQuery(), + testCase2.toQuery(), + testCase2.toQuery(), + testCase3.toQuery() + ); LogicalPlan plan = localPlan(plan(query, allTypesAnalyzer), TEST_SEARCH_STATS); @@ -1436,18 +1461,18 @@ public void testVectorFunctionsWithDuplicateFunctions() { var eval = as(project.child(), Eval.class); assertThat(eval.fields(), hasSize(3)); - // Check s1 = $$dense_vector$DotProduct$... + // Check s1 = $$dense_vector$Function1$... var s1Alias = as(eval.fields().getFirst(), Alias.class); assertThat(s1Alias.name(), equalTo("s1")); var s1FieldAttr = as(s1Alias.child(), FieldAttribute.class); assertThat(s1FieldAttr.fieldName().string(), equalTo("dense_vector")); - assertThat(s1FieldAttr.name(), startsWith("$$dense_vector$DotProduct")); + assertThat(s1FieldAttr.name(), startsWith(testCase1.toFieldAttrName())); var s1Field = as(s1FieldAttr.field(), FunctionEsField.class); var s1Config = as(s1Field.functionConfig(), DenseVectorFieldMapper.VectorSimilarityFunctionConfig.class); - assertThat(s1Config.similarityFunction(), is(DotProduct.SIMILARITY_FUNCTION)); - assertThat(s1Config.vector(), equalTo(new float[] { 1.0f, 2.0f, 3.0f })); + assertThat(s1Config.similarityFunction(), instanceOf(DenseVectorFieldMapper.SimilarityFunction.class)); + assertThat(s1Config.vector(), equalTo(testCase1.vector())); - // Check s2 = $$dense_vector$DotProduct$1606418432 * 2 / 3 (same field as s1) + // Check s2 = $$dense_vector$Function1$ * 2 / 3 (same field as s1) var s2Alias = as(eval.fields().get(1), Alias.class); assertThat(s2Alias.name(), equalTo("s2")); var s2Div = as(s2Alias.child(), Div.class); @@ -1455,51 +1480,101 @@ public void testVectorFunctionsWithDuplicateFunctions() { var s2FieldAttr = as(s2Mul.left(), FieldAttribute.class); assertThat(s1FieldAttr, is(s2FieldAttr)); - // Check r2 = $$dense_vector$DotProduct$882900992 + $$dense_vector$CosineSimilarity$882900992 + // Check r2 = $$dense_vector$Function2$ + $$dense_vector$Function2$ (deduplicated to same field) var r2Alias = as(eval.fields().get(2), Alias.class); assertThat(r2Alias.name(), equalTo("r2")); var r2Add = as(r2Alias.child(), Add.class); - // Left side: DotProduct field with vector [4.0, 5.0, 6.0] - var r2DotProductFieldAttr = as(r2Add.left(), FieldAttribute.class); - assertThat(r2DotProductFieldAttr.fieldName().string(), equalTo("dense_vector")); - assertThat(r2DotProductFieldAttr.name(), startsWith("$$dense_vector$DotProduct")); - var r2DotProductField = as(r2DotProductFieldAttr.field(), FunctionEsField.class); - var r2DotProductConfig = as(r2DotProductField.functionConfig(), DenseVectorFieldMapper.VectorSimilarityFunctionConfig.class); - assertThat(r2DotProductConfig.similarityFunction(), is(DotProduct.SIMILARITY_FUNCTION)); - assertThat(r2DotProductConfig.vector(), equalTo(new float[] { 4.0f, 5.0f, 6.0f })); - - // Right side: CosineSimilarity field - var r2CosineFieldAttr = as(r2Add.right(), FieldAttribute.class); - assertThat(r2CosineFieldAttr.fieldName().string(), equalTo("dense_vector")); - assertThat(r2CosineFieldAttr.name(), startsWith("$$dense_vector$CosineSimilarity")); - var r2CosineField = as(r2CosineFieldAttr.field(), FunctionEsField.class); - var r2CosineConfig = as(r2CosineField.functionConfig(), DenseVectorFieldMapper.VectorSimilarityFunctionConfig.class); - assertThat(r2CosineConfig.similarityFunction(), is(CosineSimilarity.SIMILARITY_FUNCTION)); - assertThat(r2CosineConfig.vector(), equalTo(new float[] { 4.0f, 5.0f, 6.0f })); + // Both sides should be the same function/vector (testCase2) + var r2LeftFieldAttr = as(r2Add.left(), FieldAttribute.class); + assertThat(r2LeftFieldAttr.fieldName().string(), equalTo("dense_vector")); + assertThat(r2LeftFieldAttr.name(), startsWith(testCase2.toFieldAttrName())); + var r2LeftField = as(r2LeftFieldAttr.field(), FunctionEsField.class); + var r2LeftConfig = as(r2LeftField.functionConfig(), DenseVectorFieldMapper.VectorSimilarityFunctionConfig.class); + assertThat(r2LeftConfig.similarityFunction(), instanceOf(DenseVectorFieldMapper.SimilarityFunction.class)); + assertThat(r2LeftConfig.vector(), equalTo(testCase2.vector())); + + var r2RightFieldAttr = as(r2Add.right(), FieldAttribute.class); + // Both sides should not be deduplicated to the same field attribute + assertThat(r2RightFieldAttr, not(is(r2LeftFieldAttr))); // Limit[1000[INTEGER],false,false] var limit = as(eval.child(), Limit.class); - // Filter[$$dense_vector$DotProduct$1606418432 + 5 + $$dense_vector$DotProduct$882900992 > 0] + // Filter[$$dense_vector$Function1$ + 5 + $$dense_vector$Function2$ > 0] var filter = as(limit.child(), Filter.class); var greaterThan = as(filter.condition(), GreaterThan.class); var filterAdd1 = as(greaterThan.left(), Add.class); var filterAdd2 = as(filterAdd1.left(), Add.class); - // Check filter uses s1 field (DotProduct with [1.0, 2.0, 3.0]) + // Check filter uses s1 field (testCase1) var filterS1FieldAttr = as(filterAdd2.left(), FieldAttribute.class); assertThat(filterS1FieldAttr, is(s1FieldAttr)); - // Check filter uses r2's DotProduct field (DotProduct with [4.0, 5.0, 6.0]) + // Check filter uses r2's field (testCase2) var filterR2FieldAttr = as(filterAdd1.right(), FieldAttribute.class); - assertThat(filterR2FieldAttr, is(r2DotProductFieldAttr)); + assertThat(filterR2FieldAttr, is(r2LeftFieldAttr)); // EsRelation[test_all][!alias_integer, boolean{f}#19, byte{f}#20, constant..] var esRelation = as(filter.child(), EsRelation.class); assertTrue(esRelation.output().contains(s1FieldAttr)); - assertTrue(esRelation.output().contains(r2DotProductFieldAttr)); - assertTrue(esRelation.output().contains(r2CosineFieldAttr)); + assertTrue(esRelation.output().contains(r2LeftFieldAttr)); + } + + private record SimilarityFunctionTestCase(String esqlFunction, String fieldName, float[] vector, String functionName) { + + public String toQuery() { + String params = randomBoolean() ? fieldName + ", " + Arrays.toString(vector) + : Arrays.toString(vector) + ", " + fieldName; + return esqlFunction + "(" + params + ")"; + } + + public String toFieldAttrName() { + return "$$" + fieldName + "$" + functionName; + } + + public static SimilarityFunctionTestCase random(String fieldName) { + float[] vector = new float[] { + randomFloat(), + randomFloat(), + randomFloat() + }; + // Only use DotProduct and CosineSimilarity as they have full pushdown support + // L1Norm, L2Norm, and Hamming are still in development + return switch(randomInt(4)) { + case 0 -> new SimilarityFunctionTestCase( + "v_dot_product", + fieldName, + vector, + "DotProduct" + ); + case 1 -> new SimilarityFunctionTestCase( + "v_cosine", + fieldName, + vector, + "CosineSimilarity" + ); + case 2 -> new SimilarityFunctionTestCase( + "v_l1_norm", + fieldName, + vector, + "L1Norm" + ); + case 3 -> new SimilarityFunctionTestCase( + "v_l2_norm", + fieldName, + vector, + "L2Norm" + ); + case 4 -> new SimilarityFunctionTestCase( + "v_hamming", + fieldName, + vector, + "Hamming" + ); + default -> throw new IllegalStateException("Unexpected value"); + }; + } } private IsNotNull isNotNull(Expression field) { From 7e59a92c81af8b1dc661e183a8c31e57036d85ae Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 4 Nov 2025 10:37:57 +0000 Subject: [PATCH 7/8] [CI] Auto commit changes from spotless --- .../PushDownVectorSimilarityFunctions.java | 9 +-- .../LocalLogicalPlanOptimizerTests.java | 64 +++++-------------- 2 files changed, 18 insertions(+), 55 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushDownVectorSimilarityFunctions.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushDownVectorSimilarityFunctions.java index 276a6db49f247..e47e4e576b78e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushDownVectorSimilarityFunctions.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushDownVectorSimilarityFunctions.java @@ -40,10 +40,7 @@ * the similarity function during value loading, when one side of the function is a literal. * It also adds the new field function attribute to the EsRelation output, and adds a projection after it to remove it from the output. */ -public class PushDownVectorSimilarityFunctions extends ParameterizedRule< - LogicalPlan, - LogicalPlan, - LocalLogicalOptimizerContext> { +public class PushDownVectorSimilarityFunctions extends ParameterizedRule { @Override public LogicalPlan apply(LogicalPlan plan, LocalLogicalOptimizerContext context) { @@ -51,7 +48,6 @@ public LogicalPlan apply(LogicalPlan plan, LocalLogicalOptimizerContext context) return plan.transformUp(LogicalPlan.class, p -> doRule(p, context.searchStats(), addedAttrs)); } - private LogicalPlan doRule(LogicalPlan plan, SearchStats searchStats, Map addedAttrs) { // Collect field attributes from previous runs int originalAddedAttrsSize = addedAttrs.size(); @@ -97,7 +93,8 @@ private LogicalPlan doRule(LogicalPlan plan, SearchStats searchStats, Map addedAttrs, SearchStats searchStats + Map addedAttrs, + SearchStats searchStats ) { // Only replace if it consists of a literal and the other a field attribute. // CanonicalizeVectorSimilarityFunctions ensures that if there is a literal, it will be on the right side. diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java index d77010af0bae5..3f59aa147ad53 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java @@ -378,11 +378,7 @@ public void testMissingFieldInNewCommand() { new MockFieldAttributeCommand( EMPTY, new Row(EMPTY, List.of()), - new FieldAttribute( - EMPTY, - "last_name", - new EsField("last_name", KEYWORD, Map.of(), true, EsField.TimeSeriesFieldType.NONE) - ) + new FieldAttribute(EMPTY, "last_name", new EsField("last_name", KEYWORD, Map.of(), true, EsField.TimeSeriesFieldType.NONE)) ), testStats ); @@ -1437,12 +1433,12 @@ public void testVectorFunctionsWithDuplicateFunctions() { String query = String.format( Locale.ROOT, """ - from test_all - | eval s1 = %s, s2 = %s * 2 / 3 - | where %s + 5 + %s > 0 - | eval r2 = %s + %s - | keep s1, s2, r2 - """, + from test_all + | eval s1 = %s, s2 = %s * 2 / 3 + | where %s + 5 + %s > 0 + | eval r2 = %s + %s + | keep s1, s2, r2 + """, testCase1.toQuery(), testCase1.toQuery(), testCase1.toQuery(), @@ -1524,8 +1520,7 @@ public void testVectorFunctionsWithDuplicateFunctions() { private record SimilarityFunctionTestCase(String esqlFunction, String fieldName, float[] vector, String functionName) { public String toQuery() { - String params = randomBoolean() ? fieldName + ", " + Arrays.toString(vector) - : Arrays.toString(vector) + ", " + fieldName; + String params = randomBoolean() ? fieldName + ", " + Arrays.toString(vector) : Arrays.toString(vector) + ", " + fieldName; return esqlFunction + "(" + params + ")"; } @@ -1534,44 +1529,15 @@ public String toFieldAttrName() { } public static SimilarityFunctionTestCase random(String fieldName) { - float[] vector = new float[] { - randomFloat(), - randomFloat(), - randomFloat() - }; + float[] vector = new float[] { randomFloat(), randomFloat(), randomFloat() }; // Only use DotProduct and CosineSimilarity as they have full pushdown support // L1Norm, L2Norm, and Hamming are still in development - return switch(randomInt(4)) { - case 0 -> new SimilarityFunctionTestCase( - "v_dot_product", - fieldName, - vector, - "DotProduct" - ); - case 1 -> new SimilarityFunctionTestCase( - "v_cosine", - fieldName, - vector, - "CosineSimilarity" - ); - case 2 -> new SimilarityFunctionTestCase( - "v_l1_norm", - fieldName, - vector, - "L1Norm" - ); - case 3 -> new SimilarityFunctionTestCase( - "v_l2_norm", - fieldName, - vector, - "L2Norm" - ); - case 4 -> new SimilarityFunctionTestCase( - "v_hamming", - fieldName, - vector, - "Hamming" - ); + return switch (randomInt(4)) { + case 0 -> new SimilarityFunctionTestCase("v_dot_product", fieldName, vector, "DotProduct"); + case 1 -> new SimilarityFunctionTestCase("v_cosine", fieldName, vector, "CosineSimilarity"); + case 2 -> new SimilarityFunctionTestCase("v_l1_norm", fieldName, vector, "L1Norm"); + case 3 -> new SimilarityFunctionTestCase("v_l2_norm", fieldName, vector, "L2Norm"); + case 4 -> new SimilarityFunctionTestCase("v_hamming", fieldName, vector, "Hamming"); default -> throw new IllegalStateException("Unexpected value"); }; } From 8ee21afe40672e0a9c832397d16b677449335435 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 6 Nov 2025 10:27:11 +0100 Subject: [PATCH 8/8] Fix test - dimensions checking must not use a null vector, or it won't fail --- .../xpack/esql/vector/VectorSimilarityFunctionsIT.java | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java index 539c673bb2a65..8f77afe70ef81 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java @@ -251,7 +251,7 @@ private Double calculateSimilarity( } public void testDifferentDimensions() { - var randomVector = randomVector(randomValueOtherThan(numDims, () -> randomIntBetween(32, 64) * 2)); + var randomVector = randomVector(randomValueOtherThan(numDims, () -> randomIntBetween(32, 64) * 2), false); var query = String.format(Locale.ROOT, """ FROM test | EVAL similarity = %s(left_vector, %s) @@ -322,8 +322,12 @@ private List randomVector() { } private List randomVector(int numDims) { + return randomVector(numDims, true); + } + + private List randomVector(int numDims, boolean allowNull) { assert numDims != 0 : "numDims must be set before calling randomVector()"; - if (rarely()) { + if (allowNull && rarely()) { return null; } List vector = new ArrayList<>(numDims);