Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,12 @@ default BlockLoaderFunctionConfig blockLoaderFunctionConfig() {
* Is retrievable from the {@link BlockLoaderContext}. The {@link MappedFieldType} can use this configuration to choose the appropriate
* implementation for transforming loaded values into blocks.
*/
public interface BlockLoaderFunctionConfig {}
public interface BlockLoaderFunctionConfig {
/**
* Returns a representable name for this configuration that can be used as part of a pushed down attribute name.
* Configurations that are equal() must return the same name.
*/
String name();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3370,6 +3370,8 @@ public interface IntBooleanConsumer {
}

public interface SimilarityFunction {
String name();

float calculateSimilarity(float[] leftVector, float[] rightVector);

float calculateSimilarity(byte[] leftVector, byte[] rightVector);
Expand All @@ -3386,9 +3388,10 @@ public static class VectorSimilarityFunctionConfig implements MappedFieldType.Bl
private byte[] vectorAsBytes;

public VectorSimilarityFunctionConfig(SimilarityFunction similarityFunction, float[] vector) {
Objects.requireNonNull(vector);
assert vector.length > 0 : "vector length must be > 0";
this.similarityFunction = similarityFunction;
this.vector = vector;

}

/**
Expand All @@ -3415,6 +3418,10 @@ public SimilarityFunction similarityFunction() {
return similarityFunction;
}

public String name() {
return similarityFunction.name() + "$" + Arrays.hashCode(vector);
}

@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
Expand All @@ -3426,7 +3433,7 @@ public boolean equals(Object o) {

@Override
public int hashCode() {
return Objects.hash(similarityFunction, Arrays.hashCode(vector), Arrays.hashCode(vectorAsBytes));
return Objects.hash(similarityFunction, Arrays.hashCode(vector));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ private Double calculateSimilarity(
}

public void testDifferentDimensions() {
var randomVector = randomVector(randomValueOtherThan(numDims, () -> randomIntBetween(32, 64) * 2));
var randomVector = randomVector(randomValueOtherThan(numDims, () -> randomIntBetween(32, 64) * 2), false);
var query = String.format(Locale.ROOT, """
FROM test
| EVAL similarity = %s(left_vector, %s)
Expand Down Expand Up @@ -322,8 +322,12 @@ private List<Number> randomVector() {
}

private List<Number> randomVector(int numDims) {
return randomVector(numDims, true);
}

private List<Number> randomVector(int numDims, boolean allowNull) {
assert numDims != 0 : "numDims must be set before calling randomVector()";
if (rarely()) {
if (allowNull && rarely()) {
return null;
}
List<Number> vector = new ArrayList<>(numDims);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ public class CosineSimilarity extends VectorSimilarityFunction {
CosineSimilarity::new
);
public static final DenseVectorFieldMapper.SimilarityFunction SIMILARITY_FUNCTION = new DenseVectorFieldMapper.SimilarityFunction() {
@Override
public String name() {
return "CosineSimilarity";
}

@Override
public float calculateSimilarity(byte[] leftVector, byte[] rightVector) {
return VectorUtil.cosine(leftVector, rightVector);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ public class DotProduct extends VectorSimilarityFunction {
);

public static final DenseVectorFieldMapper.SimilarityFunction SIMILARITY_FUNCTION = new DenseVectorFieldMapper.SimilarityFunction() {
@Override
public String name() {
return "DotProduct";
}

@Override
public float calculateSimilarity(byte[] leftVector, byte[] rightVector) {
return VectorUtil.dotProduct(leftVector, rightVector);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ public class Hamming extends VectorSimilarityFunction {

public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Hamming", Hamming::new);
public static final DenseVectorFieldMapper.SimilarityFunction SIMILARITY_FUNCTION = new DenseVectorFieldMapper.SimilarityFunction() {
@Override
public String name() {
return "Hamming";
}

@Override
public float calculateSimilarity(byte[] leftVector, byte[] rightVector) {
return Hamming.calculateSimilarity(leftVector, rightVector);
Expand All @@ -39,6 +44,11 @@ public float calculateSimilarity(float[] leftVector, float[] rightVector) {
};
public static final DenseVectorFieldMapper.SimilarityFunction EVALUATOR_SIMILARITY_FUNCTION =
new DenseVectorFieldMapper.SimilarityFunction() {
@Override
public String name() {
return "Hamming_evaluator";
}

@Override
public float calculateSimilarity(byte[] leftVector, byte[] rightVector) {
return Hamming.calculateSimilarity(leftVector, rightVector);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ public class L1Norm extends VectorSimilarityFunction {

public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "L1Norm", L1Norm::new);
public static final DenseVectorFieldMapper.SimilarityFunction SIMILARITY_FUNCTION = new DenseVectorFieldMapper.SimilarityFunction() {
@Override
public String name() {
return "L1Norm";
}

@Override
public float calculateSimilarity(byte[] leftVector, byte[] rightVector) {
if (leftVector.length != rightVector.length) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ public class L2Norm extends VectorSimilarityFunction {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "L2Norm", L2Norm::new);

public static final DenseVectorFieldMapper.SimilarityFunction SIMILARITY_FUNCTION = new DenseVectorFieldMapper.SimilarityFunction() {
@Override
public String name() {
return "L2Norm";
}

@Override
public float calculateSimilarity(byte[] leftVector, byte[] rightVector) {
return (float) Math.sqrt(VectorUtil.squareDistance(leftVector, rightVector));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,18 @@ public MappedFieldType.BlockLoaderFunctionConfig getBlockLoaderFunctionConfig()
return new DenseVectorFieldMapper.VectorSimilarityFunctionConfig(getSimilarityFunction(), vector);
}

@Override
protected Expression canonicalize() {
VectorSimilarityFunction canonical = (VectorSimilarityFunction) super.canonicalize();

// Set literals to the right
if (canonical.left() instanceof Literal && canonical.right() instanceof Literal == false) {
return canonical.replaceChildren(right(), left());
}

return canonical;
}

interface VectorValueProvider extends Releasable {

void eval(Page page);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.xpack.esql.optimizer.rules.PruneInlineJoinOnEmptyRightSide;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.BooleanFunctionEqualsElimination;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.BooleanSimplification;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.CanonicalizeVectorSimilarityFunctions;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.CombineBinaryComparisons;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.CombineDisjunctions;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.CombineEvals;
Expand Down Expand Up @@ -182,6 +183,7 @@ protected static Batch<LogicalPlan> operators() {
// boolean
new BooleanSimplification(),
new LiteralsOnTheRight(),
new CanonicalizeVectorSimilarityFunctions(),
// needs to occur before BinaryComparison combinations (see class)
new PropagateEquals(),
new PropagateNullable(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;

/**
* Ensures that vector similarity functions are in their canonical form, with literals to the right.
*/
public class CanonicalizeVectorSimilarityFunctions extends OptimizerRules.OptimizerExpressionRule<VectorSimilarityFunction> {

public CanonicalizeVectorSimilarityFunctions() {
super(OptimizerRules.TransformDirection.UP);
}

@Override
protected Expression rule(VectorSimilarityFunction vectorSimilarityFunction, LogicalOptimizerContext ctx) {
return vectorSimilarityFunction.canonical();
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had assumed there was already a rule for this!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's not. Given that we don't canonicalize other expressions, I'm thinking on removing this rule and getting back to the previous code for checking field and literal - this is adding complexity and coupling between the two rules.

WDYT?

Copy link
Contributor

@julian-elastic julian-elastic Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do have LiteralsOnTheRight rule. But it only works for BinaryOperator. It seems VectorSimilarityFunction is not BinaryOperator and might be hard to make it one.

Alternatively, we will swap left and right in the surrogate method for spacial functions. Then you don't need a new rule and the code is much simpler. See SpatialContains.surrogate() for an example how to do it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't SurrogateExpressions used in the context of aggregations? Would it make sense to make the VectorSimilarityFunctions a SurrogateExpression?

I don't see any practical reason for doing that other than simplifying the check that is done in order to push down the vector similarity functions. I think it does not pay off - we're expecting a rule to act in order to be able to simplify an expression that should be able to understand when it should be pushable or not.

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.esql.optimizer.rules.logical.local;

import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.Expression;
Expand All @@ -18,13 +19,14 @@
import org.elasticsearch.xpack.esql.core.type.FunctionEsField;
import org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction;
import org.elasticsearch.xpack.esql.optimizer.LocalLogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.Filter;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject;
import org.elasticsearch.xpack.esql.rule.ParameterizedRule;
import org.elasticsearch.xpack.esql.stats.SearchStats;

import java.util.ArrayList;
import java.util.HashMap;
Expand All @@ -38,30 +40,39 @@
* the similarity function during value loading, when one side of the function is a literal.
* It also adds the new field function attribute to the EsRelation output, and adds a projection after it to remove it from the output.
*/
public class PushDownVectorSimilarityFunctions extends OptimizerRules.ParameterizedOptimizerRule<
LogicalPlan,
LocalLogicalOptimizerContext> {
public class PushDownVectorSimilarityFunctions extends ParameterizedRule<LogicalPlan, LogicalPlan, LocalLogicalOptimizerContext> {

public PushDownVectorSimilarityFunctions() {
super(OptimizerRules.TransformDirection.DOWN);
@Override
public LogicalPlan apply(LogicalPlan plan, LocalLogicalOptimizerContext context) {
Map<Attribute.IdIgnoringWrapper, Attribute> addedAttrs = new HashMap<>();
return plan.transformUp(LogicalPlan.class, p -> doRule(p, context.searchStats(), addedAttrs));
}

@Override
protected LogicalPlan rule(LogicalPlan plan, LocalLogicalOptimizerContext context) {
private LogicalPlan doRule(LogicalPlan plan, SearchStats searchStats, Map<Attribute.IdIgnoringWrapper, Attribute> addedAttrs) {
// Collect field attributes from previous runs
int originalAddedAttrsSize = addedAttrs.size();
if (plan instanceof EsRelation rel) {
addedAttrs.clear();
for (Attribute attr : rel.output()) {
if (attr instanceof FieldAttribute fa && fa.field() instanceof FunctionEsField) {
addedAttrs.put(fa.ignoreId(), fa);
}
}
}

if (plan instanceof Eval || plan instanceof Filter || plan instanceof Aggregate) {
Map<Attribute.IdIgnoringWrapper, Attribute> addedAttrs = new HashMap<>();
LogicalPlan transformedPlan = plan.transformExpressionsOnly(
VectorSimilarityFunction.class,
similarityFunction -> replaceFieldsForFieldTransformations(similarityFunction, addedAttrs, context)
similarityFunction -> replaceFieldsForFieldTransformations(similarityFunction, addedAttrs, searchStats)
);

if (addedAttrs.isEmpty()) {
// No fields were added, return the original plan
if (addedAttrs.size() == originalAddedAttrsSize) {
return plan;
}

List<Attribute> previousAttrs = transformedPlan.output();
// Transforms EsRelation to extract the new attribute

// Transforms EsRelation to extract the new attributes
List<Attribute> addedAttrsList = addedAttrs.values().stream().toList();
transformedPlan = transformedPlan.transformDown(EsRelation.class, esRelation -> {
AttributeSet updatedOutput = esRelation.outputSet().combine(AttributeSet.of(addedAttrsList));
Expand All @@ -83,59 +94,44 @@ protected LogicalPlan rule(LogicalPlan plan, LocalLogicalOptimizerContext contex
private static Expression replaceFieldsForFieldTransformations(
VectorSimilarityFunction similarityFunction,
Map<Attribute.IdIgnoringWrapper, Attribute> addedAttrs,
LocalLogicalOptimizerContext context
SearchStats searchStats
) {
// Only replace if exactly one side is a literal and the other a field attribute
if ((similarityFunction.left() instanceof Literal ^ similarityFunction.right() instanceof Literal) == false) {
return similarityFunction;
}
// Only replace if it consists of a literal and the other a field attribute.
// CanonicalizeVectorSimilarityFunctions ensures that if there is a literal, it will be on the right side.
if (similarityFunction.left() instanceof FieldAttribute fieldAttr && similarityFunction.right() instanceof Literal) {

Literal literal = (Literal) (similarityFunction.left() instanceof Literal ? similarityFunction.left() : similarityFunction.right());
FieldAttribute fieldAttr = null;
if (similarityFunction.left() instanceof FieldAttribute fa) {
fieldAttr = fa;
} else if (similarityFunction.right() instanceof FieldAttribute fa) {
fieldAttr = fa;
}
// We can push down also for doc values, requires handling that case on the field mapper
if (fieldAttr == null || context.searchStats().isIndexed(fieldAttr.fieldName()) == false) {
return similarityFunction;
}
// We can push down also for doc values, requires handling that case on the field mapper
if (searchStats.isIndexed(fieldAttr.fieldName()) == false) {
return similarityFunction;
}

@SuppressWarnings("unchecked")
List<Number> vectorList = (List<Number>) literal.value();
float[] vectorArray = new float[vectorList.size()];
int arrayHashCode = 0;
for (int i = 0; i < vectorList.size(); i++) {
vectorArray[i] = vectorList.get(i).floatValue();
arrayHashCode = 31 * arrayHashCode + Float.floatToIntBits(vectorArray[i]);
}
// Change the similarity function to a reference of a transformation on the field
MappedFieldType.BlockLoaderFunctionConfig blockLoaderFunctionConfig = similarityFunction.getBlockLoaderFunctionConfig();
FunctionEsField functionEsField = new FunctionEsField(
fieldAttr.field(),
similarityFunction.dataType(),
blockLoaderFunctionConfig
);
var name = rawTemporaryName(fieldAttr.name(), blockLoaderFunctionConfig.name());
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rely on blockLoaderFunctionConfig.name() to get a unique name for different functions

var newFunctionAttr = new FieldAttribute(
fieldAttr.source(),
fieldAttr.parentName(),
fieldAttr.qualifier(),
name,
functionEsField,
fieldAttr.nullable(),
new NameId(),
true
);
Attribute.IdIgnoringWrapper key = newFunctionAttr.ignoreId();
if (addedAttrs.containsKey(key)) {
return addedAttrs.get(key);
}

// Change the similarity function to a reference of a transformation on the field
FunctionEsField functionEsField = new FunctionEsField(
fieldAttr.field(),
similarityFunction.dataType(),
similarityFunction.getBlockLoaderFunctionConfig()
);
var name = rawTemporaryName(fieldAttr.name(), similarityFunction.nodeName(), String.valueOf(arrayHashCode));
// TODO: Check if exists before adding, retrieve the previous one
var newFunctionAttr = new FieldAttribute(
fieldAttr.source(),
fieldAttr.parentName(),
fieldAttr.qualifier(),
name,
functionEsField,
fieldAttr.nullable(),
new NameId(),
true
);
Attribute.IdIgnoringWrapper key = newFunctionAttr.ignoreId();
if (addedAttrs.containsKey(key)) {
;
return addedAttrs.get(key);
addedAttrs.put(key, newFunctionAttr);
return newFunctionAttr;
}

addedAttrs.put(key, newFunctionAttr);
return newFunctionAttr;
return similarityFunction;
}
}
Loading