Skip to content

Commit

Permalink
Fixed spring data cosmos query plan caching. Added new tests (Azure#2…
Browse files Browse the repository at this point in the history
  • Loading branch information
kushagraThapar authored Aug 31, 2021
1 parent dc02a92 commit b5e6f63
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,20 @@
package com.azure.spring.data.cosmos.core;

import com.azure.cosmos.CosmosAsyncClient;
import com.azure.cosmos.CosmosBridgeInternal;
import com.azure.cosmos.CosmosClientBuilder;
import com.azure.cosmos.implementation.AsyncDocumentClient;
import com.azure.cosmos.implementation.query.PartitionedQueryExecutionInfo;
import com.azure.cosmos.models.PartitionKey;
import com.azure.cosmos.models.SqlQuerySpec;
import com.azure.spring.data.cosmos.CosmosFactory;
import com.azure.spring.data.cosmos.IntegrationTestCollectionManager;
import com.azure.spring.data.cosmos.common.PageTestUtils;
import com.azure.spring.data.cosmos.common.TestConstants;
import com.azure.spring.data.cosmos.common.TestUtils;
import com.azure.spring.data.cosmos.config.CosmosConfig;
import com.azure.spring.data.cosmos.core.convert.MappingCosmosConverter;
import com.azure.spring.data.cosmos.core.generator.FindQuerySpecGenerator;
import com.azure.spring.data.cosmos.core.mapping.CosmosMappingContext;
import com.azure.spring.data.cosmos.core.query.CosmosPageRequest;
import com.azure.spring.data.cosmos.core.query.CosmosQuery;
Expand All @@ -38,6 +43,7 @@
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ConcurrentMap;

import static com.azure.spring.data.cosmos.common.TestConstants.ADDRESSES;
import static com.azure.spring.data.cosmos.common.TestConstants.FIRST_NAME;
Expand All @@ -63,11 +69,12 @@ public class CosmosTemplatePartitionIT {
HOBBIES, ADDRESSES);

private static final PartitionPerson TEST_PERSON_2 = new PartitionPerson(ID_2, NEW_FIRST_NAME,
TEST_PERSON.getZipCode(), HOBBIES, ADDRESSES);
NEW_ZIP_CODE, HOBBIES, ADDRESSES);

@ClassRule
public static final IntegrationTestCollectionManager collectionManager = new IntegrationTestCollectionManager();

private static CosmosFactory cosmosFactory;
private static CosmosTemplate cosmosTemplate;
private static String containerName;
private static CosmosEntityInformation<PartitionPerson, String> personInfo;
Expand All @@ -82,8 +89,10 @@ public class CosmosTemplatePartitionIT {
@Before
public void setUp() throws ClassNotFoundException {
if (cosmosTemplate == null) {
// Enable Query plan caching for testing
System.setProperty("COSMOS.QUERYPLAN_CACHING_ENABLED", "true");
CosmosAsyncClient client = CosmosFactory.createCosmosAsyncClient(cosmosClientBuilder);
final CosmosFactory cosmosFactory = new CosmosFactory(client, TestConstants.DB_NAME);
cosmosFactory = new CosmosFactory(client, TestConstants.DB_NAME);
final CosmosMappingContext mappingContext = new CosmosMappingContext();

personInfo = new CosmosEntityInformation<>(PartitionPerson.class);
Expand Down Expand Up @@ -121,6 +130,41 @@ public void testFindWithPartition() {
assertEquals(TEST_PERSON, result.get(0));
}

@Test
public void testFindWithPartitionWithQueryPlanCachingEnabled() {
Criteria criteria = Criteria.getInstance(CriteriaType.IS_EQUAL, PROPERTY_ZIP_CODE,
Collections.singletonList(ZIP_CODE), Part.IgnoreCaseType.NEVER);
CosmosQuery query = new CosmosQuery(criteria);
SqlQuerySpec sqlQuerySpec = new FindQuerySpecGenerator().generateCosmos(query);
List<PartitionPerson> result = TestUtils.toList(cosmosTemplate.find(query, PartitionPerson.class,
PartitionPerson.class.getSimpleName()));

assertThat(result.size()).isEqualTo(1);
assertEquals(TEST_PERSON, result.get(0));

CosmosAsyncClient cosmosAsyncClient = cosmosFactory.getCosmosAsyncClient();
AsyncDocumentClient asyncDocumentClient = CosmosBridgeInternal.getAsyncDocumentClient(cosmosAsyncClient);
ConcurrentMap<String, PartitionedQueryExecutionInfo> initialCache = asyncDocumentClient.getQueryPlanCache();
assertThat(initialCache.containsKey(sqlQuerySpec.getQueryText())).isTrue();
int initialSize = initialCache.size();

cosmosTemplate.insert(TEST_PERSON_2, new PartitionKey(TEST_PERSON_2.getZipCode()));

criteria = Criteria.getInstance(CriteriaType.IS_EQUAL, PROPERTY_ZIP_CODE,
Collections.singletonList(NEW_ZIP_CODE), Part.IgnoreCaseType.NEVER);
query = new CosmosQuery(criteria);
// Fire the same query but with different partition key value to make sure query plan caching is enabled
sqlQuerySpec = new FindQuerySpecGenerator().generateCosmos(query);
result = TestUtils.toList(cosmosTemplate.find(query, PartitionPerson.class,
PartitionPerson.class.getSimpleName()));

ConcurrentMap<String, PartitionedQueryExecutionInfo> postQueryCallCache = asyncDocumentClient.getQueryPlanCache();
assertThat(postQueryCallCache.containsKey(sqlQuerySpec.getQueryText())).isTrue();
assertThat(postQueryCallCache.size()).isEqualTo(initialSize);
assertThat(result.size()).isEqualTo(1);
assertEquals(TEST_PERSON_2, result.get(0));
}

@Test
public void testFindIgnoreCaseWithPartition() {
Criteria criteria = Criteria.getInstance(CriteriaType.IS_EQUAL, PROPERTY_ZIP_CODE,
Expand Down Expand Up @@ -196,8 +240,6 @@ public void testDeleteByIdPartition() {

final List<PartitionPerson> inserted = TestUtils.toList(cosmosTemplate.findAll(PartitionPerson.class));
assertThat(inserted.size()).isEqualTo(2);
assertThat(inserted.get(0).getZipCode()).isEqualTo(TEST_PERSON.getZipCode());
assertThat(inserted.get(1).getZipCode()).isEqualTo(TEST_PERSON.getZipCode());

cosmosTemplate.deleteById(PartitionPerson.class.getSimpleName(),
TEST_PERSON.getId(), new PartitionKey(TEST_PERSON.getZipCode()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
package com.azure.spring.data.cosmos.core;

import com.azure.cosmos.CosmosAsyncClient;
import com.azure.cosmos.CosmosBridgeInternal;
import com.azure.cosmos.CosmosClientBuilder;
import com.azure.cosmos.implementation.AsyncDocumentClient;
import com.azure.cosmos.implementation.query.PartitionedQueryExecutionInfo;
import com.azure.cosmos.models.PartitionKey;
import com.azure.cosmos.models.SqlQuerySpec;
import com.azure.spring.data.cosmos.CosmosFactory;
import com.azure.spring.data.cosmos.ReactiveIntegrationTestCollectionManager;
import com.azure.spring.data.cosmos.common.TestConstants;
import com.azure.spring.data.cosmos.config.CosmosConfig;
import com.azure.spring.data.cosmos.core.convert.MappingCosmosConverter;
import com.azure.spring.data.cosmos.core.generator.FindQuerySpecGenerator;
import com.azure.spring.data.cosmos.core.mapping.CosmosMappingContext;
import com.azure.spring.data.cosmos.core.query.CosmosQuery;
import com.azure.spring.data.cosmos.core.query.Criteria;
Expand All @@ -35,7 +40,9 @@

import java.util.Collections;
import java.util.UUID;
import java.util.concurrent.ConcurrentMap;

import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertEquals;
Expand All @@ -49,11 +56,12 @@ public class ReactiveCosmosTemplatePartitionIT {

private static final PartitionPerson TEST_PERSON_2 = new PartitionPerson(TestConstants.ID_2,
TestConstants.NEW_FIRST_NAME,
TEST_PERSON.getZipCode(), TestConstants.HOBBIES, TestConstants.ADDRESSES);
TestConstants.NEW_ZIP_CODE, TestConstants.HOBBIES, TestConstants.ADDRESSES);

@ClassRule
public static final ReactiveIntegrationTestCollectionManager collectionManager = new ReactiveIntegrationTestCollectionManager();

private static CosmosFactory cosmosFactory;
private static ReactiveCosmosTemplate cosmosTemplate;
private static String containerName;
private static CosmosEntityInformation<PartitionPerson, String> personInfo;
Expand All @@ -68,8 +76,10 @@ public class ReactiveCosmosTemplatePartitionIT {
@Before
public void setUp() throws ClassNotFoundException {
if (cosmosTemplate == null) {
// Enable Query plan caching for testing
System.setProperty("COSMOS.QUERYPLAN_CACHING_ENABLED", "true");
CosmosAsyncClient client = CosmosFactory.createCosmosAsyncClient(cosmosClientBuilder);
final CosmosFactory dbFactory = new CosmosFactory(client, TestConstants.DB_NAME);
cosmosFactory = new CosmosFactory(client, TestConstants.DB_NAME);

final CosmosMappingContext mappingContext = new CosmosMappingContext();
personInfo =
Expand All @@ -80,7 +90,7 @@ public void setUp() throws ClassNotFoundException {

final MappingCosmosConverter dbConverter = new MappingCosmosConverter(mappingContext,
null);
cosmosTemplate = new ReactiveCosmosTemplate(dbFactory, cosmosConfig, dbConverter);
cosmosTemplate = new ReactiveCosmosTemplate(cosmosFactory, cosmosConfig, dbConverter);
}
collectionManager.ensureContainersCreatedAndEmpty(cosmosTemplate, PartitionPerson.class);
cosmosTemplate.insert(TEST_PERSON).block();
Expand All @@ -100,6 +110,46 @@ public void testFindWithPartition() {
}).verifyComplete();
}

@Test
public void testFindWithPartitionWithQueryPlanCachingEnabled() {
Criteria criteria = Criteria.getInstance(CriteriaType.IS_EQUAL, TestConstants.PROPERTY_ZIP_CODE,
Collections.singletonList(TestConstants.ZIP_CODE), Part.IgnoreCaseType.NEVER);
CosmosQuery query = new CosmosQuery(criteria);
SqlQuerySpec sqlQuerySpec = new FindQuerySpecGenerator().generateCosmos(query);
Flux<PartitionPerson> partitionPersonFlux = cosmosTemplate.find(query,
PartitionPerson.class,
PartitionPerson.class.getSimpleName());
StepVerifier.create(partitionPersonFlux).consumeNextWith(actual -> {
Assert.assertThat(actual.getFirstName(), is(equalTo(TEST_PERSON.getFirstName())));
Assert.assertThat(actual.getZipCode(), is(equalTo(TEST_PERSON.getZipCode())));
}).verifyComplete();

CosmosAsyncClient cosmosAsyncClient = cosmosFactory.getCosmosAsyncClient();
AsyncDocumentClient asyncDocumentClient = CosmosBridgeInternal.getAsyncDocumentClient(cosmosAsyncClient);
ConcurrentMap<String, PartitionedQueryExecutionInfo> initialCache = asyncDocumentClient.getQueryPlanCache();
assertThat(initialCache.containsKey(sqlQuerySpec.getQueryText())).isTrue();
int initialSize = initialCache.size();

cosmosTemplate.insert(TEST_PERSON_2, new PartitionKey(TEST_PERSON_2.getZipCode())).block();

// Fire the same query with different partition key value to make sure query plan caching is enabled
criteria = Criteria.getInstance(CriteriaType.IS_EQUAL, TestConstants.PROPERTY_ZIP_CODE,
Collections.singletonList(TestConstants.NEW_ZIP_CODE), Part.IgnoreCaseType.NEVER);
query = new CosmosQuery(criteria);
sqlQuerySpec = new FindQuerySpecGenerator().generateCosmos(query);
partitionPersonFlux = cosmosTemplate.find(query,
PartitionPerson.class,
PartitionPerson.class.getSimpleName());
StepVerifier.create(partitionPersonFlux).consumeNextWith(actual -> {
Assert.assertThat(actual.getFirstName(), is(equalTo(TEST_PERSON_2.getFirstName())));
Assert.assertThat(actual.getZipCode(), is(equalTo(TEST_PERSON_2.getZipCode())));
}).verifyComplete();

ConcurrentMap<String, PartitionedQueryExecutionInfo> postQueryCallCache = asyncDocumentClient.getQueryPlanCache();
assertThat(postQueryCallCache.containsKey(sqlQuerySpec.getQueryText())).isTrue();
assertThat(postQueryCallCache.size()).isEqualTo(initialSize);
}

@Test
public void testFindIgnoreCaseWithPartition() {
final Criteria criteria = Criteria.getInstance(CriteriaType.IS_EQUAL, TestConstants.PROPERTY_ZIP_CODE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import static com.azure.spring.data.cosmos.core.convert.MappingCosmosConverter.toCosmosDbValue;
Expand All @@ -31,9 +31,9 @@ public abstract class AbstractQueryGenerator {
protected AbstractQueryGenerator() {
}

private String generateQueryParameter(@NonNull String subject) {
private String generateQueryParameter(@NonNull String subject, int counter) {
// user.name, user['name'] or user["first name"] are not valid sql parameter identifiers.
return subject.replaceAll("[^a-zA-Z\\d]", "_") + UUID.randomUUID().toString().replaceAll("-", "_");
return subject.replaceAll("[^a-zA-Z\\d]", "_") + counter;
}

private String generateUnaryQuery(@NonNull Criteria criteria) {
Expand All @@ -48,14 +48,14 @@ private String generateUnaryQuery(@NonNull Criteria criteria) {
}
}

private String generateBinaryQuery(@NonNull Criteria criteria, @NonNull List<Pair<String, Object>> parameters) {
private String generateBinaryQuery(@NonNull Criteria criteria, @NonNull List<Pair<String, Object>> parameters, int counter) {
Assert.isTrue(criteria.getSubjectValues().size() == 1,
"Binary criteria should have only one subject value");
Assert.isTrue(CriteriaType.isBinary(criteria.getType()), "Criteria type should be binary operation");

final String subject = criteria.getSubject();
final Object subjectValue = toCosmosDbValue(criteria.getSubjectValues().get(0));
final String parameter = generateQueryParameter(subject);
final String parameter = generateQueryParameter(subject, counter);
final Part.IgnoreCaseType ignoreCase = criteria.getIgnoreCase();
final String sqlKeyword = criteria.getType().getSqlKeyword();
parameters.add(Pair.of(parameter, subjectValue));
Expand Down Expand Up @@ -103,14 +103,14 @@ private String getFunctionCondition(final Part.IgnoreCaseType ignoreCase, final
}
}

private String generateBetween(@NonNull Criteria criteria, @NonNull List<Pair<String, Object>> parameters) {
private String generateBetween(@NonNull Criteria criteria, @NonNull List<Pair<String, Object>> parameters, int counter) {
final String subject = criteria.getSubject();
final Object value1 = toCosmosDbValue(criteria.getSubjectValues().get(0));
final Object value2 = toCosmosDbValue(criteria.getSubjectValues().get(1));
final String subject1 = subject + "start";
final String subject2 = subject + "end";
final String parameter1 = generateQueryParameter(subject1);
final String parameter2 = generateQueryParameter(subject2);
final String parameter1 = generateQueryParameter(subject1, counter);
final String parameter2 = generateQueryParameter(subject2, counter);
final String keyword = criteria.getType().getSqlKeyword();

parameters.add(Pair.of(parameter1, value1));
Expand Down Expand Up @@ -152,7 +152,7 @@ private String generateInQuery(@NonNull Criteria criteria, @NonNull List<Pair<St
String.join(",", paras));
}

private String generateQueryBody(@NonNull Criteria criteria, @NonNull List<Pair<String, Object>> parameters) {
private String generateQueryBody(@NonNull Criteria criteria, @NonNull List<Pair<String, Object>> parameters, @NonNull final AtomicInteger counter) {
final CriteriaType type = criteria.getType();

switch (type) {
Expand All @@ -162,7 +162,7 @@ private String generateQueryBody(@NonNull Criteria criteria, @NonNull List<Pair<
case NOT_IN:
return generateInQuery(criteria, parameters);
case BETWEEN:
return generateBetween(criteria, parameters);
return generateBetween(criteria, parameters, counter.getAndIncrement());
case IS_NULL:
case IS_NOT_NULL:
case FALSE:
Expand All @@ -180,14 +180,14 @@ private String generateQueryBody(@NonNull Criteria criteria, @NonNull List<Pair<
case ENDS_WITH:
case STARTS_WITH:
case ARRAY_CONTAINS:
return generateBinaryQuery(criteria, parameters);
return generateBinaryQuery(criteria, parameters, counter.getAndIncrement());
case AND:
case OR:
Assert.isTrue(criteria.getSubCriteria().size() == 2,
"criteria should have two SubCriteria");

final String left = generateQueryBody(criteria.getSubCriteria().get(0), parameters);
final String right = generateQueryBody(criteria.getSubCriteria().get(1), parameters);
final String left = generateQueryBody(criteria.getSubCriteria().get(0), parameters, counter);
final String right = generateQueryBody(criteria.getSubCriteria().get(1), parameters, counter);

return generateClosedQuery(left, right, type);
default:
Expand All @@ -204,9 +204,9 @@ private String generateQueryBody(@NonNull Criteria criteria, @NonNull List<Pair<
* @return A pair tuple compose of Sql query.
*/
@NonNull
private Pair<String, List<Pair<String, Object>>> generateQueryBody(@NonNull CosmosQuery query) {
private Pair<String, List<Pair<String, Object>>> generateQueryBody(@NonNull CosmosQuery query, @NonNull final AtomicInteger counter) {
final List<Pair<String, Object>> parameters = new ArrayList<>();
String queryString = this.generateQueryBody(query.getCriteria(), parameters);
String queryString = this.generateQueryBody(query.getCriteria(), parameters, counter);

if (StringUtils.hasText(queryString)) {
queryString = String.join(" ", "WHERE", queryString);
Expand Down Expand Up @@ -248,7 +248,8 @@ private String generateQueryTail(@NonNull CosmosQuery query) {

protected SqlQuerySpec generateCosmosQuery(@NonNull CosmosQuery query,
@NonNull String queryHead) {
final Pair<String, List<Pair<String, Object>>> queryBody = generateQueryBody(query);
final AtomicInteger counter = new AtomicInteger();
final Pair<String, List<Pair<String, Object>>> queryBody = generateQueryBody(query, counter);
String queryString = String.join(" ", queryHead, queryBody.getFirst(), generateQueryTail(query));
final List<Pair<String, Object>> parameters = queryBody.getSecond();

Expand Down

0 comments on commit b5e6f63

Please sign in to comment.