From a8153bedaab41a5cba9e753a0172daf035f38d70 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Thu, 12 Jun 2025 08:55:50 +0200 Subject: [PATCH 01/24] Prepare issue branch. move --- pom.xml | 2 +- spring-data-mongodb-distribution/pom.xml | 2 +- spring-data-mongodb/pom.xml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pom.xml b/pom.xml index 3daab4d790..63a1959063 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-mongodb-parent - 5.0.0-SNAPSHOT + 5.0.x-GH-5004-SNAPSHOT pom Spring Data MongoDB diff --git a/spring-data-mongodb-distribution/pom.xml b/spring-data-mongodb-distribution/pom.xml index fc88571622..66a68de39f 100644 --- a/spring-data-mongodb-distribution/pom.xml +++ b/spring-data-mongodb-distribution/pom.xml @@ -15,7 +15,7 @@ org.springframework.data spring-data-mongodb-parent - 5.0.0-SNAPSHOT + 5.0.x-GH-5004-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/pom.xml b/spring-data-mongodb/pom.xml index 6f34da5660..102427d19a 100644 --- a/spring-data-mongodb/pom.xml +++ b/spring-data-mongodb/pom.xml @@ -13,7 +13,7 @@ org.springframework.data spring-data-mongodb-parent - 5.0.0-SNAPSHOT + 5.0.x-GH-5004-SNAPSHOT ../pom.xml From 529a9ce795eb2f1bb7d0f23c95baa2c61ba345d7 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Thu, 12 Jun 2025 14:39:37 +0200 Subject: [PATCH 02/24] $near with list of entities returned --- .../core/query/CriteriaDefinition.java | 24 ++++--- .../repository/aot/AotQueryCreator.java | 46 +++++++++++-- .../aot/MongoRepositoryContributor.java | 8 +-- .../repository/query/MongoQueryCreator.java | 65 ++++++++++--------- .../data/mongodb/util/BsonUtils.java | 42 +++++++++++- .../src/test/java/example/aot/Location.java | 26 ++++++++ .../src/test/java/example/aot/User.java | 10 +++ .../test/java/example/aot/UserRepository.java | 3 + .../aot/MongoRepositoryContributorTests.java | 28 +++++++- 9 files changed, 199 insertions(+), 53 deletions(-) create mode 100644 spring-data-mongodb/src/test/java/example/aot/Location.java diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/CriteriaDefinition.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/CriteriaDefinition.java index 7777e5f554..4400baa6d6 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/CriteriaDefinition.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/CriteriaDefinition.java @@ -46,9 +46,7 @@ public interface CriteriaDefinition { * @since 5.0 * @author Christoph Strobl */ - class Placeholder { - - private final Object expression; + interface Placeholder { /** * Create a new placeholder for index bindable parameter. @@ -56,23 +54,29 @@ class Placeholder { * @param position the index of the parameter to bind. * @return new instance of {@link Placeholder}. */ - public static Placeholder indexed(int position) { - return new Placeholder("?%s".formatted(position)); + static Placeholder indexed(int position) { + return new PlaceholderImpl("?%s".formatted(position)); } - public static Placeholder placeholder(String expression) { - return new Placeholder(expression); + static Placeholder placeholder(String expression) { + return new PlaceholderImpl(expression); } - Placeholder(Object value) { - this.expression = value; + Object getValue(); + } + + static class PlaceholderImpl implements Placeholder { + private final Object expression; + + public PlaceholderImpl(Object expression) { + this.expression = expression; } + @Override public Object getValue() { return expression; } - @Override public String toString() { return getValue().toString(); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java index 17c19ad951..acf0148f56 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java @@ -15,6 +15,7 @@ */ package org.springframework.data.mongodb.repository.aot; +import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.stream.Collectors; @@ -42,11 +43,16 @@ import org.springframework.data.mongodb.core.query.UpdateDefinition; import org.springframework.data.mongodb.repository.query.ConvertingParameterAccessor; import org.springframework.data.mongodb.repository.query.MongoParameterAccessor; +import org.springframework.data.mongodb.repository.query.MongoParameters; import org.springframework.data.mongodb.repository.query.MongoQueryCreator; +import org.springframework.data.repository.query.Parameter; +import org.springframework.data.repository.query.Parameters; +import org.springframework.data.repository.query.QueryMethod; import org.springframework.data.repository.query.parser.PartTree; import org.springframework.data.util.TypeInformation; import com.mongodb.DBRef; +import org.springframework.util.ClassUtils; /** * @author Christoph Strobl @@ -68,10 +74,10 @@ public AotQueryCreator() { } @SuppressWarnings("NullAway") - StringQuery createQuery(PartTree partTree, int parameterCount) { + StringQuery createQuery(PartTree partTree, QueryMethod queryMethod) { Query query = new MongoQueryCreator(partTree, - new PlaceholderConvertingParameterAccessor(new PlaceholderParameterAccessor(parameterCount)), mappingContext) + new PlaceholderConvertingParameterAccessor(new PlaceholderParameterAccessor(queryMethod)), mappingContext) .createQuery(); if (partTree.isLimiting()) { @@ -118,17 +124,25 @@ static class PlaceholderParameterAccessor implements MongoParameterAccessor { private final List placeholders; - public PlaceholderParameterAccessor(int parameterCount) { - if (parameterCount == 0) { + public PlaceholderParameterAccessor(QueryMethod queryMethod) { + if (queryMethod.getParameters().getNumberOfParameters() == 0) { placeholders = List.of(); } else { - placeholders = IntStream.range(0, parameterCount).mapToObj(Placeholder::indexed).collect(Collectors.toList()); + placeholders = new ArrayList<>(); + Parameters parameters = queryMethod.getParameters(); + for(Parameter parameter : parameters.toList()) { + if(ClassUtils.isAssignable(Point.class, parameter.getType())) { + placeholders.add(parameter.getIndex(), new GeoPlaceholder(parameter.getIndex())); + } else { + placeholders.add(parameter.getIndex(), Placeholder.indexed(parameter.getIndex())); + } + } } } @Override public Range getDistanceRange() { - return null; + return Range.unbounded(); } @Override @@ -207,4 +221,24 @@ public Iterator iterator() { return ((List) placeholders).iterator(); } } + + static class GeoPlaceholder extends Point implements Placeholder { + + int index; + + public GeoPlaceholder(int index) { + super(Double.NaN, Double.NaN); + this.index = index; + } + + @Override + public Object getValue() { + return "?" + index; + } + + @Override + public String toString() { + return getValue().toString(); + } + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index 424d067d74..d25d2f7f9a 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -154,8 +154,8 @@ private QueryInteraction createStringQuery(RepositoryInformation repositoryInfor } else { PartTree partTree = new PartTree(queryMethod.getName(), repositoryInformation.getDomainType()); - query = new QueryInteraction(queryCreator.createQuery(partTree, parameterCount), partTree.isCountProjection(), - partTree.isDelete(), partTree.isExistsProjection()); + query = new QueryInteraction(queryCreator.createQuery(partTree, queryMethod), + partTree.isCountProjection(), partTree.isDelete(), partTree.isExistsProjection()); } if (queryAnnotation != null && StringUtils.hasText(queryAnnotation.sort())) { @@ -171,8 +171,8 @@ private QueryInteraction createStringQuery(RepositoryInformation repositoryInfor private static boolean backoff(MongoQueryMethod method) { // TODO: namedQuery, Regex queries, queries accepting Shapes (e.g. within) or returning arrays. - boolean skip = method.isGeoNearQuery() || method.isSearchQuery() - || method.getName().toLowerCase(Locale.ROOT).contains("regex") || method.getReturnType().getType().isArray(); + boolean skip = method.isSearchQuery() || method.getName().toLowerCase(Locale.ROOT).contains("regex") + || method.getReturnType().getType().isArray(); if (skip && logger.isDebugEnabled()) { logger.debug("Skipping AOT generation for [%s]. Method is either returning an array or a geo-near, regex query" diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java index ba7394ec17..97712b61cb 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java @@ -54,6 +54,7 @@ import org.springframework.data.repository.query.parser.Part.Type; import org.springframework.data.repository.query.parser.PartTree; import org.springframework.data.util.Streamable; +import org.springframework.lang.NonNull; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.ObjectUtils; @@ -235,35 +236,7 @@ private Criteria from(Part part, MongoPersistentProperty property, Criteria crit return criteria.is(false); case NEAR: - Range range = accessor.getDistanceRange(); - Optional distance = range.getUpperBound().getValue(); - Optional minDistance = range.getLowerBound().getValue(); - - Point point = accessor.getGeoNearLocation(); - Point pointToUse = point == null ? nextAs(parameters, Point.class) : point; - - boolean isSpherical = isSpherical(property); - - return distance.map(it -> { - - if (isSpherical || !Metrics.NEUTRAL.equals(it.getMetric())) { - criteria.nearSphere(pointToUse); - } else { - criteria.near(pointToUse); - } - - if (pointToUse instanceof GeoJson) { // using GeoJson distance is in meters. - - criteria.maxDistance(MetricConversion.getDistanceInMeters(it)); - minDistance.map(MetricConversion::getDistanceInMeters).ifPresent(criteria::minDistance); - } else { - criteria.maxDistance(it.getNormalizedValue()); - minDistance.map(Distance::getNormalizedValue).ifPresent(criteria::minDistance); - } - - return criteria; - - }).orElseGet(() -> isSpherical ? criteria.nearSphere(pointToUse) : criteria.near(pointToUse)); + return createNearCriteria(property, criteria, parameters); case WITHIN: @@ -283,6 +256,40 @@ private Criteria from(Part part, MongoPersistentProperty property, Criteria crit } } + @NonNull + private Criteria createNearCriteria(MongoPersistentProperty property, Criteria criteria, Iterator parameters) { + + + Range range = accessor.getDistanceRange(); + Optional distance = range.getUpperBound().getValue(); + Optional minDistance = range.getLowerBound().getValue(); + + Point point = accessor.getGeoNearLocation(); + Point pointToUse = point == null ? nextAs(parameters, Point.class) : point; + + boolean isSpherical = isSpherical(property); + + return distance.map(it -> { + + if (isSpherical || !Metrics.NEUTRAL.equals(it.getMetric())) { + criteria.nearSphere(pointToUse); + } else { + criteria.near(pointToUse); + } + + if (pointToUse instanceof GeoJson) { // using GeoJson distance is in meters. + + criteria.maxDistance(MetricConversion.getDistanceInMeters(it)); + minDistance.map(MetricConversion::getDistanceInMeters).ifPresent(criteria::minDistance); + } else { + criteria.maxDistance(it.getNormalizedValue()); + minDistance.map(Distance::getNormalizedValue).ifPresent(criteria::minDistance); + } + + return criteria; + }).orElseGet(() -> isSpherical ? criteria.nearSphere(pointToUse) : criteria.near(pointToUse)); + } + private boolean isSimpleComparisonPossible(Part part) { return switch (part.shouldIgnoreCase()) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java index dc51da84ed..e23bf537b8 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java @@ -60,9 +60,11 @@ import org.bson.codecs.DocumentCodec; import org.bson.codecs.EncoderContext; import org.bson.codecs.configuration.CodecConfigurationException; +import org.bson.codecs.configuration.CodecProvider; import org.bson.codecs.configuration.CodecRegistries; import org.bson.codecs.configuration.CodecRegistry; import org.bson.conversions.Bson; +import org.bson.internal.ProvidersCodecRegistry; import org.bson.json.JsonParseException; import org.bson.types.Binary; import org.bson.types.Decimal128; @@ -74,6 +76,7 @@ import org.springframework.data.mongodb.core.mapping.FieldName; import org.springframework.data.mongodb.core.mapping.FieldName.Type; import org.springframework.data.mongodb.core.query.CriteriaDefinition.Placeholder; +import org.springframework.data.mongodb.core.query.CriteriaDefinition.PlaceholderImpl; import org.springframework.lang.Contract; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; @@ -103,7 +106,7 @@ public class BsonUtils { public static final Document EMPTY_DOCUMENT = new EmptyDocument(); private static final CodecRegistry JSON_CODEC_REGISTRY = CodecRegistries.fromRegistries( - MongoClientSettings.getDefaultCodecRegistry(), CodecRegistries.fromCodecs(new PlaceholderCodec())); + MongoClientSettings.getDefaultCodecRegistry(), CodecRegistries.fromProviders(new PlaceholderCodecProvider())); @SuppressWarnings("unchecked") @Contract("null, _ -> null") @@ -377,7 +380,7 @@ public static BsonValue simpleToBsonValue(@Nullable Object source) { @Contract("null, _ -> !null") public static BsonValue simpleToBsonValue(@Nullable Object source, CodecRegistry codecRegistry) { - if(source == null) { + if (source == null) { return BsonNull.VALUE; } @@ -1031,6 +1034,19 @@ public void flush() { } } + @NullUnmarked + public static class PlaceholderCodecProvider implements CodecProvider { + + PlaceholderCodec placeholderCodec = new PlaceholderCodec(); + + @Override + public Codec get(Class clazz, CodecRegistry registry) { + if(!ClassUtils.isAssignable(Placeholder.class, clazz)) { + return null; + } + return (Codec) placeholderCodec; + } + } /** * Internal {@link Codec} implementation to write * {@link org.springframework.data.mongodb.core.query.CriteriaDefinition.Placeholder placeholders}. @@ -1060,4 +1076,26 @@ public Class getEncoderClass() { return Placeholder.class; } } + +// @NullUnmarked +// static class PlaceholderImplCodec implements Codec { +// +// PlaceholderCodec delegate = new PlaceholderCodec(); +// +// @Override +// public PlaceholderImpl decode(BsonReader reader, DecoderContext decoderContext) { +// return null; +// } +// +// @Override +// public void encode(BsonWriter writer, PlaceholderImpl value, EncoderContext encoderContext) { +// delegate.encode(writer, value, encoderContext); +// +// } +// +// @Override +// public Class getEncoderClass() { +// return PlaceholderImpl.class; +// } +// } } diff --git a/spring-data-mongodb/src/test/java/example/aot/Location.java b/spring-data-mongodb/src/test/java/example/aot/Location.java new file mode 100644 index 0000000000..210e9e0ce6 --- /dev/null +++ b/spring-data-mongodb/src/test/java/example/aot/Location.java @@ -0,0 +1,26 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package example.aot; + +import org.springframework.data.geo.Point; + +/** + * @param planet + * @param coordinates + * @author Christoph Strobl + */ +public record Location(String planet, Point coordinates) { +} diff --git a/spring-data-mongodb/src/test/java/example/aot/User.java b/spring-data-mongodb/src/test/java/example/aot/User.java index 06022c0a55..25514a518c 100644 --- a/spring-data-mongodb/src/test/java/example/aot/User.java +++ b/spring-data-mongodb/src/test/java/example/aot/User.java @@ -32,6 +32,8 @@ public class User { @Field("last_name") String lastname; + Location location; + Instant registrationDate; Instant lastSeen; Long visits; @@ -91,4 +93,12 @@ public Long getVisits() { public void setVisits(Long visits) { this.visits = visits; } + + public Location getLocation() { + return location; + } + + public void setLocation(Location location) { + this.location = location; + } } diff --git a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java index 5eb9fed686..250d651e8a 100644 --- a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java +++ b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java @@ -32,6 +32,7 @@ import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; import org.springframework.data.domain.Window; +import org.springframework.data.geo.Point; import org.springframework.data.mongodb.core.aggregation.AggregationResults; import org.springframework.data.mongodb.repository.Aggregation; import org.springframework.data.mongodb.repository.Hint; @@ -103,6 +104,8 @@ public interface UserRepository extends CrudRepository { Window findTop2WindowByLastnameStartingWithOrderByUsername(String lastname, ScrollPosition scrollPosition); + List findByLocationCoordinatesNear(Point location); + // TODO: GeoQueries // TODO: TextSearch diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java index a2840ec268..e11f726107 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java @@ -15,7 +15,9 @@ */ package org.springframework.data.mongodb.repository.aot; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatException; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import example.aot.User; import example.aot.UserProjection; @@ -27,10 +29,10 @@ import java.util.Optional; import org.bson.Document; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; - import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -43,6 +45,7 @@ import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; import org.springframework.data.domain.Window; +import org.springframework.data.geo.Point; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.MongoTemplate; import org.springframework.data.mongodb.core.aggregation.AggregationResults; @@ -53,6 +56,7 @@ import org.springframework.util.StringUtils; import com.mongodb.client.MongoClient; +import com.mongodb.client.model.IndexOptions; /** * Integration tests for the {@link UserRepository} AOT fragment. @@ -82,6 +86,14 @@ MongoOperations mongoOperations() { } } + @BeforeAll + static void beforeAll() { + String idx = client.getDatabase(DB_NAME).getCollection("user").createIndex(new Document("location.coordinates", "2d"), +// String idx = client.getDatabase(DB_NAME).getCollection("user").createIndex(new Document("location.coordinates", "2dsphere"), + new IndexOptions()); + System.out.println("idx: " + idx); + } + @BeforeEach void beforeEach() { @@ -592,6 +604,12 @@ void testAggregationWithCollation() { .withMessageContaining("'locale' is invalid"); } + @Test + void testGeoNear() { + List users = fragment.findByLocationCoordinatesNear(new Point(-73.99171, 40.738868)); + assertThat(users).extracting(User::getUsername).containsExactly("leia"); + } + private static void initUsers() { Document luke = Document.parse(""" @@ -621,6 +639,12 @@ private static void initUsers() { "username": "leia", "first_name": "Leia", "last_name": "Organa", + "location" : { + "planet" : "Coruscant", + "coordinates" : { + "x" : -73.99171, "y" : 40.738868 + } + }, "_class": "example.springdata.aot.User" }"""); From a0dd23e92ef5373bc0ac010bc099001597d9da88 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Fri, 13 Jun 2025 07:43:25 +0200 Subject: [PATCH 03/24] tmp save --- .../repository/aot/AotQueryCreator.java | 37 ++++++++-- .../data/mongodb/util/BsonUtils.java | 69 ++++++++++++------- .../data/mongodb/util/SpringJsonWriter.java | 2 +- .../test/java/example/aot/UserRepository.java | 4 ++ .../aot/MongoRepositoryContributorTests.java | 25 +++++-- .../src/test/resources/logback.xml | 3 +- 6 files changed, 101 insertions(+), 39 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java index acf0148f56..63b95d079f 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java @@ -18,8 +18,6 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.IntStream; import org.bson.conversions.Bson; import org.jspecify.annotations.NullUnmarked; @@ -30,8 +28,11 @@ import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Sort; import org.springframework.data.domain.Vector; +import org.springframework.data.geo.Circle; import org.springframework.data.geo.Distance; +import org.springframework.data.geo.Metrics; import org.springframework.data.geo.Point; +import org.springframework.data.geo.Shape; import org.springframework.data.mongodb.core.convert.MongoCustomConversions; import org.springframework.data.mongodb.core.convert.MongoWriter; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; @@ -43,7 +44,6 @@ import org.springframework.data.mongodb.core.query.UpdateDefinition; import org.springframework.data.mongodb.repository.query.ConvertingParameterAccessor; import org.springframework.data.mongodb.repository.query.MongoParameterAccessor; -import org.springframework.data.mongodb.repository.query.MongoParameters; import org.springframework.data.mongodb.repository.query.MongoQueryCreator; import org.springframework.data.repository.query.Parameter; import org.springframework.data.repository.query.Parameters; @@ -132,8 +132,12 @@ public PlaceholderParameterAccessor(QueryMethod queryMethod) { Parameters parameters = queryMethod.getParameters(); for(Parameter parameter : parameters.toList()) { if(ClassUtils.isAssignable(Point.class, parameter.getType())) { - placeholders.add(parameter.getIndex(), new GeoPlaceholder(parameter.getIndex())); - } else { + placeholders.add(parameter.getIndex(), new PointPlaceholder(parameter.getIndex())); + } else if(ClassUtils.isAssignable(Circle.class, parameter.getType())) { + placeholders.add(parameter.getIndex(), new CirclePlaceholder(parameter.getIndex())); + } + + else { placeholders.add(parameter.getIndex(), Placeholder.indexed(parameter.getIndex())); } } @@ -222,11 +226,30 @@ public Iterator iterator() { } } - static class GeoPlaceholder extends Point implements Placeholder { + static class CirclePlaceholder extends Circle implements Placeholder { + + int index; + public CirclePlaceholder(int index) { + super(new PointPlaceholder(index), Distance.of(1, Metrics.NEUTRAL)); // + this.index = index; + } + + @Override + public Object getValue() { + return "?%s".formatted(index); + } + + @Override + public String toString() { + return getValue().toString(); + } + } + + static class PointPlaceholder extends Point implements Placeholder { int index; - public GeoPlaceholder(int index) { + public PointPlaceholder(int index) { super(Double.NaN, Double.NaN); this.index = index; } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java index e23bf537b8..2298a8b74e 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java @@ -64,7 +64,6 @@ import org.bson.codecs.configuration.CodecRegistries; import org.bson.codecs.configuration.CodecRegistry; import org.bson.conversions.Bson; -import org.bson.internal.ProvidersCodecRegistry; import org.bson.json.JsonParseException; import org.bson.types.Binary; import org.bson.types.Decimal128; @@ -72,11 +71,12 @@ import org.jspecify.annotations.NullUnmarked; import org.jspecify.annotations.Nullable; import org.springframework.core.convert.converter.Converter; +import org.springframework.data.geo.Circle; import org.springframework.data.mongodb.CodecRegistryProvider; import org.springframework.data.mongodb.core.mapping.FieldName; import org.springframework.data.mongodb.core.mapping.FieldName.Type; import org.springframework.data.mongodb.core.query.CriteriaDefinition.Placeholder; -import org.springframework.data.mongodb.core.query.CriteriaDefinition.PlaceholderImpl; +import org.springframework.data.mongodb.core.query.GeoCommand; import org.springframework.lang.Contract; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; @@ -1038,15 +1038,21 @@ public void flush() { public static class PlaceholderCodecProvider implements CodecProvider { PlaceholderCodec placeholderCodec = new PlaceholderCodec(); + GeoCommandCodec geoCommandCodec = new GeoCommandCodec(); @Override public Codec get(Class clazz, CodecRegistry registry) { - if(!ClassUtils.isAssignable(Placeholder.class, clazz)) { - return null; + if (ClassUtils.isAssignable(Placeholder.class, clazz)) { + return (Codec) placeholderCodec; + } + if (ClassUtils.isAssignable(GeoCommand.class, clazz)) { + return (Codec) geoCommandCodec; } - return (Codec) placeholderCodec; + return null; + } } + /** * Internal {@link Codec} implementation to write * {@link org.springframework.data.mongodb.core.query.CriteriaDefinition.Placeholder placeholders}. @@ -1077,25 +1083,36 @@ public Class getEncoderClass() { } } -// @NullUnmarked -// static class PlaceholderImplCodec implements Codec { -// -// PlaceholderCodec delegate = new PlaceholderCodec(); -// -// @Override -// public PlaceholderImpl decode(BsonReader reader, DecoderContext decoderContext) { -// return null; -// } -// -// @Override -// public void encode(BsonWriter writer, PlaceholderImpl value, EncoderContext encoderContext) { -// delegate.encode(writer, value, encoderContext); -// -// } -// -// @Override -// public Class getEncoderClass() { -// return PlaceholderImpl.class; -// } -// } + static class GeoCommandCodec implements Codec { + + @Override + public GeoCommand decode(BsonReader reader, DecoderContext decoderContext) { + return null; + } + + @Override + public void encode(BsonWriter writer, GeoCommand value, EncoderContext encoderContext) { + + if (writer instanceof SpringJsonWriter sjw) { + writer.writeStartDocument(); + writer.writeName(value.getCommand()); + if (value.getShape() instanceof Placeholder p) { // maybe we should wrap input to use geo command object + sjw.writePlaceholder(p.toString()); +// Circle c = null; +// List.of(c.getCenter(), c.getRadius()) +// ; + +// createQuery("{'location.coordinates':{'$geoWithin':{'$center':?0}}}", new Object[]{ List.of(circle.getCenter(), circle.getRadius())) + } + writer.writeEndDocument(); + } else { + writer.writeString(value.getCommand(), value.getShape().toString()); + } + } + + @Override + public Class getEncoderClass() { + return null; + } + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/SpringJsonWriter.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/SpringJsonWriter.java index 07eab92a01..98dbc3a682 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/SpringJsonWriter.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/SpringJsonWriter.java @@ -463,7 +463,7 @@ public void writePlaceholder(String placeholder) { write(placeholder); } - private void write(String str) { + public void write(String str) { buffer.append(str); } diff --git a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java index 250d651e8a..f466589d9d 100644 --- a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java +++ b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java @@ -32,10 +32,12 @@ import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; import org.springframework.data.domain.Window; +import org.springframework.data.geo.Circle; import org.springframework.data.geo.Point; import org.springframework.data.mongodb.core.aggregation.AggregationResults; import org.springframework.data.mongodb.repository.Aggregation; import org.springframework.data.mongodb.repository.Hint; +import org.springframework.data.mongodb.repository.Person; import org.springframework.data.mongodb.repository.Query; import org.springframework.data.mongodb.repository.ReadPreference; import org.springframework.data.mongodb.repository.Update; @@ -106,6 +108,8 @@ public interface UserRepository extends CrudRepository { List findByLocationCoordinatesNear(Point location); + List findByLocationCoordinatesWithin(Circle circle); + // TODO: GeoQueries // TODO: TextSearch diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java index e11f726107..e0d00b04c8 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java @@ -45,10 +45,12 @@ import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; import org.springframework.data.domain.Window; +import org.springframework.data.geo.Circle; import org.springframework.data.geo.Point; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.MongoTemplate; import org.springframework.data.mongodb.core.aggregation.AggregationResults; +import org.springframework.data.mongodb.core.geo.GeoJsonPoint; import org.springframework.data.mongodb.test.util.Client; import org.springframework.data.mongodb.test.util.MongoClientExtension; import org.springframework.data.mongodb.test.util.MongoTestUtils; @@ -88,10 +90,8 @@ MongoOperations mongoOperations() { @BeforeAll static void beforeAll() { - String idx = client.getDatabase(DB_NAME).getCollection("user").createIndex(new Document("location.coordinates", "2d"), -// String idx = client.getDatabase(DB_NAME).getCollection("user").createIndex(new Document("location.coordinates", "2dsphere"), + client.getDatabase(DB_NAME).getCollection("user").createIndex(new Document("location.coordinates", "2d"), new IndexOptions()); - System.out.println("idx: " + idx); } @BeforeEach @@ -605,11 +605,28 @@ void testAggregationWithCollation() { } @Test - void testGeoNear() { + void testNear() { + List users = fragment.findByLocationCoordinatesNear(new Point(-73.99171, 40.738868)); assertThat(users).extracting(User::getUsername).containsExactly("leia"); } + @Test + void testNearWithGeoJson() { + + List users = fragment.findByLocationCoordinatesNear(new GeoJsonPoint(-73.99171, 40.738868)); + assertThat(users).extracting(User::getUsername).containsExactly("leia"); + } + + @Test + void testGeoWithin() { + + List users = fragment.findByLocationCoordinatesWithin(new Circle(-78.99171, 45.738868, 170)); + assertThat(users).extracting(User::getUsername).containsExactly("leia"); + } + + //List result = repository.findByLocationWithin(); + private static void initUsers() { Document luke = Document.parse(""" diff --git a/spring-data-mongodb/src/test/resources/logback.xml b/spring-data-mongodb/src/test/resources/logback.xml index 55e4309a36..d0907937fa 100644 --- a/spring-data-mongodb/src/test/resources/logback.xml +++ b/spring-data-mongodb/src/test/resources/logback.xml @@ -20,8 +20,9 @@ - + + From 17fc18479210ef900b90359eba51931f78abb0f6 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Fri, 13 Jun 2025 08:22:55 +0200 Subject: [PATCH 04/24] ok this works for geo shapes --- .../mongodb/repository/aot/MongoCodeBlocks.java | 13 ++++++++++++- .../mongodb/repository/query/MongoParameters.java | 2 +- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java index 999391f5ec..6e6af37f77 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java @@ -28,6 +28,7 @@ import org.springframework.core.annotation.MergedAnnotation; import org.springframework.data.domain.SliceImpl; import org.springframework.data.domain.Sort.Order; +import org.springframework.data.geo.Circle; import org.springframework.data.mongodb.core.ExecutableFindOperation.FindWithQuery; import org.springframework.data.mongodb.core.ExecutableRemoveOperation.ExecutableRemove; import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate; @@ -44,6 +45,7 @@ import org.springframework.data.mongodb.repository.Hint; import org.springframework.data.mongodb.repository.Meta; import org.springframework.data.mongodb.repository.ReadPreference; +import org.springframework.data.mongodb.repository.query.MongoParameters.MongoParameter; import org.springframework.data.mongodb.repository.query.MongoQueryExecution.DeleteExecution; import org.springframework.data.mongodb.repository.query.MongoQueryExecution.PagedExecution; import org.springframework.data.mongodb.repository.query.MongoQueryExecution.SlicedExecution; @@ -686,7 +688,16 @@ static class QueryCodeBlockBuilder { QueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { this.context = context; - this.arguments = context.getBindableParameterNames(); + this.arguments = new ArrayList<>(); + + for(MongoParameter parameter : queryMethod.getParameters().getBindableParameters()) { + String parameterName = context.getParameterName(parameter.getIndex()); + if(ClassUtils.isAssignable(Circle.class, parameter.getType())) { + parameterName = "List.of(%s.getCenter(), %s.getRadius().getNormalizedValue())".formatted(parameterName, parameterName); + } + arguments.add(parameterName); + } + this.queryMethod = queryMethod; } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java index 94acef17ce..0aa9ad5fdf 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java @@ -292,7 +292,7 @@ private int getTypeIndex(List> parameterTypes, Class type, * * @author Oliver Gierke */ - static class MongoParameter extends Parameter { + public static class MongoParameter extends Parameter { private final MethodParameter parameter; private final @Nullable Integer nearIndex; From fb4ef82e371625257b0e31b7d9d255ea0b8973be Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Fri, 13 Jun 2025 08:43:48 +0200 Subject: [PATCH 05/24] switch arguments to code blocks --- .../repository/aot/MongoCodeBlocks.java | 54 +++++++++++++------ 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java index 6e6af37f77..13e132f8c4 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java @@ -16,15 +16,16 @@ package org.springframework.data.mongodb.repository.aot; import java.util.ArrayList; +import java.util.Iterator; import java.util.List; import java.util.Optional; import java.util.regex.Pattern; +import java.util.stream.Collectors; import java.util.stream.Stream; import org.bson.Document; import org.jspecify.annotations.NullUnmarked; import org.jspecify.annotations.Nullable; - import org.springframework.core.annotation.MergedAnnotation; import org.springframework.data.domain.SliceImpl; import org.springframework.data.domain.Sort.Order; @@ -470,14 +471,14 @@ static class AggregationCodeBlockBuilder { private final MongoQueryMethod queryMethod; private AggregationInteraction source; - private final List arguments; + private final List arguments; private String aggregationVariableName; private boolean pipelineOnly; AggregationCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { this.context = context; - this.arguments = context.getBindableParameterNames(); + this.arguments = context.getBindableParameterNames().stream().map(CodeBlock::of).collect(Collectors.toList()); this.queryMethod = queryMethod; } @@ -605,7 +606,7 @@ private CodeBlock aggregationOptions(String aggregationVariableName) { } private CodeBlock aggregationStages(String stageListVariableName, Iterable stages, int stageCount, - List arguments) { + List arguments) { Builder builder = CodeBlock.builder(); builder.addStatement("$T<$T> $L = new $T($L)", List.class, Object.class, stageListVariableName, ArrayList.class, @@ -682,20 +683,22 @@ static class QueryCodeBlockBuilder { private final MongoQueryMethod queryMethod; private QueryInteraction source; - private final List arguments; + private final List arguments; private String queryVariableName; QueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { this.context = context; - this.arguments = new ArrayList<>(); - for(MongoParameter parameter : queryMethod.getParameters().getBindableParameters()) { + this.arguments = new ArrayList<>(); + for (MongoParameter parameter : queryMethod.getParameters().getBindableParameters()) { String parameterName = context.getParameterName(parameter.getIndex()); - if(ClassUtils.isAssignable(Circle.class, parameter.getType())) { - parameterName = "List.of(%s.getCenter(), %s.getRadius().getNormalizedValue())".formatted(parameterName, parameterName); + if (ClassUtils.isAssignable(Circle.class, parameter.getType())) { + arguments.add(CodeBlock.builder().add("$T.of($L.getCenter(), $L.getRadius().getNormalizedValue())", + List.class, parameterName, parameterName).build()); + } else { + arguments.add(CodeBlock.of(parameterName)); } - arguments.add(parameterName); } this.queryMethod = queryMethod; @@ -797,8 +800,15 @@ private CodeBlock renderExpressionToQuery(@Nullable String source, String variab builder.addStatement("$T $L = new $T($T.parse($S))", BasicQuery.class, variableName, BasicQuery.class, Document.class, source); } else { - builder.addStatement("$T $L = createQuery($S, new $T[]{ $L })", BasicQuery.class, variableName, source, - Object.class, StringUtils.collectionToDelimitedString(arguments, ", ")); + builder.add("$T $L = createQuery($S, new $T[]{ ", BasicQuery.class, variableName, source, Object.class); + Iterator iterator = arguments.iterator(); + while (iterator.hasNext()) { + builder.add(iterator.next()); + if (iterator.hasNext()) { + builder.add(", "); + } + } + builder.add("});\n"); } return builder.build(); @@ -809,11 +819,11 @@ private CodeBlock renderExpressionToQuery(@Nullable String source, String variab static class UpdateCodeBlockBuilder { private UpdateInteraction source; - private List arguments; + private List arguments; private String updateVariableName; public UpdateCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - this.arguments = context.getBindableParameterNames(); + this.arguments = context.getBindableParameterNames().stream().map(CodeBlock::of).collect(Collectors.toList()); } public UpdateCodeBlockBuilder update(UpdateInteraction update) { @@ -841,7 +851,7 @@ CodeBlock build() { } private static CodeBlock renderExpressionToDocument(@Nullable String source, String variableName, - List arguments) { + List arguments) { Builder builder = CodeBlock.builder(); if (!StringUtils.hasText(source)) { @@ -849,8 +859,18 @@ private static CodeBlock renderExpressionToDocument(@Nullable String source, Str } else if (!containsPlaceholder(source)) { builder.addStatement("$T $L = $T.parse($S)", Document.class, variableName, Document.class, source); } else { - builder.addStatement("$T $L = bindParameters($S, new $T[]{ $L })", Document.class, variableName, source, - Object.class, StringUtils.collectionToDelimitedString(arguments, ", ")); + + builder.add("$T $L = bindParameters($S, new $T[]{ ", Document.class, variableName, source, Object.class); + Iterator iterator = arguments.iterator(); + while (iterator.hasNext()) { + builder.add(iterator.next()); + if (iterator.hasNext()) { + builder.add(", "); + } + } + builder.add("});\n"); + // builder.addStatement("$T $L = bindParameters($S, new $T[]{ $L })", Document.class, variableName, source, + // Object.class, StringUtils.collectionToDelimitedString(arguments, ", ")); } return builder.build(); } From 62c86e37b31ff062627e6a2bafb8ce1e6bfb2cb8 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Fri, 13 Jun 2025 09:17:41 +0200 Subject: [PATCH 06/24] more geo shapes --- .../repository/aot/AotQueryCreator.java | 66 +++++++++++++++++++ .../repository/aot/MongoCodeBlocks.java | 31 ++++++++- .../test/java/example/aot/UserRepository.java | 7 +- .../aot/MongoRepositoryContributorTests.java | 27 +++++++- 4 files changed, 125 insertions(+), 6 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java index 63b95d079f..6904c868ce 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java @@ -28,13 +28,16 @@ import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Sort; import org.springframework.data.domain.Vector; +import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; import org.springframework.data.geo.Distance; import org.springframework.data.geo.Metrics; import org.springframework.data.geo.Point; +import org.springframework.data.geo.Polygon; import org.springframework.data.geo.Shape; import org.springframework.data.mongodb.core.convert.MongoCustomConversions; import org.springframework.data.mongodb.core.convert.MongoWriter; +import org.springframework.data.mongodb.core.geo.Sphere; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; import org.springframework.data.mongodb.core.query.Collation; @@ -135,6 +138,12 @@ public PlaceholderParameterAccessor(QueryMethod queryMethod) { placeholders.add(parameter.getIndex(), new PointPlaceholder(parameter.getIndex())); } else if(ClassUtils.isAssignable(Circle.class, parameter.getType())) { placeholders.add(parameter.getIndex(), new CirclePlaceholder(parameter.getIndex())); + } else if(ClassUtils.isAssignable(Box.class, parameter.getType())) { + placeholders.add(parameter.getIndex(), new BoxPlaceholder(parameter.getIndex())); + } else if(ClassUtils.isAssignable(Sphere.class, parameter.getType())) { + placeholders.add(parameter.getIndex(), new SpherePlaceholder(parameter.getIndex())); + } else if(ClassUtils.isAssignable(Polygon.class, parameter.getType())) { + placeholders.add(parameter.getIndex(), new PolygonPlaceholder(parameter.getIndex())); } else { @@ -245,6 +254,63 @@ public String toString() { } } + static class SpherePlaceholder extends Sphere implements Placeholder { + + int index; + public SpherePlaceholder(int index) { + super(new PointPlaceholder(index), Distance.of(1, Metrics.NEUTRAL)); // + this.index = index; + } + + @Override + public Object getValue() { + return "?%s".formatted(index); + } + + @Override + public String toString() { + return getValue().toString(); + } + } + + static class BoxPlaceholder extends Box implements Placeholder { + int index; + + public BoxPlaceholder(int index) { + super(new PointPlaceholder(index), new PointPlaceholder(index)); + this.index = index; + } + + @Override + public Object getValue() { + return "?%s".formatted(index); + } + + @Override + public String toString() { + return getValue().toString(); + } + } + + static class PolygonPlaceholder extends Polygon implements Placeholder { + int index; + + public PolygonPlaceholder(int index) { + super(new PointPlaceholder(index), new PointPlaceholder(index), new PointPlaceholder(index), new PointPlaceholder(index)); + this.index = index; + } + + @Override + public Object getValue() { + return "?%s".formatted(index); + } + + @Override + public String toString() { + return getValue().toString(); + } + } + static class PointPlaceholder extends Point implements Placeholder { int index; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java index 13e132f8c4..9153498129 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java @@ -29,7 +29,9 @@ import org.springframework.core.annotation.MergedAnnotation; import org.springframework.data.domain.SliceImpl; import org.springframework.data.domain.Sort.Order; +import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; +import org.springframework.data.geo.Polygon; import org.springframework.data.mongodb.core.ExecutableFindOperation.FindWithQuery; import org.springframework.data.mongodb.core.ExecutableRemoveOperation.ExecutableRemove; import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate; @@ -39,6 +41,7 @@ import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; import org.springframework.data.mongodb.core.aggregation.AggregationResults; import org.springframework.data.mongodb.core.aggregation.TypedAggregation; +import org.springframework.data.mongodb.core.geo.Sphere; import org.springframework.data.mongodb.core.mapping.MongoSimpleTypes; import org.springframework.data.mongodb.core.query.BasicQuery; import org.springframework.data.mongodb.core.query.BasicUpdate; @@ -694,9 +697,31 @@ static class QueryCodeBlockBuilder { for (MongoParameter parameter : queryMethod.getParameters().getBindableParameters()) { String parameterName = context.getParameterName(parameter.getIndex()); if (ClassUtils.isAssignable(Circle.class, parameter.getType())) { - arguments.add(CodeBlock.builder().add("$T.of($L.getCenter(), $L.getRadius().getNormalizedValue())", - List.class, parameterName, parameterName).build()); - } else { + arguments.add(CodeBlock.builder() + .add("$T.of($T.of($L.getCenter().getX(), $L.getCenter().getY()), $L.getRadius().getNormalizedValue())", + List.class, List.class, parameterName, parameterName, parameterName) + .build()); + } else if (ClassUtils.isAssignable(Box.class, parameter.getType())) { + + // { $geoWithin: { $box: [ [ , ], [ , ] ] } + arguments.add(CodeBlock.builder().add( + "$T.of($T.of($L.getFirst().getX(), $L.getFirst().getY()), $T.of($L.getSecond().getX(), $L.getSecond().getY()))", + List.class, List.class, parameterName, parameterName, List.class, parameterName, parameterName).build()); + } else if (ClassUtils.isAssignable(Sphere.class, parameter.getType())) { + // { $centerSphere: [ [ , ], ] } + arguments.add(CodeBlock.builder() + .add("$T.of($T.of($L.getCenter().getX(), $L.getCenter().getY()), $L.getRadius().getNormalizedValue())", + List.class, List.class, parameterName, parameterName, parameterName) + .build()); + } else if (ClassUtils.isAssignable(Polygon.class, parameter.getType())) { + // $polygon: [ [ , ], [ , ], [ , ], ... ] + String localVar = context.localVariable("_p"); + arguments + .add(CodeBlock.builder().add("$L.getPoints().stream().map($L -> $T.of($L.getX(), $L.getY())).toList()", + parameterName, localVar, List.class, localVar, localVar).build()); + } + + else { arguments.add(CodeBlock.of(parameterName)); } } diff --git a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java index f466589d9d..8493b6b7b4 100644 --- a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java +++ b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java @@ -32,12 +32,13 @@ import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; import org.springframework.data.domain.Window; +import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; import org.springframework.data.geo.Point; +import org.springframework.data.geo.Polygon; import org.springframework.data.mongodb.core.aggregation.AggregationResults; import org.springframework.data.mongodb.repository.Aggregation; import org.springframework.data.mongodb.repository.Hint; -import org.springframework.data.mongodb.repository.Person; import org.springframework.data.mongodb.repository.Query; import org.springframework.data.mongodb.repository.ReadPreference; import org.springframework.data.mongodb.repository.Update; @@ -110,6 +111,10 @@ public interface UserRepository extends CrudRepository { List findByLocationCoordinatesWithin(Circle circle); + List findByLocationCoordinatesWithin(Box box); + + List findByLocationCoordinatesWithin(Polygon polygon); + // TODO: GeoQueries // TODO: TextSearch diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java index e0d00b04c8..c9c20259a7 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java @@ -45,8 +45,10 @@ import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; import org.springframework.data.domain.Window; +import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; import org.springframework.data.geo.Point; +import org.springframework.data.geo.Polygon; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.MongoTemplate; import org.springframework.data.mongodb.core.aggregation.AggregationResults; @@ -619,13 +621,34 @@ void testNearWithGeoJson() { } @Test - void testGeoWithin() { + void testGeoWithinCircle() { List users = fragment.findByLocationCoordinatesWithin(new Circle(-78.99171, 45.738868, 170)); assertThat(users).extracting(User::getUsername).containsExactly("leia"); } - //List result = repository.findByLocationWithin(); + @Test + void testWithinBox() { + + Box box = new Box(new Point(-78.99171, 35.738868), new Point(-68.99171, 45.738868)); + + List result = fragment.findByLocationCoordinatesWithin(box); + assertThat(result).extracting(User::getUsername).containsExactly("leia"); + } + + @Test + void findsPeopleByLocationWithinPolygon() { + + Point first = new Point(-78.99171, 35.738868); + Point second = new Point(-78.99171, 45.738868); + Point third = new Point(-68.99171, 45.738868); + Point fourth = new Point(-68.99171, 35.738868); + + List result = fragment.findByLocationCoordinatesWithin(new Polygon(first, second, third, fourth)); + assertThat(result).extracting(User::getUsername).containsExactly("leia"); + } + + private static void initUsers() { From 871bea6b129abccd770c96f3a31e412bc58690ad Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Fri, 13 Jun 2025 11:02:19 +0200 Subject: [PATCH 07/24] GeoNearExecution --- .../repository/aot/MongoCodeBlocks.java | 73 ++++++++++++++++++- .../aot/MongoRepositoryContributor.java | 19 +++++ .../repository/aot/NearQueryInteraction.java | 52 +++++++++++++ .../test/java/example/aot/UserRepository.java | 5 ++ .../aot/MongoRepositoryContributorTests.java | 10 +++ 5 files changed, 158 insertions(+), 1 deletion(-) create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/NearQueryInteraction.java diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java index 9153498129..eb851748c8 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java @@ -31,6 +31,7 @@ import org.springframework.data.domain.Sort.Order; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; +import org.springframework.data.geo.Distance; import org.springframework.data.geo.Polygon; import org.springframework.data.mongodb.core.ExecutableFindOperation.FindWithQuery; import org.springframework.data.mongodb.core.ExecutableRemoveOperation.ExecutableRemove; @@ -46,6 +47,7 @@ import org.springframework.data.mongodb.core.query.BasicQuery; import org.springframework.data.mongodb.core.query.BasicUpdate; import org.springframework.data.mongodb.core.query.Collation; +import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.repository.Hint; import org.springframework.data.mongodb.repository.Meta; import org.springframework.data.mongodb.repository.ReadPreference; @@ -151,6 +153,12 @@ static AggregationCodeBlockBuilder aggregationBlockBuilder(AotQueryMethodGenerat return new AggregationCodeBlockBuilder(context, queryMethod); } + static GeoNearCodeBlockBuilder geoNearBlockBuilder(AotQueryMethodGenerationContext context, + MongoQueryMethod queryMethod) { + + return new GeoNearCodeBlockBuilder(context, queryMethod); + } + /** * Builder for generating aggregation execution {@link CodeBlock}. * @@ -467,14 +475,77 @@ CodeBlock build() { } } + static class GeoNearCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private final List arguments; + + private String variableName; + + GeoNearCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + this.arguments = context.getBindableParameterNames().stream().map(CodeBlock::of).collect(Collectors.toList()); + this.queryMethod = queryMethod; + } + + CodeBlock build() { + + CodeBlock.Builder builder = CodeBlock.builder(); + builder.add("\n"); + + String locationParameterName = context.getParameterName(queryMethod.getParameters().getNearIndex()); + + builder.addStatement("$1T $2L = $1T.near($3L)", NearQuery.class, variableName, locationParameterName); + + if (queryMethod.getParameters().getRangeIndex() != -1) { + + String rangeParametername = context.getParameterName(queryMethod.getParameters().getRangeIndex()); + String minVarName = context.localVariable("min"); + String maxVarName = context.localVariable("max"); + + builder.beginControlFlow("if($L.getLowerBound().isPresent())", rangeParametername); + builder.addStatement("$1T $2L = $3L.getLowerBound().get()", Distance.class, minVarName, rangeParametername); + builder.addStatement("$1L.minDistance($2L.getValue()).in($2L.getMetric())", variableName, minVarName); + builder.endControlFlow(); + + builder.beginControlFlow("if($L.getUpperBound().isPresent())", rangeParametername); + builder.addStatement("$1T $2L = $3L.getUpperBound().get()", Distance.class, maxVarName, rangeParametername); + builder.addStatement("$1L.maxDistance($2L.getValue()).in($2L.getMetric())", variableName, maxVarName); + builder.endControlFlow(); + } else { + + String distanceParametername = context.getParameterName(queryMethod.getParameters().getMaxDistanceIndex()); + builder.addStatement("$1L.maxDistance($2L.getValue()).in($2L.getMetric())", variableName, + distanceParametername); + } + + if (context.getPageableParameterName() != null) { + builder.addStatement("$L.with($L)", variableName, context.getPageableParameterName()); + } + + builder.add("\n"); + builder.addStatement("return $L.query($T.class).near($L).all()", context.fieldNameOf(MongoOperations.class), + context.getRepositoryInformation().getDomainType(), variableName); + return builder.build(); + } + + public GeoNearCodeBlockBuilder usingQueryVariableName(String variableName) { + this.variableName = variableName; + return this; + } + } + @NullUnmarked static class AggregationCodeBlockBuilder { private final AotQueryMethodGenerationContext context; private final MongoQueryMethod queryMethod; + private final List arguments; private AggregationInteraction source; - private final List arguments; + private String aggregationVariableName; private boolean pipelineOnly; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index d25d2f7f9a..bd31d5ea1a 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -95,6 +95,11 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB return aggregationMethodContributor(queryMethod, aggregation); } + if(queryMethod.isGeoNearQuery()) { + NearQueryInteraction near = new NearQueryInteraction(); + return nearQueryMethodContributor(queryMethod, near); + } + QueryInteraction query = createStringQuery(getRepositoryInformation(), queryMethod, AnnotatedElementUtils.findMergedAnnotation(method, Query.class), method.getParameterCount()); @@ -181,6 +186,20 @@ private static boolean backoff(MongoQueryMethod method) { return skip; } + private static MethodContributor nearQueryMethodContributor(MongoQueryMethod queryMethod, + NearQueryInteraction interaction) { + + return MethodContributor.forQueryMethod(queryMethod).withMetadata(interaction).contribute(context -> { + + CodeBlock.Builder builder = CodeBlock.builder(); + + builder.add(geoNearBlockBuilder(context, queryMethod).usingQueryVariableName("nearQuery").build()); +// builder.add(aggregationExecutionBlockBuilder(context, queryMethod).referencing("aggregation").build()); + + return builder.build(); + }); + } + private static MethodContributor aggregationMethodContributor(MongoQueryMethod queryMethod, AggregationInteraction aggregation) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/NearQueryInteraction.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/NearQueryInteraction.java new file mode 100644 index 0000000000..23551abddc --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/NearQueryInteraction.java @@ -0,0 +1,52 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import java.util.LinkedHashMap; +import java.util.Map; + +import org.springframework.data.repository.aot.generate.QueryMetadata; +import org.springframework.util.StringUtils; + +/** + * An {@link MongoInteraction} to execute a query. + * + * @author Christoph Strobl + * @since 5.0 + */ +class NearQueryInteraction extends MongoInteraction implements QueryMetadata { + + private final InteractionType interactionType; + + NearQueryInteraction() { + interactionType = InteractionType.QUERY; + } + + @Override + InteractionType getExecutionType() { + return interactionType; + } + + @Override + public Map serialize() { + + Map serialized = new LinkedHashMap<>(); + + + + return serialized; + } +} diff --git a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java index 8493b6b7b4..868203a103 100644 --- a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java +++ b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java @@ -34,11 +34,14 @@ import org.springframework.data.domain.Window; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; +import org.springframework.data.geo.Distance; +import org.springframework.data.geo.GeoResults; import org.springframework.data.geo.Point; import org.springframework.data.geo.Polygon; import org.springframework.data.mongodb.core.aggregation.AggregationResults; import org.springframework.data.mongodb.repository.Aggregation; import org.springframework.data.mongodb.repository.Hint; +import org.springframework.data.mongodb.repository.Person; import org.springframework.data.mongodb.repository.Query; import org.springframework.data.mongodb.repository.ReadPreference; import org.springframework.data.mongodb.repository.Update; @@ -115,6 +118,8 @@ public interface UserRepository extends CrudRepository { List findByLocationCoordinatesWithin(Polygon polygon); + GeoResults findByLocationCoordinatesNear(Point point, Distance maxDistance); + // TODO: GeoQueries // TODO: TextSearch diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java index c9c20259a7..d9e7de2cc5 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java @@ -47,6 +47,9 @@ import org.springframework.data.domain.Window; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; +import org.springframework.data.geo.Distance; +import org.springframework.data.geo.GeoResults; +import org.springframework.data.geo.Metrics; import org.springframework.data.geo.Point; import org.springframework.data.geo.Polygon; import org.springframework.data.mongodb.core.MongoOperations; @@ -649,6 +652,13 @@ void findsPeopleByLocationWithinPolygon() { } + @Test + void testNearWithGeoResult() { + + GeoResults users = fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), Distance.of(2000, Metrics.KILOMETERS)); + System.out.println("users: " + users); + assertThat(users).isNotEmpty(); + } private static void initUsers() { From 25ae2fa1aa6a4a1186fe83a78f3721b507c56b55 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Fri, 13 Jun 2025 11:19:15 +0200 Subject: [PATCH 08/24] use positional parameters to simplify javapoet statements --- .../repository/aot/MongoCodeBlocks.java | 80 +++++++++---------- 1 file changed, 37 insertions(+), 43 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java index eb851748c8..f0ac290f96 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java @@ -203,8 +203,8 @@ CodeBlock build() { Object actualReturnType = isProjecting ? context.getActualReturnType().getType() : domainType; builder.add("\n"); - builder.addStatement("$T<$T> $L = $L.remove($T.class)", ExecutableRemove.class, domainType, - context.localVariable("remover"), mongoOpsRef, domainType); + builder.addStatement("$1T<$2T> $3L = $4L.remove($2T.class)", ExecutableRemove.class, domainType, + context.localVariable("remover"), mongoOpsRef); DeleteExecution.Type type = DeleteExecution.Type.FIND_AND_REMOVE_ALL; if (!queryMethod.isCollectionQuery()) { @@ -270,8 +270,8 @@ CodeBlock build() { String updateReference = updateVariableName; Class domainType = context.getRepositoryInformation().getDomainType(); - builder.addStatement("$T<$T> $L = $L.update($T.class)", ExecutableUpdate.class, domainType, - context.localVariable("updater"), mongoOpsRef, domainType); + builder.addStatement("$1T<$2T> $3L = $4L.update($2T.class)", ExecutableUpdate.class, domainType, + context.localVariable("updater"), mongoOpsRef); Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); if (ReflectionUtils.isVoid(returnType)) { @@ -344,16 +344,17 @@ CodeBlock build() { builder.addStatement("$T<$T> $L = $L.aggregateStream($L, $T.class)", Stream.class, Document.class, context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); - builder.addStatement("return $L.map(it -> ($T) convertSimpleRawResult($T.class, it))", - context.localVariable("results"), returnType, returnType); + builder.addStatement("return $1L.map(it -> ($2T) convertSimpleRawResult($2T.class, it))", + context.localVariable("results"), returnType); } else { builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class, context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); if (!queryMethod.isCollectionQuery()) { - builder.addStatement("return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))", - CollectionUtils.class, returnType, returnType, context.localVariable("results")); + builder.addStatement( + "return $1T.<$2T>firstElement(convertSimpleRawResults($2T.class, $3L.getMappedResults()))", + CollectionUtils.class, returnType, context.localVariable("results")); } else { builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType, context.localVariable("results")); @@ -366,10 +367,9 @@ CodeBlock build() { builder.addStatement("boolean $L = $L.getMappedResults().size() > $L.getPageSize()", context.localVariable("hasNext"), context.localVariable("results"), context.getPageableParameterName()); builder.addStatement( - "return new $T<>($L ? $L.getMappedResults().subList(0, $L.getPageSize()) : $L.getMappedResults(), $L, $L)", + "return new $1T<>($2L ? $3L.getMappedResults().subList(0, $4L.getPageSize()) : $3L.getMappedResults(), $4L, $2L)", SliceImpl.class, context.localVariable("hasNext"), context.localVariable("results"), - context.getPageableParameterName(), context.localVariable("results"), context.getPageableParameterName(), - context.localVariable("hasNext")); + context.getPageableParameterName()); } else { if (queryMethod.isStreamQuery()) { @@ -584,9 +584,9 @@ CodeBlock build() { if (!pipelineOnly) { - builder.addStatement("$T<$T> $L = $T.newAggregation($T.class, $L.getOperations())", TypedAggregation.class, - context.getRepositoryInformation().getDomainType(), aggregationVariableName, Aggregation.class, - context.getRepositoryInformation().getDomainType(), pipelineName); + builder.addStatement("$1T<$2T> $3L = $4T.newAggregation($2T.class, $5L.getOperations())", + TypedAggregation.class, context.getRepositoryInformation().getDomainType(), aggregationVariableName, + Aggregation.class, pipelineName); builder.add(aggregationOptions(aggregationVariableName)); } @@ -662,8 +662,8 @@ private CodeBlock aggregationOptions(String aggregationVariableName) { if (!options.isEmpty()) { Builder optionsBuilder = CodeBlock.builder(); - optionsBuilder.add("$T $L = $T.builder()\n", AggregationOptions.class, - context.localVariable("aggregationOptions"), AggregationOptions.class); + optionsBuilder.add("$1T $2L = $1T.builder()\n", AggregationOptions.class, + context.localVariable("aggregationOptions")); optionsBuilder.indent(); for (CodeBlock optionBlock : options) { optionsBuilder.add(optionBlock); @@ -673,7 +673,7 @@ private CodeBlock aggregationOptions(String aggregationVariableName) { optionsBuilder.unindent(); builder.add(optionsBuilder.build()); - builder.addStatement("$L = $L.withOptions($L)", aggregationVariableName, aggregationVariableName, + builder.addStatement("$1L = $1L.withOptions($2L)", aggregationVariableName, context.localVariable("aggregationOptions")); } return builder.build(); @@ -701,10 +701,10 @@ private CodeBlock sortingStage(String sortProvider) { Builder builder = CodeBlock.builder(); builder.beginControlFlow("if ($L.isSorted())", sortProvider); - builder.addStatement("$T $L = new $T()", Document.class, context.localVariable("sortDocument"), Document.class); + builder.addStatement("$1T $2L = new $1T()", Document.class, context.localVariable("sortDocument")); builder.beginControlFlow("for ($T $L : $L)", Order.class, context.localVariable("order"), sortProvider); - builder.addStatement("$L.append($L.getProperty(), $L.isAscending() ? 1 : -1);", - context.localVariable("sortDocument"), context.localVariable("order"), context.localVariable("order")); + builder.addStatement("$1L.append($2L.getProperty(), $2L.isAscending() ? 1 : -1);", + context.localVariable("sortDocument"), context.localVariable("order")); builder.endControlFlow(); builder.addStatement("stages.add(new $T($S, $L))", Document.class, "$sort", context.localVariable("sortDocument")); @@ -768,28 +768,26 @@ static class QueryCodeBlockBuilder { for (MongoParameter parameter : queryMethod.getParameters().getBindableParameters()) { String parameterName = context.getParameterName(parameter.getIndex()); if (ClassUtils.isAssignable(Circle.class, parameter.getType())) { - arguments.add(CodeBlock.builder() - .add("$T.of($T.of($L.getCenter().getX(), $L.getCenter().getY()), $L.getRadius().getNormalizedValue())", - List.class, List.class, parameterName, parameterName, parameterName) - .build()); + arguments.add(CodeBlock.builder().add( + "$1T.of($1T.of($2L.getCenter().getX(), $2L.getCenter().getY()), $2L.getRadius().getNormalizedValue())", + List.class, parameterName).build()); } else if (ClassUtils.isAssignable(Box.class, parameter.getType())) { // { $geoWithin: { $box: [ [ , ], [ , ] ] } arguments.add(CodeBlock.builder().add( - "$T.of($T.of($L.getFirst().getX(), $L.getFirst().getY()), $T.of($L.getSecond().getX(), $L.getSecond().getY()))", - List.class, List.class, parameterName, parameterName, List.class, parameterName, parameterName).build()); + "$1T.of($1T.of($2L.getFirst().getX(), $2L.getFirst().getY()), $1T.of($2L.getSecond().getX(), $2L.getSecond().getY()))", + List.class, parameterName).build()); } else if (ClassUtils.isAssignable(Sphere.class, parameter.getType())) { // { $centerSphere: [ [ , ], ] } - arguments.add(CodeBlock.builder() - .add("$T.of($T.of($L.getCenter().getX(), $L.getCenter().getY()), $L.getRadius().getNormalizedValue())", - List.class, List.class, parameterName, parameterName, parameterName) - .build()); + arguments.add(CodeBlock.builder().add( + "$1T.of($1T.of($2L.getCenter().getX(), $2L.getCenter().getY()), $2L.getRadius().getNormalizedValue())", + List.class, parameterName).build()); } else if (ClassUtils.isAssignable(Polygon.class, parameter.getType())) { // $polygon: [ [ , ], [ , ], [ , ], ... ] String localVar = context.localVariable("_p"); - arguments - .add(CodeBlock.builder().add("$L.getPoints().stream().map($L -> $T.of($L.getX(), $L.getY())).toList()", - parameterName, localVar, List.class, localVar, localVar).build()); + arguments.add( + CodeBlock.builder().add("$1L.getPoints().stream().map($2L -> $3T.of($2L.getX(), $2L.getY())).toList()", + parameterName, localVar, List.class).build()); } else { @@ -890,11 +888,10 @@ private CodeBlock renderExpressionToQuery(@Nullable String source, String variab Builder builder = CodeBlock.builder(); if (!StringUtils.hasText(source)) { - builder.addStatement("$T $L = new $T(new $T())", BasicQuery.class, variableName, BasicQuery.class, - Document.class); + builder.addStatement("$1T $2L = new $1T(new $3T())", BasicQuery.class, variableName, Document.class); } else if (!containsPlaceholder(source)) { - builder.addStatement("$T $L = new $T($T.parse($S))", BasicQuery.class, variableName, BasicQuery.class, - Document.class, source); + builder.addStatement("$1T $2L = new $1T($3T.parse($4S))", BasicQuery.class, variableName, Document.class, + source); } else { builder.add("$T $L = createQuery($S, new $T[]{ ", BasicQuery.class, variableName, source, Object.class); Iterator iterator = arguments.iterator(); @@ -939,8 +936,7 @@ CodeBlock build() { builder.add("\n"); String tmpVariableName = updateVariableName + "Document"; builder.add(renderExpressionToDocument(source.getUpdate().getUpdateString(), tmpVariableName, arguments)); - builder.addStatement("$T $L = new $T($L)", BasicUpdate.class, updateVariableName, BasicUpdate.class, - tmpVariableName); + builder.addStatement("$1T $2L = new $1T($3L)", BasicUpdate.class, updateVariableName, tmpVariableName); return builder.build(); } @@ -951,9 +947,9 @@ private static CodeBlock renderExpressionToDocument(@Nullable String source, Str Builder builder = CodeBlock.builder(); if (!StringUtils.hasText(source)) { - builder.addStatement("$T $L = new $T()", Document.class, variableName, Document.class); + builder.addStatement("$1T $2L = new $1T()", Document.class, variableName); } else if (!containsPlaceholder(source)) { - builder.addStatement("$T $L = $T.parse($S)", Document.class, variableName, Document.class, source); + builder.addStatement("$1T $2L = $1T.parse($3S)", Document.class, variableName, source); } else { builder.add("$T $L = bindParameters($S, new $T[]{ ", Document.class, variableName, source, Object.class); @@ -965,8 +961,6 @@ private static CodeBlock renderExpressionToDocument(@Nullable String source, Str } } builder.add("});\n"); - // builder.addStatement("$T $L = bindParameters($S, new $T[]{ $L })", Document.class, variableName, source, - // Object.class, StringUtils.collectionToDelimitedString(arguments, ", ")); } return builder.build(); } From 7689e5ba56da6777b043840bea1e740cc8869e99 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Fri, 13 Jun 2025 13:05:02 +0200 Subject: [PATCH 09/24] geo resut conversion -> still need to have the paging variant --- .../repository/aot/MongoCodeBlocks.java | 39 +++++++++--- .../aot/MongoRepositoryContributor.java | 2 +- .../test/java/example/aot/UserRepository.java | 10 ++- .../aot/MongoRepositoryContributorTests.java | 62 ++++++++++++++++--- 4 files changed, 91 insertions(+), 22 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java index f0ac290f96..e210cc8a2e 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java @@ -32,6 +32,8 @@ import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; import org.springframework.data.geo.Distance; +import org.springframework.data.geo.GeoPage; +import org.springframework.data.geo.GeoResults; import org.springframework.data.geo.Polygon; import org.springframework.data.mongodb.core.ExecutableFindOperation.FindWithQuery; import org.springframework.data.mongodb.core.ExecutableRemoveOperation.ExecutableRemove; @@ -505,20 +507,21 @@ CodeBlock build() { String minVarName = context.localVariable("min"); String maxVarName = context.localVariable("max"); - builder.beginControlFlow("if($L.getLowerBound().isPresent())", rangeParametername); - builder.addStatement("$1T $2L = $3L.getLowerBound().get()", Distance.class, minVarName, rangeParametername); - builder.addStatement("$1L.minDistance($2L.getValue()).in($2L.getMetric())", variableName, minVarName); + builder.beginControlFlow("if($L.getLowerBound().isBounded())", rangeParametername); + builder.addStatement("$1T $2L = $3L.getLowerBound().getValue().get()", Distance.class, minVarName, + rangeParametername); + builder.addStatement("$1L.minDistance($2L).in($2L.getMetric())", variableName, minVarName); builder.endControlFlow(); - builder.beginControlFlow("if($L.getUpperBound().isPresent())", rangeParametername); - builder.addStatement("$1T $2L = $3L.getUpperBound().get()", Distance.class, maxVarName, rangeParametername); - builder.addStatement("$1L.maxDistance($2L.getValue()).in($2L.getMetric())", variableName, maxVarName); + builder.beginControlFlow("if($L.getUpperBound().isBounded())", rangeParametername); + builder.addStatement("$1T $2L = $3L.getUpperBound().getValue().get()", Distance.class, maxVarName, + rangeParametername); + builder.addStatement("$1L.maxDistance($2L).in($2L.getMetric())", variableName, maxVarName); builder.endControlFlow(); } else { String distanceParametername = context.getParameterName(queryMethod.getParameters().getMaxDistanceIndex()); - builder.addStatement("$1L.maxDistance($2L.getValue()).in($2L.getMetric())", variableName, - distanceParametername); + builder.addStatement("$1L.maxDistance($2L).in($2L.getMetric())", variableName, distanceParametername); } if (context.getPageableParameterName() != null) { @@ -526,8 +529,24 @@ CodeBlock build() { } builder.add("\n"); - builder.addStatement("return $L.query($T.class).near($L).all()", context.fieldNameOf(MongoOperations.class), - context.getRepositoryInformation().getDomainType(), variableName); + + // TODO: move the section below into dedicated executor builder + if (ClassUtils.isAssignable(GeoPage.class, context.getReturnType().getRawClass())) { + builder.addStatement("return new $T<>($L.query($T.class).near($L).all())", GeoPage.class, + context.fieldNameOf(MongoOperations.class), context.getRepositoryInformation().getDomainType(), + variableName); + } + + else if (ClassUtils.isAssignable(GeoResults.class, context.getReturnType().getRawClass())) { + + builder.addStatement("return $L.query($T.class).near($L).all()", context.fieldNameOf(MongoOperations.class), + context.getRepositoryInformation().getDomainType(), variableName); + } else { + builder.addStatement("return $L.query($T.class).near($L).all().getContent()", + context.fieldNameOf(MongoOperations.class), context.getRepositoryInformation().getDomainType(), + variableName); + } + return builder.build(); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index bd31d5ea1a..00815f97d9 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -95,7 +95,7 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB return aggregationMethodContributor(queryMethod, aggregation); } - if(queryMethod.isGeoNearQuery()) { + if(queryMethod.isGeoNearQuery() || (queryMethod.getParameters().getMaxDistanceIndex() != -1 && queryMethod.getReturnType().isCollectionLike())) { NearQueryInteraction near = new NearQueryInteraction(); return nearQueryMethodContributor(queryMethod, near); } diff --git a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java index 868203a103..e1fa7dbb4f 100644 --- a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java +++ b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java @@ -28,6 +28,7 @@ import org.springframework.data.domain.Limit; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Range; import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; @@ -35,6 +36,8 @@ import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; import org.springframework.data.geo.Distance; +import org.springframework.data.geo.GeoPage; +import org.springframework.data.geo.GeoResult; import org.springframework.data.geo.GeoResults; import org.springframework.data.geo.Point; import org.springframework.data.geo.Polygon; @@ -120,7 +123,12 @@ public interface UserRepository extends CrudRepository { GeoResults findByLocationCoordinatesNear(Point point, Distance maxDistance); - // TODO: GeoQueries + List> findUserAsListByLocationCoordinatesNear(Point point, Distance maxDistance); + + GeoResults findByLocationCoordinatesNear(Point point, Range distance); + + GeoPage findByLocationCoordinatesNear(Point point, Distance maxDistance, Pageable pageable); + // TODO: TextSearch /* Annotated Queries */ diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java index d9e7de2cc5..1e1729c2bc 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java @@ -41,6 +41,7 @@ import org.springframework.data.domain.OffsetScrollPosition; import org.springframework.data.domain.Page; import org.springframework.data.domain.PageRequest; +import org.springframework.data.domain.Range; import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; @@ -48,6 +49,8 @@ import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; import org.springframework.data.geo.Distance; +import org.springframework.data.geo.GeoPage; +import org.springframework.data.geo.GeoResult; import org.springframework.data.geo.GeoResults; import org.springframework.data.geo.Metrics; import org.springframework.data.geo.Point; @@ -96,7 +99,7 @@ MongoOperations mongoOperations() { @BeforeAll static void beforeAll() { client.getDatabase(DB_NAME).getCollection("user").createIndex(new Document("location.coordinates", "2d"), - new IndexOptions()); + new IndexOptions()); } @BeforeEach @@ -613,21 +616,21 @@ void testAggregationWithCollation() { void testNear() { List users = fragment.findByLocationCoordinatesNear(new Point(-73.99171, 40.738868)); - assertThat(users).extracting(User::getUsername).containsExactly("leia"); + assertThat(users).extracting(User::getUsername).containsExactly("leia", "vader"); } @Test void testNearWithGeoJson() { List users = fragment.findByLocationCoordinatesNear(new GeoJsonPoint(-73.99171, 40.738868)); - assertThat(users).extracting(User::getUsername).containsExactly("leia"); + assertThat(users).extracting(User::getUsername).containsExactly("leia", "vader"); } @Test void testGeoWithinCircle() { List users = fragment.findByLocationCoordinatesWithin(new Circle(-78.99171, 45.738868, 170)); - assertThat(users).extracting(User::getUsername).containsExactly("leia"); + assertThat(users).extracting(User::getUsername).containsExactly("leia", "vader"); } @Test @@ -636,7 +639,7 @@ void testWithinBox() { Box box = new Box(new Point(-78.99171, 35.738868), new Point(-68.99171, 45.738868)); List result = fragment.findByLocationCoordinatesWithin(box); - assertThat(result).extracting(User::getUsername).containsExactly("leia"); + assertThat(result).extracting(User::getUsername).containsExactly("leia", "vader"); } @Test @@ -648,18 +651,51 @@ void findsPeopleByLocationWithinPolygon() { Point fourth = new Point(-68.99171, 35.738868); List result = fragment.findByLocationCoordinatesWithin(new Polygon(first, second, third, fourth)); - assertThat(result).extracting(User::getUsername).containsExactly("leia"); + assertThat(result).extracting(User::getUsername).containsExactly("leia", "vader"); } - @Test void testNearWithGeoResult() { - GeoResults users = fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), Distance.of(2000, Metrics.KILOMETERS)); - System.out.println("users: " + users); - assertThat(users).isNotEmpty(); + GeoResults users = fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), + Distance.of(5, Metrics.KILOMETERS)); + assertThat(users).extracting(GeoResult::getContent).extracting(User::getUsername).containsExactly("leia"); + } + + @Test + void testNearReturningListOfGeoResult() { + + List> users = fragment.findUserAsListByLocationCoordinatesNear(new Point(-73.99, 40.73), + Distance.of(5, Metrics.KILOMETERS)); + assertThat(users).extracting(GeoResult::getContent).extracting(User::getUsername).containsExactly("leia"); + } + + @Test + void testNearWithRange() { + + Range range = Distance.between(Distance.of(5, Metrics.KILOMETERS), Distance.of(2000, Metrics.KILOMETERS)); + GeoResults users = fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), range); + + assertThat(users).extracting(GeoResult::getContent).extracting(User::getUsername).containsExactly("vader"); } + @Test + void testNearReturningGeoPage() { + + // TODO: still need to create the count and extract the total elements + GeoPage page1 = fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), + Distance.of(2000, Metrics.KILOMETERS), PageRequest.of(0, 1)); + + assertThat(page1.hasNext()).isTrue(); + + GeoPage page2 = fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), + Distance.of(2000, Metrics.KILOMETERS), page1.nextPageable()); + assertThat(page2.hasNext()).isFalse(); + } + + /** + * GeoResults results = repository.findPersonByLocationNear(new Point(-73.99, 40.73), range); + */ private static void initUsers() { Document luke = Document.parse(""" @@ -753,6 +789,12 @@ private static void initUsers() { "username": "vader", "first_name": "Anakin", "last_name": "Skywalker", + "location" : { + "planet" : "Death Star", + "coordinates" : { + "x" : -73.9, "y" : 40.7 + } + }, "visits" : 50, "posts": [ { From e951a3bc93fa2c3c58325be1b596655cc980eab2 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Fri, 13 Jun 2025 18:36:04 +0200 Subject: [PATCH 10/24] GeoJsonPolygon et al --- .../data/mongodb/core/query/GeoCommand.java | 4 ++ .../repository/aot/AotQueryCreator.java | 55 ++++++++++++++++--- .../repository/aot/MongoCodeBlocks.java | 10 ++-- .../aot/MongoRepositoryContributor.java | 25 ++++++--- .../data/mongodb/util/BsonUtils.java | 22 ++++---- .../test/java/example/aot/UserRepository.java | 5 +- .../aot/MongoRepositoryContributorTests.java | 16 ++++++ 7 files changed, 105 insertions(+), 32 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/GeoCommand.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/GeoCommand.java index 19ecd94e23..4b8f81ef2b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/GeoCommand.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/GeoCommand.java @@ -22,6 +22,7 @@ import org.springframework.data.geo.Circle; import org.springframework.data.geo.Polygon; import org.springframework.data.geo.Shape; +import org.springframework.data.mongodb.core.geo.GeoJson; import org.springframework.data.mongodb.core.geo.Sphere; import org.springframework.util.Assert; @@ -75,6 +76,9 @@ private String getCommand(Shape shape) { Assert.notNull(shape, "Shape must not be null"); + if(shape instanceof GeoJson) { + return "$geometry"; + } if (shape instanceof Box) { return "$box"; } else if (shape instanceof Circle) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java index 6904c868ce..2b22550026 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java @@ -37,6 +37,7 @@ import org.springframework.data.geo.Shape; import org.springframework.data.mongodb.core.convert.MongoCustomConversions; import org.springframework.data.mongodb.core.convert.MongoWriter; +import org.springframework.data.mongodb.core.geo.GeoJson; import org.springframework.data.mongodb.core.geo.Sphere; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; @@ -53,9 +54,9 @@ import org.springframework.data.repository.query.QueryMethod; import org.springframework.data.repository.query.parser.PartTree; import org.springframework.data.util.TypeInformation; +import org.springframework.util.ClassUtils; import com.mongodb.DBRef; -import org.springframework.util.ClassUtils; /** * @author Christoph Strobl @@ -133,19 +134,21 @@ public PlaceholderParameterAccessor(QueryMethod queryMethod) { } else { placeholders = new ArrayList<>(); Parameters parameters = queryMethod.getParameters(); - for(Parameter parameter : parameters.toList()) { - if(ClassUtils.isAssignable(Point.class, parameter.getType())) { + for (Parameter parameter : parameters.toList()) { + if (ClassUtils.isAssignable(GeoJson.class, parameter.getType())) { + placeholders.add(parameter.getIndex(), new GeoJsonPlaceholder(parameter.getIndex(), "")); + } + else if (ClassUtils.isAssignable(Point.class, parameter.getType())) { placeholders.add(parameter.getIndex(), new PointPlaceholder(parameter.getIndex())); - } else if(ClassUtils.isAssignable(Circle.class, parameter.getType())) { + } else if (ClassUtils.isAssignable(Circle.class, parameter.getType())) { placeholders.add(parameter.getIndex(), new CirclePlaceholder(parameter.getIndex())); - } else if(ClassUtils.isAssignable(Box.class, parameter.getType())) { + } else if (ClassUtils.isAssignable(Box.class, parameter.getType())) { placeholders.add(parameter.getIndex(), new BoxPlaceholder(parameter.getIndex())); - } else if(ClassUtils.isAssignable(Sphere.class, parameter.getType())) { + } else if (ClassUtils.isAssignable(Sphere.class, parameter.getType())) { placeholders.add(parameter.getIndex(), new SpherePlaceholder(parameter.getIndex())); - } else if(ClassUtils.isAssignable(Polygon.class, parameter.getType())) { + } else if (ClassUtils.isAssignable(Polygon.class, parameter.getType())) { placeholders.add(parameter.getIndex(), new PolygonPlaceholder(parameter.getIndex())); } - else { placeholders.add(parameter.getIndex(), Placeholder.indexed(parameter.getIndex())); } @@ -238,6 +241,7 @@ public Iterator iterator() { static class CirclePlaceholder extends Circle implements Placeholder { int index; + public CirclePlaceholder(int index) { super(new PointPlaceholder(index), Distance.of(1, Metrics.NEUTRAL)); // this.index = index; @@ -257,6 +261,7 @@ public String toString() { static class SpherePlaceholder extends Sphere implements Placeholder { int index; + public SpherePlaceholder(int index) { super(new PointPlaceholder(index), Distance.of(1, Metrics.NEUTRAL)); // this.index = index; @@ -273,6 +278,37 @@ public String toString() { } } + static class GeoJsonPlaceholder implements Placeholder, GeoJson>, Shape { + + int index; + String type; + + public GeoJsonPlaceholder(int index, String type) { + this.index = index; + this.type = type; + } + + @Override + public Object getValue() { + return "?%s".formatted(index); + } + + @Override + public String toString() { + return getValue().toString(); + } + + @Override + public String getType() { + return type; + } + + @Override + public List getCoordinates() { + return List.of(); + } + } + static class BoxPlaceholder extends Box implements Placeholder { int index; @@ -296,7 +332,8 @@ static class PolygonPlaceholder extends Polygon implements Placeholder { int index; public PolygonPlaceholder(int index) { - super(new PointPlaceholder(index), new PointPlaceholder(index), new PointPlaceholder(index), new PointPlaceholder(index)); + super(new PointPlaceholder(index), new PointPlaceholder(index), new PointPlaceholder(index), + new PointPlaceholder(index)); this.index = index; } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java index e210cc8a2e..331768e0b8 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java @@ -44,6 +44,7 @@ import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; import org.springframework.data.mongodb.core.aggregation.AggregationResults; import org.springframework.data.mongodb.core.aggregation.TypedAggregation; +import org.springframework.data.mongodb.core.geo.GeoJson; import org.springframework.data.mongodb.core.geo.Sphere; import org.springframework.data.mongodb.core.mapping.MongoSimpleTypes; import org.springframework.data.mongodb.core.query.BasicQuery; @@ -786,7 +787,9 @@ static class QueryCodeBlockBuilder { this.arguments = new ArrayList<>(); for (MongoParameter parameter : queryMethod.getParameters().getBindableParameters()) { String parameterName = context.getParameterName(parameter.getIndex()); - if (ClassUtils.isAssignable(Circle.class, parameter.getType())) { + if (ClassUtils.isAssignable(GeoJson.class, parameter.getType())) { + arguments.add(CodeBlock.of(parameterName)); + } else if (ClassUtils.isAssignable(Circle.class, parameter.getType())) { arguments.add(CodeBlock.builder().add( "$1T.of($1T.of($2L.getCenter().getX(), $2L.getCenter().getY()), $2L.getRadius().getNormalizedValue())", List.class, parameterName).build()); @@ -802,14 +805,13 @@ static class QueryCodeBlockBuilder { "$1T.of($1T.of($2L.getCenter().getX(), $2L.getCenter().getY()), $2L.getRadius().getNormalizedValue())", List.class, parameterName).build()); } else if (ClassUtils.isAssignable(Polygon.class, parameter.getType())) { + // $polygon: [ [ , ], [ , ], [ , ], ... ] String localVar = context.localVariable("_p"); arguments.add( CodeBlock.builder().add("$1L.getPoints().stream().map($2L -> $3T.of($2L.getX(), $2L.getY())).toList()", parameterName, localVar, List.class).build()); - } - - else { + } else { arguments.add(CodeBlock.of(parameterName)); } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index 00815f97d9..c510ab8e16 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -15,7 +15,15 @@ */ package org.springframework.data.mongodb.repository.aot; -import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.*; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.QueryCodeBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.aggregationBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.aggregationExecutionBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.deleteExecutionBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.geoNearBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.queryBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.queryExecutionBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.updateBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.updateExecutionBlockBuilder; import java.lang.reflect.Method; import java.util.Locale; @@ -25,6 +33,7 @@ import org.apache.commons.logging.LogFactory; import org.jspecify.annotations.Nullable; import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.data.geo.GeoPage; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.aggregation.AggregationUpdate; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; @@ -42,6 +51,7 @@ import org.springframework.data.repository.query.parser.PartTree; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.TypeName; +import org.springframework.util.ClassUtils; import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; @@ -95,7 +105,8 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB return aggregationMethodContributor(queryMethod, aggregation); } - if(queryMethod.isGeoNearQuery() || (queryMethod.getParameters().getMaxDistanceIndex() != -1 && queryMethod.getReturnType().isCollectionLike())) { + if (queryMethod.isGeoNearQuery() || (queryMethod.getParameters().getMaxDistanceIndex() != -1 + && queryMethod.getReturnType().isCollectionLike())) { NearQueryInteraction near = new NearQueryInteraction(); return nearQueryMethodContributor(queryMethod, near); } @@ -159,8 +170,8 @@ private QueryInteraction createStringQuery(RepositoryInformation repositoryInfor } else { PartTree partTree = new PartTree(queryMethod.getName(), repositoryInformation.getDomainType()); - query = new QueryInteraction(queryCreator.createQuery(partTree, queryMethod), - partTree.isCountProjection(), partTree.isDelete(), partTree.isExistsProjection()); + query = new QueryInteraction(queryCreator.createQuery(partTree, queryMethod), partTree.isCountProjection(), + partTree.isDelete(), partTree.isExistsProjection()); } if (queryAnnotation != null && StringUtils.hasText(queryAnnotation.sort())) { @@ -177,7 +188,7 @@ private static boolean backoff(MongoQueryMethod method) { // TODO: namedQuery, Regex queries, queries accepting Shapes (e.g. within) or returning arrays. boolean skip = method.isSearchQuery() || method.getName().toLowerCase(Locale.ROOT).contains("regex") - || method.getReturnType().getType().isArray(); + || method.getReturnType().getType().isArray() || ClassUtils.isAssignable(GeoPage.class, method.getReturnType().getType()); if (skip && logger.isDebugEnabled()) { logger.debug("Skipping AOT generation for [%s]. Method is either returning an array or a geo-near, regex query" @@ -187,14 +198,14 @@ private static boolean backoff(MongoQueryMethod method) { } private static MethodContributor nearQueryMethodContributor(MongoQueryMethod queryMethod, - NearQueryInteraction interaction) { + NearQueryInteraction interaction) { return MethodContributor.forQueryMethod(queryMethod).withMetadata(interaction).contribute(context -> { CodeBlock.Builder builder = CodeBlock.builder(); builder.add(geoNearBlockBuilder(context, queryMethod).usingQueryVariableName("nearQuery").build()); -// builder.add(aggregationExecutionBlockBuilder(context, queryMethod).referencing("aggregation").build()); + // builder.add(aggregationExecutionBlockBuilder(context, queryMethod).referencing("aggregation").build()); return builder.build(); }); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java index 2298a8b74e..eb052da9a4 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java @@ -71,7 +71,6 @@ import org.jspecify.annotations.NullUnmarked; import org.jspecify.annotations.Nullable; import org.springframework.core.convert.converter.Converter; -import org.springframework.data.geo.Circle; import org.springframework.data.mongodb.CodecRegistryProvider; import org.springframework.data.mongodb.core.mapping.FieldName; import org.springframework.data.mongodb.core.mapping.FieldName.Type; @@ -1094,17 +1093,18 @@ public GeoCommand decode(BsonReader reader, DecoderContext decoderContext) { public void encode(BsonWriter writer, GeoCommand value, EncoderContext encoderContext) { if (writer instanceof SpringJsonWriter sjw) { - writer.writeStartDocument(); - writer.writeName(value.getCommand()); - if (value.getShape() instanceof Placeholder p) { // maybe we should wrap input to use geo command object - sjw.writePlaceholder(p.toString()); -// Circle c = null; -// List.of(c.getCenter(), c.getRadius()) -// ; - -// createQuery("{'location.coordinates':{'$geoWithin':{'$center':?0}}}", new Object[]{ List.of(circle.getCenter(), circle.getRadius())) + if (!value.getCommand().equals("$geometry")) { + writer.writeStartDocument(); + writer.writeName(value.getCommand()); + if (value.getShape() instanceof Placeholder p) { // maybe we should wrap input to use geo command object + sjw.writePlaceholder(p.toString()); + } + writer.writeEndDocument(); + } else { + if (value.getShape() instanceof Placeholder p) { // maybe we should wrap input to use geo command object + sjw.writePlaceholder(p.toString()); + } } - writer.writeEndDocument(); } else { writer.writeString(value.getCommand(), value.getShape().toString()); } diff --git a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java index e1fa7dbb4f..2fc0787a67 100644 --- a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java +++ b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java @@ -42,6 +42,7 @@ import org.springframework.data.geo.Point; import org.springframework.data.geo.Polygon; import org.springframework.data.mongodb.core.aggregation.AggregationResults; +import org.springframework.data.mongodb.core.geo.GeoJsonPolygon; import org.springframework.data.mongodb.repository.Aggregation; import org.springframework.data.mongodb.repository.Hint; import org.springframework.data.mongodb.repository.Person; @@ -121,6 +122,8 @@ public interface UserRepository extends CrudRepository { List findByLocationCoordinatesWithin(Polygon polygon); + List findByLocationCoordinatesWithin(GeoJsonPolygon polygon); + GeoResults findByLocationCoordinatesNear(Point point, Distance maxDistance); List> findUserAsListByLocationCoordinatesNear(Point point, Distance maxDistance); @@ -128,7 +131,7 @@ public interface UserRepository extends CrudRepository { GeoResults findByLocationCoordinatesNear(Point point, Range distance); GeoPage findByLocationCoordinatesNear(Point point, Distance maxDistance, Pageable pageable); - + // TODO: TextSearch /* Annotated Queries */ diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java index 1e1729c2bc..6ac9a61458 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java @@ -31,6 +31,7 @@ import org.bson.Document; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.springframework.beans.factory.annotation.Autowired; @@ -59,6 +60,7 @@ import org.springframework.data.mongodb.core.MongoTemplate; import org.springframework.data.mongodb.core.aggregation.AggregationResults; import org.springframework.data.mongodb.core.geo.GeoJsonPoint; +import org.springframework.data.mongodb.core.geo.GeoJsonPolygon; import org.springframework.data.mongodb.test.util.Client; import org.springframework.data.mongodb.test.util.MongoClientExtension; import org.springframework.data.mongodb.test.util.MongoTestUtils; @@ -654,6 +656,19 @@ void findsPeopleByLocationWithinPolygon() { assertThat(result).extracting(User::getUsername).containsExactly("leia", "vader"); } + @Test + void findsPeopleByLocationWithinGeoJsonPolygon() { + + Point first = new Point(-78.99171, 35.738868); + Point second = new Point(-78.99171, 45.738868); + Point third = new Point(-68.99171, 45.738868); + Point fourth = new Point(-68.99171, 35.738868); + + List result = fragment + .findByLocationCoordinatesWithin(new GeoJsonPolygon(first, second, third, fourth, first)); + assertThat(result).extracting(User::getUsername).containsExactly("leia", "vader"); + } + @Test void testNearWithGeoResult() { @@ -680,6 +695,7 @@ void testNearWithRange() { } @Test + @Disabled("too complicated") void testNearReturningGeoPage() { // TODO: still need to create the count and extract the total elements From 10ba08d45f7747d5396057eb30dc862d10f480ed Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Mon, 16 Jun 2025 10:26:09 +0200 Subject: [PATCH 11/24] refactor a tiny bit --- .../repository/aot/AggregationBlocks.java | 360 ++++++++ .../mongodb/repository/aot/DeleteBlocks.java | 100 ++ .../mongodb/repository/aot/GeoBlocks.java | 140 +++ .../repository/aot/MongoCodeBlocks.java | 868 +----------------- .../aot/MongoRepositoryContributor.java | 16 +- .../mongodb/repository/aot/QueryBlocks.java | 308 +++++++ .../mongodb/repository/aot/UpdateBlocks.java | 145 +++ 7 files changed, 1097 insertions(+), 840 deletions(-) create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/DeleteBlocks.java create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java new file mode 100644 index 0000000000..11ef3a4822 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java @@ -0,0 +1,360 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.bson.Document; +import org.jspecify.annotations.NullUnmarked; +import org.springframework.core.annotation.MergedAnnotation; +import org.springframework.data.domain.SliceImpl; +import org.springframework.data.domain.Sort.Order; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.aggregation.Aggregation; +import org.springframework.data.mongodb.core.aggregation.AggregationOptions; +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; +import org.springframework.data.mongodb.core.aggregation.AggregationResults; +import org.springframework.data.mongodb.core.aggregation.TypedAggregation; +import org.springframework.data.mongodb.core.mapping.MongoSimpleTypes; +import org.springframework.data.mongodb.core.query.Collation; +import org.springframework.data.mongodb.repository.Hint; +import org.springframework.data.mongodb.repository.ReadPreference; +import org.springframework.data.mongodb.repository.query.MongoQueryMethod; +import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; +import org.springframework.data.util.ReflectionUtils; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.util.ClassUtils; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +/** + * @author Christoph Strobl + * @since 5.0 + */ +class AggregationBlocks { + + @NullUnmarked + static class AggregationExecutionCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private String aggregationVariableName; + + AggregationExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + this.queryMethod = queryMethod; + } + + AggregationExecutionCodeBlockBuilder referencing(String aggregationVariableName) { + + this.aggregationVariableName = aggregationVariableName; + return this; + } + + CodeBlock build() { + + String mongoOpsRef = context.fieldNameOf(MongoOperations.class); + Builder builder = CodeBlock.builder(); + + builder.add("\n"); + + Class outputType = queryMethod.getReturnedObjectType(); + if (MongoSimpleTypes.HOLDER.isSimpleType(outputType)) { + outputType = Document.class; + } else if (ClassUtils.isAssignable(AggregationResults.class, outputType)) { + outputType = queryMethod.getReturnType().getComponentType().getType(); + } + + if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) { + builder.addStatement("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); + return builder.build(); + } + + if (ClassUtils.isAssignable(AggregationResults.class, context.getMethod().getReturnType())) { + builder.addStatement("return $L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); + return builder.build(); + } + + if (outputType == Document.class) { + + Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); + + if (queryMethod.isStreamQuery()) { + + builder.addStatement("$T<$T> $L = $L.aggregateStream($L, $T.class)", Stream.class, Document.class, + context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); + + builder.addStatement("return $1L.map(it -> ($2T) convertSimpleRawResult($2T.class, it))", + context.localVariable("results"), returnType); + } else { + + builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class, + context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); + + if (!queryMethod.isCollectionQuery()) { + builder.addStatement( + "return $1T.<$2T>firstElement(convertSimpleRawResults($2T.class, $3L.getMappedResults()))", + CollectionUtils.class, returnType, context.localVariable("results")); + } else { + builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType, + context.localVariable("results")); + } + } + } else { + if (queryMethod.isSliceQuery()) { + builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class, + context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); + builder.addStatement("boolean $L = $L.getMappedResults().size() > $L.getPageSize()", + context.localVariable("hasNext"), context.localVariable("results"), context.getPageableParameterName()); + builder.addStatement( + "return new $1T<>($2L ? $3L.getMappedResults().subList(0, $4L.getPageSize()) : $3L.getMappedResults(), $4L, $2L)", + SliceImpl.class, context.localVariable("hasNext"), context.localVariable("results"), + context.getPageableParameterName()); + } else { + + if (queryMethod.isStreamQuery()) { + builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName, + outputType); + } else { + + builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef, + aggregationVariableName, outputType); + } + } + } + + return builder.build(); + } + } + + @NullUnmarked + static class AggregationCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private final List arguments; + + private AggregationInteraction source; + + private String aggregationVariableName; + private boolean pipelineOnly; + + AggregationCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + this.arguments = context.getBindableParameterNames().stream().map(CodeBlock::of).collect(Collectors.toList()); + this.queryMethod = queryMethod; + } + + AggregationCodeBlockBuilder stages(AggregationInteraction aggregation) { + + this.source = aggregation; + return this; + } + + AggregationCodeBlockBuilder usingAggregationVariableName(String aggregationVariableName) { + + this.aggregationVariableName = aggregationVariableName; + return this; + } + + AggregationCodeBlockBuilder pipelineOnly(boolean pipelineOnly) { + + this.pipelineOnly = pipelineOnly; + return this; + } + + CodeBlock build() { + + Builder builder = CodeBlock.builder(); + builder.add("\n"); + + String pipelineName = context.localVariable(aggregationVariableName + (pipelineOnly ? "" : "Pipeline")); + builder.add(pipeline(pipelineName)); + + if (!pipelineOnly) { + + builder.addStatement("$1T<$2T> $3L = $4T.newAggregation($2T.class, $5L.getOperations())", + TypedAggregation.class, context.getRepositoryInformation().getDomainType(), aggregationVariableName, + Aggregation.class, pipelineName); + + builder.add(aggregationOptions(aggregationVariableName)); + } + + return builder.build(); + } + + private CodeBlock pipeline(String pipelineVariableName) { + + String sortParameter = context.getSortParameterName(); + String limitParameter = context.getLimitParameterName(); + String pageableParameter = context.getPageableParameterName(); + + boolean mightBeSorted = StringUtils.hasText(sortParameter); + boolean mightBeLimited = StringUtils.hasText(limitParameter); + boolean mightBePaged = StringUtils.hasText(pageableParameter); + + int stageCount = source.stages().size(); + if (mightBeSorted) { + stageCount++; + } + if (mightBeLimited) { + stageCount++; + } + if (mightBePaged) { + stageCount += 3; + } + + Builder builder = CodeBlock.builder(); + builder.add(aggregationStages(context.localVariable("stages"), source.stages(), stageCount, arguments)); + + if (mightBeSorted) { + builder.add(sortingStage(sortParameter)); + } + + if (mightBeLimited) { + builder.add(limitingStage(limitParameter)); + } + + if (mightBePaged) { + builder.add(pagingStage(pageableParameter, queryMethod.isSliceQuery())); + } + + builder.addStatement("$T $L = createPipeline($L)", AggregationPipeline.class, pipelineVariableName, + context.localVariable("stages")); + return builder.build(); + } + + private CodeBlock aggregationOptions(String aggregationVariableName) { + + Builder builder = CodeBlock.builder(); + List options = new ArrayList<>(5); + if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) { + options.add(CodeBlock.of(".skipOutput()")); + } + + MergedAnnotation hintAnnotation = context.getAnnotation(Hint.class); + String hint = hintAnnotation.isPresent() ? hintAnnotation.getString("value") : null; + if (StringUtils.hasText(hint)) { + options.add(CodeBlock.of(".hint($S)", hint)); + } + + MergedAnnotation readPreferenceAnnotation = context.getAnnotation(ReadPreference.class); + String readPreference = readPreferenceAnnotation.isPresent() ? readPreferenceAnnotation.getString("value") : null; + if (StringUtils.hasText(readPreference)) { + options.add(CodeBlock.of(".readPreference($T.valueOf($S))", com.mongodb.ReadPreference.class, readPreference)); + } + + if (queryMethod.hasAnnotatedCollation()) { + options.add(CodeBlock.of(".collation($T.parse($S))", Collation.class, queryMethod.getAnnotatedCollation())); + } + + if (!options.isEmpty()) { + + Builder optionsBuilder = CodeBlock.builder(); + optionsBuilder.add("$1T $2L = $1T.builder()\n", AggregationOptions.class, + context.localVariable("aggregationOptions")); + optionsBuilder.indent(); + for (CodeBlock optionBlock : options) { + optionsBuilder.add(optionBlock); + optionsBuilder.add("\n"); + } + optionsBuilder.add(".build();\n"); + optionsBuilder.unindent(); + builder.add(optionsBuilder.build()); + + builder.addStatement("$1L = $1L.withOptions($2L)", aggregationVariableName, + context.localVariable("aggregationOptions")); + } + return builder.build(); + } + + private CodeBlock aggregationStages(String stageListVariableName, Iterable stages, int stageCount, + List arguments) { + + Builder builder = CodeBlock.builder(); + builder.addStatement("$T<$T> $L = new $T($L)", List.class, Object.class, stageListVariableName, ArrayList.class, + stageCount); + int stageCounter = 0; + + for (String stage : stages) { + String stageName = context.localVariable("stage_%s".formatted(stageCounter++)); + builder.add(MongoCodeBlocks.renderExpressionToDocument(stage, stageName, arguments)); + builder.addStatement("$L.add($L)", context.localVariable("stages"), stageName); + } + + return builder.build(); + } + + private CodeBlock sortingStage(String sortProvider) { + + Builder builder = CodeBlock.builder(); + + builder.beginControlFlow("if ($L.isSorted())", sortProvider); + builder.addStatement("$1T $2L = new $1T()", Document.class, context.localVariable("sortDocument")); + builder.beginControlFlow("for ($T $L : $L)", Order.class, context.localVariable("order"), sortProvider); + builder.addStatement("$1L.append($2L.getProperty(), $2L.isAscending() ? 1 : -1);", + context.localVariable("sortDocument"), context.localVariable("order")); + builder.endControlFlow(); + builder.addStatement("stages.add(new $T($S, $L))", Document.class, "$sort", + context.localVariable("sortDocument")); + builder.endControlFlow(); + + return builder.build(); + } + + private CodeBlock pagingStage(String pageableProvider, boolean slice) { + + Builder builder = CodeBlock.builder(); + + builder.add(sortingStage(pageableProvider + ".getSort()")); + + builder.beginControlFlow("if ($L.isPaged())", pageableProvider); + builder.beginControlFlow("if ($L.getOffset() > 0)", pageableProvider); + builder.addStatement("$L.add($T.skip($L.getOffset()))", context.localVariable("stages"), Aggregation.class, + pageableProvider); + builder.endControlFlow(); + if (slice) { + builder.addStatement("$L.add($T.limit($L.getPageSize() + 1))", context.localVariable("stages"), + Aggregation.class, pageableProvider); + } else { + builder.addStatement("$L.add($T.limit($L.getPageSize()))", context.localVariable("stages"), Aggregation.class, + pageableProvider); + } + builder.endControlFlow(); + + return builder.build(); + } + + private CodeBlock limitingStage(String limitProvider) { + + Builder builder = CodeBlock.builder(); + + builder.beginControlFlow("if ($L.isLimited())", limitProvider); + builder.addStatement("$L.add($T.limit($L.max()))", context.localVariable("stages"), Aggregation.class, + limitProvider); + builder.endControlFlow(); + + return builder.build(); + } + + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/DeleteBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/DeleteBlocks.java new file mode 100644 index 0000000000..1d009f3085 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/DeleteBlocks.java @@ -0,0 +1,100 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import java.util.Optional; + +import org.jspecify.annotations.NullUnmarked; +import org.springframework.data.mongodb.core.ExecutableRemoveOperation.ExecutableRemove; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.repository.query.MongoQueryExecution.DeleteExecution; +import org.springframework.data.mongodb.repository.query.MongoQueryMethod; +import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.javapoet.TypeName; +import org.springframework.util.ClassUtils; +import org.springframework.util.ObjectUtils; + +/** + * @author Christoph Strobl + * @since 5.0 + */ +class DeleteBlocks { + + @NullUnmarked + static class DeleteExecutionCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private String queryVariableName; + + DeleteExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + this.queryMethod = queryMethod; + } + + DeleteExecutionCodeBlockBuilder referencing(String queryVariableName) { + + this.queryVariableName = queryVariableName; + return this; + } + + CodeBlock build() { + + String mongoOpsRef = context.fieldNameOf(MongoOperations.class); + Builder builder = CodeBlock.builder(); + + Class domainType = context.getRepositoryInformation().getDomainType(); + boolean isProjecting = context.getActualReturnType() != null + && !ObjectUtils.nullSafeEquals(TypeName.get(domainType), context.getActualReturnType()); + + Object actualReturnType = isProjecting ? context.getActualReturnType().getType() : domainType; + + builder.add("\n"); + builder.addStatement("$1T<$2T> $3L = $4L.remove($2T.class)", ExecutableRemove.class, domainType, + context.localVariable("remover"), mongoOpsRef); + + DeleteExecution.Type type = DeleteExecution.Type.FIND_AND_REMOVE_ALL; + if (!queryMethod.isCollectionQuery()) { + if (!ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType())) { + type = DeleteExecution.Type.FIND_AND_REMOVE_ONE; + } else { + type = DeleteExecution.Type.ALL; + } + } + + actualReturnType = ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType()) + ? TypeName.get(context.getMethod().getReturnType()) + : queryMethod.isCollectionQuery() ? context.getReturnTypeName() : actualReturnType; + + if (ClassUtils.isVoidType(context.getMethod().getReturnType())) { + builder.addStatement("new $T($L, $T.$L).execute($L)", DeleteExecution.class, context.localVariable("remover"), + DeleteExecution.Type.class, type.name(), queryVariableName); + } else if (context.getMethod().getReturnType() == Optional.class) { + builder.addStatement("return $T.ofNullable(($T) new $T($L, $T.$L).execute($L))", Optional.class, + actualReturnType, DeleteExecution.class, context.localVariable("remover"), DeleteExecution.Type.class, + type.name(), queryVariableName); + } else { + builder.addStatement("return ($T) new $T($L, $T.$L).execute($L)", actualReturnType, DeleteExecution.class, + context.localVariable("remover"), DeleteExecution.Type.class, type.name(), queryVariableName); + } + + return builder.build(); + } + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java new file mode 100644 index 0000000000..ecf111433b --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java @@ -0,0 +1,140 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import java.util.List; +import java.util.stream.Collectors; + +import org.springframework.data.geo.Distance; +import org.springframework.data.geo.GeoPage; +import org.springframework.data.geo.GeoResults; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.query.NearQuery; +import org.springframework.data.mongodb.repository.query.MongoQueryMethod; +import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; +import org.springframework.javapoet.CodeBlock; +import org.springframework.util.ClassUtils; + +/** + * @author Christoph Strobl + * @since 5.0 + */ +class GeoBlocks { + + static class GeoNearCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private final List arguments; + + private String variableName; + + GeoNearCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + this.arguments = context.getBindableParameterNames().stream().map(CodeBlock::of).collect(Collectors.toList()); + this.queryMethod = queryMethod; + } + + CodeBlock build() { + + CodeBlock.Builder builder = CodeBlock.builder(); + builder.add("\n"); + + String locationParameterName = context.getParameterName(queryMethod.getParameters().getNearIndex()); + + builder.addStatement("$1T $2L = $1T.near($3L)", NearQuery.class, variableName, locationParameterName); + + if (queryMethod.getParameters().getRangeIndex() != -1) { + + String rangeParametername = context.getParameterName(queryMethod.getParameters().getRangeIndex()); + String minVarName = context.localVariable("min"); + String maxVarName = context.localVariable("max"); + + builder.beginControlFlow("if($L.getLowerBound().isBounded())", rangeParametername); + builder.addStatement("$1T $2L = $3L.getLowerBound().getValue().get()", Distance.class, minVarName, + rangeParametername); + builder.addStatement("$1L.minDistance($2L).in($2L.getMetric())", variableName, minVarName); + builder.endControlFlow(); + + builder.beginControlFlow("if($L.getUpperBound().isBounded())", rangeParametername); + builder.addStatement("$1T $2L = $3L.getUpperBound().getValue().get()", Distance.class, maxVarName, + rangeParametername); + builder.addStatement("$1L.maxDistance($2L).in($2L.getMetric())", variableName, maxVarName); + builder.endControlFlow(); + } else { + + String distanceParametername = context.getParameterName(queryMethod.getParameters().getMaxDistanceIndex()); + builder.addStatement("$1L.maxDistance($2L).in($2L.getMetric())", variableName, distanceParametername); + } + + if (context.getPageableParameterName() != null) { + builder.addStatement("$L.with($L)", variableName, context.getPageableParameterName()); + } + + return builder.build(); + } + + public GeoNearCodeBlockBuilder usingQueryVariableName(String variableName) { + this.variableName = variableName; + return this; + } + } + + static class GeoNearExecutionCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private String queryVariableName; + + GeoNearExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + this.queryMethod = queryMethod; + } + + GeoNearExecutionCodeBlockBuilder referencing(String queryVariableName) { + + this.queryVariableName = queryVariableName; + return this; + } + + CodeBlock build() { + + CodeBlock.Builder builder = CodeBlock.builder(); + builder.add("\n"); + + // TODO: move the section below into dedicated executor builder + if (ClassUtils.isAssignable(GeoPage.class, context.getReturnType().getRawClass())) { + builder.addStatement("return new $T<>($L.query($T.class).near($L).all())", GeoPage.class, + context.fieldNameOf(MongoOperations.class), context.getRepositoryInformation().getDomainType(), + queryVariableName); + } + + else if (ClassUtils.isAssignable(GeoResults.class, context.getReturnType().getRawClass())) { + + builder.addStatement("return $L.query($T.class).near($L).all()", context.fieldNameOf(MongoOperations.class), + context.getRepositoryInformation().getDomainType(), queryVariableName); + } else { + builder.addStatement("return $L.query($T.class).near($L).all().getContent()", + context.fieldNameOf(MongoOperations.class), context.getRepositoryInformation().getDomainType(), + queryVariableName); + } + return builder.build(); + } + + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java index 331768e0b8..2a51bc81b2 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java @@ -15,59 +15,25 @@ */ package org.springframework.data.mongodb.repository.aot; -import java.util.ArrayList; import java.util.Iterator; import java.util.List; -import java.util.Optional; import java.util.regex.Pattern; -import java.util.stream.Collectors; -import java.util.stream.Stream; import org.bson.Document; -import org.jspecify.annotations.NullUnmarked; import org.jspecify.annotations.Nullable; -import org.springframework.core.annotation.MergedAnnotation; -import org.springframework.data.domain.SliceImpl; -import org.springframework.data.domain.Sort.Order; -import org.springframework.data.geo.Box; -import org.springframework.data.geo.Circle; -import org.springframework.data.geo.Distance; -import org.springframework.data.geo.GeoPage; -import org.springframework.data.geo.GeoResults; -import org.springframework.data.geo.Polygon; -import org.springframework.data.mongodb.core.ExecutableFindOperation.FindWithQuery; -import org.springframework.data.mongodb.core.ExecutableRemoveOperation.ExecutableRemove; -import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate; -import org.springframework.data.mongodb.core.MongoOperations; -import org.springframework.data.mongodb.core.aggregation.Aggregation; -import org.springframework.data.mongodb.core.aggregation.AggregationOptions; -import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; -import org.springframework.data.mongodb.core.aggregation.AggregationResults; -import org.springframework.data.mongodb.core.aggregation.TypedAggregation; -import org.springframework.data.mongodb.core.geo.GeoJson; -import org.springframework.data.mongodb.core.geo.Sphere; -import org.springframework.data.mongodb.core.mapping.MongoSimpleTypes; -import org.springframework.data.mongodb.core.query.BasicQuery; -import org.springframework.data.mongodb.core.query.BasicUpdate; -import org.springframework.data.mongodb.core.query.Collation; -import org.springframework.data.mongodb.core.query.NearQuery; -import org.springframework.data.mongodb.repository.Hint; -import org.springframework.data.mongodb.repository.Meta; -import org.springframework.data.mongodb.repository.ReadPreference; -import org.springframework.data.mongodb.repository.query.MongoParameters.MongoParameter; -import org.springframework.data.mongodb.repository.query.MongoQueryExecution.DeleteExecution; -import org.springframework.data.mongodb.repository.query.MongoQueryExecution.PagedExecution; -import org.springframework.data.mongodb.repository.query.MongoQueryExecution.SlicedExecution; +import org.springframework.data.mongodb.repository.aot.AggregationBlocks.AggregationCodeBlockBuilder; +import org.springframework.data.mongodb.repository.aot.AggregationBlocks.AggregationExecutionCodeBlockBuilder; +import org.springframework.data.mongodb.repository.aot.DeleteBlocks.DeleteExecutionCodeBlockBuilder; +import org.springframework.data.mongodb.repository.aot.GeoBlocks.GeoNearCodeBlockBuilder; +import org.springframework.data.mongodb.repository.aot.GeoBlocks.GeoNearExecutionCodeBlockBuilder; +import org.springframework.data.mongodb.repository.aot.QueryBlocks.QueryCodeBlockBuilder; +import org.springframework.data.mongodb.repository.aot.QueryBlocks.QueryExecutionCodeBlockBuilder; +import org.springframework.data.mongodb.repository.aot.UpdateBlocks.UpdateCodeBlockBuilder; +import org.springframework.data.mongodb.repository.aot.UpdateBlocks.UpdateExecutionCodeBlockBuilder; import org.springframework.data.mongodb.repository.query.MongoQueryMethod; import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; -import org.springframework.data.util.ReflectionUtils; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock.Builder; -import org.springframework.javapoet.TypeName; -import org.springframework.util.ClassUtils; -import org.springframework.util.CollectionUtils; -import org.springframework.util.NumberUtils; -import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; /** @@ -89,6 +55,7 @@ class MongoCodeBlocks { */ static QueryCodeBlockBuilder queryBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + return new QueryCodeBlockBuilder(context, queryMethod); } @@ -127,6 +94,7 @@ static DeleteExecutionCodeBlockBuilder deleteExecutionBlockBuilder(AotQueryMetho */ static UpdateCodeBlockBuilder updateBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + return new UpdateCodeBlockBuilder(context, queryMethod); } @@ -156,12 +124,6 @@ static AggregationCodeBlockBuilder aggregationBlockBuilder(AotQueryMethodGenerat return new AggregationCodeBlockBuilder(context, queryMethod); } - static GeoNearCodeBlockBuilder geoNearBlockBuilder(AotQueryMethodGenerationContext context, - MongoQueryMethod queryMethod) { - - return new GeoNearCodeBlockBuilder(context, queryMethod); - } - /** * Builder for generating aggregation execution {@link CodeBlock}. * @@ -175,796 +137,34 @@ static AggregationExecutionCodeBlockBuilder aggregationExecutionBlockBuilder(Aot return new AggregationExecutionCodeBlockBuilder(context, queryMethod); } - @NullUnmarked - static class DeleteExecutionCodeBlockBuilder { - - private final AotQueryMethodGenerationContext context; - private final MongoQueryMethod queryMethod; - private String queryVariableName; - - DeleteExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - - this.context = context; - this.queryMethod = queryMethod; - } - - DeleteExecutionCodeBlockBuilder referencing(String queryVariableName) { - - this.queryVariableName = queryVariableName; - return this; - } - - CodeBlock build() { - - String mongoOpsRef = context.fieldNameOf(MongoOperations.class); - Builder builder = CodeBlock.builder(); - - Class domainType = context.getRepositoryInformation().getDomainType(); - boolean isProjecting = context.getActualReturnType() != null - && !ObjectUtils.nullSafeEquals(TypeName.get(domainType), context.getActualReturnType()); - - Object actualReturnType = isProjecting ? context.getActualReturnType().getType() : domainType; - - builder.add("\n"); - builder.addStatement("$1T<$2T> $3L = $4L.remove($2T.class)", ExecutableRemove.class, domainType, - context.localVariable("remover"), mongoOpsRef); - - DeleteExecution.Type type = DeleteExecution.Type.FIND_AND_REMOVE_ALL; - if (!queryMethod.isCollectionQuery()) { - if (!ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType())) { - type = DeleteExecution.Type.FIND_AND_REMOVE_ONE; - } else { - type = DeleteExecution.Type.ALL; - } - } - - actualReturnType = ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType()) - ? TypeName.get(context.getMethod().getReturnType()) - : queryMethod.isCollectionQuery() ? context.getReturnTypeName() : actualReturnType; - - if (ClassUtils.isVoidType(context.getMethod().getReturnType())) { - builder.addStatement("new $T($L, $T.$L).execute($L)", DeleteExecution.class, context.localVariable("remover"), - DeleteExecution.Type.class, type.name(), queryVariableName); - } else if (context.getMethod().getReturnType() == Optional.class) { - builder.addStatement("return $T.ofNullable(($T) new $T($L, $T.$L).execute($L))", Optional.class, - actualReturnType, DeleteExecution.class, context.localVariable("remover"), DeleteExecution.Type.class, - type.name(), queryVariableName); - } else { - builder.addStatement("return ($T) new $T($L, $T.$L).execute($L)", actualReturnType, DeleteExecution.class, - context.localVariable("remover"), DeleteExecution.Type.class, type.name(), queryVariableName); - } - - return builder.build(); - } - } - - @NullUnmarked - static class UpdateExecutionCodeBlockBuilder { - - private final AotQueryMethodGenerationContext context; - private final MongoQueryMethod queryMethod; - private String queryVariableName; - private String updateVariableName; - - UpdateExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - - this.context = context; - this.queryMethod = queryMethod; - } - - UpdateExecutionCodeBlockBuilder withFilter(String queryVariableName) { - - this.queryVariableName = queryVariableName; - return this; - } - - UpdateExecutionCodeBlockBuilder referencingUpdate(String updateVariableName) { - - this.updateVariableName = updateVariableName; - return this; - } - - CodeBlock build() { - - String mongoOpsRef = context.fieldNameOf(MongoOperations.class); - Builder builder = CodeBlock.builder(); - - builder.add("\n"); - - String updateReference = updateVariableName; - Class domainType = context.getRepositoryInformation().getDomainType(); - builder.addStatement("$1T<$2T> $3L = $4L.update($2T.class)", ExecutableUpdate.class, domainType, - context.localVariable("updater"), mongoOpsRef); - - Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); - if (ReflectionUtils.isVoid(returnType)) { - builder.addStatement("$L.matching($L).apply($L).all()", context.localVariable("updater"), queryVariableName, - updateReference); - } else if (ClassUtils.isAssignable(Long.class, returnType)) { - builder.addStatement("return $L.matching($L).apply($L).all().getModifiedCount()", - context.localVariable("updater"), queryVariableName, updateReference); - } else { - builder.addStatement("$T $L = $L.matching($L).apply($L).all().getModifiedCount()", Long.class, - context.localVariable("modifiedCount"), context.localVariable("updater"), queryVariableName, - updateReference); - builder.addStatement("return $T.convertNumberToTargetClass($L, $T.class)", NumberUtils.class, - context.localVariable("modifiedCount"), returnType); - } - - return builder.build(); - } - } - - @NullUnmarked - static class AggregationExecutionCodeBlockBuilder { - - private final AotQueryMethodGenerationContext context; - private final MongoQueryMethod queryMethod; - private String aggregationVariableName; - - AggregationExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - - this.context = context; - this.queryMethod = queryMethod; - } - - AggregationExecutionCodeBlockBuilder referencing(String aggregationVariableName) { - - this.aggregationVariableName = aggregationVariableName; - return this; - } - - CodeBlock build() { - - String mongoOpsRef = context.fieldNameOf(MongoOperations.class); - Builder builder = CodeBlock.builder(); - - builder.add("\n"); - - Class outputType = queryMethod.getReturnedObjectType(); - if (MongoSimpleTypes.HOLDER.isSimpleType(outputType)) { - outputType = Document.class; - } else if (ClassUtils.isAssignable(AggregationResults.class, outputType)) { - outputType = queryMethod.getReturnType().getComponentType().getType(); - } - - if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) { - builder.addStatement("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); - return builder.build(); - } - - if (ClassUtils.isAssignable(AggregationResults.class, context.getMethod().getReturnType())) { - builder.addStatement("return $L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); - return builder.build(); - } - - if (outputType == Document.class) { - - Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); - - if (queryMethod.isStreamQuery()) { - - builder.addStatement("$T<$T> $L = $L.aggregateStream($L, $T.class)", Stream.class, Document.class, - context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); - - builder.addStatement("return $1L.map(it -> ($2T) convertSimpleRawResult($2T.class, it))", - context.localVariable("results"), returnType); - } else { - - builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class, - context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); - - if (!queryMethod.isCollectionQuery()) { - builder.addStatement( - "return $1T.<$2T>firstElement(convertSimpleRawResults($2T.class, $3L.getMappedResults()))", - CollectionUtils.class, returnType, context.localVariable("results")); - } else { - builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType, - context.localVariable("results")); - } - } - } else { - if (queryMethod.isSliceQuery()) { - builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class, - context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); - builder.addStatement("boolean $L = $L.getMappedResults().size() > $L.getPageSize()", - context.localVariable("hasNext"), context.localVariable("results"), context.getPageableParameterName()); - builder.addStatement( - "return new $1T<>($2L ? $3L.getMappedResults().subList(0, $4L.getPageSize()) : $3L.getMappedResults(), $4L, $2L)", - SliceImpl.class, context.localVariable("hasNext"), context.localVariable("results"), - context.getPageableParameterName()); - } else { - - if (queryMethod.isStreamQuery()) { - builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName, - outputType); - } else { - - builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef, - aggregationVariableName, outputType); - } - } - } - - return builder.build(); - } - } - - @NullUnmarked - static class QueryExecutionCodeBlockBuilder { - - private final AotQueryMethodGenerationContext context; - private final MongoQueryMethod queryMethod; - private QueryInteraction query; - - QueryExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - - this.context = context; - this.queryMethod = queryMethod; - } - - QueryExecutionCodeBlockBuilder forQuery(QueryInteraction query) { - - this.query = query; - return this; - } - - CodeBlock build() { - - String mongoOpsRef = context.fieldNameOf(MongoOperations.class); - - Builder builder = CodeBlock.builder(); - - boolean isProjecting = context.getReturnedType().isProjecting(); - Class domainType = context.getRepositoryInformation().getDomainType(); - Object actualReturnType = queryMethod.getParameters().hasDynamicProjection() || isProjecting - ? TypeName.get(context.getActualReturnType().getType()) - : domainType; - - builder.add("\n"); - - if (queryMethod.getParameters().hasDynamicProjection()) { - builder.addStatement("$T<$T> $L = $L.query($T.class).as($L)", FindWithQuery.class, actualReturnType, - context.localVariable("finder"), mongoOpsRef, domainType, context.getDynamicProjectionParameterName()); - } else if (isProjecting) { - builder.addStatement("$T<$T> $L = $L.query($T.class).as($T.class)", FindWithQuery.class, actualReturnType, - context.localVariable("finder"), mongoOpsRef, domainType, actualReturnType); - } else { - - builder.addStatement("$T<$T> $L = $L.query($T.class)", FindWithQuery.class, actualReturnType, - context.localVariable("finder"), mongoOpsRef, domainType); - } - - String terminatingMethod; - - if (queryMethod.isCollectionQuery() || queryMethod.isPageQuery() || queryMethod.isSliceQuery()) { - terminatingMethod = "all()"; - } else if (query.isCount()) { - terminatingMethod = "count()"; - } else if (query.isExists()) { - terminatingMethod = "exists()"; - } else if (queryMethod.isStreamQuery()) { - terminatingMethod = "stream()"; - } else { - terminatingMethod = Optional.class.isAssignableFrom(context.getReturnType().toClass()) ? "one()" : "oneValue()"; - } - - if (queryMethod.isPageQuery()) { - builder.addStatement("return new $T($L, $L).execute($L)", PagedExecution.class, context.localVariable("finder"), - context.getPageableParameterName(), query.name()); - } else if (queryMethod.isSliceQuery()) { - builder.addStatement("return new $T($L, $L).execute($L)", SlicedExecution.class, - context.localVariable("finder"), context.getPageableParameterName(), query.name()); - } else if (queryMethod.isScrollQuery()) { - - String scrollPositionParameterName = context.getScrollPositionParameterName(); - - builder.addStatement("return $L.matching($L).scroll($L)", context.localVariable("finder"), query.name(), - scrollPositionParameterName); - } else { - if (query.isCount() && !ClassUtils.isAssignable(Long.class, context.getActualReturnType().getRawClass())) { - - Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); - builder.addStatement("return $T.convertNumberToTargetClass($L.matching($L).$L, $T.class)", NumberUtils.class, - context.localVariable("finder"), query.name(), terminatingMethod, returnType); - - } else { - builder.addStatement("return $L.matching($L).$L", context.localVariable("finder"), query.name(), - terminatingMethod); - } - } - - return builder.build(); - } - } - - static class GeoNearCodeBlockBuilder { - - private final AotQueryMethodGenerationContext context; - private final MongoQueryMethod queryMethod; - private final List arguments; - - private String variableName; - - GeoNearCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - - this.context = context; - this.arguments = context.getBindableParameterNames().stream().map(CodeBlock::of).collect(Collectors.toList()); - this.queryMethod = queryMethod; - } - - CodeBlock build() { - - CodeBlock.Builder builder = CodeBlock.builder(); - builder.add("\n"); - - String locationParameterName = context.getParameterName(queryMethod.getParameters().getNearIndex()); - - builder.addStatement("$1T $2L = $1T.near($3L)", NearQuery.class, variableName, locationParameterName); - - if (queryMethod.getParameters().getRangeIndex() != -1) { - - String rangeParametername = context.getParameterName(queryMethod.getParameters().getRangeIndex()); - String minVarName = context.localVariable("min"); - String maxVarName = context.localVariable("max"); - - builder.beginControlFlow("if($L.getLowerBound().isBounded())", rangeParametername); - builder.addStatement("$1T $2L = $3L.getLowerBound().getValue().get()", Distance.class, minVarName, - rangeParametername); - builder.addStatement("$1L.minDistance($2L).in($2L.getMetric())", variableName, minVarName); - builder.endControlFlow(); - - builder.beginControlFlow("if($L.getUpperBound().isBounded())", rangeParametername); - builder.addStatement("$1T $2L = $3L.getUpperBound().getValue().get()", Distance.class, maxVarName, - rangeParametername); - builder.addStatement("$1L.maxDistance($2L).in($2L.getMetric())", variableName, maxVarName); - builder.endControlFlow(); - } else { - - String distanceParametername = context.getParameterName(queryMethod.getParameters().getMaxDistanceIndex()); - builder.addStatement("$1L.maxDistance($2L).in($2L.getMetric())", variableName, distanceParametername); - } - - if (context.getPageableParameterName() != null) { - builder.addStatement("$L.with($L)", variableName, context.getPageableParameterName()); - } - - builder.add("\n"); - - // TODO: move the section below into dedicated executor builder - if (ClassUtils.isAssignable(GeoPage.class, context.getReturnType().getRawClass())) { - builder.addStatement("return new $T<>($L.query($T.class).near($L).all())", GeoPage.class, - context.fieldNameOf(MongoOperations.class), context.getRepositoryInformation().getDomainType(), - variableName); - } - - else if (ClassUtils.isAssignable(GeoResults.class, context.getReturnType().getRawClass())) { - - builder.addStatement("return $L.query($T.class).near($L).all()", context.fieldNameOf(MongoOperations.class), - context.getRepositoryInformation().getDomainType(), variableName); - } else { - builder.addStatement("return $L.query($T.class).near($L).all().getContent()", - context.fieldNameOf(MongoOperations.class), context.getRepositoryInformation().getDomainType(), - variableName); - } - - return builder.build(); - } - - public GeoNearCodeBlockBuilder usingQueryVariableName(String variableName) { - this.variableName = variableName; - return this; - } - } - - @NullUnmarked - static class AggregationCodeBlockBuilder { - - private final AotQueryMethodGenerationContext context; - private final MongoQueryMethod queryMethod; - private final List arguments; - - private AggregationInteraction source; - - private String aggregationVariableName; - private boolean pipelineOnly; - - AggregationCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - - this.context = context; - this.arguments = context.getBindableParameterNames().stream().map(CodeBlock::of).collect(Collectors.toList()); - this.queryMethod = queryMethod; - } - - AggregationCodeBlockBuilder stages(AggregationInteraction aggregation) { - - this.source = aggregation; - return this; - } - - AggregationCodeBlockBuilder usingAggregationVariableName(String aggregationVariableName) { - - this.aggregationVariableName = aggregationVariableName; - return this; - } - - AggregationCodeBlockBuilder pipelineOnly(boolean pipelineOnly) { - - this.pipelineOnly = pipelineOnly; - return this; - } - - CodeBlock build() { - - CodeBlock.Builder builder = CodeBlock.builder(); - builder.add("\n"); - - String pipelineName = context.localVariable(aggregationVariableName + (pipelineOnly ? "" : "Pipeline")); - builder.add(pipeline(pipelineName)); - - if (!pipelineOnly) { - - builder.addStatement("$1T<$2T> $3L = $4T.newAggregation($2T.class, $5L.getOperations())", - TypedAggregation.class, context.getRepositoryInformation().getDomainType(), aggregationVariableName, - Aggregation.class, pipelineName); - - builder.add(aggregationOptions(aggregationVariableName)); - } - - return builder.build(); - } - - private CodeBlock pipeline(String pipelineVariableName) { - - String sortParameter = context.getSortParameterName(); - String limitParameter = context.getLimitParameterName(); - String pageableParameter = context.getPageableParameterName(); - - boolean mightBeSorted = StringUtils.hasText(sortParameter); - boolean mightBeLimited = StringUtils.hasText(limitParameter); - boolean mightBePaged = StringUtils.hasText(pageableParameter); - - int stageCount = source.stages().size(); - if (mightBeSorted) { - stageCount++; - } - if (mightBeLimited) { - stageCount++; - } - if (mightBePaged) { - stageCount += 3; - } - - Builder builder = CodeBlock.builder(); - builder.add(aggregationStages(context.localVariable("stages"), source.stages(), stageCount, arguments)); - - if (mightBeSorted) { - builder.add(sortingStage(sortParameter)); - } - - if (mightBeLimited) { - builder.add(limitingStage(limitParameter)); - } - - if (mightBePaged) { - builder.add(pagingStage(pageableParameter, queryMethod.isSliceQuery())); - } - - builder.addStatement("$T $L = createPipeline($L)", AggregationPipeline.class, pipelineVariableName, - context.localVariable("stages")); - return builder.build(); - } - - private CodeBlock aggregationOptions(String aggregationVariableName) { - - Builder builder = CodeBlock.builder(); - List options = new ArrayList<>(5); - if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) { - options.add(CodeBlock.of(".skipOutput()")); - } - - MergedAnnotation hintAnnotation = context.getAnnotation(Hint.class); - String hint = hintAnnotation.isPresent() ? hintAnnotation.getString("value") : null; - if (StringUtils.hasText(hint)) { - options.add(CodeBlock.of(".hint($S)", hint)); - } - - MergedAnnotation readPreferenceAnnotation = context.getAnnotation(ReadPreference.class); - String readPreference = readPreferenceAnnotation.isPresent() ? readPreferenceAnnotation.getString("value") : null; - if (StringUtils.hasText(readPreference)) { - options.add(CodeBlock.of(".readPreference($T.valueOf($S))", com.mongodb.ReadPreference.class, readPreference)); - } - - if (queryMethod.hasAnnotatedCollation()) { - options.add(CodeBlock.of(".collation($T.parse($S))", Collation.class, queryMethod.getAnnotatedCollation())); - } - - if (!options.isEmpty()) { - - Builder optionsBuilder = CodeBlock.builder(); - optionsBuilder.add("$1T $2L = $1T.builder()\n", AggregationOptions.class, - context.localVariable("aggregationOptions")); - optionsBuilder.indent(); - for (CodeBlock optionBlock : options) { - optionsBuilder.add(optionBlock); - optionsBuilder.add("\n"); - } - optionsBuilder.add(".build();\n"); - optionsBuilder.unindent(); - builder.add(optionsBuilder.build()); - - builder.addStatement("$1L = $1L.withOptions($2L)", aggregationVariableName, - context.localVariable("aggregationOptions")); - } - return builder.build(); - } - - private CodeBlock aggregationStages(String stageListVariableName, Iterable stages, int stageCount, - List arguments) { - - Builder builder = CodeBlock.builder(); - builder.addStatement("$T<$T> $L = new $T($L)", List.class, Object.class, stageListVariableName, ArrayList.class, - stageCount); - int stageCounter = 0; - - for (String stage : stages) { - String stageName = context.localVariable("stage_%s".formatted(stageCounter++)); - builder.add(renderExpressionToDocument(stage, stageName, arguments)); - builder.addStatement("$L.add($L)", context.localVariable("stages"), stageName); - } - - return builder.build(); - } - - private CodeBlock sortingStage(String sortProvider) { - - Builder builder = CodeBlock.builder(); - - builder.beginControlFlow("if ($L.isSorted())", sortProvider); - builder.addStatement("$1T $2L = new $1T()", Document.class, context.localVariable("sortDocument")); - builder.beginControlFlow("for ($T $L : $L)", Order.class, context.localVariable("order"), sortProvider); - builder.addStatement("$1L.append($2L.getProperty(), $2L.isAscending() ? 1 : -1);", - context.localVariable("sortDocument"), context.localVariable("order")); - builder.endControlFlow(); - builder.addStatement("stages.add(new $T($S, $L))", Document.class, "$sort", - context.localVariable("sortDocument")); - builder.endControlFlow(); - - return builder.build(); - } - - private CodeBlock pagingStage(String pageableProvider, boolean slice) { - - Builder builder = CodeBlock.builder(); - - builder.add(sortingStage(pageableProvider + ".getSort()")); - - builder.beginControlFlow("if ($L.isPaged())", pageableProvider); - builder.beginControlFlow("if ($L.getOffset() > 0)", pageableProvider); - builder.addStatement("$L.add($T.skip($L.getOffset()))", context.localVariable("stages"), Aggregation.class, - pageableProvider); - builder.endControlFlow(); - if (slice) { - builder.addStatement("$L.add($T.limit($L.getPageSize() + 1))", context.localVariable("stages"), - Aggregation.class, pageableProvider); - } else { - builder.addStatement("$L.add($T.limit($L.getPageSize()))", context.localVariable("stages"), Aggregation.class, - pageableProvider); - } - builder.endControlFlow(); - - return builder.build(); - } - - private CodeBlock limitingStage(String limitProvider) { - - Builder builder = CodeBlock.builder(); - - builder.beginControlFlow("if ($L.isLimited())", limitProvider); - builder.addStatement("$L.add($T.limit($L.max()))", context.localVariable("stages"), Aggregation.class, - limitProvider); - builder.endControlFlow(); - - return builder.build(); - } - - } - - @NullUnmarked - static class QueryCodeBlockBuilder { - - private final AotQueryMethodGenerationContext context; - private final MongoQueryMethod queryMethod; - - private QueryInteraction source; - private final List arguments; - private String queryVariableName; - - QueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - - this.context = context; - - this.arguments = new ArrayList<>(); - for (MongoParameter parameter : queryMethod.getParameters().getBindableParameters()) { - String parameterName = context.getParameterName(parameter.getIndex()); - if (ClassUtils.isAssignable(GeoJson.class, parameter.getType())) { - arguments.add(CodeBlock.of(parameterName)); - } else if (ClassUtils.isAssignable(Circle.class, parameter.getType())) { - arguments.add(CodeBlock.builder().add( - "$1T.of($1T.of($2L.getCenter().getX(), $2L.getCenter().getY()), $2L.getRadius().getNormalizedValue())", - List.class, parameterName).build()); - } else if (ClassUtils.isAssignable(Box.class, parameter.getType())) { - - // { $geoWithin: { $box: [ [ , ], [ , ] ] } - arguments.add(CodeBlock.builder().add( - "$1T.of($1T.of($2L.getFirst().getX(), $2L.getFirst().getY()), $1T.of($2L.getSecond().getX(), $2L.getSecond().getY()))", - List.class, parameterName).build()); - } else if (ClassUtils.isAssignable(Sphere.class, parameter.getType())) { - // { $centerSphere: [ [ , ], ] } - arguments.add(CodeBlock.builder().add( - "$1T.of($1T.of($2L.getCenter().getX(), $2L.getCenter().getY()), $2L.getRadius().getNormalizedValue())", - List.class, parameterName).build()); - } else if (ClassUtils.isAssignable(Polygon.class, parameter.getType())) { - - // $polygon: [ [ , ], [ , ], [ , ], ... ] - String localVar = context.localVariable("_p"); - arguments.add( - CodeBlock.builder().add("$1L.getPoints().stream().map($2L -> $3T.of($2L.getX(), $2L.getY())).toList()", - parameterName, localVar, List.class).build()); - } else { - arguments.add(CodeBlock.of(parameterName)); - } - } - - this.queryMethod = queryMethod; - } - - QueryCodeBlockBuilder filter(QueryInteraction query) { - - this.source = query; - return this; - } - - QueryCodeBlockBuilder usingQueryVariableName(String queryVariableName) { - this.queryVariableName = queryVariableName; - return this; - } - - CodeBlock build() { - - CodeBlock.Builder builder = CodeBlock.builder(); - - builder.add("\n"); - builder.add(renderExpressionToQuery(source.getQuery().getQueryString(), queryVariableName)); - - if (StringUtils.hasText(source.getQuery().getFieldsString())) { - - builder.add(renderExpressionToDocument(source.getQuery().getFieldsString(), "fields", arguments)); - builder.addStatement("$L.setFieldsObject(fields)", queryVariableName); - } - - String sortParameter = context.getSortParameterName(); - if (StringUtils.hasText(sortParameter)) { - builder.addStatement("$L.with($L)", queryVariableName, sortParameter); - } else if (StringUtils.hasText(source.getQuery().getSortString())) { - - builder.add(renderExpressionToDocument(source.getQuery().getSortString(), "sort", arguments)); - builder.addStatement("$L.setSortObject(sort)", queryVariableName); - } - - String limitParameter = context.getLimitParameterName(); - if (StringUtils.hasText(limitParameter)) { - builder.addStatement("$L.limit($L)", queryVariableName, limitParameter); - } else if (context.getPageableParameterName() == null && source.getQuery().isLimited()) { - builder.addStatement("$L.limit($L)", queryVariableName, source.getQuery().getLimit()); - } - - String pageableParameter = context.getPageableParameterName(); - if (StringUtils.hasText(pageableParameter) && !queryMethod.isPageQuery() && !queryMethod.isSliceQuery()) { - builder.addStatement("$L.with($L)", queryVariableName, pageableParameter); - } - - MergedAnnotation hintAnnotation = context.getAnnotation(Hint.class); - String hint = hintAnnotation.isPresent() ? hintAnnotation.getString("value") : null; - - if (StringUtils.hasText(hint)) { - builder.addStatement("$L.withHint($S)", queryVariableName, hint); - } - - MergedAnnotation readPreferenceAnnotation = context.getAnnotation(ReadPreference.class); - String readPreference = readPreferenceAnnotation.isPresent() ? readPreferenceAnnotation.getString("value") : null; - - if (StringUtils.hasText(readPreference)) { - builder.addStatement("$L.withReadPreference($T.valueOf($S))", queryVariableName, - com.mongodb.ReadPreference.class, readPreference); - } - - MergedAnnotation metaAnnotation = context.getAnnotation(Meta.class); - - if (metaAnnotation.isPresent()) { - - long maxExecutionTimeMs = metaAnnotation.getLong("maxExecutionTimeMs"); - if (maxExecutionTimeMs != -1) { - builder.addStatement("$L.maxTimeMsec($L)", queryVariableName, maxExecutionTimeMs); - } - - int cursorBatchSize = metaAnnotation.getInt("cursorBatchSize"); - if (cursorBatchSize != 0) { - builder.addStatement("$L.cursorBatchSize($L)", queryVariableName, cursorBatchSize); - } - - String comment = metaAnnotation.getString("comment"); - if (StringUtils.hasText("comment")) { - builder.addStatement("$L.comment($S)", queryVariableName, comment); - } - } - - // TODO: Meta annotation: Disk usage - - return builder.build(); - } - - private CodeBlock renderExpressionToQuery(@Nullable String source, String variableName) { - - Builder builder = CodeBlock.builder(); - if (!StringUtils.hasText(source)) { - - builder.addStatement("$1T $2L = new $1T(new $3T())", BasicQuery.class, variableName, Document.class); - } else if (!containsPlaceholder(source)) { - builder.addStatement("$1T $2L = new $1T($3T.parse($4S))", BasicQuery.class, variableName, Document.class, - source); - } else { - builder.add("$T $L = createQuery($S, new $T[]{ ", BasicQuery.class, variableName, source, Object.class); - Iterator iterator = arguments.iterator(); - while (iterator.hasNext()) { - builder.add(iterator.next()); - if (iterator.hasNext()) { - builder.add(", "); - } - } - builder.add("});\n"); - } + /** + * Builder for generating {@link org.springframework.data.mongodb.core.query.NearQuery} {@link CodeBlock}. + * + * @param context + * @param queryMethod + * @return + */ + static GeoNearCodeBlockBuilder geoNearBlockBuilder(AotQueryMethodGenerationContext context, + MongoQueryMethod queryMethod) { - return builder.build(); - } + return new GeoNearCodeBlockBuilder(context, queryMethod); } - @NullUnmarked - static class UpdateCodeBlockBuilder { - - private UpdateInteraction source; - private List arguments; - private String updateVariableName; - - public UpdateCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - this.arguments = context.getBindableParameterNames().stream().map(CodeBlock::of).collect(Collectors.toList()); - } - - public UpdateCodeBlockBuilder update(UpdateInteraction update) { - this.source = update; - return this; - } - - public UpdateCodeBlockBuilder usingUpdateVariableName(String updateVariableName) { - this.updateVariableName = updateVariableName; - return this; - } - - CodeBlock build() { - - CodeBlock.Builder builder = CodeBlock.builder(); - - builder.add("\n"); - String tmpVariableName = updateVariableName + "Document"; - builder.add(renderExpressionToDocument(source.getUpdate().getUpdateString(), tmpVariableName, arguments)); - builder.addStatement("$1T $2L = new $1T($3L)", BasicUpdate.class, updateVariableName, tmpVariableName); + /** + * Builder for generating {@link org.springframework.data.mongodb.core.query.NearQuery} execution {@link CodeBlock} + * that can return {@link org.springframework.data.geo.GeoResults}. + * + * @param context + * @param queryMethod + * @return + */ + static GeoNearExecutionCodeBlockBuilder geoNearExecutionBlockBuilder(AotQueryMethodGenerationContext context, + MongoQueryMethod queryMethod) { - return builder.build(); - } + return new GeoNearExecutionCodeBlockBuilder(context, queryMethod); } - private static CodeBlock renderExpressionToDocument(@Nullable String source, String variableName, - List arguments) { + static CodeBlock renderExpressionToDocument(@Nullable String source, String variableName, List arguments) { Builder builder = CodeBlock.builder(); if (!StringUtils.hasText(source)) { @@ -986,7 +186,7 @@ private static CodeBlock renderExpressionToDocument(@Nullable String source, Str return builder.build(); } - private static boolean containsPlaceholder(String source) { + static boolean containsPlaceholder(String source) { return PARAMETER_BINDING_PATTERN.matcher(source).find(); } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index c510ab8e16..9b4d3b4b83 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -15,15 +15,16 @@ */ package org.springframework.data.mongodb.repository.aot; -import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.QueryCodeBlockBuilder; import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.aggregationBlockBuilder; import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.aggregationExecutionBlockBuilder; import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.deleteExecutionBlockBuilder; import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.geoNearBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.geoNearExecutionBlockBuilder; import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.queryBlockBuilder; import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.queryExecutionBlockBuilder; import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.updateBlockBuilder; import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.updateExecutionBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.QueryBlocks.QueryCodeBlockBuilder; import java.lang.reflect.Method; import java.util.Locale; @@ -188,7 +189,8 @@ private static boolean backoff(MongoQueryMethod method) { // TODO: namedQuery, Regex queries, queries accepting Shapes (e.g. within) or returning arrays. boolean skip = method.isSearchQuery() || method.getName().toLowerCase(Locale.ROOT).contains("regex") - || method.getReturnType().getType().isArray() || ClassUtils.isAssignable(GeoPage.class, method.getReturnType().getType()); + || method.getReturnType().getType().isArray() + || ClassUtils.isAssignable(GeoPage.class, method.getReturnType().getType()); if (skip && logger.isDebugEnabled()) { logger.debug("Skipping AOT generation for [%s]. Method is either returning an array or a geo-near, regex query" @@ -204,8 +206,9 @@ private static MethodContributor nearQueryMethodContributor(Mo CodeBlock.Builder builder = CodeBlock.builder(); - builder.add(geoNearBlockBuilder(context, queryMethod).usingQueryVariableName("nearQuery").build()); - // builder.add(aggregationExecutionBlockBuilder(context, queryMethod).referencing("aggregation").build()); + String variableName = "nearQuery"; + builder.add(geoNearBlockBuilder(context, queryMethod).usingQueryVariableName(variableName).build()); + builder.add(geoNearExecutionBlockBuilder(context, queryMethod).referencing(variableName).build()); return builder.build(); }); @@ -218,9 +221,10 @@ private static MethodContributor aggregationMethodContributor( CodeBlock.Builder builder = CodeBlock.builder(); + String variableName = "aggregation"; builder.add(aggregationBlockBuilder(context, queryMethod).stages(aggregation) - .usingAggregationVariableName("aggregation").build()); - builder.add(aggregationExecutionBlockBuilder(context, queryMethod).referencing("aggregation").build()); + .usingAggregationVariableName(variableName).build()); + builder.add(aggregationExecutionBlockBuilder(context, queryMethod).referencing(variableName).build()); return builder.build(); }); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java new file mode 100644 index 0000000000..50aed69d33 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java @@ -0,0 +1,308 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +import org.bson.Document; +import org.jspecify.annotations.NullUnmarked; +import org.jspecify.annotations.Nullable; +import org.springframework.core.annotation.MergedAnnotation; +import org.springframework.data.geo.Box; +import org.springframework.data.geo.Circle; +import org.springframework.data.geo.Polygon; +import org.springframework.data.mongodb.core.ExecutableFindOperation.FindWithQuery; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.geo.GeoJson; +import org.springframework.data.mongodb.core.geo.Sphere; +import org.springframework.data.mongodb.core.query.BasicQuery; +import org.springframework.data.mongodb.repository.Hint; +import org.springframework.data.mongodb.repository.Meta; +import org.springframework.data.mongodb.repository.ReadPreference; +import org.springframework.data.mongodb.repository.query.MongoParameters.MongoParameter; +import org.springframework.data.mongodb.repository.query.MongoQueryExecution.PagedExecution; +import org.springframework.data.mongodb.repository.query.MongoQueryExecution.SlicedExecution; +import org.springframework.data.mongodb.repository.query.MongoQueryMethod; +import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.javapoet.TypeName; +import org.springframework.util.ClassUtils; +import org.springframework.util.NumberUtils; +import org.springframework.util.StringUtils; + +/** + * @author Christoph Strobl + * @since 5.0 + */ +class QueryBlocks { + + @NullUnmarked + static class QueryExecutionCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private QueryInteraction query; + + QueryExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + this.queryMethod = queryMethod; + } + + QueryExecutionCodeBlockBuilder forQuery(QueryInteraction query) { + + this.query = query; + return this; + } + + CodeBlock build() { + + String mongoOpsRef = context.fieldNameOf(MongoOperations.class); + + Builder builder = CodeBlock.builder(); + + boolean isProjecting = context.getReturnedType().isProjecting(); + Class domainType = context.getRepositoryInformation().getDomainType(); + Object actualReturnType = queryMethod.getParameters().hasDynamicProjection() || isProjecting + ? TypeName.get(context.getActualReturnType().getType()) + : domainType; + + builder.add("\n"); + + if (queryMethod.getParameters().hasDynamicProjection()) { + builder.addStatement("$T<$T> $L = $L.query($T.class).as($L)", FindWithQuery.class, actualReturnType, + context.localVariable("finder"), mongoOpsRef, domainType, context.getDynamicProjectionParameterName()); + } else if (isProjecting) { + builder.addStatement("$T<$T> $L = $L.query($T.class).as($T.class)", FindWithQuery.class, actualReturnType, + context.localVariable("finder"), mongoOpsRef, domainType, actualReturnType); + } else { + + builder.addStatement("$T<$T> $L = $L.query($T.class)", FindWithQuery.class, actualReturnType, + context.localVariable("finder"), mongoOpsRef, domainType); + } + + String terminatingMethod; + + if (queryMethod.isCollectionQuery() || queryMethod.isPageQuery() || queryMethod.isSliceQuery()) { + terminatingMethod = "all()"; + } else if (query.isCount()) { + terminatingMethod = "count()"; + } else if (query.isExists()) { + terminatingMethod = "exists()"; + } else if (queryMethod.isStreamQuery()) { + terminatingMethod = "stream()"; + } else { + terminatingMethod = Optional.class.isAssignableFrom(context.getReturnType().toClass()) ? "one()" : "oneValue()"; + } + + if (queryMethod.isPageQuery()) { + builder.addStatement("return new $T($L, $L).execute($L)", PagedExecution.class, context.localVariable("finder"), + context.getPageableParameterName(), query.name()); + } else if (queryMethod.isSliceQuery()) { + builder.addStatement("return new $T($L, $L).execute($L)", SlicedExecution.class, + context.localVariable("finder"), context.getPageableParameterName(), query.name()); + } else if (queryMethod.isScrollQuery()) { + + String scrollPositionParameterName = context.getScrollPositionParameterName(); + + builder.addStatement("return $L.matching($L).scroll($L)", context.localVariable("finder"), query.name(), + scrollPositionParameterName); + } else { + if (query.isCount() && !ClassUtils.isAssignable(Long.class, context.getActualReturnType().getRawClass())) { + + Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); + builder.addStatement("return $T.convertNumberToTargetClass($L.matching($L).$L, $T.class)", NumberUtils.class, + context.localVariable("finder"), query.name(), terminatingMethod, returnType); + + } else { + builder.addStatement("return $L.matching($L).$L", context.localVariable("finder"), query.name(), + terminatingMethod); + } + } + + return builder.build(); + } + } + + @NullUnmarked + static class QueryCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + + private QueryInteraction source; + private final List arguments; + private String queryVariableName; + + QueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + + this.arguments = new ArrayList<>(); + this.queryMethod = queryMethod; + collectArguments(context); + + } + + private void collectArguments(AotQueryMethodGenerationContext context) { + + for (MongoParameter parameter : queryMethod.getParameters().getBindableParameters()) { + String parameterName = context.getParameterName(parameter.getIndex()); + if (ClassUtils.isAssignable(GeoJson.class, parameter.getType())) { + + // renders as generic $geometry, thus can be handled by the converter when parsing + arguments.add(CodeBlock.of(parameterName)); + } else if (ClassUtils.isAssignable(Circle.class, parameter.getType()) + || ClassUtils.isAssignable(Sphere.class, parameter.getType())) { + + // $center | $centerSphere : [ [ , ], ] + arguments.add(CodeBlock.builder().add( + "$1T.of($1T.of($2L.getCenter().getX(), $2L.getCenter().getY()), $2L.getRadius().getNormalizedValue())", + List.class, parameterName).build()); + } else if (ClassUtils.isAssignable(Box.class, parameter.getType())) { + + // $box: [ [ , ], [ , ] ] + arguments.add(CodeBlock.builder().add( + "$1T.of($1T.of($2L.getFirst().getX(), $2L.getFirst().getY()), $1T.of($2L.getSecond().getX(), $2L.getSecond().getY()))", + List.class, parameterName).build()); + } else if (ClassUtils.isAssignable(Polygon.class, parameter.getType())) { + + // $polygon: [ [ , ], [ , ], [ , ], ... ] + String localVar = context.localVariable("_p"); + arguments.add( + CodeBlock.builder().add("$1L.getPoints().stream().map($2L -> $3T.of($2L.getX(), $2L.getY())).toList()", + parameterName, localVar, List.class).build()); + } else { + arguments.add(CodeBlock.of(parameterName)); + } + } + } + + QueryCodeBlockBuilder filter(QueryInteraction query) { + + this.source = query; + return this; + } + + QueryCodeBlockBuilder usingQueryVariableName(String queryVariableName) { + + this.queryVariableName = queryVariableName; + return this; + } + + CodeBlock build() { + + Builder builder = CodeBlock.builder(); + + builder.add("\n"); + builder.add(renderExpressionToQuery(source.getQuery().getQueryString(), queryVariableName)); + + if (StringUtils.hasText(source.getQuery().getFieldsString())) { + + builder + .add(MongoCodeBlocks.renderExpressionToDocument(source.getQuery().getFieldsString(), "fields", arguments)); + builder.addStatement("$L.setFieldsObject(fields)", queryVariableName); + } + + String sortParameter = context.getSortParameterName(); + if (StringUtils.hasText(sortParameter)) { + builder.addStatement("$L.with($L)", queryVariableName, sortParameter); + } else if (StringUtils.hasText(source.getQuery().getSortString())) { + + builder.add(MongoCodeBlocks.renderExpressionToDocument(source.getQuery().getSortString(), "sort", arguments)); + builder.addStatement("$L.setSortObject(sort)", queryVariableName); + } + + String limitParameter = context.getLimitParameterName(); + if (StringUtils.hasText(limitParameter)) { + builder.addStatement("$L.limit($L)", queryVariableName, limitParameter); + } else if (context.getPageableParameterName() == null && source.getQuery().isLimited()) { + builder.addStatement("$L.limit($L)", queryVariableName, source.getQuery().getLimit()); + } + + String pageableParameter = context.getPageableParameterName(); + if (StringUtils.hasText(pageableParameter) && !queryMethod.isPageQuery() && !queryMethod.isSliceQuery()) { + builder.addStatement("$L.with($L)", queryVariableName, pageableParameter); + } + + MergedAnnotation hintAnnotation = context.getAnnotation(Hint.class); + String hint = hintAnnotation.isPresent() ? hintAnnotation.getString("value") : null; + + if (StringUtils.hasText(hint)) { + builder.addStatement("$L.withHint($S)", queryVariableName, hint); + } + + MergedAnnotation readPreferenceAnnotation = context.getAnnotation(ReadPreference.class); + String readPreference = readPreferenceAnnotation.isPresent() ? readPreferenceAnnotation.getString("value") : null; + + if (StringUtils.hasText(readPreference)) { + builder.addStatement("$L.withReadPreference($T.valueOf($S))", queryVariableName, + com.mongodb.ReadPreference.class, readPreference); + } + + MergedAnnotation metaAnnotation = context.getAnnotation(Meta.class); + if (metaAnnotation.isPresent()) { + + long maxExecutionTimeMs = metaAnnotation.getLong("maxExecutionTimeMs"); + if (maxExecutionTimeMs != -1) { + builder.addStatement("$L.maxTimeMsec($L)", queryVariableName, maxExecutionTimeMs); + } + + int cursorBatchSize = metaAnnotation.getInt("cursorBatchSize"); + if (cursorBatchSize != 0) { + builder.addStatement("$L.cursorBatchSize($L)", queryVariableName, cursorBatchSize); + } + + String comment = metaAnnotation.getString("comment"); + if (StringUtils.hasText("comment")) { + builder.addStatement("$L.comment($S)", queryVariableName, comment); + } + } + + // TODO: Meta annotation: Disk usage + + return builder.build(); + } + + private CodeBlock renderExpressionToQuery(@Nullable String source, String variableName) { + + Builder builder = CodeBlock.builder(); + if (!StringUtils.hasText(source)) { + + builder.addStatement("$1T $2L = new $1T(new $3T())", BasicQuery.class, variableName, Document.class); + } else if (!MongoCodeBlocks.containsPlaceholder(source)) { + builder.addStatement("$1T $2L = new $1T($3T.parse($4S))", BasicQuery.class, variableName, Document.class, + source); + } else { + builder.add("$T $L = createQuery($S, new $T[]{ ", BasicQuery.class, variableName, source, Object.class); + Iterator iterator = arguments.iterator(); + while (iterator.hasNext()) { + builder.add(iterator.next()); + if (iterator.hasNext()) { + builder.add(", "); + } + } + builder.add("});\n"); + } + + return builder.build(); + } + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java new file mode 100644 index 0000000000..c4c9fb62c9 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java @@ -0,0 +1,145 @@ +/* + * Copyright 2025. the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import java.util.List; +import java.util.stream.Collectors; + +import org.jspecify.annotations.NullUnmarked; +import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.query.BasicUpdate; +import org.springframework.data.mongodb.repository.query.MongoQueryMethod; +import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; +import org.springframework.data.util.ReflectionUtils; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.util.ClassUtils; +import org.springframework.util.NumberUtils; + +/** + * @author Christoph Strobl + * @since 2025/06 + */ +class UpdateBlocks { + + @NullUnmarked + static class UpdateExecutionCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private String queryVariableName; + private String updateVariableName; + + UpdateExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + this.queryMethod = queryMethod; + } + + UpdateExecutionCodeBlockBuilder withFilter(String queryVariableName) { + + this.queryVariableName = queryVariableName; + return this; + } + + UpdateExecutionCodeBlockBuilder referencingUpdate(String updateVariableName) { + + this.updateVariableName = updateVariableName; + return this; + } + + CodeBlock build() { + + String mongoOpsRef = context.fieldNameOf(MongoOperations.class); + Builder builder = CodeBlock.builder(); + + builder.add("\n"); + + String updateReference = updateVariableName; + Class domainType = context.getRepositoryInformation().getDomainType(); + builder.addStatement("$1T<$2T> $3L = $4L.update($2T.class)", ExecutableUpdate.class, domainType, + context.localVariable("updater"), mongoOpsRef); + + Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); + if (ReflectionUtils.isVoid(returnType)) { + builder.addStatement("$L.matching($L).apply($L).all()", context.localVariable("updater"), queryVariableName, + updateReference); + } else if (ClassUtils.isAssignable(Long.class, returnType)) { + builder.addStatement("return $L.matching($L).apply($L).all().getModifiedCount()", + context.localVariable("updater"), queryVariableName, updateReference); + } else { + builder.addStatement("$T $L = $L.matching($L).apply($L).all().getModifiedCount()", Long.class, + context.localVariable("modifiedCount"), context.localVariable("updater"), queryVariableName, + updateReference); + builder.addStatement("return $T.convertNumberToTargetClass($L, $T.class)", NumberUtils.class, + context.localVariable("modifiedCount"), returnType); + } + + return builder.build(); + } + } + + @NullUnmarked + static class UpdateCodeBlockBuilder { + + private UpdateInteraction source; + private List arguments; + private String updateVariableName; + + public UpdateCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + this.arguments = context.getBindableParameterNames().stream().map(CodeBlock::of).collect(Collectors.toList()); + } + + public UpdateCodeBlockBuilder update(UpdateInteraction update) { + this.source = update; + return this; + } + + public UpdateCodeBlockBuilder usingUpdateVariableName(String updateVariableName) { + this.updateVariableName = updateVariableName; + return this; + } + + CodeBlock build() { + + Builder builder = CodeBlock.builder(); + + builder.add("\n"); + String tmpVariableName = updateVariableName + "Document"; + builder.add(MongoCodeBlocks.renderExpressionToDocument(source.getUpdate().getUpdateString(), tmpVariableName, arguments)); + builder.addStatement("$1T $2L = new $1T($3L)", BasicUpdate.class, updateVariableName, tmpVariableName); + + return builder.build(); + } + } +} From 364fa66babad6cc897da4fa5ab09641ba7c3cc8d Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Mon, 16 Jun 2025 10:31:36 +0200 Subject: [PATCH 12/24] make sure things are skipped for real --- .../aot/MongoRepositoryContributor.java | 8 ++++--- .../aot/MongoRepositoryContributorTests.java | 23 ++++++++++--------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index 9b4d3b4b83..df5dc66bb9 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -101,6 +101,10 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB MongoQueryMethod queryMethod = new MongoQueryMethod(method, getRepositoryInformation(), getProjectionFactory(), mappingContext); + if (backoff(queryMethod)) { + return null; + } + if (queryMethod.hasAnnotatedAggregation()) { AggregationInteraction aggregation = new AggregationInteraction(queryMethod.getAnnotatedAggregation()); return aggregationMethodContributor(queryMethod, aggregation); @@ -127,9 +131,7 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB } } - if (backoff(queryMethod)) { - return null; - } + if (query.isDelete()) { return deleteMethodContributor(queryMethod, query); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java index 6ac9a61458..6a30389705 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java @@ -31,7 +31,6 @@ import org.bson.Document; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.springframework.beans.factory.annotation.Autowired; @@ -50,7 +49,6 @@ import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; import org.springframework.data.geo.Distance; -import org.springframework.data.geo.GeoPage; import org.springframework.data.geo.GeoResult; import org.springframework.data.geo.GeoResults; import org.springframework.data.geo.Metrics; @@ -695,18 +693,21 @@ void testNearWithRange() { } @Test - @Disabled("too complicated") void testNearReturningGeoPage() { - // TODO: still need to create the count and extract the total elements - GeoPage page1 = fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), - Distance.of(2000, Metrics.KILOMETERS), PageRequest.of(0, 1)); - - assertThat(page1.hasNext()).isTrue(); + assertThatExceptionOfType(NoSuchMethodException.class) + .isThrownBy(() -> fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), + Distance.of(2000, Metrics.KILOMETERS), PageRequest.of(0, 1))); - GeoPage page2 = fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), - Distance.of(2000, Metrics.KILOMETERS), page1.nextPageable()); - assertThat(page2.hasNext()).isFalse(); + // TODO: still need to create the count and extract the total elements + // GeoPage page1 = fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), + // Distance.of(2000, Metrics.KILOMETERS), PageRequest.of(0, 1)); + // + // assertThat(page1.hasNext()).isTrue(); + // + // GeoPage page2 = fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), + // Distance.of(2000, Metrics.KILOMETERS), page1.nextPageable()); + // assertThat(page2.hasNext()).isFalse(); } /** From 435ceab0ec65fee438196b120378cb6a99464445 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Mon, 16 Jun 2025 12:57:15 +0200 Subject: [PATCH 13/24] ReadPreference for geo queries --- .../mongodb/repository/aot/GeoBlocks.java | 2 + .../repository/aot/MongoCodeBlocks.java | 13 ++ .../aot/MongoRepositoryContributor.java | 12 +- .../mongodb/repository/aot/QueryBlocks.java | 11 +- .../test/java/example/aot/UserRepository.java | 3 + .../aot/QueryMethodContributionUnitTests.java | 156 ++++++++++++++++++ 6 files changed, 182 insertions(+), 15 deletions(-) create mode 100644 spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java index ecf111433b..0dfbe25401 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java @@ -85,6 +85,8 @@ CodeBlock build() { builder.addStatement("$L.with($L)", variableName, context.getPageableParameterName()); } + MongoCodeBlocks.appendReadPreference(context, builder, variableName); + return builder.build(); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java index 2a51bc81b2..9a87d4afe8 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java @@ -21,6 +21,8 @@ import org.bson.Document; import org.jspecify.annotations.Nullable; +import org.springframework.core.annotation.MergedAnnotation; +import org.springframework.data.mongodb.repository.ReadPreference; import org.springframework.data.mongodb.repository.aot.AggregationBlocks.AggregationCodeBlockBuilder; import org.springframework.data.mongodb.repository.aot.AggregationBlocks.AggregationExecutionCodeBlockBuilder; import org.springframework.data.mongodb.repository.aot.DeleteBlocks.DeleteExecutionCodeBlockBuilder; @@ -189,4 +191,15 @@ static CodeBlock renderExpressionToDocument(@Nullable String source, String vari static boolean containsPlaceholder(String source) { return PARAMETER_BINDING_PATTERN.matcher(source).find(); } + + static void appendReadPreference(AotQueryMethodGenerationContext context, Builder builder, String queryVariableName) { + + MergedAnnotation readPreferenceAnnotation = context.getAnnotation(ReadPreference.class); + String readPreference = readPreferenceAnnotation.isPresent() ? readPreferenceAnnotation.getString("value") : null; + + if (StringUtils.hasText(readPreference)) { + builder.addStatement("$L.withReadPreference($T.valueOf($S))", queryVariableName, + com.mongodb.ReadPreference.class, readPreference); + } + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index df5dc66bb9..1077cc0716 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -131,8 +131,6 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB } } - - if (query.isDelete()) { return deleteMethodContributor(queryMethod, query); } @@ -216,7 +214,7 @@ private static MethodContributor nearQueryMethodContributor(Mo }); } - private static MethodContributor aggregationMethodContributor(MongoQueryMethod queryMethod, + static MethodContributor aggregationMethodContributor(MongoQueryMethod queryMethod, AggregationInteraction aggregation) { return MethodContributor.forQueryMethod(queryMethod).withMetadata(aggregation).contribute(context -> { @@ -232,7 +230,7 @@ private static MethodContributor aggregationMethodContributor( }); } - private static MethodContributor updateMethodContributor(MongoQueryMethod queryMethod, + static MethodContributor updateMethodContributor(MongoQueryMethod queryMethod, UpdateInteraction update) { return MethodContributor.forQueryMethod(queryMethod).withMetadata(update).contribute(context -> { @@ -261,7 +259,7 @@ private static MethodContributor updateMethodContributor(Mongo }); } - private static MethodContributor aggregationUpdateMethodContributor(MongoQueryMethod queryMethod, + static MethodContributor aggregationUpdateMethodContributor(MongoQueryMethod queryMethod, AggregationUpdateInteraction update) { return MethodContributor.forQueryMethod(queryMethod).withMetadata(update).contribute(context -> { @@ -287,7 +285,7 @@ private static MethodContributor aggregationUpdateMethodContri }); } - private static MethodContributor deleteMethodContributor(MongoQueryMethod queryMethod, + static MethodContributor deleteMethodContributor(MongoQueryMethod queryMethod, QueryInteraction query) { return MethodContributor.forQueryMethod(queryMethod).withMetadata(query).contribute(context -> { @@ -302,7 +300,7 @@ private static MethodContributor deleteMethodContributor(Mongo }); } - private static MethodContributor queryMethodContributor(MongoQueryMethod queryMethod, + static MethodContributor queryMethodContributor(MongoQueryMethod queryMethod, QueryInteraction query) { return MethodContributor.forQueryMethod(queryMethod).withMetadata(query).contribute(context -> { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java index 50aed69d33..97a53f921d 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java @@ -34,7 +34,6 @@ import org.springframework.data.mongodb.core.query.BasicQuery; import org.springframework.data.mongodb.repository.Hint; import org.springframework.data.mongodb.repository.Meta; -import org.springframework.data.mongodb.repository.ReadPreference; import org.springframework.data.mongodb.repository.query.MongoParameters.MongoParameter; import org.springframework.data.mongodb.repository.query.MongoQueryExecution.PagedExecution; import org.springframework.data.mongodb.repository.query.MongoQueryExecution.SlicedExecution; @@ -249,13 +248,7 @@ CodeBlock build() { builder.addStatement("$L.withHint($S)", queryVariableName, hint); } - MergedAnnotation readPreferenceAnnotation = context.getAnnotation(ReadPreference.class); - String readPreference = readPreferenceAnnotation.isPresent() ? readPreferenceAnnotation.getString("value") : null; - - if (StringUtils.hasText(readPreference)) { - builder.addStatement("$L.withReadPreference($T.valueOf($S))", queryVariableName, - com.mongodb.ReadPreference.class, readPreference); - } + MongoCodeBlocks.appendReadPreference(context, builder, queryVariableName); MergedAnnotation metaAnnotation = context.getAnnotation(Meta.class); if (metaAnnotation.isPresent()) { @@ -297,6 +290,8 @@ private CodeBlock renderExpressionToQuery(@Nullable String source, String variab builder.add(iterator.next()); if (iterator.hasNext()) { builder.add(", "); + } else { + builder.add(" "); } } builder.add("});\n"); diff --git a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java index 2fc0787a67..d2288477fd 100644 --- a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java +++ b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java @@ -43,6 +43,7 @@ import org.springframework.data.geo.Polygon; import org.springframework.data.mongodb.core.aggregation.AggregationResults; import org.springframework.data.mongodb.core.geo.GeoJsonPolygon; +import org.springframework.data.mongodb.core.geo.Sphere; import org.springframework.data.mongodb.repository.Aggregation; import org.springframework.data.mongodb.repository.Hint; import org.springframework.data.mongodb.repository.Person; @@ -118,6 +119,8 @@ public interface UserRepository extends CrudRepository { List findByLocationCoordinatesWithin(Circle circle); + List findByLocationCoordinatesWithin(Sphere circle); + List findByLocationCoordinatesWithin(Box box); List findByLocationCoordinatesWithin(Polygon polygon); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java new file mode 100644 index 0000000000..531efe3308 --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java @@ -0,0 +1,156 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import static org.assertj.core.api.Assertions.assertThat; + +import example.aot.User; +import example.aot.UserRepository; + +import java.lang.reflect.Method; + +import javax.lang.model.element.Modifier; + +import org.junit.jupiter.api.Test; +import org.springframework.data.geo.Box; +import org.springframework.data.geo.Circle; +import org.springframework.data.geo.Distance; +import org.springframework.data.geo.GeoResults; +import org.springframework.data.geo.Point; +import org.springframework.data.geo.Polygon; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.geo.Sphere; +import org.springframework.data.mongodb.repository.ReadPreference; +import org.springframework.data.repository.Repository; +import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; +import org.springframework.data.repository.aot.generate.AotRepositoryFragmentMetadata; +import org.springframework.data.repository.aot.generate.MethodContributor; +import org.springframework.data.repository.core.RepositoryInformation; +import org.springframework.data.repository.query.QueryMethod; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.FieldSpec; +import org.springframework.javapoet.MethodSpec; + +/** + * @author Christoph Strobl + */ +public class QueryMethodContributionUnitTests { + + @Test + void rendersQueryForNearUsingPoint() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesNear", Point.class); + + assertThat(methodSpec.toString()) // + .contains("{'location.coordinates':{'$near':?0}}") // + .contains("Object[]{ location }") // + .contains("return finder.matching(filterQuery).all()"); + } + + @Test + void rendersQueryForWithinUsingCircle() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesWithin", Circle.class); + + assertThat(methodSpec.toString()) // + .contains("{'location.coordinates':{'$geoWithin':{'$center':?0}}") // + .contains( + "List.of(circle.getCenter().getX(), circle.getCenter().getY()), circle.getRadius().getNormalizedValue())") // + .contains("return finder.matching(filterQuery).all()"); + } + + @Test + void rendersQueryForWithinUsingSphere() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesWithin", Sphere.class); + + assertThat(methodSpec.toString()) // + .contains("{'location.coordinates':{'$geoWithin':{'$centerSphere':?0}}") // + .contains( + "List.of(circle.getCenter().getX(), circle.getCenter().getY()), circle.getRadius().getNormalizedValue())") // + .contains("return finder.matching(filterQuery).all()"); + } + + @Test + void rendersQueryForWithinUsingBox() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesWithin", Box.class); + + assertThat(methodSpec.toString()) // + .contains("{'location.coordinates':{'$geoWithin':{'$box':?0}}") // + .contains("List.of(box.getFirst().getX(), box.getFirst().getY())") // + .contains("List.of(box.getSecond().getX(), box.getSecond().getY())") // + .contains("return finder.matching(filterQuery).all()"); + } + + @Test + void rendersQueryForWithinUsingPolygon() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesWithin", Polygon.class); + + assertThat(methodSpec.toString()) // + .contains("{'location.coordinates':{'$geoWithin':{'$polygon':?0}}") // + .contains("polygon.getPoints().stream().map(_p ->") // + .contains("List.of(_p.getX(), _p.getY())") // + .contains("return finder.matching(filterQuery).all()"); + } + + @Test + void rendersNearQueryForGeoResults() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepoWithMeta.class, "findByLocationCoordinatesNear", Point.class, + Distance.class); + + assertThat(methodSpec.toString()) // + .contains("NearQuery.near(point)") // + .contains("nearQuery.maxDistance(maxDistance).in(maxDistance.getMetric())") // + .contains(".withReadPreference(com.mongodb.ReadPreference.valueOf(\"NEAREST\")") // + .contains(".near(nearQuery).all()"); + } + + private static MethodSpec codeOf(Class repository, String methodName, Class... args) + throws NoSuchMethodException { + + Method method = repository.getMethod(methodName, args); + + TestMongoAotRepositoryContext repoContext = new TestMongoAotRepositoryContext(repository, null); + MongoRepositoryContributor contributor = new MongoRepositoryContributor(repoContext); + MethodContributor methodContributor = contributor.contributeQueryMethod(method); + + AotRepositoryFragmentMetadata metadata = new AotRepositoryFragmentMetadata(ClassName.get(UserRepository.class)); + metadata.addField( + FieldSpec.builder(MongoOperations.class, "mongoOperations", Modifier.PRIVATE, Modifier.FINAL).build()); + + TestQueryMethodGenerationContext methodContext = new TestQueryMethodGenerationContext( + repoContext.getRepositoryInformation(), method, methodContributor.getQueryMethod(), metadata); + return methodContributor.contribute(methodContext); + } + + static class TestQueryMethodGenerationContext extends AotQueryMethodGenerationContext { + + protected TestQueryMethodGenerationContext(RepositoryInformation repositoryInformation, Method method, + QueryMethod queryMethod, AotRepositoryFragmentMetadata targetTypeMetadata) { + super(repositoryInformation, method, queryMethod, targetTypeMetadata); + } + } + + interface UserRepoWithMeta extends Repository { + + @ReadPreference("NEAREST") + GeoResults findByLocationCoordinatesNear(Point point, Distance maxDistance); + + } +} From fbceb946a47e112f4c8122dee9e825af0ea77c64 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Mon, 16 Jun 2025 14:47:48 +0200 Subject: [PATCH 14/24] GeoNear query with additional filter --- .../repository/aot/AotQueryCreator.java | 6 ++- .../aot/MongoRepositoryContributor.java | 18 ++++++-- .../repository/aot/NearQueryInteraction.java | 9 +++- .../test/java/example/aot/UserRepository.java | 2 + ...tractPersonRepositoryIntegrationTests.java | 46 +++++++++++++++---- .../mongodb/repository/PersonRepository.java | 1 + .../AotFragmentTestConfigurationSupport.java | 9 +++- .../aot/MongoRepositoryContributorTests.java | 11 ++++- .../aot/QueryMethodContributionUnitTests.java | 16 ++++++- 9 files changed, 99 insertions(+), 19 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java index 2b22550026..11d6e8bdd2 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java @@ -49,6 +49,7 @@ import org.springframework.data.mongodb.repository.query.ConvertingParameterAccessor; import org.springframework.data.mongodb.repository.query.MongoParameterAccessor; import org.springframework.data.mongodb.repository.query.MongoQueryCreator; +import org.springframework.data.mongodb.repository.query.MongoQueryMethod; import org.springframework.data.repository.query.Parameter; import org.springframework.data.repository.query.Parameters; import org.springframework.data.repository.query.QueryMethod; @@ -80,8 +81,11 @@ public AotQueryCreator() { @SuppressWarnings("NullAway") StringQuery createQuery(PartTree partTree, QueryMethod queryMethod) { + + boolean geoNear = queryMethod instanceof MongoQueryMethod mqm ? mqm.isGeoNearQuery() : false; + Query query = new MongoQueryCreator(partTree, - new PlaceholderConvertingParameterAccessor(new PlaceholderParameterAccessor(queryMethod)), mappingContext) + new PlaceholderConvertingParameterAccessor(new PlaceholderParameterAccessor(queryMethod)), mappingContext, geoNear, queryMethod.isSearchQuery()) .createQuery(); if (partTree.isLimiting()) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index 1077cc0716..ad70df271c 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -110,15 +110,15 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB return aggregationMethodContributor(queryMethod, aggregation); } + QueryInteraction query = createStringQuery(getRepositoryInformation(), queryMethod, + AnnotatedElementUtils.findMergedAnnotation(method, Query.class), method.getParameterCount()); + if (queryMethod.isGeoNearQuery() || (queryMethod.getParameters().getMaxDistanceIndex() != -1 && queryMethod.getReturnType().isCollectionLike())) { - NearQueryInteraction near = new NearQueryInteraction(); + NearQueryInteraction near = new NearQueryInteraction(query); return nearQueryMethodContributor(queryMethod, near); } - QueryInteraction query = createStringQuery(getRepositoryInformation(), queryMethod, - AnnotatedElementUtils.findMergedAnnotation(method, Query.class), method.getParameterCount()); - if (queryMethod.hasAnnotatedQuery()) { if (StringUtils.hasText(queryMethod.getAnnotatedQuery()) && Pattern.compile("[\\?:][#$]\\{.*\\}").matcher(queryMethod.getAnnotatedQuery()).find()) { @@ -206,8 +206,16 @@ private static MethodContributor nearQueryMethodContributor(Mo CodeBlock.Builder builder = CodeBlock.builder(); - String variableName = "nearQuery"; + String variableName = context.localVariable("nearQuery"); builder.add(geoNearBlockBuilder(context, queryMethod).usingQueryVariableName(variableName).build()); + + if (!context.getBindableParameterNames().isEmpty()) { + String filterQueryVariableName = context.localVariable("filterQuery"); + builder.add(queryBlockBuilder(context, queryMethod).usingQueryVariableName(filterQueryVariableName) + .filter(interaction.getQuery()).build()); + builder.addStatement("$L.query($L)", variableName, filterQueryVariableName); + } + builder.add(geoNearExecutionBlockBuilder(context, queryMethod).referencing(variableName).build()); return builder.build(); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/NearQueryInteraction.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/NearQueryInteraction.java index 23551abddc..30609ea15d 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/NearQueryInteraction.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/NearQueryInteraction.java @@ -18,6 +18,7 @@ import java.util.LinkedHashMap; import java.util.Map; +import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.repository.aot.generate.QueryMetadata; import org.springframework.util.StringUtils; @@ -30,9 +31,11 @@ class NearQueryInteraction extends MongoInteraction implements QueryMetadata { private final InteractionType interactionType; + private final QueryInteraction query; - NearQueryInteraction() { + NearQueryInteraction(QueryInteraction query) { interactionType = InteractionType.QUERY; + this.query = query; } @Override @@ -40,6 +43,10 @@ InteractionType getExecutionType() { return interactionType; } + public QueryInteraction getQuery() { + return query; + } + @Override public Map serialize() { diff --git a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java index d2288477fd..2044b3108a 100644 --- a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java +++ b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java @@ -129,6 +129,8 @@ public interface UserRepository extends CrudRepository { GeoResults findByLocationCoordinatesNear(Point point, Distance maxDistance); + GeoResults findByLocationCoordinatesNearAndLastname(Point point, Distance maxDistance, String lastname); + List> findUserAsListByLocationCoordinatesNear(Point point, Distance maxDistance); GeoResults findByLocationCoordinatesNear(Point point, Range distance); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java index c2cb6cacf8..9f7eda351d 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java @@ -15,10 +15,12 @@ */ package org.springframework.data.mongodb.repository; -import static java.util.Arrays.*; -import static org.assertj.core.api.Assertions.*; -import static org.assertj.core.api.Assumptions.*; -import static org.springframework.data.geo.Metrics.*; +import static java.util.Arrays.asList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assumptions.assumeThat; +import static org.springframework.data.geo.Metrics.KILOMETERS; import java.util.ArrayList; import java.util.Arrays; @@ -38,13 +40,22 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.extension.ExtendWith; - import org.springframework.beans.factory.annotation.Autowired; import org.springframework.dao.DuplicateKeyException; import org.springframework.dao.IncorrectResultSizeDataAccessException; -import org.springframework.data.domain.*; +import org.springframework.data.domain.Example; +import org.springframework.data.domain.ExampleMatcher; import org.springframework.data.domain.ExampleMatcher.GenericPropertyMatcher; +import org.springframework.data.domain.Limit; +import org.springframework.data.domain.Page; +import org.springframework.data.domain.PageRequest; +import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.ScrollPosition; +import org.springframework.data.domain.Slice; +import org.springframework.data.domain.Sort; import org.springframework.data.domain.Sort.Direction; +import org.springframework.data.domain.Window; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; import org.springframework.data.geo.Distance; @@ -216,8 +227,8 @@ void appliesScrollPositionCorrectly() { @Test // GH-4397 void appliesLimitToScrollingCorrectly() { - Window page = repository.findByLastnameLikeOrderByLastnameAscFirstnameAsc("*a*", - ScrollPosition.keyset(), Limit.of(2)); + Window page = repository.findByLastnameLikeOrderByLastnameAscFirstnameAsc("*a*", ScrollPosition.keyset(), + Limit.of(2)); assertThat(page.isLast()).isFalse(); assertThat(page.size()).isEqualTo(2); @@ -250,7 +261,8 @@ void executesPagedFinderCorrectly() { @Test // GH-4397 void executesFinderCorrectlyWithSortAndLimit() { - List page = repository.findByLastnameLike("*a*", Sort.by(Direction.ASC, "lastname", "firstname"), Limit.of(2)); + List page = repository.findByLastnameLike("*a*", Sort.by(Direction.ASC, "lastname", "firstname"), + Limit.of(2)); assertThat(page).containsExactly(carter, stefan); } @@ -462,6 +474,22 @@ void executesGeoNearQueryForResultsCorrectly() { assertThat(results.getContent()).isNotEmpty(); } + @Test + void executesGeoNearQueryWithAdditionalFilterCorrectly() { + + Point point = new Point(-73.99171, 40.738868); + dave.setLocation(point); + repository.save(dave); + + Person p2 = new Person("fn", "ln", 42, Sex.MALE); + p2.setLocation(point); + repository.save(p2); + + GeoResults results = repository.findByLocationNearAndLastname(new Point(-73.99, 40.73), + Distance.of(2000, Metrics.KILOMETERS), "ln"); + assertThat(results.getContent()).hasSize(1); + } + @Test void executesGeoPageQueryForResultsCorrectly() { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java index 1f4f682ebc..6c40ab622a 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java @@ -221,6 +221,7 @@ Window findByLastnameLikeOrderByLastnameAscFirstnameAsc(String lastname, List findByNamedQuery(String firstname); GeoResults findByLocationNear(Point point, Distance maxDistance); + GeoResults findByLocationNearAndLastname(Point point, Distance maxDistance, String Lastname); // DATAMONGO-1110 GeoResults findPersonByLocationNear(Point point, Range distance); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/AotFragmentTestConfigurationSupport.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/AotFragmentTestConfigurationSupport.java index eba08ecc2e..5b86acdace 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/AotFragmentTestConfigurationSupport.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/AotFragmentTestConfigurationSupport.java @@ -93,7 +93,7 @@ private Object getFragmentFacadeProxy(Object fragment) { Method target = ReflectionUtils.findMethod(fragment.getClass(), method.getName(), method.getParameterTypes()); if (target == null) { - throw new NoSuchMethodException("Method [%s] is not implemented by [%s]".formatted(method, target)); + throw new MethodNotImplementedException("Method [%s] is not implemented by [%s]".formatted(method, target)); } try { @@ -127,4 +127,11 @@ public ProjectionFactory getProjectionFactory() { } }; } + + public static class MethodNotImplementedException extends RuntimeException { + + public MethodNotImplementedException(String message) { + super(message); + } + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java index 6a30389705..971cce7c32 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java @@ -59,6 +59,7 @@ import org.springframework.data.mongodb.core.aggregation.AggregationResults; import org.springframework.data.mongodb.core.geo.GeoJsonPoint; import org.springframework.data.mongodb.core.geo.GeoJsonPolygon; +import org.springframework.data.mongodb.repository.aot.AotFragmentTestConfigurationSupport.MethodNotImplementedException; import org.springframework.data.mongodb.test.util.Client; import org.springframework.data.mongodb.test.util.MongoClientExtension; import org.springframework.data.mongodb.test.util.MongoTestUtils; @@ -675,6 +676,14 @@ void testNearWithGeoResult() { assertThat(users).extracting(GeoResult::getContent).extracting(User::getUsername).containsExactly("leia"); } + @Test + void testNearWithAdditionalFilterQueryAsGeoResult() { + + GeoResults users = fragment.findByLocationCoordinatesNearAndLastname(new Point(-73.99, 40.73), + Distance.of(50, Metrics.KILOMETERS), "Organa"); + assertThat(users).extracting(GeoResult::getContent).extracting(User::getUsername).containsExactly("leia"); + } + @Test void testNearReturningListOfGeoResult() { @@ -695,7 +704,7 @@ void testNearWithRange() { @Test void testNearReturningGeoPage() { - assertThatExceptionOfType(NoSuchMethodException.class) + assertThatExceptionOfType(MethodNotImplementedException.class) .isThrownBy(() -> fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), Distance.of(2000, Metrics.KILOMETERS), PageRequest.of(0, 1))); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java index 531efe3308..b9c598dfce 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java @@ -118,6 +118,21 @@ void rendersNearQueryForGeoResults() throws NoSuchMethodException { .contains("NearQuery.near(point)") // .contains("nearQuery.maxDistance(maxDistance).in(maxDistance.getMetric())") // .contains(".withReadPreference(com.mongodb.ReadPreference.valueOf(\"NEAREST\")") // + .doesNotContain(".query(") // + .contains(".near(nearQuery).all()"); + } + + @Test + void rendersNearQueryWithFilterForGeoResults() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesNearAndLastname", Point.class, + Distance.class, String.class); + + assertThat(methodSpec.toString()) // + .contains("NearQuery.near(point)") // + .contains("nearQuery.maxDistance(maxDistance).in(maxDistance.getMetric())") // + .contains("filterQuery = createQuery(\"{'lastname':?0}\", new java.lang.Object[]{ lastname })") // + .contains("nearQuery.query(filterQuery)") // .contains(".near(nearQuery).all()"); } @@ -151,6 +166,5 @@ interface UserRepoWithMeta extends Repository { @ReadPreference("NEAREST") GeoResults findByLocationCoordinatesNear(Point point, Distance maxDistance); - } } From 253a78caa4b472b1421d47fb793850c4a78530dd Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Tue, 17 Jun 2025 14:43:18 +0200 Subject: [PATCH 15/24] Support GeoPage --- .../mongodb/core/ExecutableFindOperation.java | 8 +++ .../core/ExecutableFindOperationSupport.java | 5 ++ .../data/mongodb/core/MongoTemplate.java | 26 ++++++++++ .../core/aggregation/GeoNearOperation.java | 51 ++++++++++++++++--- .../data/mongodb/core/query/NearQuery.java | 2 +- .../mongodb/repository/aot/GeoBlocks.java | 37 +++++++------- .../aot/MongoRepositoryContributor.java | 5 +- .../repository/query/MongoQueryExecution.java | 23 +++++---- .../ExecutableFindOperationSupportTests.java | 24 ++++++--- ...tractPersonRepositoryIntegrationTests.java | 25 +++++++++ .../aot/MongoRepositoryContributorTests.java | 23 ++++----- .../aot/QueryMethodContributionUnitTests.java | 47 +++++++++++++++-- .../query/MongoQueryExecutionUnitTests.java | 4 +- 13 files changed, 214 insertions(+), 66 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ExecutableFindOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ExecutableFindOperation.java index 43c0d521c3..47fea8a02f 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ExecutableFindOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ExecutableFindOperation.java @@ -225,6 +225,14 @@ interface TerminatingFindNear { * @return never {@literal null}. */ GeoResults all(); + + /** + * Count matching elements. + * + * @return number of elements matching the query. + * @since 5.0 + */ + long count(); } /** diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ExecutableFindOperationSupport.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ExecutableFindOperationSupport.java index 46289ecfa4..39f4affd35 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ExecutableFindOperationSupport.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ExecutableFindOperationSupport.java @@ -243,6 +243,11 @@ public TerminatingFindNear map(QueryResultConverter all() { return template.doGeoNear(nearQuery, domainType, getCollectionName(), returnType, resultConverter); } + + @Override + public long count() { + return template.doGeoNearCount(nearQuery, domainType, getCollectionName()); + } } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index 8682f77ec8..03c0bb7682 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -48,6 +48,7 @@ import org.springframework.dao.support.PersistenceExceptionTranslator; import org.springframework.data.convert.EntityReader; import org.springframework.data.domain.OffsetScrollPosition; +import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Window; import org.springframework.data.geo.Distance; import org.springframework.data.geo.GeoResult; @@ -1044,6 +1045,31 @@ public GeoResults geoNear(NearQuery near, Class domainType, String col return doGeoNear(near, domainType, collectionName, returnType, QueryResultConverter.entity()); } + long doGeoNearCount(NearQuery near, Class domainType, String collectionName) { + + Builder optionsBuilder = AggregationOptions.builder().collation(near.getCollation()); + + if (near.hasReadPreference()) { + optionsBuilder.readPreference(near.getReadPreference()); + } + + if (near.hasReadConcern()) { + optionsBuilder.readConcern(near.getReadConcern()); + } + + String distanceField = operations.nearQueryDistanceFieldName(domainType); + Aggregation $geoNear = TypedAggregation.newAggregation(domainType, + Aggregation.geoNear(near, distanceField).skip(-1).limit(-1), Aggregation.count().as("_totalCount")) + .withOptions(optionsBuilder.build()); + + AggregationResults results = doAggregate($geoNear, collectionName, Document.class, + queryOperations.createAggregation($geoNear, (AggregationOperationContext) null)); + Iterator iterator = results.iterator(); + return iterator.hasNext() + ? NumberUtils.convertNumberToTargetClass(iterator.next().get("_totalCount", Integer.class), Long.class) + : 0L; + } + GeoResults doGeoNear(NearQuery near, Class domainType, String collectionName, Class returnType, QueryResultConverter resultConverter) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GeoNearOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GeoNearOperation.java index bcfc64f2b4..04b793f839 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GeoNearOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GeoNearOperation.java @@ -42,6 +42,8 @@ public class GeoNearOperation implements AggregationOperation { private final NearQuery nearQuery; private final String distanceField; private final @Nullable String indexKey; + private final @Nullable Long skip; + private final @Nullable Integer limit; /** * Creates a new {@link GeoNearOperation} from the given {@link NearQuery} and the given distance field. The @@ -51,7 +53,7 @@ public class GeoNearOperation implements AggregationOperation { * @param distanceField must not be {@literal null}. */ public GeoNearOperation(NearQuery nearQuery, String distanceField) { - this(nearQuery, distanceField, null); + this(nearQuery, distanceField, null, nearQuery.getSkip(), null); } /** @@ -63,7 +65,8 @@ public GeoNearOperation(NearQuery nearQuery, String distanceField) { * @param indexKey can be {@literal null}; * @since 2.1 */ - private GeoNearOperation(NearQuery nearQuery, String distanceField, @Nullable String indexKey) { + private GeoNearOperation(NearQuery nearQuery, String distanceField, @Nullable String indexKey, @Nullable Long skip, + @Nullable Integer limit) { Assert.notNull(nearQuery, "NearQuery must not be null"); Assert.hasLength(distanceField, "Distance field must not be null or empty"); @@ -71,6 +74,8 @@ private GeoNearOperation(NearQuery nearQuery, String distanceField, @Nullable St this.nearQuery = nearQuery; this.distanceField = distanceField; this.indexKey = indexKey; + this.skip = skip; + this.limit = limit; } /** @@ -83,7 +88,30 @@ private GeoNearOperation(NearQuery nearQuery, String distanceField, @Nullable St */ @Contract("_ -> new") public GeoNearOperation useIndex(String key) { - return new GeoNearOperation(nearQuery, distanceField, key); + return new GeoNearOperation(nearQuery, distanceField, key, skip, limit); + } + + /** + * Override potential skip applied via {@link NearQuery#getSkip()}. Adds an additional {@link SkipOperation} if value + * is non negative. + * + * @param skip + * @return new instance of {@link GeoNearOperation}. + * @since 5.0 + */ + public GeoNearOperation skip(long skip) { + return new GeoNearOperation(nearQuery, distanceField, indexKey, skip, limit); + } + + /** + * Override potential limit value. Adds an additional {@link LimitOperation} if value is non negative. + * + * @param limit + * @return new instance of {@link GeoNearOperation}. + * @since 5.0 + */ + public GeoNearOperation limit(Integer limit) { + return new GeoNearOperation(nearQuery, distanceField, indexKey, skip, limit); } @Override @@ -92,7 +120,13 @@ public Document toDocument(AggregationOperationContext context) { Document command = context.getMappedObject(nearQuery.toDocument()); if (command.containsKey("query")) { - command.replace("query", context.getMappedObject(command.get("query", Document.class))); + Document query = command.get("query", Document.class); + if (query == null || query.isEmpty()) { + command.remove("query"); + } else { + command.replace("query", context.getMappedObject(query)); + } + } command.remove("collation"); @@ -115,15 +149,18 @@ public List toPipelineStages(AggregationOperationContext context) { Document command = toDocument(context); Number limit = (Number) command.get("$geoNear", Document.class).remove("num"); + if (limit != null && this.limit != null) { + limit = this.limit; + } List stages = new ArrayList<>(3); stages.add(command); - if (nearQuery.getSkip() != null && nearQuery.getSkip() > 0) { - stages.add(new Document("$skip", nearQuery.getSkip())); + if (this.skip != null && this.skip > 0) { + stages.add(new Document("$skip", this.skip)); } - if (limit != null) { + if (limit != null && limit.longValue() > 0) { stages.add(new Document("$limit", limit.longValue())); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/NearQuery.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/NearQuery.java index 88d7dc5c1d..4f42437704 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/NearQuery.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/NearQuery.java @@ -671,7 +671,7 @@ public Document toDocument() { document.put("distanceMultiplier", getDistanceMultiplier()); } - if (limit != null) { + if (limit != null && limit > 0) { document.put("num", limit); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java index 0dfbe25401..b94f55adc2 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java @@ -15,9 +15,6 @@ */ package org.springframework.data.mongodb.repository.aot; -import java.util.List; -import java.util.stream.Collectors; - import org.springframework.data.geo.Distance; import org.springframework.data.geo.GeoPage; import org.springframework.data.geo.GeoResults; @@ -25,6 +22,7 @@ import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.repository.query.MongoQueryMethod; import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; +import org.springframework.data.support.PageableExecutionUtils; import org.springframework.javapoet.CodeBlock; import org.springframework.util.ClassUtils; @@ -38,14 +36,12 @@ static class GeoNearCodeBlockBuilder { private final AotQueryMethodGenerationContext context; private final MongoQueryMethod queryMethod; - private final List arguments; private String variableName; GeoNearCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { this.context = context; - this.arguments = context.getBindableParameterNames().stream().map(CodeBlock::of).collect(Collectors.toList()); this.queryMethod = queryMethod; } @@ -119,24 +115,31 @@ CodeBlock build() { CodeBlock.Builder builder = CodeBlock.builder(); builder.add("\n"); - // TODO: move the section below into dedicated executor builder + String executorVar = context.localVariable("nearFinder"); + builder.addStatement("var $L = $L.query($T.class).near($L)", executorVar, + context.fieldNameOf(MongoOperations.class), context.getRepositoryInformation().getDomainType(), + queryVariableName); + if (ClassUtils.isAssignable(GeoPage.class, context.getReturnType().getRawClass())) { - builder.addStatement("return new $T<>($L.query($T.class).near($L).all())", GeoPage.class, - context.fieldNameOf(MongoOperations.class), context.getRepositoryInformation().getDomainType(), - queryVariableName); - } - else if (ClassUtils.isAssignable(GeoResults.class, context.getReturnType().getRawClass())) { + String geoResultVar = context.localVariable("geoResult"); + builder.addStatement("var $L = $L.all()", geoResultVar, executorVar); - builder.addStatement("return $L.query($T.class).near($L).all()", context.fieldNameOf(MongoOperations.class), - context.getRepositoryInformation().getDomainType(), queryVariableName); + builder.beginControlFlow("if($L.isUnpaged())", context.getPageableParameterName()); + builder.addStatement("return new $T<>($L)", GeoPage.class, geoResultVar); + builder.endControlFlow(); + + String pageVar = context.localVariable("resultPage"); + builder.addStatement("var $L = $T.getPage($L.getContent(), $L, () -> $L.count())", pageVar, + PageableExecutionUtils.class, geoResultVar, context.getPageableParameterName(), executorVar); + builder.addStatement("return new $T<>($L, $L, $L.getTotalElements())", GeoPage.class, geoResultVar, + context.getPageableParameterName(), pageVar); + } else if (ClassUtils.isAssignable(GeoResults.class, context.getReturnType().getRawClass())) { + builder.addStatement("return $L.all()", executorVar); } else { - builder.addStatement("return $L.query($T.class).near($L).all().getContent()", - context.fieldNameOf(MongoOperations.class), context.getRepositoryInformation().getDomainType(), - queryVariableName); + builder.addStatement("return $L.all().getContent()", executorVar); } return builder.build(); } - } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index ad70df271c..e5ec0836c9 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -34,7 +34,6 @@ import org.apache.commons.logging.LogFactory; import org.jspecify.annotations.Nullable; import org.springframework.core.annotation.AnnotatedElementUtils; -import org.springframework.data.geo.GeoPage; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.aggregation.AggregationUpdate; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; @@ -52,7 +51,6 @@ import org.springframework.data.repository.query.parser.PartTree; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.TypeName; -import org.springframework.util.ClassUtils; import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; @@ -189,8 +187,7 @@ private static boolean backoff(MongoQueryMethod method) { // TODO: namedQuery, Regex queries, queries accepting Shapes (e.g. within) or returning arrays. boolean skip = method.isSearchQuery() || method.getName().toLowerCase(Locale.ROOT).contains("regex") - || method.getReturnType().getType().isArray() - || ClassUtils.isAssignable(GeoPage.class, method.getReturnType().getType()); + || method.getReturnType().getType().isArray(); if (skip && logger.isDebugEnabled()) { logger.debug("Skipping AOT generation for [%s]. Method is either returning an array or a geo-near, regex query" diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java index c0531e0e19..acf80db214 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java @@ -186,8 +186,11 @@ public Object execute(Query query) { return isListOfGeoResult(method.getReturnType()) ? results.getContent() : results; } - @SuppressWarnings({ "unchecked", "NullAway" }) GeoResults doExecuteQuery(Query query) { + return doExecuteQuery(nearQuery(query)); + } + + NearQuery nearQuery(Query query) { Point nearLocation = accessor.getGeoNearLocation(); Assert.notNull(nearLocation, "[query.location] must not be null"); @@ -205,9 +208,12 @@ GeoResults doExecuteQuery(Query query) { distances.getUpperBound().getValue().ifPresent(it -> nearQuery.maxDistance(it).in(it.getMetric())); Pageable pageable = accessor.getPageable(); - nearQuery.with(pageable); + return nearQuery.with(pageable); + } - return (GeoResults) operation.near(nearQuery).all(); + @SuppressWarnings({ "unchecked", "NullAway" }) + GeoResults doExecuteQuery(NearQuery query) { + return (GeoResults) operation.near(query).all(); } private static boolean isListOfGeoResult(TypeInformation returnType) { @@ -324,16 +330,11 @@ final class PagingGeoNearExecution extends GeoNearExecution { @Override public Object execute(Query query) { - GeoResults geoResults = doExecuteQuery(query); + NearQuery nearQuery = nearQuery(query); + GeoResults geoResults = doExecuteQuery(nearQuery); Page> page = PageableExecutionUtils.getPage(geoResults.getContent(), accessor.getPageable(), - () -> { - - Query countQuery = mongoQuery.createCountQuery(accessor); - countQuery = mongoQuery.applyQueryMetaAttributesWhenPresent(countQuery); - - return operation.matching(countQuery).count(); - }); + () -> operation.near(nearQuery).count()); // transform to GeoPage after applying optimization return new GeoPage<>(geoResults, accessor.getPageable(), page.getTotalElements()); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ExecutableFindOperationSupportTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ExecutableFindOperationSupportTests.java index 835367990a..3c95a5a8ea 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ExecutableFindOperationSupportTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ExecutableFindOperationSupportTests.java @@ -15,10 +15,13 @@ */ package org.springframework.data.mongodb.core; -import static org.assertj.core.api.Assertions.*; -import static org.springframework.data.mongodb.core.query.Criteria.*; -import static org.springframework.data.mongodb.core.query.Query.*; -import static org.springframework.data.mongodb.test.util.DirtiesStateExtension.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.springframework.data.mongodb.core.query.Criteria.where; +import static org.springframework.data.mongodb.core.query.Query.query; +import static org.springframework.data.mongodb.test.util.DirtiesStateExtension.DirtiesState; +import static org.springframework.data.mongodb.test.util.DirtiesStateExtension.StateFunctions; import java.util.Date; import java.util.List; @@ -47,6 +50,7 @@ import org.springframework.data.mongodb.core.mapping.Field; import org.springframework.data.mongodb.core.query.BasicQuery; import org.springframework.data.mongodb.core.query.NearQuery; +import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.test.util.DirtiesStateExtension; import org.springframework.data.mongodb.test.util.MongoTemplateExtension; import org.springframework.data.mongodb.test.util.MongoTestTemplate; @@ -81,7 +85,7 @@ public void clear() { @Override public void setupState() { - template.indexOps(Planet.class).ensureIndex( + template.indexOps(Planet.class).createIndex( new GeospatialIndex("coordinates").typed(GeoSpatialIndexType.GEO_2DSPHERE).named("planet-coordinate-idx")); initPersons(); @@ -162,7 +166,7 @@ void findAllByWithCollection() { void findAllAsDocument() { assertThat( template.query(Document.class).inCollection(STAR_WARS).matching(query(where("firstname").is("luke"))).all()) - .hasSize(1); + .hasSize(1); } @Test // DATAMONGO-1563 @@ -324,6 +328,14 @@ void findAllNearBy() { assertThat(results.getContent().get(0).getDistance()).isNotNull(); } + @Test + void countResultsOfNearQuery() { + + Long count = template.query(Planet.class) + .near(NearQuery.near(-73.9667, 40.78).spherical(true).query(new Query(where("name").is("alderan")))).count(); + assertThat(count).isEqualTo(1); + } + @Test // DATAMONGO-1563 void findAllNearByWithCollectionAndProjection() { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java index 9f7eda351d..493a23e4e5 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java @@ -666,6 +666,7 @@ void executesGeoPageQueryForWithPageRequestForPageInBetween() { assertThat(results.getContent()).isNotEmpty(); assertThat(results.getNumberOfElements()).isEqualTo(2); + assertThat(results.getTotalElements()).isEqualTo(5); assertThat(results.isFirst()).isFalse(); assertThat(results.isLast()).isFalse(); assertThat(results.getAverageDistance().getMetric()).isEqualTo(Metrics.KILOMETERS); @@ -725,6 +726,30 @@ void executesGeoPageQueryForWithPageRequestForJustOneElementEmptyPage() { assertThat(results.getAverageDistance().getMetric()).isEqualTo(Metrics.KILOMETERS); } + @Test + void executesGeoPageCountCorrectly() { + + Point farAway = new Point(-73.9, 40.7); + Point here = new Point(-73.99, 40.73); + + dave.setLocation(farAway); + oliver.setLocation(here); + carter.setLocation(here); + boyd.setLocation(here); + leroi.setLocation(here); + + repository.saveAll(Arrays.asList(dave, oliver, carter, boyd, leroi)); + + GeoPage results = repository.findByLocationNear(new Point(-73.99, 40.73), + Distance.of(5, Metrics.KILOMETERS), PageRequest.of(1, 2)); + + assertThat(results.getContent()).isNotEmpty(); + assertThat(results.getNumberOfElements()).isEqualTo(2); + assertThat(results.getTotalElements()).isEqualTo(4); + assertThat(results.isFirst()).isFalse(); + assertThat(results.isLast()).isTrue(); + } + @Test // DATAMONGO-1608 void findByFirstNameIgnoreCaseWithNull() { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java index 971cce7c32..1d3494e8ee 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java @@ -49,6 +49,7 @@ import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; import org.springframework.data.geo.Distance; +import org.springframework.data.geo.GeoPage; import org.springframework.data.geo.GeoResult; import org.springframework.data.geo.GeoResults; import org.springframework.data.geo.Metrics; @@ -59,7 +60,6 @@ import org.springframework.data.mongodb.core.aggregation.AggregationResults; import org.springframework.data.mongodb.core.geo.GeoJsonPoint; import org.springframework.data.mongodb.core.geo.GeoJsonPolygon; -import org.springframework.data.mongodb.repository.aot.AotFragmentTestConfigurationSupport.MethodNotImplementedException; import org.springframework.data.mongodb.test.util.Client; import org.springframework.data.mongodb.test.util.MongoClientExtension; import org.springframework.data.mongodb.test.util.MongoTestUtils; @@ -704,19 +704,14 @@ void testNearWithRange() { @Test void testNearReturningGeoPage() { - assertThatExceptionOfType(MethodNotImplementedException.class) - .isThrownBy(() -> fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), - Distance.of(2000, Metrics.KILOMETERS), PageRequest.of(0, 1))); - - // TODO: still need to create the count and extract the total elements - // GeoPage page1 = fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), - // Distance.of(2000, Metrics.KILOMETERS), PageRequest.of(0, 1)); - // - // assertThat(page1.hasNext()).isTrue(); - // - // GeoPage page2 = fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), - // Distance.of(2000, Metrics.KILOMETERS), page1.nextPageable()); - // assertThat(page2.hasNext()).isFalse(); + GeoPage page1 = fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), + Distance.of(2000, Metrics.KILOMETERS), PageRequest.of(0, 1)); + + assertThat(page1.hasNext()).isTrue(); + + GeoPage page2 = fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), + Distance.of(2000, Metrics.KILOMETERS), page1.nextPageable()); + assertThat(page2.hasNext()).isFalse(); } /** diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java index b9c598dfce..cfc29cc32b 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java @@ -21,10 +21,14 @@ import example.aot.UserRepository; import java.lang.reflect.Method; +import java.util.Arrays; import javax.lang.model.element.Modifier; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Range; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; import org.springframework.data.geo.Distance; @@ -118,8 +122,40 @@ void rendersNearQueryForGeoResults() throws NoSuchMethodException { .contains("NearQuery.near(point)") // .contains("nearQuery.maxDistance(maxDistance).in(maxDistance.getMetric())") // .contains(".withReadPreference(com.mongodb.ReadPreference.valueOf(\"NEAREST\")") // - .doesNotContain(".query(") // - .contains(".near(nearQuery).all()"); + .doesNotContain("nearQuery.query(") // + .contains(".near(nearQuery)") // + .contains("return nearFinder.all()"); + } + + @Test + void rendersNearQueryWithDistanceRangeForGeoResults() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesNear", Point.class, + Range.class); + + assertThat(methodSpec.toString()) // + .contains("NearQuery.near(point)") // + .contains("if(distance.getLowerBound().isBounded())") // + .contains("nearQuery.minDistance(min).in(min.getMetric())") // + .contains("if(distance.getUpperBound().isBounded())") // + .contains("nearQuery.maxDistance(max).in(max.getMetric())") // + .contains(".near(nearQuery)") // + .contains("return nearFinder.all()"); + } + + @Test + void rendersNearQueryReturningGeoPage() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesNear", Point.class, Distance.class, + Pageable.class); + + assertThat(methodSpec.toString()) // + .contains("NearQuery.near(point)") // + .contains("nearQuery.maxDistance(maxDistance).in(maxDistance.getMetric())") // + .doesNotContain("nearQuery.query(") // + .contains("var geoResult = nearFinder.all()") // + .contains("PageableExecutionUtils.getPage(geoResult.getContent(), pageable, () -> nearFinder.count())") + .contains("GeoPage<>(geoResult, pageable, resultPage.getTotalElements())"); } @Test @@ -133,7 +169,8 @@ void rendersNearQueryWithFilterForGeoResults() throws NoSuchMethodException { .contains("nearQuery.maxDistance(maxDistance).in(maxDistance.getMetric())") // .contains("filterQuery = createQuery(\"{'lastname':?0}\", new java.lang.Object[]{ lastname })") // .contains("nearQuery.query(filterQuery)") // - .contains(".near(nearQuery).all()"); + .contains(".near(nearQuery)") // + .contains("return nearFinder.all()"); } private static MethodSpec codeOf(Class repository, String methodName, Class... args) @@ -145,6 +182,10 @@ private static MethodSpec codeOf(Class repository, String methodName, Class methodContributor = contributor.contributeQueryMethod(method); + if (methodContributor == null) { + Assertions.fail("No contribution for method %s.%s(%s)".formatted(repository.getSimpleName(), methodName, + Arrays.stream(args).map(Class::getSimpleName).toList())); + } AotRepositoryFragmentMetadata metadata = new AotRepositoryFragmentMetadata(ClassName.get(UserRepository.class)); metadata.addField( FieldSpec.builder(MongoOperations.class, "mongoOperations", Modifier.PRIVATE, Modifier.FINAL).build()); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryExecutionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryExecutionUnitTests.java index 2c0c996bc3..11a025ea5c 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryExecutionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryExecutionUnitTests.java @@ -32,7 +32,6 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; - import org.springframework.data.domain.PageRequest; import org.springframework.data.domain.Pageable; import org.springframework.data.geo.Distance; @@ -171,7 +170,6 @@ void pagingGeoExecutionRetrievesObjectsForPageableOutOfRange() { when(mongoOperationsMock.query(any(Class.class))).thenReturn(findOperationMock); when(findOperationMock.near(any(NearQuery.class))).thenReturn(terminatingGeoMock); doReturn(new GeoResults<>(Collections.emptyList())).when(terminatingGeoMock).all(); - doReturn(terminatingMock).when(findOperationMock).matching(any(Query.class)); ConvertingParameterAccessor accessor = new ConvertingParameterAccessor(converter, new MongoParametersParameterAccessor(queryMethod, new Object[] { POINT, DISTANCE, PageRequest.of(2, 10) })); @@ -183,7 +181,7 @@ void pagingGeoExecutionRetrievesObjectsForPageableOutOfRange() { execution.execute(new Query()); verify(terminatingGeoMock).all(); - verify(terminatingMock).count(); + verify(terminatingGeoMock).count(); } @Test // DATAMONGO-2351 From 71192a96e98bfc5e56de9dfacd4d6f5dd8a73a29 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Tue, 17 Jun 2025 15:58:25 +0200 Subject: [PATCH 16/24] Update method metadata rendering --- .../aot/MongoRepositoryContributor.java | 6 +++--- .../repository/aot/NearQueryInteraction.java | 21 +++++++++++++------ .../aot/MongoRepositoryMetadataTests.java | 12 +++++------ 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index e5ec0836c9..1b3a682a83 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -109,11 +109,11 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB } QueryInteraction query = createStringQuery(getRepositoryInformation(), queryMethod, - AnnotatedElementUtils.findMergedAnnotation(method, Query.class), method.getParameterCount()); + AnnotatedElementUtils.findMergedAnnotation(method, Query.class)); if (queryMethod.isGeoNearQuery() || (queryMethod.getParameters().getMaxDistanceIndex() != -1 && queryMethod.getReturnType().isCollectionLike())) { - NearQueryInteraction near = new NearQueryInteraction(query); + NearQueryInteraction near = new NearQueryInteraction(query, queryMethod.getParameters()); return nearQueryMethodContributor(queryMethod, near); } @@ -160,7 +160,7 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB @SuppressWarnings("NullAway") private QueryInteraction createStringQuery(RepositoryInformation repositoryInformation, MongoQueryMethod queryMethod, - @Nullable Query queryAnnotation, int parameterCount) { + @Nullable Query queryAnnotation) { QueryInteraction query; if (queryMethod.hasAnnotatedQuery() && queryAnnotation != null) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/NearQueryInteraction.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/NearQueryInteraction.java index 30609ea15d..2005626784 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/NearQueryInteraction.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/NearQueryInteraction.java @@ -18,9 +18,8 @@ import java.util.LinkedHashMap; import java.util.Map; -import org.springframework.data.mongodb.core.query.Query; +import org.springframework.data.mongodb.repository.query.MongoParameters; import org.springframework.data.repository.aot.generate.QueryMetadata; -import org.springframework.util.StringUtils; /** * An {@link MongoInteraction} to execute a query. @@ -32,10 +31,12 @@ class NearQueryInteraction extends MongoInteraction implements QueryMetadata { private final InteractionType interactionType; private final QueryInteraction query; + private final MongoParameters parameters; - NearQueryInteraction(QueryInteraction query) { + NearQueryInteraction(QueryInteraction query, MongoParameters parameters) { interactionType = InteractionType.QUERY; this.query = query; + this.parameters = parameters; } @Override @@ -51,9 +52,17 @@ public QueryInteraction getQuery() { public Map serialize() { Map serialized = new LinkedHashMap<>(); - - - + serialized.put("near", "?%s".formatted(parameters.getNearIndex())); + if (parameters.getRangeIndex() != -1) { + serialized.put("minDistance", "?%s".formatted(parameters.getRangeIndex())); + serialized.put("maxDistance", "?%s".formatted(parameters.getRangeIndex())); + } else if (parameters.getMaxDistanceIndex() != -1) { + serialized.put("minDistance", "?%s".formatted(parameters.getMaxDistanceIndex())); + } + Object filter = query.serialize().get("filter"); // TODO: filter position index can be off due to bindable params + if (filter != null) { + serialized.put("filter", filter); + } return serialized; } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryMetadataTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryMetadataTests.java index aa069a2710..7fb8870263 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryMetadataTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryMetadataTests.java @@ -15,19 +15,19 @@ */ package org.springframework.data.mongodb.repository.aot; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.*; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; import example.aot.UserRepository; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.List; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; - import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -150,7 +150,7 @@ void shouldDocumentAggregation() throws IOException { assertThatJson(json).inPath("$.methods[?(@.name == 'findAllLastnames')].query").isArray().element(0).isObject() .containsEntry("pipeline", - "[{ '$match' : { 'last_name' : { '$ne' : null } } }, { '$project': { '_id' : '$last_name' } }]"); + List.of("{ '$match' : { 'last_name' : { '$ne' : null } } }", "{ '$project': { '_id' : '$last_name' } }")); } @Test // GH-4964 @@ -165,7 +165,7 @@ void shouldDocumentPipelineUpdate() throws IOException { assertThatJson(json).inPath("$.methods[?(@.name == 'findAndIncrementVisitsViaPipelineByLastname')].query").isArray() .element(0).isObject().containsEntry("filter", "{'lastname':?0}").containsEntry("update-pipeline", - "[{ '$set' : { 'visits' : { '$ifNull' : [ {'$add' : [ '$visits', ?1 ] }, ?1 ] } } }]"); + List.of("{ '$set' : { 'visits' : { '$ifNull' : [ {'$add' : [ '$visits', ?1 ] }, ?1 ] } } }")); } @Test // GH-4964 From 0fb84ba9fb76461e761939d1e7147a360f55123f Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Wed, 18 Jun 2025 11:26:02 +0200 Subject: [PATCH 17/24] Update documentation --- .../mongodb/repository/PersonRepository.java | 1 + .../mongodb/repositories/query-methods.adoc | 32 ++++++++++++++++--- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java index 6c40ab622a..9ab0d71dc3 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java @@ -221,6 +221,7 @@ Window findByLastnameLikeOrderByLastnameAscFirstnameAsc(String lastname, List findByNamedQuery(String firstname); GeoResults findByLocationNear(Point point, Distance maxDistance); + GeoResults findByLocationNearAndLastname(Point point, Distance maxDistance, String Lastname); // DATAMONGO-1110 diff --git a/src/main/antora/modules/ROOT/pages/mongodb/repositories/query-methods.adoc b/src/main/antora/modules/ROOT/pages/mongodb/repositories/query-methods.adoc index adb2392f04..31a19b5aca 100644 --- a/src/main/antora/modules/ROOT/pages/mongodb/repositories/query-methods.adoc +++ b/src/main/antora/modules/ROOT/pages/mongodb/repositories/query-methods.adoc @@ -209,9 +209,9 @@ NOTE: If the property criterion compares a document, the order of the fields and == Geo-spatial Queries As you saw in the preceding table of keywords, a few keywords trigger geo-spatial operations within a MongoDB query. -The `Near` keyword allows some further modification, as the next few examples show. +The `Near` and `Within` keywords allows some further modification, as the next few examples show. -The following example shows how to define a `near` query that finds all persons with a given distance of a given point: +The following example shows how to define a `near` / `within` query that finds all persons using different shapes: .Advanced `Near` queries [tabs] @@ -222,8 +222,20 @@ Imperative:: ---- public interface PersonRepository extends MongoRepository { - // { 'location' : { '$near' : [point.x, point.y], '$maxDistance' : distance}} + // { 'location' : { '$near' : [point.x, point.y], '$maxDistance' : distance } } List findByLocationNear(Point location, Distance distance); + + // { 'location' : { $geoWithin: { $center: [ [ circle.center.x, circle.center.y ], circle.radius ] } } } + List findByLocationWithin(Circle circle); + + // { 'location' : { $geoWithin: { $box: [ [ box.first.x, box.first.y ], [ box.second.x, box.second.y ] ] } } } + List findByLocationWithin(Box box); + + // { 'location' : { $geoWithin: { $polygon: [ [ polygon.x1, polygon.y1 ], [ polygon.x2, polygon.y2 ], ... ] } } } + List findByLocationWithin(Polygon polygon); + + // { 'location' : { $geoWithin: { $geometry: { $type : 'polygon', coordinates: [[ polygon.x1, polygon.y1 ], [ polygon.x2, polygon.y2 ], ... ] } } } } + List findByLocationWithin(GeoJsonPolygon polygon); } ---- @@ -233,8 +245,20 @@ Reactive:: ---- interface PersonRepository extends ReactiveMongoRepository { - // { 'location' : { '$near' : [point.x, point.y], '$maxDistance' : distance}} + // { 'location' : { '$near' : [point.x, point.y], '$maxDistance' : distance } } Flux findByLocationNear(Point location, Distance distance); + + // { 'location' : { $geoWithin: { $center: [ [ circle.center.x, circle.center.y ], circle.radius ] } } } + Flux findByLocationWithin(Circle circle); + + // { 'location' : { $geoWithin: { $box: [ [ box.first.x, box.first.y ], [ box.second.x, box.second.y ] ] } } } + Flux findByLocationWithin(Box box); + + // { 'location' : { $geoWithin: { $polygon: [ [ polygon.x1, polygon.y1 ], [ polygon.x2, polygon.y2 ], ... ] } } } + Flux findByLocationWithin(Polygon polygon); + + // { 'location' : { $geoWithin: { $geometry: { $type : 'polygon', coordinates: [[ polygon.x1, polygon.y1 ], [ polygon.x2, polygon.y2 ], ... ] } } } } + Flux findByLocationWithin(GeoJsonPolygon polygon); } ---- ====== From eeb06cdfe1f60acf8183bb35190f674793e50a55 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Wed, 18 Jun 2025 11:32:24 +0200 Subject: [PATCH 18/24] Update issue reference & add missing test for geojson --- .../test/java/example/aot/UserRepository.java | 3 ++ .../aot/MongoRepositoryContributorTests.java | 35 +++++++++----- .../aot/QueryMethodContributionUnitTests.java | 47 ++++++++++++------- 3 files changed, 56 insertions(+), 29 deletions(-) diff --git a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java index 2044b3108a..aeb10285fa 100644 --- a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java +++ b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java @@ -42,6 +42,7 @@ import org.springframework.data.geo.Point; import org.springframework.data.geo.Polygon; import org.springframework.data.mongodb.core.aggregation.AggregationResults; +import org.springframework.data.mongodb.core.geo.GeoJson; import org.springframework.data.mongodb.core.geo.GeoJsonPolygon; import org.springframework.data.mongodb.core.geo.Sphere; import org.springframework.data.mongodb.repository.Aggregation; @@ -127,6 +128,8 @@ public interface UserRepository extends CrudRepository { List findByLocationCoordinatesWithin(GeoJsonPolygon polygon); + List findUserByLocationCoordinatesWithin(GeoJson geoJson); + GeoResults findByLocationCoordinatesNear(Point point, Distance maxDistance); GeoResults findByLocationCoordinatesNearAndLastname(Point point, Distance maxDistance, String lastname); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java index 1d3494e8ee..335ba33a18 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java @@ -613,28 +613,28 @@ void testAggregationWithCollation() { .withMessageContaining("'locale' is invalid"); } - @Test + @Test // GH-5004 void testNear() { List users = fragment.findByLocationCoordinatesNear(new Point(-73.99171, 40.738868)); assertThat(users).extracting(User::getUsername).containsExactly("leia", "vader"); } - @Test + @Test // GH-5004 void testNearWithGeoJson() { List users = fragment.findByLocationCoordinatesNear(new GeoJsonPoint(-73.99171, 40.738868)); assertThat(users).extracting(User::getUsername).containsExactly("leia", "vader"); } - @Test + @Test // GH-5004 void testGeoWithinCircle() { List users = fragment.findByLocationCoordinatesWithin(new Circle(-78.99171, 45.738868, 170)); assertThat(users).extracting(User::getUsername).containsExactly("leia", "vader"); } - @Test + @Test // GH-5004 void testWithinBox() { Box box = new Box(new Point(-78.99171, 35.738868), new Point(-68.99171, 45.738868)); @@ -643,7 +643,7 @@ void testWithinBox() { assertThat(result).extracting(User::getUsername).containsExactly("leia", "vader"); } - @Test + @Test // GH-5004 void findsPeopleByLocationWithinPolygon() { Point first = new Point(-78.99171, 35.738868); @@ -655,7 +655,7 @@ void findsPeopleByLocationWithinPolygon() { assertThat(result).extracting(User::getUsername).containsExactly("leia", "vader"); } - @Test + @Test // GH-5004 void findsPeopleByLocationWithinGeoJsonPolygon() { Point first = new Point(-78.99171, 35.738868); @@ -668,7 +668,20 @@ void findsPeopleByLocationWithinGeoJsonPolygon() { assertThat(result).extracting(User::getUsername).containsExactly("leia", "vader"); } - @Test + @Test // GH-5004 + void findsPeopleByLocationWithinSomeGenericGeoJsonObject() { + + Point first = new Point(-78.99171, 35.738868); + Point second = new Point(-78.99171, 45.738868); + Point third = new Point(-68.99171, 45.738868); + Point fourth = new Point(-68.99171, 35.738868); + + List result = fragment + .findUserByLocationCoordinatesWithin(new GeoJsonPolygon(first, second, third, fourth, first)); + assertThat(result).extracting(User::getUsername).containsExactly("leia", "vader"); + } + + @Test // GH-5004 void testNearWithGeoResult() { GeoResults users = fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), @@ -676,7 +689,7 @@ void testNearWithGeoResult() { assertThat(users).extracting(GeoResult::getContent).extracting(User::getUsername).containsExactly("leia"); } - @Test + @Test // GH-5004 void testNearWithAdditionalFilterQueryAsGeoResult() { GeoResults users = fragment.findByLocationCoordinatesNearAndLastname(new Point(-73.99, 40.73), @@ -684,7 +697,7 @@ void testNearWithAdditionalFilterQueryAsGeoResult() { assertThat(users).extracting(GeoResult::getContent).extracting(User::getUsername).containsExactly("leia"); } - @Test + @Test // GH-5004 void testNearReturningListOfGeoResult() { List> users = fragment.findUserAsListByLocationCoordinatesNear(new Point(-73.99, 40.73), @@ -692,7 +705,7 @@ void testNearReturningListOfGeoResult() { assertThat(users).extracting(GeoResult::getContent).extracting(User::getUsername).containsExactly("leia"); } - @Test + @Test // GH-5004 void testNearWithRange() { Range range = Distance.between(Distance.of(5, Metrics.KILOMETERS), Distance.of(2000, Metrics.KILOMETERS)); @@ -701,7 +714,7 @@ void testNearWithRange() { assertThat(users).extracting(GeoResult::getContent).extracting(User::getUsername).containsExactly("vader"); } - @Test + @Test // GH-5004 void testNearReturningGeoPage() { GeoPage page1 = fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java index cfc29cc32b..cbfbf05231 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java @@ -36,6 +36,7 @@ import org.springframework.data.geo.Point; import org.springframework.data.geo.Polygon; import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.geo.GeoJsonPolygon; import org.springframework.data.mongodb.core.geo.Sphere; import org.springframework.data.mongodb.repository.ReadPreference; import org.springframework.data.repository.Repository; @@ -53,7 +54,7 @@ */ public class QueryMethodContributionUnitTests { - @Test + @Test // GH-5004 void rendersQueryForNearUsingPoint() throws NoSuchMethodException { MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesNear", Point.class); @@ -64,7 +65,7 @@ void rendersQueryForNearUsingPoint() throws NoSuchMethodException { .contains("return finder.matching(filterQuery).all()"); } - @Test + @Test // GH-5004 void rendersQueryForWithinUsingCircle() throws NoSuchMethodException { MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesWithin", Circle.class); @@ -76,7 +77,7 @@ void rendersQueryForWithinUsingCircle() throws NoSuchMethodException { .contains("return finder.matching(filterQuery).all()"); } - @Test + @Test // GH-5004 void rendersQueryForWithinUsingSphere() throws NoSuchMethodException { MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesWithin", Sphere.class); @@ -88,7 +89,7 @@ void rendersQueryForWithinUsingSphere() throws NoSuchMethodException { .contains("return finder.matching(filterQuery).all()"); } - @Test + @Test // GH-5004 void rendersQueryForWithinUsingBox() throws NoSuchMethodException { MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesWithin", Box.class); @@ -100,7 +101,7 @@ void rendersQueryForWithinUsingBox() throws NoSuchMethodException { .contains("return finder.matching(filterQuery).all()"); } - @Test + @Test // GH-5004 void rendersQueryForWithinUsingPolygon() throws NoSuchMethodException { MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesWithin", Polygon.class); @@ -112,7 +113,18 @@ void rendersQueryForWithinUsingPolygon() throws NoSuchMethodException { .contains("return finder.matching(filterQuery).all()"); } - @Test + @Test // GH-5004 + void rendersQueryForWithinUsingGeoJsonPolygon() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesWithin", GeoJsonPolygon.class); + + assertThat(methodSpec.toString()) // + .contains("{'location.coordinates':{'$geoWithin':{'$geometry':?0}}") // + .contains("Object[]{ polygon }") // + .contains("return finder.matching(filterQuery).all()"); + } + + @Test // GH-5004 void rendersNearQueryForGeoResults() throws NoSuchMethodException { MethodSpec methodSpec = codeOf(UserRepoWithMeta.class, "findByLocationCoordinatesNear", Point.class, @@ -127,23 +139,22 @@ void rendersNearQueryForGeoResults() throws NoSuchMethodException { .contains("return nearFinder.all()"); } - @Test + @Test // GH-5004 void rendersNearQueryWithDistanceRangeForGeoResults() throws NoSuchMethodException { - MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesNear", Point.class, - Range.class); + MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesNear", Point.class, Range.class); assertThat(methodSpec.toString()) // - .contains("NearQuery.near(point)") // - .contains("if(distance.getLowerBound().isBounded())") // - .contains("nearQuery.minDistance(min).in(min.getMetric())") // - .contains("if(distance.getUpperBound().isBounded())") // - .contains("nearQuery.maxDistance(max).in(max.getMetric())") // - .contains(".near(nearQuery)") // - .contains("return nearFinder.all()"); + .contains("NearQuery.near(point)") // + .contains("if(distance.getLowerBound().isBounded())") // + .contains("nearQuery.minDistance(min).in(min.getMetric())") // + .contains("if(distance.getUpperBound().isBounded())") // + .contains("nearQuery.maxDistance(max).in(max.getMetric())") // + .contains(".near(nearQuery)") // + .contains("return nearFinder.all()"); } - @Test + @Test // GH-5004 void rendersNearQueryReturningGeoPage() throws NoSuchMethodException { MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesNear", Point.class, Distance.class, @@ -158,7 +169,7 @@ void rendersNearQueryReturningGeoPage() throws NoSuchMethodException { .contains("GeoPage<>(geoResult, pageable, resultPage.getTotalElements())"); } - @Test + @Test // GH-5004 void rendersNearQueryWithFilterForGeoResults() throws NoSuchMethodException { MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesNearAndLastname", Point.class, From ad875100b1259ff9ed661b7c1b4e8c63008f7bbf Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Wed, 18 Jun 2025 14:45:17 +0200 Subject: [PATCH 19/24] Support generating queries containing expression parameter binding during AOT run. --- .../repository/aot/AggregationBlocks.java | 9 +- .../MongoAotRepositoryFragmentSupport.java | 57 ++++++- .../repository/aot/MongoCodeBlocks.java | 66 ++++++-- .../aot/MongoRepositoryContributor.java | 13 -- .../mongodb/repository/aot/QueryBlocks.java | 34 ++--- .../mongodb/repository/aot/UpdateBlocks.java | 142 +++++++++--------- .../repository/query/MongoParameters.java | 4 + .../test/java/example/aot/UserRepository.java | 8 +- .../aot/MongoRepositoryContributorTests.java | 14 ++ .../aot/QueryMethodContributionUnitTests.java | 20 +++ 10 files changed, 247 insertions(+), 120 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java index 11ef3a4822..37f24cd849 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java @@ -16,7 +16,9 @@ package org.springframework.data.mongodb.repository.aot; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -150,7 +152,7 @@ static class AggregationCodeBlockBuilder { private final AotQueryMethodGenerationContext context; private final MongoQueryMethod queryMethod; - private final List arguments; + private final Map arguments; private AggregationInteraction source; @@ -160,7 +162,8 @@ static class AggregationCodeBlockBuilder { AggregationCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { this.context = context; - this.arguments = context.getBindableParameterNames().stream().map(CodeBlock::of).collect(Collectors.toList()); + this.arguments = new LinkedHashMap<>(); + context.getBindableParameterNames().forEach(it -> arguments.put(it, CodeBlock.of(it))); this.queryMethod = queryMethod; } @@ -288,7 +291,7 @@ private CodeBlock aggregationOptions(String aggregationVariableName) { } private CodeBlock aggregationStages(String stageListVariableName, Iterable stages, int stageCount, - List arguments) { + Map arguments) { Builder builder = CodeBlock.builder(); builder.addStatement("$T<$T> $L = new $T($L)", List.class, Object.class, stageListVariableName, ArrayList.class, diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java index 178ce4bda6..6686f9794d 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java @@ -21,6 +21,9 @@ import org.bson.Document; import org.jspecify.annotations.Nullable; +import org.springframework.data.expression.ValueEvaluationContext; +import org.springframework.data.expression.ValueExpression; +import org.springframework.data.mapping.model.ValueExpressionEvaluator; import org.springframework.data.mongodb.BindableMongoExpression; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.aggregation.AggregationOperation; @@ -28,9 +31,15 @@ import org.springframework.data.mongodb.core.convert.MongoConverter; import org.springframework.data.mongodb.core.mapping.FieldName; import org.springframework.data.mongodb.core.query.BasicQuery; +import org.springframework.data.mongodb.repository.query.MongoParameters; +import org.springframework.data.mongodb.util.json.ParameterBindingContext; +import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec; +import org.springframework.data.mongodb.util.json.ValueProvider; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.repository.core.RepositoryMetadata; import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport; +import org.springframework.data.repository.query.ValueExpressionDelegate; +import org.springframework.expression.EvaluationContext; import org.springframework.util.ClassUtils; import org.springframework.util.ObjectUtils; @@ -46,31 +55,69 @@ public class MongoAotRepositoryFragmentSupport { private final MongoOperations mongoOperations; private final MongoConverter mongoConverter; private final ProjectionFactory projectionFactory; + private final ValueExpressionDelegate valueExpressionDelegate; protected MongoAotRepositoryFragmentSupport(MongoOperations mongoOperations, RepositoryFactoryBeanSupport.FragmentCreationContext context) { - this(mongoOperations, context.getRepositoryMetadata(), context.getProjectionFactory()); + this(mongoOperations, context.getRepositoryMetadata(), context.getProjectionFactory(), + context.getValueExpressionDelegate()); } protected MongoAotRepositoryFragmentSupport(MongoOperations mongoOperations, RepositoryMetadata repositoryMetadata, - ProjectionFactory projectionFactory) { + ProjectionFactory projectionFactory, ValueExpressionDelegate valueExpressionDelegate) { this.mongoOperations = mongoOperations; this.mongoConverter = mongoOperations.getConverter(); this.repositoryMetadata = repositoryMetadata; this.projectionFactory = projectionFactory; + this.valueExpressionDelegate = valueExpressionDelegate; } protected Document bindParameters(String source, Object[] parameters) { return new BindableMongoExpression(source, this.mongoConverter, parameters).toDocument(); } + protected Document bindParameters(String source, Map parameters) { + + ValueEvaluationContext valueEvaluationContext = this.valueExpressionDelegate.getEvaluationContextAccessor() + .create(new NoMongoParameters()).getEvaluationContext(parameters.values()); + + EvaluationContext evaluationContext = valueEvaluationContext.getEvaluationContext(); + parameters.forEach(evaluationContext::setVariable); + + ParameterBindingContext bindingContext = new ParameterBindingContext(new ValueProvider() { + + private final List args = new ArrayList<>(parameters.values()); + + @Override + public @Nullable Object getBindableValue(int index) { + return args.get(index); + } + }, new ValueExpressionEvaluator() { + + @Override + @SuppressWarnings("unchecked") + public @Nullable T evaluate(String expression) { + ValueExpression parse = valueExpressionDelegate.getValueExpressionParser().parse(expression); + return (T) parse.evaluate(valueEvaluationContext); + } + }); + + return new ParameterBindingDocumentCodec().decode(source, bindingContext); + } + protected BasicQuery createQuery(String queryString, Object[] parameters) { Document queryDocument = bindParameters(queryString, parameters); return new BasicQuery(queryDocument); } + protected BasicQuery createQuery(String queryString, Map parameters) { + + Document queryDocument = bindParameters(queryString, parameters); + return new BasicQuery(queryDocument); + } + protected AggregationPipeline createPipeline(List rawStages) { List stages = new ArrayList<>(rawStages.size()); @@ -151,4 +198,10 @@ private static T getPotentiallyConvertedSimpleTypeValue(MongoConverter conve return converter.getConversionService().convert(value, targetType); } + static class NoMongoParameters extends MongoParameters { + + NoMongoParameters() { + super(); + } + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java index 9a87d4afe8..3881994437 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java @@ -16,7 +16,8 @@ package org.springframework.data.mongodb.repository.aot; import java.util.Iterator; -import java.util.List; +import java.util.Map; +import java.util.Map.Entry; import java.util.regex.Pattern; import org.bson.Document; @@ -47,6 +48,7 @@ class MongoCodeBlocks { private static final Pattern PARAMETER_BINDING_PATTERN = Pattern.compile("\\?(\\d+)"); + private static final Pattern EXPRESSION_BINDING_PATTERN = Pattern.compile("[\\?:][#$]\\{.*\\}"); /** * Builder for generating query parsing {@link CodeBlock}. @@ -166,7 +168,8 @@ static GeoNearExecutionCodeBlockBuilder geoNearExecutionBlockBuilder(AotQueryMet return new GeoNearExecutionCodeBlockBuilder(context, queryMethod); } - static CodeBlock renderExpressionToDocument(@Nullable String source, String variableName, List arguments) { + static CodeBlock renderExpressionToDocument(@Nullable String source, String variableName, + Map arguments) { Builder builder = CodeBlock.builder(); if (!StringUtils.hasText(source)) { @@ -174,21 +177,60 @@ static CodeBlock renderExpressionToDocument(@Nullable String source, String vari } else if (!containsPlaceholder(source)) { builder.addStatement("$1T $2L = $1T.parse($3S)", Document.class, variableName, source); } else { + builder.add("$T $L = bindParameters($S, ", Document.class, variableName, source); + if (containsNamedPlaceholder(source)) { + renderArgumentMap(arguments); + } else { + builder.add(renderArgumentArray(arguments)); + } + builder.add(");\n"); + } + return builder.build(); + } + + static CodeBlock renderArgumentMap(Map arguments) { - builder.add("$T $L = bindParameters($S, new $T[]{ ", Document.class, variableName, source, Object.class); - Iterator iterator = arguments.iterator(); - while (iterator.hasNext()) { - builder.add(iterator.next()); - if (iterator.hasNext()) { - builder.add(", "); - } + Builder builder = CodeBlock.builder(); + builder.add("$T.of(", Map.class); + Iterator> iterator = arguments.entrySet().iterator(); + while (iterator.hasNext()) { + Entry next = iterator.next(); + builder.add("$S, ", next.getKey()); + builder.add(next.getValue()); + if (iterator.hasNext()) { + builder.add(", "); } - builder.add("});\n"); } + builder.add(")"); + return builder.build(); + } + + static CodeBlock renderArgumentArray(Map arguments) { + + Builder builder = CodeBlock.builder(); + builder.add("new $T[]{ ", Object.class); + Iterator iterator = arguments.values().iterator(); + while (iterator.hasNext()) { + builder.add(iterator.next()); + if (iterator.hasNext()) { + builder.add(", "); + } else { + builder.add(" "); + } + } + builder.add("}"); return builder.build(); } static boolean containsPlaceholder(String source) { + return containsIndexedPlaceholder(source) || containsNamedPlaceholder(source); + } + + static boolean containsNamedPlaceholder(String source) { + return EXPRESSION_BINDING_PATTERN.matcher(source).find(); + } + + static boolean containsIndexedPlaceholder(String source) { return PARAMETER_BINDING_PATTERN.matcher(source).find(); } @@ -198,8 +240,8 @@ static void appendReadPreference(AotQueryMethodGenerationContext context, Builde String readPreference = readPreferenceAnnotation.isPresent() ? readPreferenceAnnotation.getString("value") : null; if (StringUtils.hasText(readPreference)) { - builder.addStatement("$L.withReadPreference($T.valueOf($S))", queryVariableName, - com.mongodb.ReadPreference.class, readPreference); + builder.addStatement("$L.withReadPreference($T.valueOf($S))", queryVariableName, com.mongodb.ReadPreference.class, + readPreference); } } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index 1b3a682a83..e741a28e27 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -28,7 +28,6 @@ import java.lang.reflect.Method; import java.util.Locale; -import java.util.regex.Pattern; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -117,18 +116,6 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB return nearQueryMethodContributor(queryMethod, near); } - if (queryMethod.hasAnnotatedQuery()) { - if (StringUtils.hasText(queryMethod.getAnnotatedQuery()) - && Pattern.compile("[\\?:][#$]\\{.*\\}").matcher(queryMethod.getAnnotatedQuery()).find()) { - - if (logger.isDebugEnabled()) { - logger.debug( - "Skipping AOT generation for [%s]. SpEL expressions are not supported".formatted(method.getName())); - } - return MethodContributor.forQueryMethod(queryMethod).metadataOnly(query); - } - } - if (query.isDelete()) { return deleteMethodContributor(queryMethod, query); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java index 97a53f921d..e9425dce87 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java @@ -15,9 +15,9 @@ */ package org.springframework.data.mongodb.repository.aot; -import java.util.ArrayList; -import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import org.bson.Document; @@ -147,14 +147,14 @@ static class QueryCodeBlockBuilder { private final MongoQueryMethod queryMethod; private QueryInteraction source; - private final List arguments; + private final Map arguments; private String queryVariableName; QueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { this.context = context; - this.arguments = new ArrayList<>(); + this.arguments = new LinkedHashMap<>(); this.queryMethod = queryMethod; collectArguments(context); @@ -167,29 +167,29 @@ private void collectArguments(AotQueryMethodGenerationContext context) { if (ClassUtils.isAssignable(GeoJson.class, parameter.getType())) { // renders as generic $geometry, thus can be handled by the converter when parsing - arguments.add(CodeBlock.of(parameterName)); + arguments.put(parameterName, CodeBlock.of(parameterName)); } else if (ClassUtils.isAssignable(Circle.class, parameter.getType()) || ClassUtils.isAssignable(Sphere.class, parameter.getType())) { // $center | $centerSphere : [ [ , ], ] - arguments.add(CodeBlock.builder().add( + arguments.put(parameterName, CodeBlock.builder().add( "$1T.of($1T.of($2L.getCenter().getX(), $2L.getCenter().getY()), $2L.getRadius().getNormalizedValue())", List.class, parameterName).build()); } else if (ClassUtils.isAssignable(Box.class, parameter.getType())) { // $box: [ [ , ], [ , ] ] - arguments.add(CodeBlock.builder().add( + arguments.put(parameterName, CodeBlock.builder().add( "$1T.of($1T.of($2L.getFirst().getX(), $2L.getFirst().getY()), $1T.of($2L.getSecond().getX(), $2L.getSecond().getY()))", List.class, parameterName).build()); } else if (ClassUtils.isAssignable(Polygon.class, parameter.getType())) { // $polygon: [ [ , ], [ , ], [ , ], ... ] String localVar = context.localVariable("_p"); - arguments.add( + arguments.put(parameterName, CodeBlock.builder().add("$1L.getPoints().stream().map($2L -> $3T.of($2L.getX(), $2L.getY())).toList()", parameterName, localVar, List.class).build()); } else { - arguments.add(CodeBlock.of(parameterName)); + arguments.put(parameterName, CodeBlock.of(parameterName)); } } } @@ -284,17 +284,13 @@ private CodeBlock renderExpressionToQuery(@Nullable String source, String variab builder.addStatement("$1T $2L = new $1T($3T.parse($4S))", BasicQuery.class, variableName, Document.class, source); } else { - builder.add("$T $L = createQuery($S, new $T[]{ ", BasicQuery.class, variableName, source, Object.class); - Iterator iterator = arguments.iterator(); - while (iterator.hasNext()) { - builder.add(iterator.next()); - if (iterator.hasNext()) { - builder.add(", "); - } else { - builder.add(" "); - } + builder.add("$T $L = createQuery($S, ", BasicQuery.class, variableName, source); + if (MongoCodeBlocks.containsNamedPlaceholder(source)) { + builder.add(MongoCodeBlocks.renderArgumentMap(arguments)); + } else { + builder.add(MongoCodeBlocks.renderArgumentArray(arguments)); } - builder.add("});\n"); + builder.add(");\n"); } return builder.build(); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java index c4c9fb62c9..e4061c7717 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java @@ -31,8 +31,8 @@ */ package org.springframework.data.mongodb.repository.aot; -import java.util.List; -import java.util.stream.Collectors; +import java.util.LinkedHashMap; +import java.util.Map; import org.jspecify.annotations.NullUnmarked; import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate; @@ -52,94 +52,96 @@ */ class UpdateBlocks { - @NullUnmarked - static class UpdateExecutionCodeBlockBuilder { + @NullUnmarked + static class UpdateExecutionCodeBlockBuilder { - private final AotQueryMethodGenerationContext context; - private final MongoQueryMethod queryMethod; - private String queryVariableName; - private String updateVariableName; + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private String queryVariableName; + private String updateVariableName; - UpdateExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + UpdateExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - this.context = context; - this.queryMethod = queryMethod; - } + this.context = context; + this.queryMethod = queryMethod; + } - UpdateExecutionCodeBlockBuilder withFilter(String queryVariableName) { + UpdateExecutionCodeBlockBuilder withFilter(String queryVariableName) { - this.queryVariableName = queryVariableName; - return this; - } + this.queryVariableName = queryVariableName; + return this; + } - UpdateExecutionCodeBlockBuilder referencingUpdate(String updateVariableName) { + UpdateExecutionCodeBlockBuilder referencingUpdate(String updateVariableName) { - this.updateVariableName = updateVariableName; - return this; - } + this.updateVariableName = updateVariableName; + return this; + } - CodeBlock build() { + CodeBlock build() { - String mongoOpsRef = context.fieldNameOf(MongoOperations.class); - Builder builder = CodeBlock.builder(); + String mongoOpsRef = context.fieldNameOf(MongoOperations.class); + Builder builder = CodeBlock.builder(); - builder.add("\n"); + builder.add("\n"); - String updateReference = updateVariableName; - Class domainType = context.getRepositoryInformation().getDomainType(); - builder.addStatement("$1T<$2T> $3L = $4L.update($2T.class)", ExecutableUpdate.class, domainType, - context.localVariable("updater"), mongoOpsRef); + String updateReference = updateVariableName; + Class domainType = context.getRepositoryInformation().getDomainType(); + builder.addStatement("$1T<$2T> $3L = $4L.update($2T.class)", ExecutableUpdate.class, domainType, + context.localVariable("updater"), mongoOpsRef); - Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); - if (ReflectionUtils.isVoid(returnType)) { - builder.addStatement("$L.matching($L).apply($L).all()", context.localVariable("updater"), queryVariableName, - updateReference); - } else if (ClassUtils.isAssignable(Long.class, returnType)) { - builder.addStatement("return $L.matching($L).apply($L).all().getModifiedCount()", - context.localVariable("updater"), queryVariableName, updateReference); - } else { - builder.addStatement("$T $L = $L.matching($L).apply($L).all().getModifiedCount()", Long.class, - context.localVariable("modifiedCount"), context.localVariable("updater"), queryVariableName, - updateReference); - builder.addStatement("return $T.convertNumberToTargetClass($L, $T.class)", NumberUtils.class, - context.localVariable("modifiedCount"), returnType); - } + Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); + if (ReflectionUtils.isVoid(returnType)) { + builder.addStatement("$L.matching($L).apply($L).all()", context.localVariable("updater"), queryVariableName, + updateReference); + } else if (ClassUtils.isAssignable(Long.class, returnType)) { + builder.addStatement("return $L.matching($L).apply($L).all().getModifiedCount()", + context.localVariable("updater"), queryVariableName, updateReference); + } else { + builder.addStatement("$T $L = $L.matching($L).apply($L).all().getModifiedCount()", Long.class, + context.localVariable("modifiedCount"), context.localVariable("updater"), queryVariableName, + updateReference); + builder.addStatement("return $T.convertNumberToTargetClass($L, $T.class)", NumberUtils.class, + context.localVariable("modifiedCount"), returnType); + } - return builder.build(); - } - } + return builder.build(); + } + } - @NullUnmarked - static class UpdateCodeBlockBuilder { + @NullUnmarked + static class UpdateCodeBlockBuilder { - private UpdateInteraction source; - private List arguments; - private String updateVariableName; + private UpdateInteraction source; + private Map arguments; + private String updateVariableName; - public UpdateCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - this.arguments = context.getBindableParameterNames().stream().map(CodeBlock::of).collect(Collectors.toList()); - } + public UpdateCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + this.arguments = new LinkedHashMap<>(); + context.getBindableParameterNames().forEach(it -> arguments.put(it, CodeBlock.of(it))); + } - public UpdateCodeBlockBuilder update(UpdateInteraction update) { - this.source = update; - return this; - } + public UpdateCodeBlockBuilder update(UpdateInteraction update) { + this.source = update; + return this; + } - public UpdateCodeBlockBuilder usingUpdateVariableName(String updateVariableName) { - this.updateVariableName = updateVariableName; - return this; - } + public UpdateCodeBlockBuilder usingUpdateVariableName(String updateVariableName) { + this.updateVariableName = updateVariableName; + return this; + } - CodeBlock build() { + CodeBlock build() { - Builder builder = CodeBlock.builder(); + Builder builder = CodeBlock.builder(); - builder.add("\n"); - String tmpVariableName = updateVariableName + "Document"; - builder.add(MongoCodeBlocks.renderExpressionToDocument(source.getUpdate().getUpdateString(), tmpVariableName, arguments)); - builder.addStatement("$1T $2L = new $1T($3L)", BasicUpdate.class, updateVariableName, tmpVariableName); + builder.add("\n"); + String tmpVariableName = updateVariableName + "Document"; + builder.add( + MongoCodeBlocks.renderExpressionToDocument(source.getUpdate().getUpdateString(), tmpVariableName, arguments)); + builder.addStatement("$1T $2L = new $1T($3L)", BasicUpdate.class, updateVariableName, tmpVariableName); - return builder.build(); - } - } + return builder.build(); + } + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java index 0aa9ad5fdf..76738bf375 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java @@ -122,6 +122,10 @@ private MongoParameters(List parameters, int maxDistanceIndex, i this.domainType = domainType; } + protected MongoParameters() { + this(List.of(), -1, -1, -1, -1, -1, -1, TypeInformation.OBJECT); + } + static boolean isGeoNearQuery(Method method) { Class returnType = method.getReturnType(); diff --git a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java index aeb10285fa..276a4fac8d 100644 --- a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java +++ b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java @@ -47,11 +47,11 @@ import org.springframework.data.mongodb.core.geo.Sphere; import org.springframework.data.mongodb.repository.Aggregation; import org.springframework.data.mongodb.repository.Hint; -import org.springframework.data.mongodb.repository.Person; import org.springframework.data.mongodb.repository.Query; import org.springframework.data.mongodb.repository.ReadPreference; import org.springframework.data.mongodb.repository.Update; import org.springframework.data.repository.CrudRepository; +import org.springframework.data.repository.query.Param; /** * @author Christoph Strobl @@ -179,6 +179,12 @@ public interface UserRepository extends CrudRepository { @Query("{ 'lastname' : { '$regex' : '^?0' } }") Slice findAnnotatedQuerySliceOfUsersByLastname(String lastname, Pageable pageable); + @Query("{ firstname : ?#{[0]} }") + List findWithExpressionUsingParameterIndex(String firstname); + + @Query("{ firstname : :#{#firstname} }") + List findWithExpressionUsingParameterName(@Param("firstname") String firstname); + /* deletes */ User deleteByUsername(String username); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java index 335ba33a18..6e6e6ccef5 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java @@ -334,6 +334,20 @@ void testAnnotatedFinderReturningSingleValueWithQuery() { assertThat(user).isNotNull().extracting(User::getUsername).isEqualTo("yoda"); } + @Test // GH-5006 + void testAnnotatedFinderWithExpressionUsingParameterIndex() { + + List users = fragment.findWithExpressionUsingParameterIndex("Luke"); + assertThat(users).extracting(User::getUsername).containsExactly("luke"); + } + + @Test // GH-5006 + void testAnnotatedFinderWithExpressionUsingParameterName() { + + List users = fragment.findWithExpressionUsingParameterName("Luke"); + assertThat(users).extracting(User::getUsername).containsExactly("luke"); + } + @Test void testAnnotatedCount() { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java index cbfbf05231..75a2e8d0a7 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java @@ -184,6 +184,26 @@ void rendersNearQueryWithFilterForGeoResults() throws NoSuchMethodException { .contains("return nearFinder.all()"); } + @Test // GH-5006 + void rendersExpressionUsingParameterIndex() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findWithExpressionUsingParameterIndex", String.class); + + assertThat(methodSpec.toString()) // + .contains("createQuery(\"{ firstname : ?#{[0]} }\"") // + .contains("Map.of(\"firstname\", firstname)"); + } + + @Test // GH-5006 + void rendersExpressionUsingParameterName() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findWithExpressionUsingParameterName", String.class); + + assertThat(methodSpec.toString()) // + .contains("createQuery(\"{ firstname : :#{#firstname} }\"") // + .contains("Map.of(\"firstname\", firstname)"); + } + private static MethodSpec codeOf(Class repository, String methodName, Class... args) throws NoSuchMethodException { From 52be56de8953437fc3f67e91143a639a4c11df8e Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Fri, 20 Jun 2025 10:42:39 +0200 Subject: [PATCH 20/24] Support generating queries containing regular expressions during AOT run. --- .../repository/aot/MongoRepositoryContributor.java | 4 +--- .../repository/query/MongoQueryCreator.java | 14 +++++++++++--- .../src/test/java/example/aot/UserRepository.java | 3 +++ .../aot/AotFragmentTestConfigurationSupport.java | 3 ++- .../aot/MongoRepositoryContributorTests.java | 8 ++++++++ .../aot/QueryMethodContributionUnitTests.java | 11 +++++++++++ 6 files changed, 36 insertions(+), 7 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index e741a28e27..95b1108f9f 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -27,7 +27,6 @@ import static org.springframework.data.mongodb.repository.aot.QueryBlocks.QueryCodeBlockBuilder; import java.lang.reflect.Method; -import java.util.Locale; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -173,8 +172,7 @@ private QueryInteraction createStringQuery(RepositoryInformation repositoryInfor private static boolean backoff(MongoQueryMethod method) { // TODO: namedQuery, Regex queries, queries accepting Shapes (e.g. within) or returning arrays. - boolean skip = method.isSearchQuery() || method.getName().toLowerCase(Locale.ROOT).contains("regex") - || method.getReturnType().getType().isArray(); + boolean skip = method.isSearchQuery() || method.getReturnType().getType().isArray(); if (skip && logger.isDebugEnabled()) { logger.debug("Skipping AOT generation for [%s]. Method is either returning an array or a geo-near, regex query" diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java index 97712b61cb..3436a52d1f 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java @@ -221,8 +221,7 @@ private Criteria from(Part part, MongoPersistentProperty property, Criteria crit return createContainingCriteria(part, property, criteria.not(), parameters); case REGEX: - Object param = parameters.next(); - return param instanceof Pattern pattern ? criteria.regex(pattern) : criteria.regex(param.toString()); + return createPatternCriteria(criteria, parameters); case EXISTS: Object next = parameters.next(); if (next instanceof Placeholder placeholder) { @@ -257,8 +256,17 @@ private Criteria from(Part part, MongoPersistentProperty property, Criteria crit } @NonNull - private Criteria createNearCriteria(MongoPersistentProperty property, Criteria criteria, Iterator parameters) { + private static Criteria createPatternCriteria(Criteria criteria, Iterator parameters) { + Object param = parameters.next(); + if (param instanceof Placeholder) { + return criteria.raw("$regex", param); + } + return param instanceof Pattern pattern ? criteria.regex(pattern) : criteria.regex(param.toString()); + } + @NonNull + private Criteria createNearCriteria(MongoPersistentProperty property, Criteria criteria, + Iterator parameters) { Range range = accessor.getDistanceRange(); Optional distance = range.getUpperBound().getValue(); diff --git a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java index 276a4fac8d..e3d04293c2 100644 --- a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java +++ b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java @@ -22,6 +22,7 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.regex.Pattern; import java.util.stream.Stream; import org.springframework.data.annotation.Id; @@ -96,6 +97,8 @@ public interface UserRepository extends CrudRepository { List findByLastnameNot(String lastname); + List findByFirstnameRegex(Pattern pattern); + List findTop2ByLastnameStartingWith(String lastname); List findByLastnameStartingWithOrderByUsername(String lastname); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/AotFragmentTestConfigurationSupport.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/AotFragmentTestConfigurationSupport.java index 5b86acdace..0a0549eb1b 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/AotFragmentTestConfigurationSupport.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/AotFragmentTestConfigurationSupport.java @@ -93,7 +93,8 @@ private Object getFragmentFacadeProxy(Object fragment) { Method target = ReflectionUtils.findMethod(fragment.getClass(), method.getName(), method.getParameterTypes()); if (target == null) { - throw new MethodNotImplementedException("Method [%s] is not implemented by [%s]".formatted(method, target)); + throw new MethodNotImplementedException( + "Method [%s] is not implemented by [%s]".formatted(method, fragment.getClass())); } try { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java index 6e6e6ccef5..8b6331a717 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java @@ -27,6 +27,7 @@ import java.time.Instant; import java.util.List; import java.util.Optional; +import java.util.regex.Pattern; import org.bson.Document; import org.junit.jupiter.api.BeforeAll; @@ -229,6 +230,13 @@ void testNot() { assertThat(users).extracting(User::getUsername).isNotEmpty().doesNotContain("luke", "vader"); } + @Test // GH-4939 + void testRegex() { + + List lukes = fragment.findByFirstnameRegex(Pattern.compile(".*uk.*")); + assertThat(lukes).extracting(User::getUsername).containsExactly("luke"); + } + @Test void testExistsCriteria() { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java index 75a2e8d0a7..ef62b79694 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java @@ -22,6 +22,7 @@ import java.lang.reflect.Method; import java.util.Arrays; +import java.util.regex.Pattern; import javax.lang.model.element.Modifier; @@ -204,6 +205,16 @@ void rendersExpressionUsingParameterName() throws NoSuchMethodException { .contains("Map.of(\"firstname\", firstname)"); } + @Test // GH-4939 + void rendersRegexCriteria() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findByFirstnameRegex", Pattern.class); + + assertThat(methodSpec.toString()) // + .contains("createQuery(\"{'firstname':{'$regex':?0}}\"") // + .contains("Object[]{ pattern }"); + } + private static MethodSpec codeOf(Class repository, String methodName, Class... args) throws NoSuchMethodException { From 09835d844b20e1dfd007e1e0ecd6bf3d74f6cf71 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Fri, 20 Jun 2025 14:04:09 +0200 Subject: [PATCH 21/24] Add Collation to generated query if present. --- .../MongoAotRepositoryFragmentSupport.java | 35 +++++++++++++++ .../mongodb/repository/aot/QueryBlocks.java | 19 +++++++- .../aot/QueryMethodContributionUnitTests.java | 44 +++++++++++++++++-- 3 files changed, 93 insertions(+), 5 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java index 6686f9794d..84de3bb835 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java @@ -17,6 +17,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Locale; import java.util.Map; import org.bson.Document; @@ -31,6 +32,7 @@ import org.springframework.data.mongodb.core.convert.MongoConverter; import org.springframework.data.mongodb.core.mapping.FieldName; import org.springframework.data.mongodb.core.query.BasicQuery; +import org.springframework.data.mongodb.core.query.Collation; import org.springframework.data.mongodb.repository.query.MongoParameters; import org.springframework.data.mongodb.util.json.ParameterBindingContext; import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec; @@ -106,6 +108,39 @@ protected Document bindParameters(String source, Map parameters) return new ParameterBindingDocumentCodec().decode(source, bindingContext); } + protected Object evaluate(String source, Map parameters) { + + ValueEvaluationContext valueEvaluationContext = this.valueExpressionDelegate.getEvaluationContextAccessor() + .create(new NoMongoParameters()).getEvaluationContext(parameters.values()); + + EvaluationContext evaluationContext = valueEvaluationContext.getEvaluationContext(); + parameters.forEach(evaluationContext::setVariable); + + ValueExpression parse = valueExpressionDelegate.getValueExpressionParser().parse(source); + return parse.evaluate(valueEvaluationContext); + } + + protected Collation collationOf(@Nullable Object source) { + + if(source == null) { + return Collation.simple(); + } + if (source instanceof String) { + return Collation.parse(source.toString()); + } + if (source instanceof Locale locale) { + return Collation.of(locale); + } + if (source instanceof Document document) { + return Collation.from(document); + } + if (source instanceof Collation collation) { + return collation; + } + throw new IllegalArgumentException( + "Unsupported collation source [%s]".formatted(ObjectUtils.nullSafeClassName(source))); + } + protected BasicQuery createQuery(String queryString, Object[] parameters) { Document queryDocument = bindParameters(queryString, parameters); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java index e9425dce87..e014625474 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java @@ -29,6 +29,7 @@ import org.springframework.data.geo.Polygon; import org.springframework.data.mongodb.core.ExecutableFindOperation.FindWithQuery; import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.annotation.Collation; import org.springframework.data.mongodb.core.geo.GeoJson; import org.springframework.data.mongodb.core.geo.Sphere; import org.springframework.data.mongodb.core.query.BasicQuery; @@ -264,12 +265,26 @@ CodeBlock build() { } String comment = metaAnnotation.getString("comment"); - if (StringUtils.hasText("comment")) { + if (StringUtils.hasText(comment)) { builder.addStatement("$L.comment($S)", queryVariableName, comment); } } - // TODO: Meta annotation: Disk usage + MergedAnnotation collationAnnotation = context.getAnnotation(Collation.class); + if (collationAnnotation.isPresent()) { + + String collationString = collationAnnotation.getString("value"); + if(StringUtils.hasText(collationString)) { + if (!MongoCodeBlocks.containsPlaceholder(collationString)) { + builder.addStatement("$L.collation($T.parse($S))", queryVariableName, + org.springframework.data.mongodb.core.query.Collation.class, collationString); + } else { + builder.add("$L.collation(collationOf(evaluate($S, ", queryVariableName, collationString); + builder.add(MongoCodeBlocks.renderArgumentMap(arguments)); + builder.add(")));\n"); + } + } + } return builder.build(); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java index ef62b79694..d3d0f40e41 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java @@ -22,6 +22,7 @@ import java.lang.reflect.Method; import java.util.Arrays; +import java.util.List; import java.util.regex.Pattern; import javax.lang.model.element.Modifier; @@ -37,8 +38,10 @@ import org.springframework.data.geo.Point; import org.springframework.data.geo.Polygon; import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.annotation.Collation; import org.springframework.data.mongodb.core.geo.GeoJsonPolygon; import org.springframework.data.mongodb.core.geo.Sphere; +import org.springframework.data.mongodb.repository.Hint; import org.springframework.data.mongodb.repository.ReadPreference; import org.springframework.data.repository.Repository; import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; @@ -211,8 +214,36 @@ void rendersRegexCriteria() throws NoSuchMethodException { MethodSpec methodSpec = codeOf(UserRepository.class, "findByFirstnameRegex", Pattern.class); assertThat(methodSpec.toString()) // - .contains("createQuery(\"{'firstname':{'$regex':?0}}\"") // - .contains("Object[]{ pattern }"); + .contains("createQuery(\"{'firstname':{'$regex':?0}}\"") // + .contains("Object[]{ pattern }"); + } + + @Test // GH-4939 + void rendersHint() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepoWithMeta.class, "findByFirstname", String.class); + + assertThat(methodSpec.toString()) // + .contains(".withHint(\"fn-idx\")"); + } + + @Test // GH-4939 + void rendersCollation() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepoWithMeta.class, "findByFirstname", String.class); + + assertThat(methodSpec.toString()) // + .containsPattern(".*\\.collation\\(.*Collation\\.parse\\(\"en_US\"\\)\\)"); + } + + @Test // GH-4939 + void rendersCollationFromExpression() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepoWithMeta.class, "findWithCollationByFirstname", String.class, String.class); + + assertThat(methodSpec.toString()) // + .containsIgnoringWhitespaces( + "collationOf(evaluate(\"?#{[1]}\", java.util.Map.of(\"firstname\", firstname, \"locale\", locale)))"); } private static MethodSpec codeOf(Class repository, String methodName, Class... args) @@ -228,7 +259,7 @@ private static MethodSpec codeOf(Class repository, String methodName, Class { + @Hint("fn-idx") + @Collation("en_US") + List findByFirstname(String firstname); + + @Collation("?#{[1]}") + List findWithCollationByFirstname(String firstname, String locale); + @ReadPreference("NEAREST") GeoResults findByLocationCoordinatesNear(Point point, Distance maxDistance); } From 88fc95530ee6e74fcf3835a2b5a43cffcebc0d52 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Mon, 23 Jun 2025 08:08:09 +0200 Subject: [PATCH 22/24] Initial support for AOT generated VectorSearch --- .../repository/aot/AotQueryCreator.java | 18 +- .../MongoAotRepositoryFragmentSupport.java | 86 ++++++- .../repository/aot/MongoCodeBlocks.java | 31 ++- .../aot/MongoRepositoryContributor.java | 32 ++- .../mongodb/repository/aot/QueryBlocks.java | 11 +- .../repository/aot/SearchInteraction.java | 48 ++++ .../repository/aot/VectorSearchBocks.java | 211 ++++++++++++++++++ .../src/test/java/example/aot/User.java | 3 + .../test/java/example/aot/UserRepository.java | 30 +++ .../aot/MongoRepositoryContributorTests.java | 143 ++++++++++-- .../aot/QueryMethodContributionUnitTests.java | 144 +++++++++++- 11 files changed, 709 insertions(+), 48 deletions(-) create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java index 11d6e8bdd2..219f90348c 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java @@ -15,6 +15,7 @@ */ package org.springframework.data.mongodb.repository.aot; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Iterator; import java.util.List; @@ -46,6 +47,7 @@ import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.TextCriteria; import org.springframework.data.mongodb.core.query.UpdateDefinition; +import org.springframework.data.mongodb.repository.VectorSearch; import org.springframework.data.mongodb.repository.query.ConvertingParameterAccessor; import org.springframework.data.mongodb.repository.query.MongoParameterAccessor; import org.springframework.data.mongodb.repository.query.MongoQueryCreator; @@ -79,14 +81,16 @@ public AotQueryCreator() { } @SuppressWarnings("NullAway") - StringQuery createQuery(PartTree partTree, QueryMethod queryMethod) { - + StringQuery createQuery(PartTree partTree, QueryMethod queryMethod, Method source) { boolean geoNear = queryMethod instanceof MongoQueryMethod mqm ? mqm.isGeoNearQuery() : false; + boolean searchQuery = queryMethod instanceof MongoQueryMethod mqm + ? mqm.isSearchQuery() || source.isAnnotationPresent(VectorSearch.class) + : source.isAnnotationPresent(VectorSearch.class); Query query = new MongoQueryCreator(partTree, - new PlaceholderConvertingParameterAccessor(new PlaceholderParameterAccessor(queryMethod)), mappingContext, geoNear, queryMethod.isSearchQuery()) - .createQuery(); + new PlaceholderConvertingParameterAccessor(new PlaceholderParameterAccessor(queryMethod)), mappingContext, + geoNear, searchQuery).createQuery(); if (partTree.isLimiting()) { query.limit(partTree.getMaxResults()); @@ -141,8 +145,7 @@ public PlaceholderParameterAccessor(QueryMethod queryMethod) { for (Parameter parameter : parameters.toList()) { if (ClassUtils.isAssignable(GeoJson.class, parameter.getType())) { placeholders.add(parameter.getIndex(), new GeoJsonPlaceholder(parameter.getIndex(), "")); - } - else if (ClassUtils.isAssignable(Point.class, parameter.getType())) { + } else if (ClassUtils.isAssignable(Point.class, parameter.getType())) { placeholders.add(parameter.getIndex(), new PointPlaceholder(parameter.getIndex())); } else if (ClassUtils.isAssignable(Circle.class, parameter.getType())) { placeholders.add(parameter.getIndex(), new CirclePlaceholder(parameter.getIndex())); @@ -152,8 +155,7 @@ else if (ClassUtils.isAssignable(Point.class, parameter.getType())) { placeholders.add(parameter.getIndex(), new SpherePlaceholder(parameter.getIndex())); } else if (ClassUtils.isAssignable(Polygon.class, parameter.getType())) { placeholders.add(parameter.getIndex(), new PolygonPlaceholder(parameter.getIndex())); - } - else { + } else { placeholders.add(parameter.getIndex(), Placeholder.indexed(parameter.getIndex())); } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java index 84de3bb835..86b3217b07 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java @@ -16,12 +16,17 @@ package org.springframework.data.mongodb.repository.aot; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.function.Consumer; import org.bson.Document; import org.jspecify.annotations.Nullable; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.ScoringFunction; import org.springframework.data.expression.ValueEvaluationContext; import org.springframework.data.expression.ValueExpression; import org.springframework.data.mapping.model.ValueExpressionEvaluator; @@ -33,6 +38,7 @@ import org.springframework.data.mongodb.core.mapping.FieldName; import org.springframework.data.mongodb.core.query.BasicQuery; import org.springframework.data.mongodb.core.query.Collation; +import org.springframework.data.mongodb.core.query.Criteria; import org.springframework.data.mongodb.repository.query.MongoParameters; import org.springframework.data.mongodb.util.json.ParameterBindingContext; import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec; @@ -42,7 +48,9 @@ import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport; import org.springframework.data.repository.query.ValueExpressionDelegate; import org.springframework.expression.EvaluationContext; +import org.springframework.util.Assert; import org.springframework.util.ClassUtils; +import org.springframework.util.CollectionUtils; import org.springframework.util.ObjectUtils; /** @@ -108,7 +116,27 @@ protected Document bindParameters(String source, Map parameters) return new ParameterBindingDocumentCodec().decode(source, bindingContext); } - protected Object evaluate(String source, Map parameters) { + protected Object[] arguments(Object... arguments) { + return arguments; + } + + protected Map argumentMap(Object... parameters) { + + Assert.state(parameters.length % 2 == 0, "even number of args required"); + + LinkedHashMap argumentMap = CollectionUtils.newLinkedHashMap(parameters.length / 2); + for (int i = 0; i < parameters.length; i += 2) { + + if (!(parameters[i] instanceof String key)) { + throw new IllegalArgumentException("key must be a String"); + } + argumentMap.put(key, parameters[i + 1]); + } + + return argumentMap; + } + + protected @Nullable Object evaluate(String source, Map parameters) { ValueEvaluationContext valueEvaluationContext = this.valueExpressionDelegate.getEvaluationContextAccessor() .create(new NoMongoParameters()).getEvaluationContext(parameters.values()); @@ -120,9 +148,63 @@ protected Object evaluate(String source, Map parameters) { return parse.evaluate(valueEvaluationContext); } + protected Consumer scoreBetween(Range.Bound lower, Range.Bound upper) { + + return criteria -> { + if (lower.isBounded()) { + double value = lower.getValue().get().getValue(); + if (lower.isInclusive()) { + criteria.gte(value); + } else { + criteria.gt(value); + } + } + + if (upper.isBounded()) { + + double value = upper.getValue().get().getValue(); + if (upper.isInclusive()) { + criteria.lte(value); + } else { + criteria.lt(value); + } + } + + }; + } + + protected ScoringFunction scoringFunction(Range scoreRange) { + + if (scoreRange != null) { + if (scoreRange.getUpperBound().isBounded()) { + return scoreRange.getUpperBound().getValue().get().getFunction(); + } + + if (scoreRange.getLowerBound().isBounded()) { + return scoreRange.getLowerBound().getValue().get().getFunction(); + } + } + + return ScoringFunction.unspecified(); + } + + // Range scoreRange = accessor.getScoreRange(); + // + // if (scoreRange != null) { + // if (scoreRange.getUpperBound().isBounded()) { + // return scoreRange.getUpperBound().getValue().get().getFunction(); + // } + // + // if (scoreRange.getLowerBound().isBounded()) { + // return scoreRange.getLowerBound().getValue().get().getFunction(); + // } + // } + // + // return ScoringFunction.unspecified(); + protected Collation collationOf(@Nullable Object source) { - if(source == null) { + if (source == null) { return Collation.simple(); } if (source instanceof String) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java index 3881994437..4125139bd9 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java @@ -37,6 +37,7 @@ import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.util.NumberUtils; import org.springframework.util.StringUtils; /** @@ -49,6 +50,7 @@ class MongoCodeBlocks { private static final Pattern PARAMETER_BINDING_PATTERN = Pattern.compile("\\?(\\d+)"); private static final Pattern EXPRESSION_BINDING_PATTERN = Pattern.compile("[\\?:][#$]\\{.*\\}"); + private static final Pattern VALUE_EXPRESSION_PATTERN = Pattern.compile("^#\\{.*}$"); /** * Builder for generating query parsing {@link CodeBlock}. @@ -179,7 +181,7 @@ static CodeBlock renderExpressionToDocument(@Nullable String source, String vari } else { builder.add("$T $L = bindParameters($S, ", Document.class, variableName, source); if (containsNamedPlaceholder(source)) { - renderArgumentMap(arguments); + builder.add(renderArgumentMap(arguments)); } else { builder.add(renderArgumentArray(arguments)); } @@ -191,7 +193,7 @@ static CodeBlock renderExpressionToDocument(@Nullable String source, String vari static CodeBlock renderArgumentMap(Map arguments) { Builder builder = CodeBlock.builder(); - builder.add("$T.of(", Map.class); + builder.add("argumentMap("); Iterator> iterator = arguments.entrySet().iterator(); while (iterator.hasNext()) { Entry next = iterator.next(); @@ -208,24 +210,41 @@ static CodeBlock renderArgumentMap(Map arguments) { static CodeBlock renderArgumentArray(Map arguments) { Builder builder = CodeBlock.builder(); - builder.add("new $T[]{ ", Object.class); + builder.add("arguments("); Iterator iterator = arguments.values().iterator(); while (iterator.hasNext()) { builder.add(iterator.next()); if (iterator.hasNext()) { builder.add(", "); - } else { - builder.add(" "); } } - builder.add("}"); + builder.add(")"); return builder.build(); } + static CodeBlock evaluateNumberPotentially(String value, Class targetType, + Map arguments) { + try { + Number number = NumberUtils.parseNumber(value, targetType); + return CodeBlock.of("$L", number); + } catch (IllegalArgumentException e) { + + Builder builder = CodeBlock.builder(); + builder.add("($T) evaluate($S, ", targetType, value); + builder.add(MongoCodeBlocks.renderArgumentMap(arguments)); + builder.add(")"); + return builder.build(); + } + } + static boolean containsPlaceholder(String source) { return containsIndexedPlaceholder(source) || containsNamedPlaceholder(source); } + static boolean containsExpression(String source) { + return VALUE_EXPRESSION_PATTERN.matcher(source).find(); + } + static boolean containsNamedPlaceholder(String source) { return EXPRESSION_BINDING_PATTERN.matcher(source).find(); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index 95b1108f9f..524c5e8f23 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -37,6 +37,7 @@ import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.repository.Query; import org.springframework.data.mongodb.repository.Update; +import org.springframework.data.mongodb.repository.VectorSearch; import org.springframework.data.mongodb.repository.query.MongoQueryMethod; import org.springframework.data.repository.aot.generate.AotRepositoryClassBuilder; import org.springframework.data.repository.aot.generate.AotRepositoryConstructorBuilder; @@ -107,7 +108,11 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB } QueryInteraction query = createStringQuery(getRepositoryInformation(), queryMethod, - AnnotatedElementUtils.findMergedAnnotation(method, Query.class)); + AnnotatedElementUtils.findMergedAnnotation(method, Query.class), method); + + if (queryMethod.isSearchQuery() || method.isAnnotationPresent(VectorSearch.class)) { + return searchMethodContributor(queryMethod, new SearchInteraction(query.getQuery())); + } if (queryMethod.isGeoNearQuery() || (queryMethod.getParameters().getMaxDistanceIndex() != -1 && queryMethod.getReturnType().isCollectionLike())) { @@ -126,8 +131,8 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB UpdateInteraction update = new UpdateInteraction(query, null, updateIndex); return updateMethodContributor(queryMethod, update); - } else { + Update updateSource = queryMethod.getUpdateSource(); if (StringUtils.hasText(updateSource.value())) { UpdateInteraction update = new UpdateInteraction(query, new StringUpdate(updateSource.value()), null); @@ -146,7 +151,7 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB @SuppressWarnings("NullAway") private QueryInteraction createStringQuery(RepositoryInformation repositoryInformation, MongoQueryMethod queryMethod, - @Nullable Query queryAnnotation) { + @Nullable Query queryAnnotation, Method source) { QueryInteraction query; if (queryMethod.hasAnnotatedQuery() && queryAnnotation != null) { @@ -155,8 +160,8 @@ private QueryInteraction createStringQuery(RepositoryInformation repositoryInfor } else { PartTree partTree = new PartTree(queryMethod.getName(), repositoryInformation.getDomainType()); - query = new QueryInteraction(queryCreator.createQuery(partTree, queryMethod), partTree.isCountProjection(), - partTree.isDelete(), partTree.isExistsProjection()); + query = new QueryInteraction(queryCreator.createQuery(partTree, queryMethod, source), + partTree.isCountProjection(), partTree.isDelete(), partTree.isExistsProjection()); } if (queryAnnotation != null && StringUtils.hasText(queryAnnotation.sort())) { @@ -172,7 +177,7 @@ private QueryInteraction createStringQuery(RepositoryInformation repositoryInfor private static boolean backoff(MongoQueryMethod method) { // TODO: namedQuery, Regex queries, queries accepting Shapes (e.g. within) or returning arrays. - boolean skip = method.isSearchQuery() || method.getReturnType().getType().isArray(); + boolean skip = method.getReturnType().getType().isArray(); if (skip && logger.isDebugEnabled()) { logger.debug("Skipping AOT generation for [%s]. Method is either returning an array or a geo-near, regex query" @@ -220,6 +225,21 @@ static MethodContributor aggregationMethodContributor(MongoQue }); } + static MethodContributor searchMethodContributor(MongoQueryMethod queryMethod, + SearchInteraction interaction) { + return MethodContributor.forQueryMethod(queryMethod).withMetadata(interaction).contribute(context -> { + + CodeBlock.Builder builder = CodeBlock.builder(); + + String variableName = "search"; + + builder.add(new VectorSearchBocks.VectorSearchQueryCodeBlockBuilder(context, queryMethod) + .usingVariableName(variableName).withFilter(interaction.getFilter()).build()); + + return builder.build(); + }); + } + static MethodContributor updateMethodContributor(MongoQueryMethod queryMethod, UpdateInteraction update) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java index e014625474..7ad0c25b16 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java @@ -211,8 +211,7 @@ CodeBlock build() { Builder builder = CodeBlock.builder(); - builder.add("\n"); - builder.add(renderExpressionToQuery(source.getQuery().getQueryString(), queryVariableName)); + builder.add(buildJustTheQuery()); if (StringUtils.hasText(source.getQuery().getFieldsString())) { @@ -289,6 +288,14 @@ CodeBlock build() { return builder.build(); } + CodeBlock buildJustTheQuery() { + + Builder builder = CodeBlock.builder(); + builder.add("\n"); + builder.add(renderExpressionToQuery(source.getQuery().getQueryString(), queryVariableName)); + return builder.build(); + } + private CodeBlock renderExpressionToQuery(@Nullable String source, String variableName) { Builder builder = CodeBlock.builder(); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java new file mode 100644 index 0000000000..a94ff1082b --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java @@ -0,0 +1,48 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import java.util.Map; + +import org.jspecify.annotations.Nullable; +import org.springframework.data.repository.aot.generate.QueryMetadata; + +/** + * @author Christoph Strobl + */ +public class SearchInteraction extends MongoInteraction implements QueryMetadata { + + StringQuery filter; + + public SearchInteraction(StringQuery filter) { + this.filter = filter; + } + + public StringQuery getFilter() { + return filter; + } + + @Override + InteractionType getExecutionType() { + return InteractionType.AGGREGATION; + } + + @Override + public Map serialize() { + + return Map.of("FIXME", "please!"); + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java new file mode 100644 index 0000000000..3efdc080b2 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java @@ -0,0 +1,211 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import java.lang.reflect.Field; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.bson.Document; +import org.springframework.core.annotation.MergedAnnotation; +import org.springframework.data.domain.Limit; +import org.springframework.data.domain.ScoringFunction; +import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.aggregation.Aggregation; +import org.springframework.data.mongodb.core.aggregation.AggregationOperation; +import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext; +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType; +import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.repository.query.MongoQueryExecution.VectorSearchExecution; +import org.springframework.data.mongodb.repository.query.MongoQueryMethod; +import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; +import org.springframework.data.util.TypeInformation; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.util.StringUtils; + +/** + * @author Christoph Strobl + * @since 5.0 + */ +class VectorSearchBocks { + + static class VectorSearchQueryCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private String searchQueryVariableName; + private StringQuery filter; + private final Map arguments; + + VectorSearchQueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + this.queryMethod = queryMethod; + this.arguments = new LinkedHashMap<>(); + context.getBindableParameterNames().forEach(it -> arguments.put(it, CodeBlock.of(it))); + } + + VectorSearchQueryCodeBlockBuilder usingVariableName(String searchQueryVariableName) { + + this.searchQueryVariableName = searchQueryVariableName; + return this; + } + + CodeBlock build() { + + Builder builder = CodeBlock.builder(); + + String vectorParameterName = context.getVectorParameterName(); + + MergedAnnotation annotation = context.getAnnotation(VectorSearch.class); + String searchPath = annotation.getString("path"); + String indexName = annotation.getString("indexName"); + String numCandidates = annotation.getString("numCandidates"); + SearchType searchType = annotation.getEnum("searchType", SearchType.class); + String limit = annotation.getString("limit"); + + if (!StringUtils.hasText(searchPath)) { // FIXME: somehow duplicate logic of AnnotatedQueryFactory + + Field[] declaredFields = context.getRepositoryInformation().getDomainType().getDeclaredFields(); + for (Field field : declaredFields) { + if (Vector.class.isAssignableFrom(field.getType())) { + searchPath = field.getName(); + break; + } + } + + } + + String vectorSearchVar = context.localVariable("$vectorSearch"); + builder.add("$T $L = $T.vectorSearch($S).path($S).vector($L)", VectorSearchOperation.class, vectorSearchVar, + Aggregation.class, indexName, searchPath, vectorParameterName); + + if (StringUtils.hasText(context.getLimitParameterName())) { + builder.add(".limit($L);\n", context.getLimitParameterName()); + } else if (filter.isLimited()) { + builder.add(".limit($L);\n", filter.getLimit()); + } else if (StringUtils.hasText(limit)) { + if (MongoCodeBlocks.containsPlaceholder(limit) || MongoCodeBlocks.containsExpression(limit)) { + builder.add(".limit("); + builder.add(MongoCodeBlocks.evaluateNumberPotentially(limit, Integer.class, arguments)); + builder.add(");\n"); + } else { + builder.add(".limit($L);\n", limit); + } + } else { + builder.add(".limit($T.unlimited());\n", Limit.class); + } + + if (!searchType.equals(SearchType.DEFAULT)) { + builder.addStatement("$1L = $1L.searchType($2T.$3L)", vectorSearchVar, SearchType.class, searchType.name()); + } + + if (StringUtils.hasText(numCandidates)) { + builder.add("$1L = $1L.numCandidates(", vectorSearchVar); + builder.add(MongoCodeBlocks.evaluateNumberPotentially(numCandidates, Integer.class, arguments)); + builder.add(");\n"); + } else if (searchType == VectorSearchOperation.SearchType.ANN + || searchType == VectorSearchOperation.SearchType.DEFAULT) { + + builder.add( + "// MongoDB: We recommend that you specify a number at least 20 times higher than the number of documents to return\n"); + if (StringUtils.hasText(context.getLimitParameterName())) { + builder.addStatement("$1L = $1L.numCandidates($2L.max() * 20)", vectorSearchVar, + context.getLimitParameterName()); + } else if (StringUtils.hasText(limit)) { + if (MongoCodeBlocks.containsPlaceholder(limit) || MongoCodeBlocks.containsExpression(limit)) { + + builder.add("$1L = $1L.numCandidates((", vectorSearchVar); + builder.add(MongoCodeBlocks.evaluateNumberPotentially(limit, Integer.class, arguments)); + builder.add(") * 20);\n"); + } else { + builder.addStatement("$1L = $1L.numCandidates($2L * 20)", vectorSearchVar, limit); + } + } else { + builder.addStatement("$1L = $1L.numCandidates($2L)", vectorSearchVar, filter.getLimit() * 20); + } + } + + builder.addStatement("$1L = $1L.withSearchScore(\"__score__\")", vectorSearchVar); + if (StringUtils.hasText(context.getScoreParameterName())) { + + String scoreCriteriaVar = context.localVariable("criteria"); + builder.addStatement("$1L = $1L.withFilterBySore($2L -> { $2L.gt($3L.getValue()); })", vectorSearchVar, + scoreCriteriaVar, context.getScoreParameterName()); + } else if (StringUtils.hasText(context.getScoreRangeParameterName())) { + builder.addStatement("$1L = $1L.withFilterBySore(scoreBetween($2L.getLowerBound(), $2L.getUpperBound()))", + vectorSearchVar, context.getScoreRangeParameterName()); + } + + if (StringUtils.hasText(filter.getQueryString())) { + + String filterVar = context.localVariable("filter"); + builder.add(MongoCodeBlocks.queryBlockBuilder(context, queryMethod).usingQueryVariableName("filter") + .filter(new QueryInteraction(this.filter, false, false, false)).buildJustTheQuery()); + builder.addStatement("$1L = $1L.filter($2L.getQueryObject())", vectorSearchVar, filterVar); + builder.add("\n"); + } + + + String sortStageVar = context.localVariable("$sort"); + if(filter.isSorted()) { + + builder.add("$T $L = (_ctx) -> {\n", AggregationOperation.class, sortStageVar); + builder.indent(); + + builder.addStatement("$1T _mappedSort = _ctx.getMappedObject($1T.parse($2S), $3T.class)", Document.class, filter.getSortString(), context.getActualReturnType().getType()); + builder.addStatement("return new $T($S, _mappedSort.append(\"__score__\", -1))", Document.class, "$sort"); + builder.unindent(); + builder.add("};"); + + } else { + builder.addStatement("var $L = $T.sort($T.Direction.DESC, $S)", sortStageVar, Aggregation.class, Sort.class, "__score__"); + } + builder.add("\n"); + + builder.addStatement("$1T $2L = new $1T($3T.of($4L, $5L))", AggregationPipeline.class, searchQueryVariableName, + List.class, vectorSearchVar, sortStageVar); + + String scoringFunctionVar = context.localVariable("scoringFunction"); + builder.add("$1T $2L = ", ScoringFunction.class, scoringFunctionVar); + if (StringUtils.hasText(context.getScoreParameterName())) { + builder.add("$L.getFunction();\n", context.getScoreParameterName()); + } else if (StringUtils.hasText(context.getScoreRangeParameterName())) { + builder.add("scoringFunction($L);\n", context.getScoreRangeParameterName()); + } else { + builder.add("$1T.unspecified();\n", ScoringFunction.class); + } + + builder.addStatement( + "return ($5T) new $1T($2L, $3T.class, $2L.getCollectionName($3T.class), $4T.of($5T.class), $6L, $7L).execute(null)", + VectorSearchExecution.class, context.fieldNameOf(MongoOperations.class), + context.getRepositoryInformation().getDomainType(), TypeInformation.class, + queryMethod.getReturnType().getType(), searchQueryVariableName, scoringFunctionVar); + return builder.build(); + } + + public VectorSearchQueryCodeBlockBuilder withFilter(StringQuery filter) { + this.filter = filter; + return this; + } + } +} diff --git a/spring-data-mongodb/src/test/java/example/aot/User.java b/spring-data-mongodb/src/test/java/example/aot/User.java index 25514a518c..dfe3ec3553 100644 --- a/spring-data-mongodb/src/test/java/example/aot/User.java +++ b/spring-data-mongodb/src/test/java/example/aot/User.java @@ -17,6 +17,7 @@ import java.time.Instant; +import org.springframework.data.domain.Vector; import org.springframework.data.mongodb.core.mapping.Field; /** @@ -38,6 +39,8 @@ public class User { Instant lastSeen; Long visits; + Vector embedding; + public String getId() { return id; } diff --git a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java index e3d04293c2..ee1058fdc6 100644 --- a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java +++ b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java @@ -30,9 +30,13 @@ import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Similarity; import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.domain.Window; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; @@ -43,6 +47,7 @@ import org.springframework.data.geo.Point; import org.springframework.data.geo.Polygon; import org.springframework.data.mongodb.core.aggregation.AggregationResults; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; import org.springframework.data.mongodb.core.geo.GeoJson; import org.springframework.data.mongodb.core.geo.GeoJsonPolygon; import org.springframework.data.mongodb.core.geo.Sphere; @@ -51,6 +56,7 @@ import org.springframework.data.mongodb.repository.Query; import org.springframework.data.mongodb.repository.ReadPreference; import org.springframework.data.mongodb.repository.Update; +import org.springframework.data.mongodb.repository.VectorSearch; import org.springframework.data.repository.CrudRepository; import org.springframework.data.repository.query.Param; @@ -291,6 +297,30 @@ public interface UserRepository extends CrudRepository { "{ '$project': { '_id' : '$last_name' } }" }, collation = "no_collation") List findAllLastnamesWithCollation(); + // Vector Search + + @VectorSearch(indexName = "embedding.vector_cos", filter = "{lastname: ?0}", numCandidates = "#{10+10}", + searchType = VectorSearchOperation.SearchType.ANN) + SearchResults annotatedVectorSearch(String lastname, Vector vector, Score distance, Limit limit); + + @VectorSearch(indexName = "embedding.vector_cos") + SearchResults searchCosineByLastnameAndEmbeddingNear(String lastname, Vector vector, Score similarity, + Limit limit); + + @VectorSearch(indexName = "embedding.vector_cos") + List searchAsListByLastnameAndEmbeddingNear(String lastname, Vector vector, Limit limit); + + @VectorSearch(indexName = "embedding.vector_cos", limit = "10") + SearchResults searchByLastnameAndEmbeddingWithin(String lastname, Vector vector, Range distance); + + @VectorSearch(indexName = "embedding.vector_cos", limit = "10") + SearchResults searchByLastnameAndEmbeddingWithinOrderByFirstname(String lastname, Vector vector, + Range distance); + + @VectorSearch(indexName = "embedding.vector_cos") + SearchResults searchTop1ByLastnameAndEmbeddingWithin(String lastname, Vector vector, + Range distance); + class UserAggregate { @Id // diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java index 8b6331a717..e40f0cd53b 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java @@ -24,16 +24,18 @@ import example.aot.UserRepository; import example.aot.UserRepository.UserAggregate; +import java.time.Duration; import java.time.Instant; +import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.regex.Pattern; +import org.bson.BsonString; import org.bson.Document; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -43,9 +45,13 @@ import org.springframework.data.domain.Page; import org.springframework.data.domain.PageRequest; import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Similarity; import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.domain.Window; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; @@ -61,14 +67,20 @@ import org.springframework.data.mongodb.core.aggregation.AggregationResults; import org.springframework.data.mongodb.core.geo.GeoJsonPoint; import org.springframework.data.mongodb.core.geo.GeoJsonPolygon; -import org.springframework.data.mongodb.test.util.Client; -import org.springframework.data.mongodb.test.util.MongoClientExtension; +import org.springframework.data.mongodb.test.util.AtlasContainer; import org.springframework.data.mongodb.test.util.MongoTestUtils; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; import org.springframework.util.StringUtils; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.shaded.org.awaitility.Awaitility; import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; +import com.mongodb.client.MongoCollection; import com.mongodb.client.model.IndexOptions; +import com.mongodb.client.model.SearchIndexModel; +import com.mongodb.client.model.SearchIndexType; /** * Integration tests for the {@link UserRepository} AOT fragment. @@ -76,13 +88,15 @@ * @author Christoph Strobl * @author Mark Paluch */ -@ExtendWith(MongoClientExtension.class) +@Testcontainers(disabledWithoutDocker = true) @SpringJUnitConfig(classes = MongoRepositoryContributorTests.MongoRepositoryContributorConfiguration.class) class MongoRepositoryContributorTests { + private static final @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch(); private static final String DB_NAME = "aot-repo-tests"; + private static final String COLLECTION_NAME = "user"; - @Client static MongoClient client; + static MongoClient client; @Autowired UserRepository fragment; @Configuration @@ -99,13 +113,39 @@ MongoOperations mongoOperations() { } @BeforeAll - static void beforeAll() { - client.getDatabase(DB_NAME).getCollection("user").createIndex(new Document("location.coordinates", "2d"), - new IndexOptions()); + static void beforeAll() throws InterruptedException { + + client = MongoClients.create(atlasLocal.getConnectionString()); + MongoCollection userCollection = client.getDatabase(DB_NAME).getCollection(COLLECTION_NAME); + userCollection.createIndex(new Document("location.coordinates", "2d"), new IndexOptions()); + userCollection.createIndex(new Document("location.coordinates", "2dsphere"), new IndexOptions()); + + Document searchIndex = new Document("fields", + List.of(new Document("type", "vector").append("path", "embedding").append("numDimensions", 5) + .append("similarity", "cosine"), new Document("type", "filter").append("path", "last_name"))); + + userCollection.createSearchIndexes(List.of( + new SearchIndexModel("embedding.vector_cos", searchIndex, SearchIndexType.of(new BsonString("vectorSearch"))))); + + Awaitility.await().atMost(Duration.ofSeconds(120)).pollInterval(Duration.ofMillis(200)).until(() -> { + + List execute = userCollection + .aggregate( + List.of(Document.parse("{'$listSearchIndexes': { 'name' : '%s'}}".formatted("embedding.vector_cos")))) + .into(new ArrayList<>()); + for (Document doc : execute) { + if (doc.getString("name").equals("embedding.vector_cos")) { + return doc.getString("status").equals("READY"); + } + } + return false; + }); + + Thread.sleep(250); // just wait a little or the index will be broken } @BeforeEach - void beforeEach() { + void beforeEach() throws InterruptedException { MongoTestUtils.flushCollection(DB_NAME, "user", client); initUsers(); @@ -749,10 +789,80 @@ void testNearReturningGeoPage() { assertThat(page2.hasNext()).isFalse(); } + @Test + void vectorSearchFromAnnotation() throws InterruptedException { + + Thread.sleep(1000); // srly - reindex for vector search + + Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); + SearchResults results = fragment.annotatedVectorSearch("Skywalker", vector, Score.of(0.99), Limit.of(10)); + + assertThat(results).hasSize(1); + } + + @Test + void vectorSearchWithDerivedQuery() throws InterruptedException { + + Thread.sleep(1000); // srly - reindex for vector search + + Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); + SearchResults results = fragment.searchCosineByLastnameAndEmbeddingNear("Skywalker", vector, Score.of(0.98), + Limit.of(10)); + + assertThat(results).hasSize(1); + } + + @Test + void vectorSearchReturningResultsAsList() throws InterruptedException { + + Thread.sleep(1000); // srly - reindex for vector search + + Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); + List results = fragment.searchAsListByLastnameAndEmbeddingNear("Skywalker", vector, Limit.of(10)); + + assertThat(results).hasSize(2); + } + + @Test + void vectorSearchWithLimitFromAnnotation() throws InterruptedException { + + Thread.sleep(1000); // srly - reindex for vector search + + Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); + SearchResults results = fragment.searchByLastnameAndEmbeddingWithin("Skywalker", vector, + Similarity.between(0.4, 0.99)); + + assertThat(results).hasSize(1); + } + + @Test + void vectorSearchWithSorting() throws InterruptedException { + + Thread.sleep(1000); // srly - reindex for vector search + + Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); + SearchResults results = fragment.searchByLastnameAndEmbeddingWithinOrderByFirstname("Skywalker", vector, + Similarity.between(0.4, 1.0)); + + assertThat(results).hasSize(2); + } + + @Test + void vectorSearchWithLimitFromDerivedQuery() throws InterruptedException { + + Thread.sleep(1000); // srly - reindex for vector search + + Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); + SearchResults results = fragment.searchTop1ByLastnameAndEmbeddingWithin("Skywalker", vector, + Similarity.between(0.4, 1.0)); + + assertThat(results).hasSize(1); + } + /** * GeoResults results = repository.findPersonByLocationNear(new Point(-73.99, 40.73), range); */ - private static void initUsers() { + private static void initUsers() throws InterruptedException { Document luke = Document.parse(""" { @@ -772,6 +882,7 @@ private static void initUsers() { } } ], + "embedding" : [1.00000, 1.12345, 2.23456, 3.34567, 4.45678], "_class": "example.springdata.aot.User" }"""); @@ -787,6 +898,7 @@ private static void initUsers() { "x" : -73.99171, "y" : 40.738868 } }, + "embedding" : [1.0001, 2.12345, 3.23456, 4.34567, 5.45678], "_class": "example.springdata.aot.User" }"""); @@ -804,6 +916,7 @@ private static void initUsers() { } } ], + "embedding" : [2.0002, 3.12345, 4.23456, 5.34567, 6.45678], "_class": "example.springdata.aot.User" }"""); @@ -814,6 +927,7 @@ private static void initUsers() { "lastSeen" : { "$date": "2025-01-01T00:00:00.000Z" }, + "embedding" : [3.0003, 4.12345, 5.23456, 6.34567, 7.45678], "_class": "example.springdata.aot.User" }"""); @@ -836,7 +950,8 @@ private static void initUsers() { "$date": "2025-01-15T13:53:33.855Z" } } - ] + ], + "embedding" : [4.0004, 5.12345, 6.23456, 7.34567, 8.45678] }"""); Document vader = Document.parse(""" @@ -859,7 +974,8 @@ private static void initUsers() { "$date": "2025-01-15T13:46:33.855Z" } } - ] + ], + "embedding" : [5.0005, 6.12345, 7.23456, 8.34567, 9.45678] }"""); Document kylo = Document.parse(""" @@ -867,7 +983,8 @@ private static void initUsers() { "_id": "id-7", "username": "kylo", "first_name": "Ben", - "last_name": "Solo" + "last_name": "Solo", + "embedding" : [6.0006, 7.12345, 8.23456, 9.34567, 10.45678] } """); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java index d3d0f40e41..d8de601d4e 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java @@ -29,8 +29,13 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.springframework.data.domain.Limit; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Similarity; +import org.springframework.data.domain.Vector; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; import org.springframework.data.geo.Distance; @@ -43,6 +48,7 @@ import org.springframework.data.mongodb.core.geo.Sphere; import org.springframework.data.mongodb.repository.Hint; import org.springframework.data.mongodb.repository.ReadPreference; +import org.springframework.data.mongodb.repository.VectorSearch; import org.springframework.data.repository.Repository; import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; import org.springframework.data.repository.aot.generate.AotRepositoryFragmentMetadata; @@ -65,7 +71,7 @@ void rendersQueryForNearUsingPoint() throws NoSuchMethodException { assertThat(methodSpec.toString()) // .contains("{'location.coordinates':{'$near':?0}}") // - .contains("Object[]{ location }") // + .contains("arguments(location)") // .contains("return finder.matching(filterQuery).all()"); } @@ -124,7 +130,7 @@ void rendersQueryForWithinUsingGeoJsonPolygon() throws NoSuchMethodException { assertThat(methodSpec.toString()) // .contains("{'location.coordinates':{'$geoWithin':{'$geometry':?0}}") // - .contains("Object[]{ polygon }") // + .contains("arguments(polygon)") // .contains("return finder.matching(filterQuery).all()"); } @@ -182,7 +188,7 @@ void rendersNearQueryWithFilterForGeoResults() throws NoSuchMethodException { assertThat(methodSpec.toString()) // .contains("NearQuery.near(point)") // .contains("nearQuery.maxDistance(maxDistance).in(maxDistance.getMetric())") // - .contains("filterQuery = createQuery(\"{'lastname':?0}\", new java.lang.Object[]{ lastname })") // + .contains("filterQuery = createQuery(\"{'lastname':?0}\", arguments(lastname))") // .contains("nearQuery.query(filterQuery)") // .contains(".near(nearQuery)") // .contains("return nearFinder.all()"); @@ -194,8 +200,7 @@ void rendersExpressionUsingParameterIndex() throws NoSuchMethodException { MethodSpec methodSpec = codeOf(UserRepository.class, "findWithExpressionUsingParameterIndex", String.class); assertThat(methodSpec.toString()) // - .contains("createQuery(\"{ firstname : ?#{[0]} }\"") // - .contains("Map.of(\"firstname\", firstname)"); + .contains("createQuery(\"{ firstname : ?#{[0]} }\", argumentMap(\"firstname\", firstname))"); } @Test // GH-5006 @@ -204,8 +209,7 @@ void rendersExpressionUsingParameterName() throws NoSuchMethodException { MethodSpec methodSpec = codeOf(UserRepository.class, "findWithExpressionUsingParameterName", String.class); assertThat(methodSpec.toString()) // - .contains("createQuery(\"{ firstname : :#{#firstname} }\"") // - .contains("Map.of(\"firstname\", firstname)"); + .contains("createQuery(\"{ firstname : :#{#firstname} }\", argumentMap(\"firstname\", firstname))"); } @Test // GH-4939 @@ -214,8 +218,7 @@ void rendersRegexCriteria() throws NoSuchMethodException { MethodSpec methodSpec = codeOf(UserRepository.class, "findByFirstnameRegex", Pattern.class); assertThat(methodSpec.toString()) // - .contains("createQuery(\"{'firstname':{'$regex':?0}}\"") // - .contains("Object[]{ pattern }"); + .contains("createQuery(\"{'firstname':{'$regex':?0}}\", arguments(pattern))"); } @Test // GH-4939 @@ -233,7 +236,7 @@ void rendersCollation() throws NoSuchMethodException { MethodSpec methodSpec = codeOf(UserRepoWithMeta.class, "findByFirstname", String.class); assertThat(methodSpec.toString()) // - .containsPattern(".*\\.collation\\(.*Collation\\.parse\\(\"en_US\"\\)\\)"); + .containsSubsequence(".collation(", "Collation.parse(\"en_US\"))"); } @Test // GH-4939 @@ -243,7 +246,122 @@ void rendersCollationFromExpression() throws NoSuchMethodException { assertThat(methodSpec.toString()) // .containsIgnoringWhitespaces( - "collationOf(evaluate(\"?#{[1]}\", java.util.Map.of(\"firstname\", firstname, \"locale\", locale)))"); + "collationOf(evaluate(\"?#{[1]}\", argumentMap(\"firstname\", firstname, \"locale\", locale)))"); + } + + @Test + void rendersVectorSearchFilterFromAnnotatedQuery() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "annotatedVectorSearch", String.class, Vector.class, + Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .containsSubsequence("$vectorSearch =", + "Aggregation.vectorSearch(\"embedding.vector_cos\").path(\"embedding\").vector(vector).limit(limit);") + .contains("filter = createQuery(\"{lastname: ?0}\", arguments(lastname, distance))") + .contains("$vectorSearch.filter(filter.getQueryObject())"); + } + + @Test + void rendersVectorSearchNumCandidatesExpression() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "annotatedVectorSearch", String.class, Vector.class, + Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .containsSubsequence("$vectorSearch.numCandidates", + "evaluate(\"#{10+10}\", argumentMap(\"lastname\", lastname, \"distance\", distance)))"); + } + + @Test + void rendersVectorSearchScoringFunctionFromScore() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "annotatedVectorSearch", String.class, Vector.class, + Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .contains("ScoringFunction scoringFunction = distance.getFunction()"); + } + + @Test + void rendersVectorSearchSearchTypeFromAnnotation() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "annotatedVectorSearch", String.class, Vector.class, + Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .containsSubsequence("$vectorSearch.searchType(", "VectorSearchOperation.SearchType.ANN)"); + } + + @Test + void rendersVectorSearchQueryFromMethodName() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "searchCosineByLastnameAndEmbeddingNear", String.class, + Vector.class, Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .contains("filter = createQuery(\"{'lastname':?0}\", arguments(lastname, similarity))"); + } + + @Test + void rendersVectorSearchNumCandidatesFromLimitIfNotExplicitlyDefined() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "searchCosineByLastnameAndEmbeddingNear", String.class, + Vector.class, Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .contains("$vectorSearch.numCandidates(limit.max() * 20)"); + } + + @Test + void rendersVectorSearchLimitFromAnnotation() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "searchByLastnameAndEmbeddingWithin", String.class, + Vector.class, Range.class); + + assertThat(methodSpec.toString()) // + .contains("Aggregation.vectorSearch(\"embedding.vector_cos\").path(\"embedding\").vector(vector).limit(10)") + .contains("$vectorSearch.numCandidates(10 * 20)"); + } + + @Test + void rendersVectorSearchLimitFromExpression() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepoWithMeta.class, + "searchWithLimitAsExpressionByLastnameAndEmbeddingWithinOrderByFirstname", String.class, Vector.class, + Range.class); + + assertThat(methodSpec.toString()) // + .containsSubsequence( + "Aggregation.vectorSearch(\"embedding.vector_cos\").path(\"embedding\").vector(vector).limit(", + "evaluate(\"#{5+5}\", argumentMap(\"lastname\", lastname, \"distance\", distance)") + .containsSubsequence("$vectorSearch.numCandidates(", + "evaluate(\"#{5+5}\", argumentMap(\"lastname\", lastname, \"distance\", distance))) * 20)"); + } + + @Test + void rendersVectorSearchOrderByScoreAsDefault() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "searchCosineByLastnameAndEmbeddingNear", String.class, + Vector.class, Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .contains("$vectorSearch.withSearchScore(\"__score__\")") + .containsSubsequence("$sort = ", "Aggregation.sort(", "DESC, \"__score__\")") + .containsSubsequence("AggregationPipeline(", "List.of($vectorSearch, $sort))"); + } + + @Test + void rendersVectorSearchOrderByWithScoreLast() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "searchByLastnameAndEmbeddingWithinOrderByFirstname", + String.class, Vector.class, Range.class); + + assertThat(methodSpec.toString()) // + .containsSubsequence("AggregationOperation $sort = (_ctx) -> {", // + "_mappedSort = _ctx.getMappedObject(", // + "Document.parse(\"{'firstname':{'$numberInt':'1'}}\")", // + "Document(\"$sort\", _mappedSort.append(\"__score__\", -1))"); } private static MethodSpec codeOf(Class repository, String methodName, Class... args) @@ -287,5 +405,9 @@ interface UserRepoWithMeta extends Repository { @ReadPreference("NEAREST") GeoResults findByLocationCoordinatesNear(Point point, Distance maxDistance); + + @VectorSearch(indexName = "embedding.vector_cos", limit = "#{5+5}") + SearchResults searchWithLimitAsExpressionByLastnameAndEmbeddingWithinOrderByFirstname(String lastname, + Vector vector, Range distance); } } From 4e58d13ae651b9e2b16c257bdd7cab0a37f4124d Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Thu, 26 Jun 2025 09:55:53 +0200 Subject: [PATCH 23/24] Smart wait for search index to become available --- .../aot/MongoRepositoryContributorTests.java | 58 +++++++++++++------ 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java index e40f0cd53b..995da14e74 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java @@ -78,6 +78,7 @@ import com.mongodb.client.MongoClient; import com.mongodb.client.MongoClients; import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoCursor; import com.mongodb.client.model.IndexOptions; import com.mongodb.client.model.SearchIndexModel; import com.mongodb.client.model.SearchIndexType; @@ -120,28 +121,49 @@ static void beforeAll() throws InterruptedException { userCollection.createIndex(new Document("location.coordinates", "2d"), new IndexOptions()); userCollection.createIndex(new Document("location.coordinates", "2dsphere"), new IndexOptions()); + Thread.sleep(250); // just wait a little or the index will be broken + } + + /** + * Create the vector search index and wait till it is queryable and actually serving data. Since this may slow down + * tests quite a bit, better call it only when needed to run certain tests. + */ + private static void initializeVectorIndex() { + + String indexName = "embedding.vector_cos"; + Document searchIndex = new Document("fields", List.of(new Document("type", "vector").append("path", "embedding").append("numDimensions", 5) .append("similarity", "cosine"), new Document("type", "filter").append("path", "last_name"))); - userCollection.createSearchIndexes(List.of( - new SearchIndexModel("embedding.vector_cos", searchIndex, SearchIndexType.of(new BsonString("vectorSearch"))))); + MongoCollection userCollection = client.getDatabase(DB_NAME).getCollection(COLLECTION_NAME); + userCollection.createSearchIndexes( + List.of(new SearchIndexModel(indexName, searchIndex, SearchIndexType.of(new BsonString("vectorSearch"))))); + // wait for search index to be queryable Awaitility.await().atMost(Duration.ofSeconds(120)).pollInterval(Duration.ofMillis(200)).until(() -> { List execute = userCollection - .aggregate( - List.of(Document.parse("{'$listSearchIndexes': { 'name' : '%s'}}".formatted("embedding.vector_cos")))) + .aggregate(List.of(Document.parse("{'$listSearchIndexes': { 'name' : '%s'}}".formatted(indexName)))) .into(new ArrayList<>()); for (Document doc : execute) { - if (doc.getString("name").equals("embedding.vector_cos")) { + if (doc.getString("name").equals(indexName)) { return doc.getString("status").equals("READY"); } } return false; }); - Thread.sleep(250); // just wait a little or the index will be broken + Document $vectorSearch = new Document("$vectorSearch", + new Document("index", indexName).append("limit", 1).append("numCandidates", 20).append("path", "embedding") + .append("queryVector", List.of(1.0, 1.12345, 2.23456, 3.34567, 4.45678))); + + // wait for search index to serve data + Awaitility.await().atMost(Duration.ofSeconds(120)).pollInterval(Duration.ofMillis(200)).until(() -> { + try (MongoCursor cursor = userCollection.aggregate(List.of($vectorSearch)).iterator()) { + return cursor.hasNext(); + } + }); } @BeforeEach @@ -790,9 +812,9 @@ void testNearReturningGeoPage() { } @Test - void vectorSearchFromAnnotation() throws InterruptedException { + void vectorSearchFromAnnotation() { - Thread.sleep(1000); // srly - reindex for vector search + initializeVectorIndex(); Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); SearchResults results = fragment.annotatedVectorSearch("Skywalker", vector, Score.of(0.99), Limit.of(10)); @@ -801,9 +823,9 @@ void vectorSearchFromAnnotation() throws InterruptedException { } @Test - void vectorSearchWithDerivedQuery() throws InterruptedException { + void vectorSearchWithDerivedQuery() { - Thread.sleep(1000); // srly - reindex for vector search + initializeVectorIndex(); Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); SearchResults results = fragment.searchCosineByLastnameAndEmbeddingNear("Skywalker", vector, Score.of(0.98), @@ -813,9 +835,9 @@ void vectorSearchWithDerivedQuery() throws InterruptedException { } @Test - void vectorSearchReturningResultsAsList() throws InterruptedException { + void vectorSearchReturningResultsAsList() { - Thread.sleep(1000); // srly - reindex for vector search + initializeVectorIndex(); Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); List results = fragment.searchAsListByLastnameAndEmbeddingNear("Skywalker", vector, Limit.of(10)); @@ -824,9 +846,9 @@ void vectorSearchReturningResultsAsList() throws InterruptedException { } @Test - void vectorSearchWithLimitFromAnnotation() throws InterruptedException { + void vectorSearchWithLimitFromAnnotation() { - Thread.sleep(1000); // srly - reindex for vector search + initializeVectorIndex(); Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); SearchResults results = fragment.searchByLastnameAndEmbeddingWithin("Skywalker", vector, @@ -836,9 +858,9 @@ void vectorSearchWithLimitFromAnnotation() throws InterruptedException { } @Test - void vectorSearchWithSorting() throws InterruptedException { + void vectorSearchWithSorting() { - Thread.sleep(1000); // srly - reindex for vector search + initializeVectorIndex(); Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); SearchResults results = fragment.searchByLastnameAndEmbeddingWithinOrderByFirstname("Skywalker", vector, @@ -848,9 +870,9 @@ void vectorSearchWithSorting() throws InterruptedException { } @Test - void vectorSearchWithLimitFromDerivedQuery() throws InterruptedException { + void vectorSearchWithLimitFromDerivedQuery() { - Thread.sleep(1000); // srly - reindex for vector search + initializeVectorIndex(); Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); SearchResults results = fragment.searchTop1ByLastnameAndEmbeddingWithin("Skywalker", vector, From cdb46ca1c74b2508f1cd0cf27a0fffa022fc16a6 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Thu, 26 Jun 2025 11:36:03 +0200 Subject: [PATCH 24/24] refactor a bit --- .../repository/aot/AggregationBlocks.java | 563 +++++++++--------- .../repository/aot/BuilderStyleSnippet.java | 44 ++ .../mongodb/repository/aot/DeleteBlocks.java | 127 ++-- .../repository/aot/ExpressionSnippet.java | 53 ++ .../mongodb/repository/aot/GeoBlocks.java | 63 +- .../repository/aot/MongoCodeBlocks.java | 19 + .../aot/MongoRepositoryContributor.java | 10 +- .../mongodb/repository/aot/QueryBlocks.java | 47 +- .../repository/aot/SearchInteraction.java | 91 ++- .../data/mongodb/repository/aot/Snippet.java | 224 +++++++ .../mongodb/repository/aot/UpdateBlocks.java | 23 +- .../repository/aot/VariableSnippet.java | 119 ++++ .../repository/aot/VectorSearchBocks.java | 221 ++++--- .../aot/QueryMethodContributionUnitTests.java | 11 +- 14 files changed, 1110 insertions(+), 505 deletions(-) create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/BuilderStyleSnippet.java create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/ExpressionSnippet.java create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/Snippet.java create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VariableSnippet.java diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java index 37f24cd849..e40bd518ce 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java @@ -19,11 +19,11 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import java.util.stream.Stream; import org.bson.Document; import org.jspecify.annotations.NullUnmarked; +import org.springframework.core.ResolvableType; import org.springframework.core.annotation.MergedAnnotation; import org.springframework.data.domain.SliceImpl; import org.springframework.data.domain.Sort.Order; @@ -52,312 +52,323 @@ */ class AggregationBlocks { - @NullUnmarked - static class AggregationExecutionCodeBlockBuilder { - - private final AotQueryMethodGenerationContext context; - private final MongoQueryMethod queryMethod; - private String aggregationVariableName; + @NullUnmarked + static class AggregationExecutionCodeBlockBuilder { - AggregationExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private String aggregationVariableName; - this.context = context; - this.queryMethod = queryMethod; - } + AggregationExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - AggregationExecutionCodeBlockBuilder referencing(String aggregationVariableName) { + this.context = context; + this.queryMethod = queryMethod; + } - this.aggregationVariableName = aggregationVariableName; - return this; - } + AggregationExecutionCodeBlockBuilder referencing(String aggregationVariableName) { - CodeBlock build() { + this.aggregationVariableName = aggregationVariableName; + return this; + } - String mongoOpsRef = context.fieldNameOf(MongoOperations.class); - Builder builder = CodeBlock.builder(); + CodeBlock build() { - builder.add("\n"); + String mongoOpsRef = context.fieldNameOf(MongoOperations.class); + Builder builder = CodeBlock.builder(); - Class outputType = queryMethod.getReturnedObjectType(); - if (MongoSimpleTypes.HOLDER.isSimpleType(outputType)) { - outputType = Document.class; - } else if (ClassUtils.isAssignable(AggregationResults.class, outputType)) { - outputType = queryMethod.getReturnType().getComponentType().getType(); - } + builder.add("\n"); - if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) { - builder.addStatement("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); - return builder.build(); - } + Class outputType = queryMethod.getReturnedObjectType(); + if (MongoSimpleTypes.HOLDER.isSimpleType(outputType)) { + outputType = Document.class; + } else if (ClassUtils.isAssignable(AggregationResults.class, outputType)) { + outputType = queryMethod.getReturnType().getComponentType().getType(); + } - if (ClassUtils.isAssignable(AggregationResults.class, context.getMethod().getReturnType())) { - builder.addStatement("return $L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); - return builder.build(); - } + if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) { + builder.addStatement("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); + return builder.build(); + } - if (outputType == Document.class) { + if (ClassUtils.isAssignable(AggregationResults.class, context.getMethod().getReturnType())) { + builder.addStatement("return $L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); + return builder.build(); + } - Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); + if (outputType == Document.class) { - if (queryMethod.isStreamQuery()) { + Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); - builder.addStatement("$T<$T> $L = $L.aggregateStream($L, $T.class)", Stream.class, Document.class, - context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); + if (queryMethod.isStreamQuery()) { - builder.addStatement("return $1L.map(it -> ($2T) convertSimpleRawResult($2T.class, it))", - context.localVariable("results"), returnType); - } else { + VariableSnippet results = Snippet.declare(builder) + .variable(ResolvableType.forClassWithGenerics(Stream.class, Document.class), + context.localVariable("results")) + .as("$L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); - builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class, - context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); + builder.addStatement("return $1L.map(it -> ($2T) convertSimpleRawResult($2T.class, it))", + results.getVariableName(), returnType); + } else { - if (!queryMethod.isCollectionQuery()) { - builder.addStatement( - "return $1T.<$2T>firstElement(convertSimpleRawResults($2T.class, $3L.getMappedResults()))", - CollectionUtils.class, returnType, context.localVariable("results")); - } else { - builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType, - context.localVariable("results")); - } - } - } else { - if (queryMethod.isSliceQuery()) { - builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class, - context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); - builder.addStatement("boolean $L = $L.getMappedResults().size() > $L.getPageSize()", - context.localVariable("hasNext"), context.localVariable("results"), context.getPageableParameterName()); - builder.addStatement( - "return new $1T<>($2L ? $3L.getMappedResults().subList(0, $4L.getPageSize()) : $3L.getMappedResults(), $4L, $2L)", - SliceImpl.class, context.localVariable("hasNext"), context.localVariable("results"), - context.getPageableParameterName()); - } else { + VariableSnippet results = Snippet.declare(builder) + .variable(AggregationResults.class, context.localVariable("results")) + .as("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); - if (queryMethod.isStreamQuery()) { - builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName, - outputType); - } else { + if (!queryMethod.isCollectionQuery()) { + builder.addStatement( + "return $1T.<$2T>firstElement(convertSimpleRawResults($2T.class, $3L.getMappedResults()))", + CollectionUtils.class, returnType, results.getVariableName()); + } else { + builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType, + results.getVariableName()); + } + } + } else { + if (queryMethod.isSliceQuery()) { - builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef, - aggregationVariableName, outputType); - } - } - } + VariableSnippet results = Snippet.declare(builder) + .variable(AggregationResults.class, context.localVariable("results")) + .as("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); - return builder.build(); - } - } + VariableSnippet hasNext = Snippet.declare(builder).variable("hasNext").as( + "$L.getMappedResults().size() > $L.getPageSize()", results.getVariableName(), + context.getPageableParameterName()); - @NullUnmarked - static class AggregationCodeBlockBuilder { + builder.addStatement( + "return new $1T<>($2L ? $3L.getMappedResults().subList(0, $4L.getPageSize()) : $3L.getMappedResults(), $4L, $2L)", + SliceImpl.class, hasNext.getVariableName(), results.getVariableName(), + context.getPageableParameterName()); + } else { - private final AotQueryMethodGenerationContext context; - private final MongoQueryMethod queryMethod; - private final Map arguments; + if (queryMethod.isStreamQuery()) { + builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName, + outputType); + } else { - private AggregationInteraction source; + builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef, + aggregationVariableName, outputType); + } + } + } - private String aggregationVariableName; - private boolean pipelineOnly; + return builder.build(); + } + } - AggregationCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + @NullUnmarked + static class AggregationCodeBlockBuilder { - this.context = context; - this.arguments = new LinkedHashMap<>(); - context.getBindableParameterNames().forEach(it -> arguments.put(it, CodeBlock.of(it))); - this.queryMethod = queryMethod; - } + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private final Map arguments; - AggregationCodeBlockBuilder stages(AggregationInteraction aggregation) { - - this.source = aggregation; - return this; - } + private AggregationInteraction source; - AggregationCodeBlockBuilder usingAggregationVariableName(String aggregationVariableName) { - - this.aggregationVariableName = aggregationVariableName; - return this; - } - - AggregationCodeBlockBuilder pipelineOnly(boolean pipelineOnly) { - - this.pipelineOnly = pipelineOnly; - return this; - } - - CodeBlock build() { - - Builder builder = CodeBlock.builder(); - builder.add("\n"); - - String pipelineName = context.localVariable(aggregationVariableName + (pipelineOnly ? "" : "Pipeline")); - builder.add(pipeline(pipelineName)); - - if (!pipelineOnly) { - - builder.addStatement("$1T<$2T> $3L = $4T.newAggregation($2T.class, $5L.getOperations())", - TypedAggregation.class, context.getRepositoryInformation().getDomainType(), aggregationVariableName, - Aggregation.class, pipelineName); - - builder.add(aggregationOptions(aggregationVariableName)); - } - - return builder.build(); - } - - private CodeBlock pipeline(String pipelineVariableName) { - - String sortParameter = context.getSortParameterName(); - String limitParameter = context.getLimitParameterName(); - String pageableParameter = context.getPageableParameterName(); - - boolean mightBeSorted = StringUtils.hasText(sortParameter); - boolean mightBeLimited = StringUtils.hasText(limitParameter); - boolean mightBePaged = StringUtils.hasText(pageableParameter); - - int stageCount = source.stages().size(); - if (mightBeSorted) { - stageCount++; - } - if (mightBeLimited) { - stageCount++; - } - if (mightBePaged) { - stageCount += 3; - } - - Builder builder = CodeBlock.builder(); - builder.add(aggregationStages(context.localVariable("stages"), source.stages(), stageCount, arguments)); + private String aggregationVariableName; + private boolean pipelineOnly; - if (mightBeSorted) { - builder.add(sortingStage(sortParameter)); - } + AggregationCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - if (mightBeLimited) { - builder.add(limitingStage(limitParameter)); - } + this.context = context; + this.arguments = new LinkedHashMap<>(); + context.getBindableParameterNames().forEach(it -> arguments.put(it, CodeBlock.of(it))); + this.queryMethod = queryMethod; + } - if (mightBePaged) { - builder.add(pagingStage(pageableParameter, queryMethod.isSliceQuery())); - } - - builder.addStatement("$T $L = createPipeline($L)", AggregationPipeline.class, pipelineVariableName, - context.localVariable("stages")); - return builder.build(); - } - - private CodeBlock aggregationOptions(String aggregationVariableName) { - - Builder builder = CodeBlock.builder(); - List options = new ArrayList<>(5); - if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) { - options.add(CodeBlock.of(".skipOutput()")); - } + AggregationCodeBlockBuilder stages(AggregationInteraction aggregation) { + + this.source = aggregation; + return this; + } - MergedAnnotation hintAnnotation = context.getAnnotation(Hint.class); - String hint = hintAnnotation.isPresent() ? hintAnnotation.getString("value") : null; - if (StringUtils.hasText(hint)) { - options.add(CodeBlock.of(".hint($S)", hint)); - } - - MergedAnnotation readPreferenceAnnotation = context.getAnnotation(ReadPreference.class); - String readPreference = readPreferenceAnnotation.isPresent() ? readPreferenceAnnotation.getString("value") : null; - if (StringUtils.hasText(readPreference)) { - options.add(CodeBlock.of(".readPreference($T.valueOf($S))", com.mongodb.ReadPreference.class, readPreference)); - } + AggregationCodeBlockBuilder usingAggregationVariableName(String aggregationVariableName) { - if (queryMethod.hasAnnotatedCollation()) { - options.add(CodeBlock.of(".collation($T.parse($S))", Collation.class, queryMethod.getAnnotatedCollation())); - } - - if (!options.isEmpty()) { - - Builder optionsBuilder = CodeBlock.builder(); - optionsBuilder.add("$1T $2L = $1T.builder()\n", AggregationOptions.class, - context.localVariable("aggregationOptions")); - optionsBuilder.indent(); - for (CodeBlock optionBlock : options) { - optionsBuilder.add(optionBlock); - optionsBuilder.add("\n"); - } - optionsBuilder.add(".build();\n"); - optionsBuilder.unindent(); - builder.add(optionsBuilder.build()); - - builder.addStatement("$1L = $1L.withOptions($2L)", aggregationVariableName, - context.localVariable("aggregationOptions")); - } - return builder.build(); - } - - private CodeBlock aggregationStages(String stageListVariableName, Iterable stages, int stageCount, - Map arguments) { - - Builder builder = CodeBlock.builder(); - builder.addStatement("$T<$T> $L = new $T($L)", List.class, Object.class, stageListVariableName, ArrayList.class, - stageCount); - int stageCounter = 0; - - for (String stage : stages) { - String stageName = context.localVariable("stage_%s".formatted(stageCounter++)); - builder.add(MongoCodeBlocks.renderExpressionToDocument(stage, stageName, arguments)); - builder.addStatement("$L.add($L)", context.localVariable("stages"), stageName); - } - - return builder.build(); - } - - private CodeBlock sortingStage(String sortProvider) { - - Builder builder = CodeBlock.builder(); - - builder.beginControlFlow("if ($L.isSorted())", sortProvider); - builder.addStatement("$1T $2L = new $1T()", Document.class, context.localVariable("sortDocument")); - builder.beginControlFlow("for ($T $L : $L)", Order.class, context.localVariable("order"), sortProvider); - builder.addStatement("$1L.append($2L.getProperty(), $2L.isAscending() ? 1 : -1);", - context.localVariable("sortDocument"), context.localVariable("order")); - builder.endControlFlow(); - builder.addStatement("stages.add(new $T($S, $L))", Document.class, "$sort", - context.localVariable("sortDocument")); - builder.endControlFlow(); - - return builder.build(); - } - - private CodeBlock pagingStage(String pageableProvider, boolean slice) { - - Builder builder = CodeBlock.builder(); - - builder.add(sortingStage(pageableProvider + ".getSort()")); - - builder.beginControlFlow("if ($L.isPaged())", pageableProvider); - builder.beginControlFlow("if ($L.getOffset() > 0)", pageableProvider); - builder.addStatement("$L.add($T.skip($L.getOffset()))", context.localVariable("stages"), Aggregation.class, - pageableProvider); - builder.endControlFlow(); - if (slice) { - builder.addStatement("$L.add($T.limit($L.getPageSize() + 1))", context.localVariable("stages"), - Aggregation.class, pageableProvider); - } else { - builder.addStatement("$L.add($T.limit($L.getPageSize()))", context.localVariable("stages"), Aggregation.class, - pageableProvider); - } - builder.endControlFlow(); - - return builder.build(); - } - - private CodeBlock limitingStage(String limitProvider) { - - Builder builder = CodeBlock.builder(); - - builder.beginControlFlow("if ($L.isLimited())", limitProvider); - builder.addStatement("$L.add($T.limit($L.max()))", context.localVariable("stages"), Aggregation.class, - limitProvider); - builder.endControlFlow(); - - return builder.build(); - } - - } + this.aggregationVariableName = aggregationVariableName; + return this; + } + + AggregationCodeBlockBuilder pipelineOnly(boolean pipelineOnly) { + + this.pipelineOnly = pipelineOnly; + return this; + } + + CodeBlock build() { + + Builder builder = CodeBlock.builder(); + builder.add("\n"); + + String pipelineName = context.localVariable(aggregationVariableName + (pipelineOnly ? "" : "Pipeline")); + builder.add(pipeline(pipelineName)); + + if (!pipelineOnly) { + + Class domainType = context.getRepositoryInformation().getDomainType(); + Snippet.declare(builder) + .variable(ResolvableType.forClassWithGenerics(TypedAggregation.class, domainType), aggregationVariableName) + .as("$T.newAggregation($T.class, $L.getOperations())", Aggregation.class, domainType, pipelineName); + + builder.add(aggregationOptions(aggregationVariableName)); + } + + return builder.build(); + } + + private CodeBlock pipeline(String pipelineVariableName) { + + String sortParameter = context.getSortParameterName(); + String limitParameter = context.getLimitParameterName(); + String pageableParameter = context.getPageableParameterName(); + + boolean mightBeSorted = StringUtils.hasText(sortParameter); + boolean mightBeLimited = StringUtils.hasText(limitParameter); + boolean mightBePaged = StringUtils.hasText(pageableParameter); + + int stageCount = source.stages().size(); + if (mightBeSorted) { + stageCount++; + } + if (mightBeLimited) { + stageCount++; + } + if (mightBePaged) { + stageCount += 3; + } + + Builder builder = CodeBlock.builder(); + builder.add(aggregationStages(context.localVariable("stages"), source.stages(), stageCount, arguments)); + + if (mightBeSorted) { + builder.add(sortingStage(sortParameter)); + } + + if (mightBeLimited) { + builder.add(limitingStage(limitParameter)); + } + + if (mightBePaged) { + builder.add(pagingStage(pageableParameter, queryMethod.isSliceQuery())); + } + + builder.addStatement("$T $L = createPipeline($L)", AggregationPipeline.class, pipelineVariableName, + context.localVariable("stages")); + return builder.build(); + } + + private CodeBlock aggregationOptions(String aggregationVariableName) { + + Builder builder = CodeBlock.builder(); + List options = new ArrayList<>(5); + if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) { + options.add(CodeBlock.of(".skipOutput()")); + } + + MergedAnnotation hintAnnotation = context.getAnnotation(Hint.class); + String hint = hintAnnotation.isPresent() ? hintAnnotation.getString("value") : null; + if (StringUtils.hasText(hint)) { + options.add(CodeBlock.of(".hint($S)", hint)); + } + + MergedAnnotation readPreferenceAnnotation = context.getAnnotation(ReadPreference.class); + String readPreference = readPreferenceAnnotation.isPresent() ? readPreferenceAnnotation.getString("value") : null; + if (StringUtils.hasText(readPreference)) { + options.add(CodeBlock.of(".readPreference($T.valueOf($S))", com.mongodb.ReadPreference.class, readPreference)); + } + + if (queryMethod.hasAnnotatedCollation()) { + options.add(CodeBlock.of(".collation($T.parse($S))", Collation.class, queryMethod.getAnnotatedCollation())); + } + + if (!options.isEmpty()) { + + Builder optionsBuilder = CodeBlock.builder(); + optionsBuilder.add("$1T $2L = $1T.builder()\n", AggregationOptions.class, + context.localVariable("aggregationOptions")); + optionsBuilder.indent(); + for (CodeBlock optionBlock : options) { + optionsBuilder.add(optionBlock); + optionsBuilder.add("\n"); + } + optionsBuilder.add(".build();\n"); + optionsBuilder.unindent(); + builder.add(optionsBuilder.build()); + + builder.addStatement("$1L = $1L.withOptions($2L)", aggregationVariableName, + context.localVariable("aggregationOptions")); + } + return builder.build(); + } + + private CodeBlock aggregationStages(String stageListVariableName, Iterable stages, int stageCount, + Map arguments) { + + Builder builder = CodeBlock.builder(); + builder.addStatement("$T<$T> $L = new $T($L)", List.class, Object.class, stageListVariableName, ArrayList.class, + stageCount); + int stageCounter = 0; + + for (String stage : stages) { + + VariableSnippet stageSnippet = Snippet.declare(builder) + .variable(Document.class, context.localVariable("stage_%s".formatted(stageCounter++))) + .of(MongoCodeBlocks.asDocument(stage, arguments)); + builder.addStatement("$L.add($L)", stageListVariableName, stageSnippet.getVariableName()); + } + + return builder.build(); + } + + private CodeBlock sortingStage(String sortProvider) { + + Builder builder = CodeBlock.builder(); + + builder.beginControlFlow("if ($L.isSorted())", sortProvider); + builder.addStatement("$1T $2L = new $1T()", Document.class, context.localVariable("sortDocument")); + builder.beginControlFlow("for ($T $L : $L)", Order.class, context.localVariable("order"), sortProvider); + builder.addStatement("$1L.append($2L.getProperty(), $2L.isAscending() ? 1 : -1);", + context.localVariable("sortDocument"), context.localVariable("order")); + builder.endControlFlow(); + builder.addStatement("stages.add(new $T($S, $L))", Document.class, "$sort", + context.localVariable("sortDocument")); + builder.endControlFlow(); + + return builder.build(); + } + + private CodeBlock pagingStage(String pageableProvider, boolean slice) { + + Builder builder = CodeBlock.builder(); + + builder.add(sortingStage(pageableProvider + ".getSort()")); + + builder.beginControlFlow("if ($L.isPaged())", pageableProvider); + builder.beginControlFlow("if ($L.getOffset() > 0)", pageableProvider); + builder.addStatement("$L.add($T.skip($L.getOffset()))", context.localVariable("stages"), Aggregation.class, + pageableProvider); + builder.endControlFlow(); + if (slice) { + builder.addStatement("$L.add($T.limit($L.getPageSize() + 1))", context.localVariable("stages"), + Aggregation.class, pageableProvider); + } else { + builder.addStatement("$L.add($T.limit($L.getPageSize()))", context.localVariable("stages"), Aggregation.class, + pageableProvider); + } + builder.endControlFlow(); + + return builder.build(); + } + + private CodeBlock limitingStage(String limitProvider) { + + Builder builder = CodeBlock.builder(); + + builder.beginControlFlow("if ($L.isLimited())", limitProvider); + builder.addStatement("$L.add($T.limit($L.max()))", context.localVariable("stages"), Aggregation.class, + limitProvider); + builder.endControlFlow(); + + return builder.build(); + } + + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/BuilderStyleSnippet.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/BuilderStyleSnippet.java new file mode 100644 index 0000000000..42627839e8 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/BuilderStyleSnippet.java @@ -0,0 +1,44 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; + +/** + * @author Christoph Strobl + */ +public class BuilderStyleSnippet implements Snippet { + + private final String targetVariableName; + private final String methodName; + private final Snippet argumentValue; + + BuilderStyleSnippet(String targetVariableName, String methodName, Snippet argumentValue) { + this.targetVariableName = targetVariableName; + this.methodName = methodName; + this.argumentValue = argumentValue; + } + + @Override + public CodeBlock code() { + + Builder builder = CodeBlock.builder(); + builder.add("$1L = $1L.$2L($3L);\n", targetVariableName, methodName, argumentValue.code()); + return builder.build(); + } + +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/DeleteBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/DeleteBlocks.java index 1d009f3085..74f11a1ae8 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/DeleteBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/DeleteBlocks.java @@ -18,6 +18,7 @@ import java.util.Optional; import org.jspecify.annotations.NullUnmarked; +import org.springframework.core.ResolvableType; import org.springframework.data.mongodb.core.ExecutableRemoveOperation.ExecutableRemove; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.repository.query.MongoQueryExecution.DeleteExecution; @@ -35,66 +36,68 @@ */ class DeleteBlocks { - @NullUnmarked - static class DeleteExecutionCodeBlockBuilder { - - private final AotQueryMethodGenerationContext context; - private final MongoQueryMethod queryMethod; - private String queryVariableName; - - DeleteExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - - this.context = context; - this.queryMethod = queryMethod; - } - - DeleteExecutionCodeBlockBuilder referencing(String queryVariableName) { - - this.queryVariableName = queryVariableName; - return this; - } - - CodeBlock build() { - - String mongoOpsRef = context.fieldNameOf(MongoOperations.class); - Builder builder = CodeBlock.builder(); - - Class domainType = context.getRepositoryInformation().getDomainType(); - boolean isProjecting = context.getActualReturnType() != null - && !ObjectUtils.nullSafeEquals(TypeName.get(domainType), context.getActualReturnType()); - - Object actualReturnType = isProjecting ? context.getActualReturnType().getType() : domainType; - - builder.add("\n"); - builder.addStatement("$1T<$2T> $3L = $4L.remove($2T.class)", ExecutableRemove.class, domainType, - context.localVariable("remover"), mongoOpsRef); - - DeleteExecution.Type type = DeleteExecution.Type.FIND_AND_REMOVE_ALL; - if (!queryMethod.isCollectionQuery()) { - if (!ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType())) { - type = DeleteExecution.Type.FIND_AND_REMOVE_ONE; - } else { - type = DeleteExecution.Type.ALL; - } - } - - actualReturnType = ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType()) - ? TypeName.get(context.getMethod().getReturnType()) - : queryMethod.isCollectionQuery() ? context.getReturnTypeName() : actualReturnType; - - if (ClassUtils.isVoidType(context.getMethod().getReturnType())) { - builder.addStatement("new $T($L, $T.$L).execute($L)", DeleteExecution.class, context.localVariable("remover"), - DeleteExecution.Type.class, type.name(), queryVariableName); - } else if (context.getMethod().getReturnType() == Optional.class) { - builder.addStatement("return $T.ofNullable(($T) new $T($L, $T.$L).execute($L))", Optional.class, - actualReturnType, DeleteExecution.class, context.localVariable("remover"), DeleteExecution.Type.class, - type.name(), queryVariableName); - } else { - builder.addStatement("return ($T) new $T($L, $T.$L).execute($L)", actualReturnType, DeleteExecution.class, - context.localVariable("remover"), DeleteExecution.Type.class, type.name(), queryVariableName); - } - - return builder.build(); - } - } + @NullUnmarked + static class DeleteExecutionCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private String queryVariableName; + + DeleteExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + this.queryMethod = queryMethod; + } + + DeleteExecutionCodeBlockBuilder referencing(String queryVariableName) { + + this.queryVariableName = queryVariableName; + return this; + } + + CodeBlock build() { + + String mongoOpsRef = context.fieldNameOf(MongoOperations.class); + Builder builder = CodeBlock.builder(); + + Class domainType = context.getRepositoryInformation().getDomainType(); + boolean isProjecting = context.getActualReturnType() != null + && !ObjectUtils.nullSafeEquals(TypeName.get(domainType), context.getActualReturnType()); + + Object actualReturnType = isProjecting ? context.getActualReturnType().getType() : domainType; + + builder.add("\n"); + VariableSnippet remover = Snippet.declare(builder) + .variable(ResolvableType.forClassWithGenerics(ExecutableRemove.class, domainType), + context.localVariable("remover")) + .as("$L.remove($T.class)", mongoOpsRef, domainType); + + DeleteExecution.Type type = DeleteExecution.Type.FIND_AND_REMOVE_ALL; + if (!queryMethod.isCollectionQuery()) { + if (!ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType())) { + type = DeleteExecution.Type.FIND_AND_REMOVE_ONE; + } else { + type = DeleteExecution.Type.ALL; + } + } + + actualReturnType = ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType()) + ? TypeName.get(context.getMethod().getReturnType()) + : queryMethod.isCollectionQuery() ? context.getReturnTypeName() : actualReturnType; + + if (ClassUtils.isVoidType(context.getMethod().getReturnType())) { + builder.addStatement("new $T($L, $T.$L).execute($L)", DeleteExecution.class, remover.getVariableName(), + DeleteExecution.Type.class, type.name(), queryVariableName); + } else if (context.getMethod().getReturnType() == Optional.class) { + builder.addStatement("return $T.ofNullable(($T) new $T($L, $T.$L).execute($L))", Optional.class, + actualReturnType, DeleteExecution.class, remover.getVariableName(), DeleteExecution.Type.class, type.name(), + queryVariableName); + } else { + builder.addStatement("return ($T) new $T($L, $T.$L).execute($L)", actualReturnType, DeleteExecution.class, + context.localVariable("remover"), DeleteExecution.Type.class, type.name(), queryVariableName); + } + + return builder.build(); + } + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/ExpressionSnippet.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/ExpressionSnippet.java new file mode 100644 index 0000000000..a803704bef --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/ExpressionSnippet.java @@ -0,0 +1,53 @@ +/* + * Copyright 2025-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import org.springframework.javapoet.CodeBlock; + +/** + * @author Christoph Strobl + * @since 5.0 + */ +class ExpressionSnippet implements Snippet { + + private final CodeBlock block; + private final boolean requiresEvaluation; + + public ExpressionSnippet(CodeBlock block) { + this(block, false); + } + + public ExpressionSnippet(Snippet block) { + this(block.code(), block instanceof ExpressionSnippet eb && eb.requiresEvaluation()); + } + + public ExpressionSnippet(CodeBlock block, boolean requiresEvaluation) { + this.block = block; + this.requiresEvaluation = requiresEvaluation; + } + + public static ExpressionSnippet empty() { + return new ExpressionSnippet(CodeBlock.builder().build()); + } + + public boolean requiresEvaluation() { + return requiresEvaluation; + } + + public CodeBlock code() { + return block; + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java index b94f55adc2..8f2df3e4c4 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java @@ -50,38 +50,37 @@ CodeBlock build() { CodeBlock.Builder builder = CodeBlock.builder(); builder.add("\n"); - String locationParameterName = context.getParameterName(queryMethod.getParameters().getNearIndex()); - - builder.addStatement("$1T $2L = $1T.near($3L)", NearQuery.class, variableName, locationParameterName); + VariableSnippet query = Snippet.declare(builder).variable(NearQuery.class, variableName).as("$T.near($L)", + NearQuery.class, context.getParameterName(queryMethod.getParameters().getNearIndex())); if (queryMethod.getParameters().getRangeIndex() != -1) { - String rangeParametername = context.getParameterName(queryMethod.getParameters().getRangeIndex()); - String minVarName = context.localVariable("min"); - String maxVarName = context.localVariable("max"); + String rangeParameter = context.getParameterName(queryMethod.getParameters().getRangeIndex()); - builder.beginControlFlow("if($L.getLowerBound().isBounded())", rangeParametername); - builder.addStatement("$1T $2L = $3L.getLowerBound().getValue().get()", Distance.class, minVarName, - rangeParametername); - builder.addStatement("$1L.minDistance($2L).in($2L.getMetric())", variableName, minVarName); + builder.beginControlFlow("if($L.getLowerBound().isBounded())", rangeParameter); + VariableSnippet min = Snippet.declare(builder).variable(Distance.class, context.localVariable("min")) + .as("$L.getLowerBound().getValue().get()", rangeParameter); + builder.addStatement("$1L.minDistance($2L).in($2L.getMetric())", query.getVariableName(), + min.getVariableName()); builder.endControlFlow(); - builder.beginControlFlow("if($L.getUpperBound().isBounded())", rangeParametername); - builder.addStatement("$1T $2L = $3L.getUpperBound().getValue().get()", Distance.class, maxVarName, - rangeParametername); - builder.addStatement("$1L.maxDistance($2L).in($2L.getMetric())", variableName, maxVarName); + builder.beginControlFlow("if($L.getUpperBound().isBounded())", rangeParameter); + VariableSnippet max = Snippet.declare(builder).variable(Distance.class, context.localVariable("max")) + .as("$L.getUpperBound().getValue().get()", rangeParameter); + builder.addStatement("$1L.maxDistance($2L).in($2L.getMetric())", query.getVariableName(), + max.getVariableName()); builder.endControlFlow(); } else { - String distanceParametername = context.getParameterName(queryMethod.getParameters().getMaxDistanceIndex()); - builder.addStatement("$1L.maxDistance($2L).in($2L.getMetric())", variableName, distanceParametername); + String distanceParameter = context.getParameterName(queryMethod.getParameters().getMaxDistanceIndex()); + builder.addStatement("$1L.maxDistance($2L).in($2L.getMetric())", query.code(), distanceParameter); } if (context.getPageableParameterName() != null) { - builder.addStatement("$L.with($L)", variableName, context.getPageableParameterName()); + builder.addStatement("$L.with($L)", query.code(), context.getPageableParameterName()); } - MongoCodeBlocks.appendReadPreference(context, builder, variableName); + MongoCodeBlocks.appendReadPreference(context, builder, query.getVariableName()); return builder.build(); } @@ -115,29 +114,29 @@ CodeBlock build() { CodeBlock.Builder builder = CodeBlock.builder(); builder.add("\n"); - String executorVar = context.localVariable("nearFinder"); - builder.addStatement("var $L = $L.query($T.class).near($L)", executorVar, - context.fieldNameOf(MongoOperations.class), context.getRepositoryInformation().getDomainType(), - queryVariableName); + VariableSnippet queryExecutor = Snippet.declare(builder).variable(context.localVariable("nearFinder")).as( + "$L.query($T.class).near($L)", context.fieldNameOf(MongoOperations.class), + context.getRepositoryInformation().getDomainType(), queryVariableName); if (ClassUtils.isAssignable(GeoPage.class, context.getReturnType().getRawClass())) { - String geoResultVar = context.localVariable("geoResult"); - builder.addStatement("var $L = $L.all()", geoResultVar, executorVar); + VariableSnippet geoResult = Snippet.declare(builder).variable(context.localVariable("geoResult")).as("$L.all()", + queryExecutor.getVariableName()); builder.beginControlFlow("if($L.isUnpaged())", context.getPageableParameterName()); - builder.addStatement("return new $T<>($L)", GeoPage.class, geoResultVar); + builder.addStatement("return new $T<>($L)", GeoPage.class, geoResult.getVariableName()); builder.endControlFlow(); - String pageVar = context.localVariable("resultPage"); - builder.addStatement("var $L = $T.getPage($L.getContent(), $L, () -> $L.count())", pageVar, - PageableExecutionUtils.class, geoResultVar, context.getPageableParameterName(), executorVar); - builder.addStatement("return new $T<>($L, $L, $L.getTotalElements())", GeoPage.class, geoResultVar, - context.getPageableParameterName(), pageVar); + VariableSnippet resultPage = Snippet.declare(builder).variable(context.localVariable("resultPage")).as( + "$T.getPage($L.getContent(), $L, () -> $L.count())", PageableExecutionUtils.class, + geoResult.getVariableName(), context.getPageableParameterName(), queryExecutor.getVariableName()); + + builder.addStatement("return new $T<>($L, $L, $L.getTotalElements())", GeoPage.class, + geoResult.getVariableName(), context.getPageableParameterName(), resultPage.getVariableName()); } else if (ClassUtils.isAssignable(GeoResults.class, context.getReturnType().getRawClass())) { - builder.addStatement("return $L.all()", executorVar); + builder.addStatement("return $L.all()", queryExecutor.getVariableName()); } else { - builder.addStatement("return $L.all().getContent()", executorVar); + builder.addStatement("return $L.all().getContent()", queryExecutor.getVariableName()); } return builder.build(); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java index 4125139bd9..3cbb4e8562 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java @@ -170,6 +170,25 @@ static GeoNearExecutionCodeBlockBuilder geoNearExecutionBlockBuilder(AotQueryMet return new GeoNearExecutionCodeBlockBuilder(context, queryMethod); } + static CodeBlock asDocument(String source, Map arguments) { + + Builder builder = CodeBlock.builder(); + if (!StringUtils.hasText(source)) { + builder.add("new $T()", Document.class); + } else if (!containsPlaceholder(source)) { + builder.add("$T.parse($S)", Document.class, source); + } else { + builder.add("bindParameters($S, ", source); + if (containsNamedPlaceholder(source)) { + builder.add(renderArgumentMap(arguments)); + } else { + builder.add(renderArgumentArray(arguments)); + } + builder.add(");\n"); + } + return builder.build(); + } + static CodeBlock renderExpressionToDocument(@Nullable String source, String variableName, Map arguments) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index 524c5e8f23..7f4f069a85 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -111,7 +111,10 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB AnnotatedElementUtils.findMergedAnnotation(method, Query.class), method); if (queryMethod.isSearchQuery() || method.isAnnotationPresent(VectorSearch.class)) { - return searchMethodContributor(queryMethod, new SearchInteraction(query.getQuery())); + + VectorSearch vectorSearch = AnnotatedElementUtils.findMergedAnnotation(method, VectorSearch.class); + return searchMethodContributor(queryMethod, new SearchInteraction(getRepositoryInformation().getDomainType(), + vectorSearch, query.getQuery(), queryMethod.getParameters())); } if (queryMethod.isGeoNearQuery() || (queryMethod.getParameters().getMaxDistanceIndex() != -1 @@ -233,8 +236,9 @@ static MethodContributor searchMethodContributor(MongoQueryMet String variableName = "search"; - builder.add(new VectorSearchBocks.VectorSearchQueryCodeBlockBuilder(context, queryMethod) - .usingVariableName(variableName).withFilter(interaction.getFilter()).build()); + builder.add( + new VectorSearchBocks.VectorSearchQueryCodeBlockBuilder(context, queryMethod, interaction.getSearchPath()) + .usingVariableName(variableName).withFilter(interaction.getFilter()).build()); return builder.build(); }); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java index 7ad0c25b16..8e4e6b4603 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java @@ -22,7 +22,6 @@ import org.bson.Document; import org.jspecify.annotations.NullUnmarked; -import org.jspecify.annotations.Nullable; import org.springframework.core.annotation.MergedAnnotation; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; @@ -215,9 +214,9 @@ CodeBlock build() { if (StringUtils.hasText(source.getQuery().getFieldsString())) { - builder - .add(MongoCodeBlocks.renderExpressionToDocument(source.getQuery().getFieldsString(), "fields", arguments)); - builder.addStatement("$L.setFieldsObject(fields)", queryVariableName); + VariableSnippet fields = Snippet.declare(builder).variable(Document.class, context.localVariable("fields")) + .of(MongoCodeBlocks.asDocument(source.getQuery().getFieldsString(), arguments)); + builder.addStatement("$L.setFieldsObject($L)", queryVariableName, fields.getVariableName()); } String sortParameter = context.getSortParameterName(); @@ -225,8 +224,9 @@ CodeBlock build() { builder.addStatement("$L.with($L)", queryVariableName, sortParameter); } else if (StringUtils.hasText(source.getQuery().getSortString())) { - builder.add(MongoCodeBlocks.renderExpressionToDocument(source.getQuery().getSortString(), "sort", arguments)); - builder.addStatement("$L.setSortObject(sort)", queryVariableName); + VariableSnippet sort = Snippet.declare(builder).variable(Document.class, context.localVariable("sort")) + .of(MongoCodeBlocks.asDocument(source.getQuery().getSortString(), arguments)); + builder.addStatement("$L.setSortObject($L)", queryVariableName, sort.getVariableName()); } String limitParameter = context.getLimitParameterName(); @@ -273,10 +273,10 @@ CodeBlock build() { if (collationAnnotation.isPresent()) { String collationString = collationAnnotation.getString("value"); - if(StringUtils.hasText(collationString)) { + if (StringUtils.hasText(collationString)) { if (!MongoCodeBlocks.containsPlaceholder(collationString)) { builder.addStatement("$L.collation($T.parse($S))", queryVariableName, - org.springframework.data.mongodb.core.query.Collation.class, collationString); + org.springframework.data.mongodb.core.query.Collation.class, collationString); } else { builder.add("$L.collation(collationOf(evaluate($S, ", queryVariableName, collationString); builder.add(MongoCodeBlocks.renderArgumentMap(arguments)); @@ -292,29 +292,28 @@ CodeBlock buildJustTheQuery() { Builder builder = CodeBlock.builder(); builder.add("\n"); - builder.add(renderExpressionToQuery(source.getQuery().getQueryString(), queryVariableName)); + + Snippet.declare(builder).variable(BasicQuery.class, this.queryVariableName).of(renderExpressionToQuery()); return builder.build(); } - private CodeBlock renderExpressionToQuery(@Nullable String source, String variableName) { + private CodeBlock renderExpressionToQuery() { - Builder builder = CodeBlock.builder(); + String source = this.source.getQuery().getQueryString(); if (!StringUtils.hasText(source)) { - - builder.addStatement("$1T $2L = new $1T(new $3T())", BasicQuery.class, variableName, Document.class); - } else if (!MongoCodeBlocks.containsPlaceholder(source)) { - builder.addStatement("$1T $2L = new $1T($3T.parse($4S))", BasicQuery.class, variableName, Document.class, - source); + return CodeBlock.of("new $T(new $T())", BasicQuery.class, Document.class); + } + if (!MongoCodeBlocks.containsPlaceholder(source)) { + return CodeBlock.of("new $T($T.parse($S))", BasicQuery.class, Document.class, source); + } + Builder builder = CodeBlock.builder(); + builder.add("createQuery($S, ", source); + if (MongoCodeBlocks.containsNamedPlaceholder(source)) { + builder.add(MongoCodeBlocks.renderArgumentMap(arguments)); } else { - builder.add("$T $L = createQuery($S, ", BasicQuery.class, variableName, source); - if (MongoCodeBlocks.containsNamedPlaceholder(source)) { - builder.add(MongoCodeBlocks.renderArgumentMap(arguments)); - } else { - builder.add(MongoCodeBlocks.renderArgumentArray(arguments)); - } - builder.add(");\n"); + builder.add(MongoCodeBlocks.renderArgumentArray(arguments)); } - + builder.add(");\n"); return builder.build(); } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java index a94ff1082b..906061018c 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java @@ -15,26 +15,55 @@ */ package org.springframework.data.mongodb.repository.aot; +import java.lang.reflect.Field; +import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; +import org.bson.Document; +import org.bson.json.JsonMode; +import org.bson.json.JsonWriterSettings; import org.jspecify.annotations.Nullable; +import org.springframework.data.domain.Vector; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType; +import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.repository.query.MongoParameters; import org.springframework.data.repository.aot.generate.QueryMetadata; +import org.springframework.util.StringUtils; /** * @author Christoph Strobl */ public class SearchInteraction extends MongoInteraction implements QueryMetadata { - StringQuery filter; + private final Class domainType; + private final StringQuery filter; + private final @Nullable VectorSearch vectorSearch; + private final MongoParameters parameters; + + public SearchInteraction(Class domainType, @Nullable VectorSearch vectorSearch, StringQuery filter, + MongoParameters parameters) { + + this.domainType = domainType; + this.vectorSearch = vectorSearch; - public SearchInteraction(StringQuery filter) { this.filter = filter; + this.parameters = parameters; } public StringQuery getFilter() { return filter; } + @Nullable + String getIndexName() { + return vectorSearch != null ? vectorSearch.indexName() : null; + } + + public MongoParameters getParameters() { + return parameters; + } + @Override InteractionType getExecutionType() { return InteractionType.AGGREGATION; @@ -43,6 +72,62 @@ InteractionType getExecutionType() { @Override public Map serialize() { - return Map.of("FIXME", "please!"); + Map serialized = new LinkedHashMap<>(); + + if (vectorSearch != null && StringUtils.hasText(vectorSearch.indexName())) { + serialized.put("index", vectorSearch.indexName()); + } + + serialized.put("path", getSearchPath()); + + if (vectorSearch.searchType().equals(SearchType.ENN)) { + serialized.put("exact", true); + } + + if (StringUtils.hasText(filter.getQueryString())) { + serialized.put("filter", filter.getQueryString()); + } + + String limit = limitParameter(); + if (StringUtils.hasText(limit)) { + serialized.put("limit", limit); + } + + if (StringUtils.hasText(vectorSearch.numCandidates())) { + serialized.put("numCandidates", vectorSearch.numCandidates()); + } else if (StringUtils.hasText(limit)) { + serialized.put("numCandidates", limit + " * 20"); + } + + serialized.put("queryVector", "?" + parameters.getVectorIndex()); + + return Map.of("pipeline", List.of(new Document("$vectorSearch", serialized) + .toJson(JsonWriterSettings.builder().outputMode(JsonMode.RELAXED).build()).replaceAll("\\\"", "'"))); + } + + private @Nullable String limitParameter() { + + if (parameters.hasLimitParameter()) { + return "?" + parameters.getLimitIndex(); + } else if (StringUtils.hasText(vectorSearch.limit())) { + return vectorSearch.limit(); + } + return null; + } + + public String getSearchPath() { + + if (vectorSearch != null && StringUtils.hasText(vectorSearch.path())) { + return vectorSearch.path(); + } + + Field[] declaredFields = domainType.getDeclaredFields(); + for (Field field : declaredFields) { + if (Vector.class.isAssignableFrom(field.getType())) { + return field.getName(); + } + } + + throw new IllegalArgumentException("No vector search path found for type %s".formatted(domainType)); } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/Snippet.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/Snippet.java new file mode 100644 index 0000000000..7f62f0cdda --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/Snippet.java @@ -0,0 +1,224 @@ +/* + * Copyright 2025-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import java.util.function.Function; + +import org.jspecify.annotations.Nullable; +import org.springframework.core.ResolvableType; +import org.springframework.data.mongodb.repository.aot.Snippet.BuilderStyleBuilder.BuilderStyleMethodArgumentBuilder; +import org.springframework.data.mongodb.repository.aot.Snippet.BuilderStyleVariableBuilder.BuilderStyleVariableBuilderImpl; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; + +/** + * @author Christoph Strobl + * @since 5.0 + */ +interface Snippet { + + CodeBlock code(); + + default boolean isEmpty() { + return code().isEmpty(); + } + + default void appendTo(CodeBlock.Builder builder) { + if (!isEmpty()) { + builder.add(code()); + } + } + + default T as(Function transformer) { + return transformer.apply(this); + } + + default Snippet wrap(String prefix, String suffix) { + return wrap("%s$L%s".formatted(prefix, suffix)); + } + + default Snippet wrap(CodeBlock prefix, CodeBlock suffix) { + return new Snippet() { + + @Override + public CodeBlock code() { + return CodeBlock.builder().add(prefix).add(Snippet.this.code()).add(suffix).build(); + } + }; + } + + default Snippet wrap(String statement) { + return new Snippet() { + + @Override + public CodeBlock code() { + return CodeBlock.of(statement, Snippet.this.code()); + } + }; + } + + static Snippet just(CodeBlock codeBlock) { + return new Snippet() { + @Override + public CodeBlock code() { + return codeBlock; + } + }; + } + + static ContextualSnippetBuilder declare(CodeBlock.Builder builder) { + + return new ContextualSnippetBuilder() { + + @Override + public VariableBuilder variable(String variableName) { + return VariableSnippet.variable(variableName).targeting(builder); + } + + @Override + public VariableBuilder variable(Class type, String variableName) { + return VariableSnippet.variable(type, variableName).targeting(builder); + } + + @Override + public VariableBuilder variable(ResolvableType resolvableType, String variableName) { + return VariableSnippet.variable(resolvableType, variableName).targeting(builder); + } + + @Override + public BuilderStyleVariableBuilder variableBuilder(String variableName) { + return new BuilderStyleVariableBuilderImpl(builder, null, variableName); + } + + @Override + public BuilderStyleVariableBuilder variableBuilder(Class type, String variableName) { + return variableBuilder(ResolvableType.forClass(type), variableName); + } + + @Override + public BuilderStyleVariableBuilder variableBuilder(ResolvableType resolvableType, String variableName) { + return new BuilderStyleVariableBuilderImpl(builder, resolvableType, variableName); + } + }; + } + + interface ContextualSnippetBuilder { + + VariableBuilder variable(String variableName); + + VariableBuilder variable(Class type, String variableName); + + VariableBuilder variable(ResolvableType resolvableType, String variableName); + + BuilderStyleVariableBuilder variableBuilder(String variableName); + + BuilderStyleVariableBuilder variableBuilder(Class type, String variableName); + + BuilderStyleVariableBuilder variableBuilder(ResolvableType resolvableType, String variableName); + } + + interface VariableBuilder { + + default VariableSnippet as(String declaration, Object... args) { + return of(CodeBlock.of(declaration, args)); + } + + VariableSnippet of(CodeBlock codeBlock); + } + + interface BuilderStyleVariableBuilder { + + default BuilderStyleBuilder as(String declaration, Object... args) { + return of(CodeBlock.of(declaration, args)); + } + + BuilderStyleBuilder of(CodeBlock codeBlock); + + class BuilderStyleVariableBuilderImpl + implements BuilderStyleVariableBuilder, BuilderStyleBuilder, BuilderStyleMethodArgumentBuilder { + + Builder targetBuilder; + @Nullable ResolvableType type; + String targetVariableName; + @Nullable String targetMethodName; + @Nullable VariableSnippet variableSnippet; + + public BuilderStyleVariableBuilderImpl(Builder targetBuilder, @Nullable ResolvableType type, + String targetVariableName) { + this.targetBuilder = targetBuilder; + this.type = type; + this.targetVariableName = targetVariableName; + } + + @Override + public BuilderStyleBuilder as(String declaration, Object... args) { + + if (type != null) { + this.variableSnippet = Snippet.declare(targetBuilder).variable(type, targetVariableName).as(declaration, args); + } else { + this.variableSnippet = Snippet.declare(targetBuilder).variable(targetVariableName).as(declaration, args); + } + return this; + } + + @Override + public BuilderStyleBuilder of(CodeBlock codeBlock) { + if (type != null) { + this.variableSnippet = Snippet.declare(targetBuilder).variable(type, targetVariableName).of(codeBlock); + } else { + this.variableSnippet = Snippet.declare(targetBuilder).variable(targetVariableName).of(codeBlock); + } + return this; + } + + @Override + public BuilderStyleMethodArgumentBuilder call(String methodName) { + this.targetMethodName = methodName; + return this; + } + + @Override + public BuilderStyleBuilder with(Snippet snippet) { + new BuilderStyleSnippet(targetVariableName, targetMethodName, snippet).appendTo(targetBuilder); + return this; + } + + @Override + public VariableSnippet variable() { + return this.variableSnippet; + } + } + } + + interface BuilderStyleBuilder { + + BuilderStyleMethodArgumentBuilder call(String methodName); + VariableSnippet variable(); + + interface BuilderStyleMethodArgumentBuilder { + default BuilderStyleBuilder with(String statement, Object... args) { + return with(CodeBlock.of(statement, args)); + } + + default BuilderStyleBuilder with(CodeBlock codeBlock) { + return with(Snippet.just(codeBlock)); + } + + BuilderStyleBuilder with(Snippet snippet); + } + } + +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java index e4061c7717..3ece399edb 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java @@ -35,6 +35,7 @@ import java.util.Map; import org.jspecify.annotations.NullUnmarked; +import org.springframework.core.ResolvableType; import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.query.BasicUpdate; @@ -87,22 +88,26 @@ CodeBlock build() { String updateReference = updateVariableName; Class domainType = context.getRepositoryInformation().getDomainType(); - builder.addStatement("$1T<$2T> $3L = $4L.update($2T.class)", ExecutableUpdate.class, domainType, - context.localVariable("updater"), mongoOpsRef); + VariableSnippet updater = Snippet.declare(builder) + .variable(ResolvableType.forClassWithGenerics(ExecutableUpdate.class, domainType), + context.localVariable("updater")) + .as("$L.update($T.class)", mongoOpsRef, domainType); Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); if (ReflectionUtils.isVoid(returnType)) { - builder.addStatement("$L.matching($L).apply($L).all()", context.localVariable("updater"), queryVariableName, + builder.addStatement("$L.matching($L).apply($L).all()", updater.getVariableName(), queryVariableName, updateReference); } else if (ClassUtils.isAssignable(Long.class, returnType)) { - builder.addStatement("return $L.matching($L).apply($L).all().getModifiedCount()", - context.localVariable("updater"), queryVariableName, updateReference); + builder.addStatement("return $L.matching($L).apply($L).all().getModifiedCount()", updater.getVariableName(), + queryVariableName, updateReference); } else { - builder.addStatement("$T $L = $L.matching($L).apply($L).all().getModifiedCount()", Long.class, - context.localVariable("modifiedCount"), context.localVariable("updater"), queryVariableName, - updateReference); + + VariableSnippet modifiedCount = Snippet.declare(builder) + .variable(Long.class, context.localVariable("modifiedCount")) + .as("$L.matching($L).apply($L).all().getModifiedCount()", updater.getVariableName(), queryVariableName, + updateReference); builder.addStatement("return $T.convertNumberToTargetClass($L, $T.class)", NumberUtils.class, - context.localVariable("modifiedCount"), returnType); + modifiedCount.getVariableName(), returnType); } return builder.build(); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VariableSnippet.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VariableSnippet.java new file mode 100644 index 0000000000..e3d67049af --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VariableSnippet.java @@ -0,0 +1,119 @@ +/* + * Copyright 2025-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import org.jspecify.annotations.Nullable; +import org.springframework.core.ResolvableType; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.javapoet.TypeName; + +/** + * @author Christoph Strobl + * @since 5.0 + */ +class VariableSnippet extends ExpressionSnippet { + + private final String variableName; + private final @Nullable TypeName typeName; + + public VariableSnippet(String variableName, Snippet delegate) { + this((TypeName) null, variableName, delegate); + } + + public VariableSnippet(Class typeName, String variableName, Snippet delegate) { + this(TypeName.get(typeName), variableName, delegate); + } + + public VariableSnippet(@Nullable TypeName typeName, String variableName, Snippet delegate) { + super(delegate); + this.typeName = typeName; + this.variableName = variableName; + } + + static VariableBuilderImp variable(String name) { + return new VariableBuilderImp(null, name); + } + + static VariableBuilderImp variable(Class typeName, String name) { + return variable(TypeName.get(typeName), name); + } + + static VariableBuilderImp variable(ResolvableType resolvableType, String name) { + return variable(TypeName.get(resolvableType.getType()), name); + } + + static VariableBuilderImp variable(TypeName typeName, String name) { + return new VariableBuilderImp(typeName, name); + } + + static class VariableBuilderImp implements VariableBuilder { + + private @Nullable TypeName typeName; + private String variableName; + + private CodeBlock.@Nullable Builder target; + + VariableBuilderImp(@Nullable TypeName typeName, String variableName) { + this.typeName = typeName; + this.variableName = variableName; + } + + @Override + public VariableSnippet of(CodeBlock codeBlock) { + + VariableSnippet variableSnippet = new VariableSnippet(typeName, variableName, Snippet.just(codeBlock)); + if (target != null) { + variableSnippet.renderDeclaration(target); + } + return variableSnippet; + } + + VariableBuilderImp targeting(@Nullable Builder target) { + this.target = target; + return this; + } + } + + @Override + public CodeBlock code() { + return CodeBlock.of("$L", variableName); + } + + public String getVariableName() { + return variableName; + } + + void renderDeclaration(CodeBlock.Builder builder) { + if (typeName != null) { + builder.addStatement("$T $L = $L", typeName, variableName, super.code()); + } else { + builder.addStatement("var $L = $L", variableName, super.code()); + } + } + + static VariableBlockBuilder create(Snippet snippet) { + return variableName -> create(variableName, snippet); + } + + static VariableSnippet create(String variableName, Snippet snippet) { + return new VariableSnippet(variableName, snippet); + } + + interface VariableBlockBuilder { + VariableSnippet variableName(String variableName); + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java index 3efdc080b2..c9d74edc63 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java @@ -15,7 +15,6 @@ */ package org.springframework.data.mongodb.repository.aot; -import java.lang.reflect.Field; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -25,15 +24,14 @@ import org.springframework.data.domain.Limit; import org.springframework.data.domain.ScoringFunction; import org.springframework.data.domain.Sort; -import org.springframework.data.domain.Vector; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.aggregation.Aggregation; import org.springframework.data.mongodb.core.aggregation.AggregationOperation; -import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext; import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType; import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.repository.aot.Snippet.BuilderStyleBuilder; import org.springframework.data.mongodb.repository.query.MongoQueryExecution.VectorSearchExecution; import org.springframework.data.mongodb.repository.query.MongoQueryMethod; import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; @@ -55,11 +53,14 @@ static class VectorSearchQueryCodeBlockBuilder { private String searchQueryVariableName; private StringQuery filter; private final Map arguments; + private final String searchPath; - VectorSearchQueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + VectorSearchQueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod, + String searchPath) { this.context = context; this.queryMethod = queryMethod; + this.searchPath = searchPath; this.arguments = new LinkedHashMap<>(); context.getBindableParameterNames().forEach(it -> arguments.put(it, CodeBlock.of(it))); } @@ -77,113 +78,56 @@ CodeBlock build() { String vectorParameterName = context.getVectorParameterName(); MergedAnnotation annotation = context.getAnnotation(VectorSearch.class); - String searchPath = annotation.getString("path"); String indexName = annotation.getString("indexName"); - String numCandidates = annotation.getString("numCandidates"); SearchType searchType = annotation.getEnum("searchType", SearchType.class); - String limit = annotation.getString("limit"); - if (!StringUtils.hasText(searchPath)) { // FIXME: somehow duplicate logic of AnnotatedQueryFactory + ExpressionSnippet limit = getLimitExpression(); - Field[] declaredFields = context.getRepositoryInformation().getDomainType().getDeclaredFields(); - for (Field field : declaredFields) { - if (Vector.class.isAssignableFrom(field.getType())) { - searchPath = field.getName(); - break; - } - } + if (limit.requiresEvaluation() && !StringUtils.hasText(annotation.getString("numCandidates")) + && (searchType == VectorSearchOperation.SearchType.ANN + || searchType == VectorSearchOperation.SearchType.DEFAULT)) { + VariableSnippet variableBlock = limit.as(VariableSnippet::create) + .variableName(context.localVariable("limitToUse")); + variableBlock.renderDeclaration(builder); + limit = variableBlock; } - String vectorSearchVar = context.localVariable("$vectorSearch"); - builder.add("$T $L = $T.vectorSearch($S).path($S).vector($L)", VectorSearchOperation.class, vectorSearchVar, - Aggregation.class, indexName, searchPath, vectorParameterName); - - if (StringUtils.hasText(context.getLimitParameterName())) { - builder.add(".limit($L);\n", context.getLimitParameterName()); - } else if (filter.isLimited()) { - builder.add(".limit($L);\n", filter.getLimit()); - } else if (StringUtils.hasText(limit)) { - if (MongoCodeBlocks.containsPlaceholder(limit) || MongoCodeBlocks.containsExpression(limit)) { - builder.add(".limit("); - builder.add(MongoCodeBlocks.evaluateNumberPotentially(limit, Integer.class, arguments)); - builder.add(");\n"); - } else { - builder.add(".limit($L);\n", limit); - } - } else { - builder.add(".limit($T.unlimited());\n", Limit.class); - } + BuilderStyleBuilder vectorSearchOperationBuilder = Snippet.declare(builder) + .variableBuilder(VectorSearchOperation.class, context.localVariable("$vectorSearch")) + .as("$T.vectorSearch($S).path($S).vector($L).limit($L)", Aggregation.class, indexName, searchPath, + vectorParameterName, limit.code()); if (!searchType.equals(SearchType.DEFAULT)) { - builder.addStatement("$1L = $1L.searchType($2T.$3L)", vectorSearchVar, SearchType.class, searchType.name()); + vectorSearchOperationBuilder.call("searchType").with("$T.$L", SearchType.class, searchType.name()); } - if (StringUtils.hasText(numCandidates)) { - builder.add("$1L = $1L.numCandidates(", vectorSearchVar); - builder.add(MongoCodeBlocks.evaluateNumberPotentially(numCandidates, Integer.class, arguments)); - builder.add(");\n"); - } else if (searchType == VectorSearchOperation.SearchType.ANN - || searchType == VectorSearchOperation.SearchType.DEFAULT) { - - builder.add( - "// MongoDB: We recommend that you specify a number at least 20 times higher than the number of documents to return\n"); - if (StringUtils.hasText(context.getLimitParameterName())) { - builder.addStatement("$1L = $1L.numCandidates($2L.max() * 20)", vectorSearchVar, - context.getLimitParameterName()); - } else if (StringUtils.hasText(limit)) { - if (MongoCodeBlocks.containsPlaceholder(limit) || MongoCodeBlocks.containsExpression(limit)) { - - builder.add("$1L = $1L.numCandidates((", vectorSearchVar); - builder.add(MongoCodeBlocks.evaluateNumberPotentially(limit, Integer.class, arguments)); - builder.add(") * 20);\n"); - } else { - builder.addStatement("$1L = $1L.numCandidates($2L * 20)", vectorSearchVar, limit); - } - } else { - builder.addStatement("$1L = $1L.numCandidates($2L)", vectorSearchVar, filter.getLimit() * 20); - } + ExpressionSnippet numCandidates = getNumCandidatesExpression(searchType, limit); + if (!numCandidates.isEmpty()) { + vectorSearchOperationBuilder.call("numCandidates").with(numCandidates); } - builder.addStatement("$1L = $1L.withSearchScore(\"__score__\")", vectorSearchVar); - if (StringUtils.hasText(context.getScoreParameterName())) { + vectorSearchOperationBuilder.call("withSearchScore").with("\"__score__\""); - String scoreCriteriaVar = context.localVariable("criteria"); - builder.addStatement("$1L = $1L.withFilterBySore($2L -> { $2L.gt($3L.getValue()); })", vectorSearchVar, - scoreCriteriaVar, context.getScoreParameterName()); + if (StringUtils.hasText(context.getScoreParameterName())) { + vectorSearchOperationBuilder.call("withFilterBySore").with("$1L -> { $1L.gt($2L.getValue()); }", + context.localVariable("criteria"), context.getScoreParameterName()); } else if (StringUtils.hasText(context.getScoreRangeParameterName())) { - builder.addStatement("$1L = $1L.withFilterBySore(scoreBetween($2L.getLowerBound(), $2L.getUpperBound()))", - vectorSearchVar, context.getScoreRangeParameterName()); + vectorSearchOperationBuilder.call("withFilterBySore") + .with("scoreBetween($1L.getLowerBound(), $1L.getUpperBound())", context.getScoreRangeParameterName()); } - if (StringUtils.hasText(filter.getQueryString())) { - - String filterVar = context.localVariable("filter"); - builder.add(MongoCodeBlocks.queryBlockBuilder(context, queryMethod).usingQueryVariableName("filter") - .filter(new QueryInteraction(this.filter, false, false, false)).buildJustTheQuery()); - builder.addStatement("$1L = $1L.filter($2L.getQueryObject())", vectorSearchVar, filterVar); - builder.add("\n"); - } + VariableSnippet vectorSearchOperation = vectorSearchOperationBuilder.variable(); + getFilter(vectorSearchOperation.getVariableName()).appendTo(builder); + VariableSnippet sortStage = getSort().as(VariableSnippet::create).variableName(context.localVariable("$sort")); + sortStage.renderDeclaration(builder); - String sortStageVar = context.localVariable("$sort"); - if(filter.isSorted()) { - - builder.add("$T $L = (_ctx) -> {\n", AggregationOperation.class, sortStageVar); - builder.indent(); - - builder.addStatement("$1T _mappedSort = _ctx.getMappedObject($1T.parse($2S), $3T.class)", Document.class, filter.getSortString(), context.getActualReturnType().getType()); - builder.addStatement("return new $T($S, _mappedSort.append(\"__score__\", -1))", Document.class, "$sort"); - builder.unindent(); - builder.add("};"); - - } else { - builder.addStatement("var $L = $T.sort($T.Direction.DESC, $S)", sortStageVar, Aggregation.class, Sort.class, "__score__"); - } builder.add("\n"); - builder.addStatement("$1T $2L = new $1T($3T.of($4L, $5L))", AggregationPipeline.class, searchQueryVariableName, - List.class, vectorSearchVar, sortStageVar); + VariableSnippet aggregationPipeline = Snippet.declare(builder) + .variable(AggregationPipeline.class, searchQueryVariableName).as("new $T($T.of($L, $L))", + AggregationPipeline.class, List.class, vectorSearchOperation.getVariableName(), sortStage.code()); String scoringFunctionVar = context.localVariable("scoringFunction"); builder.add("$1T $2L = ", ScoringFunction.class, scoringFunctionVar); @@ -199,13 +143,108 @@ CodeBlock build() { "return ($5T) new $1T($2L, $3T.class, $2L.getCollectionName($3T.class), $4T.of($5T.class), $6L, $7L).execute(null)", VectorSearchExecution.class, context.fieldNameOf(MongoOperations.class), context.getRepositoryInformation().getDomainType(), TypeInformation.class, - queryMethod.getReturnType().getType(), searchQueryVariableName, scoringFunctionVar); + queryMethod.getReturnType().getType(), aggregationPipeline.getVariableName(), scoringFunctionVar); return builder.build(); } + private ExpressionSnippet getSort() { + + if (!filter.isSorted()) { + return new ExpressionSnippet( + CodeBlock.of("$T.sort($T.Direction.DESC, $S)", Aggregation.class, Sort.class, "__score__")); + } + + Builder builder = CodeBlock.builder(); + + builder.add("($T) (_ctx) -> {\n", AggregationOperation.class); + builder.indent(); + + builder.add("$1T _mappedSort = _ctx.getMappedObject($1T.parse($2S), $3T.class);\n", Document.class, + filter.getSortString(), context.getActualReturnType().getType()); + builder.add("return new $T($S, _mappedSort.append(\"__score__\", -1));\n", Document.class, "$sort"); + builder.unindent(); + builder.add("};"); + + return new ExpressionSnippet(builder.build()); + } + + private Snippet getFilter(String vectorSearchVar) { + + if (!StringUtils.hasText(filter.getQueryString())) { + return ExpressionSnippet.empty(); + } + + Builder builder = CodeBlock.builder(); + String filterVar = context.localVariable("filter"); + builder.add(MongoCodeBlocks.queryBlockBuilder(context, queryMethod).usingQueryVariableName("filter") + .filter(new QueryInteraction(this.filter, false, false, false)).buildJustTheQuery()); + builder.addStatement("$1L = $1L.filter($2L.getQueryObject())", vectorSearchVar, filterVar); + builder.add("\n"); + + return new ExpressionSnippet(builder.build()); + } + public VectorSearchQueryCodeBlockBuilder withFilter(StringQuery filter) { this.filter = filter; return this; } + + private ExpressionSnippet getNumCandidatesExpression(SearchType searchType, ExpressionSnippet limit) { + + MergedAnnotation annotation = context.getAnnotation(VectorSearch.class); + String numCandidates = annotation.getString("numCandidates"); + + if (StringUtils.hasText(numCandidates)) { + if (MongoCodeBlocks.containsPlaceholder(numCandidates) || MongoCodeBlocks.containsExpression(numCandidates)) { + return new ExpressionSnippet( + MongoCodeBlocks.evaluateNumberPotentially(numCandidates, Integer.class, arguments), true); + } else { + return new ExpressionSnippet(CodeBlock.of("$L", numCandidates)); + } + } + + if (searchType == VectorSearchOperation.SearchType.ANN + || searchType == VectorSearchOperation.SearchType.DEFAULT) { + + Builder builder = CodeBlock.builder(); + + if (StringUtils.hasText(context.getLimitParameterName())) { + builder.add("$L.max() * 20", context.getLimitParameterName()); + } else if (filter.isLimited()) { + builder.add("$L", filter.getLimit() * 20); + } else { + builder.add("$L * 20", limit.code()); + } + + return new ExpressionSnippet(builder.build()); + } + + return ExpressionSnippet.empty(); + } + + private ExpressionSnippet getLimitExpression() { + + if (StringUtils.hasText(context.getLimitParameterName())) { + return new ExpressionSnippet(CodeBlock.of("$L", context.getLimitParameterName())); + } + + if (filter.isLimited()) { + return new ExpressionSnippet(CodeBlock.of("$L", filter.getLimit())); + } + + MergedAnnotation annotation = context.getAnnotation(VectorSearch.class); + String limit = annotation.getString("limit"); + + if (StringUtils.hasText(limit)) { + + if (MongoCodeBlocks.containsPlaceholder(limit) || MongoCodeBlocks.containsExpression(limit)) { + return new ExpressionSnippet(MongoCodeBlocks.evaluateNumberPotentially(limit, Integer.class, arguments), + true); + } else { + return new ExpressionSnippet(CodeBlock.of("$L", limit)); + } + } + return new ExpressionSnippet(CodeBlock.of("$T.unlimited()", Limit.class)); + } } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java index d8de601d4e..5395dba2bf 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java @@ -332,11 +332,11 @@ void rendersVectorSearchLimitFromExpression() throws NoSuchMethodException { Range.class); assertThat(methodSpec.toString()) // - .containsSubsequence( - "Aggregation.vectorSearch(\"embedding.vector_cos\").path(\"embedding\").vector(vector).limit(", + .containsSubsequence("var limitToUse = ", "evaluate(\"#{5+5}\", argumentMap(\"lastname\", lastname, \"distance\", distance)") - .containsSubsequence("$vectorSearch.numCandidates(", - "evaluate(\"#{5+5}\", argumentMap(\"lastname\", lastname, \"distance\", distance))) * 20)"); + .contains( + "Aggregation.vectorSearch(\"embedding.vector_cos\").path(\"embedding\").vector(vector).limit(limitToUse)") + .contains("$vectorSearch.numCandidates(limitToUse * 20)"); } @Test @@ -358,7 +358,8 @@ void rendersVectorSearchOrderByWithScoreLast() throws NoSuchMethodException { String.class, Vector.class, Range.class); assertThat(methodSpec.toString()) // - .containsSubsequence("AggregationOperation $sort = (_ctx) -> {", // + .containsSubsequence("var $sort = ", // + "(_ctx) -> {", // "_mappedSort = _ctx.getMappedObject(", // "Document.parse(\"{'firstname':{'$numberInt':'1'}}\")", // "Document(\"$sort\", _mappedSort.append(\"__score__\", -1))");