Skip to content

Commit

Permalink
Add StatementFilterFunction to R2dbcEntityTemplate.
Browse files Browse the repository at this point in the history
See #1652
  • Loading branch information
mp911de committed Oct 1, 2024
1 parent 90b6d8e commit 4834d08
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAw

private @Nullable ReactiveEntityCallbacks entityCallbacks;

private Function<Statement, Statement> statementFilterFunction = Function.identity();

/**
* Create a new {@link R2dbcEntityTemplate} given {@link ConnectionFactory}.
*
Expand Down Expand Up @@ -174,6 +176,19 @@ public R2dbcEntityTemplate(DatabaseClient databaseClient, ReactiveDataAccessStra
this.projectionFactory = new SpelAwareProxyProjectionFactory();
}

/**
* Set a {@link Function Statement Filter Function} that is applied to every {@link Statement}.
*
* @param statementFilterFunction must not be {@literal null}.
* @since 3.4
*/
public void setStatementFilterFunction(Function<Statement, Statement> statementFilterFunction) {

Assert.notNull(statementFilterFunction, "StatementFilterFunction must not be null");

this.statementFilterFunction = statementFilterFunction;
}

@Override
public DatabaseClient getDatabaseClient() {
return this.databaseClient;
Expand Down Expand Up @@ -274,6 +289,7 @@ Mono<Long> doCount(Query query, Class<?> entityClass, SqlIdentifier tableName) {
PreparedOperation<?> operation = statementMapper.getMappedObject(selectSpec);

return this.databaseClient.sql(operation) //
.filter(statementFilterFunction) //
.map((r, md) -> r.get(0, Long.class)) //
.first() //
.defaultIfEmpty(0L);
Expand Down Expand Up @@ -302,6 +318,7 @@ Mono<Boolean> doExists(Query query, Class<?> entityClass, SqlIdentifier tableNam
PreparedOperation<?> operation = statementMapper.getMappedObject(selectSpec);

return this.databaseClient.sql(operation) //
.filter(statementFilterFunction) //
.map((r, md) -> r) //
.first() //
.hasElement();
Expand Down Expand Up @@ -362,7 +379,7 @@ private <T> RowsFetchSpec<T> doSelect(Query query, Class<?> entityType, SqlIdent
PreparedOperation<?> operation = statementMapper.getMappedObject(selectSpec);

return getRowsFetchSpec(
databaseClient.sql(operation).filter(filterFunction),
databaseClient.sql(operation).filter(statementFilterFunction.andThen(filterFunction)),
entityType,
returnType
);
Expand Down Expand Up @@ -397,7 +414,7 @@ Mono<Long> doUpdate(Query query, Update update, Class<?> entityClass, SqlIdentif
}

PreparedOperation<?> operation = statementMapper.getMappedObject(selectSpec);
return this.databaseClient.sql(operation).fetch().rowsUpdated();
return this.databaseClient.sql(operation).filter(statementFilterFunction).fetch().rowsUpdated();
}

@Override
Expand All @@ -422,7 +439,7 @@ Mono<Long> doDelete(Query query, Class<?> entityClass, SqlIdentifier tableName)
}

PreparedOperation<?> operation = statementMapper.getMappedObject(deleteSpec);
return this.databaseClient.sql(operation).fetch().rowsUpdated().defaultIfEmpty(0L);
return this.databaseClient.sql(operation).filter(statementFilterFunction).fetch().rowsUpdated().defaultIfEmpty(0L);
}

// -------------------------------------------------------------------------
Expand All @@ -441,7 +458,8 @@ public <T> RowsFetchSpec<T> query(PreparedOperation<?> operation, Class<?> entit
Assert.notNull(operation, "PreparedOperation must not be null");
Assert.notNull(entityClass, "Entity class must not be null");

return new EntityCallbackAdapter<>(getRowsFetchSpec(databaseClient.sql(operation), entityClass, resultType),
return new EntityCallbackAdapter<>(
getRowsFetchSpec(databaseClient.sql(operation).filter(statementFilterFunction), entityClass, resultType),
getTableNameOrEmpty(entityClass));
}

Expand All @@ -451,7 +469,8 @@ public <T> RowsFetchSpec<T> query(PreparedOperation<?> operation, BiFunction<Row
Assert.notNull(operation, "PreparedOperation must not be null");
Assert.notNull(rowMapper, "Row mapper must not be null");

return new EntityCallbackAdapter<>(databaseClient.sql(operation).map(rowMapper), SqlIdentifier.EMPTY);
return new EntityCallbackAdapter<>(databaseClient.sql(operation).filter(statementFilterFunction).map(rowMapper),
SqlIdentifier.EMPTY);
}

@Override
Expand All @@ -462,7 +481,8 @@ public <T> RowsFetchSpec<T> query(PreparedOperation<?> operation, Class<?> entit
Assert.notNull(entityClass, "Entity class must not be null");
Assert.notNull(rowMapper, "Row mapper must not be null");

return new EntityCallbackAdapter<>(databaseClient.sql(operation).map(rowMapper), getTableNameOrEmpty(entityClass));
return new EntityCallbackAdapter<>(databaseClient.sql(operation).filter(statementFilterFunction).map(rowMapper),
getTableNameOrEmpty(entityClass));
}

// -------------------------------------------------------------------------
Expand Down Expand Up @@ -541,6 +561,8 @@ private <T> Mono<T> doInsert(T entity, SqlIdentifier tableName, OutboundRow outb
return this.databaseClient.sql(operation) //
.filter(statement -> {

statement = statementFilterFunction.apply(statement);

if (identifierColumns.isEmpty()) {
return statement.returnGeneratedValues();
}
Expand Down Expand Up @@ -632,6 +654,7 @@ private <T> Mono<T> doUpdate(T entity, SqlIdentifier tableName, RelationalPersis
PreparedOperation<?> operation = mapper.getMappedObject(updateSpec);

return this.databaseClient.sql(operation) //
.filter(statementFilterFunction) //
.fetch() //
.rowsUpdated() //
.handle((rowsUpdated, sink) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,6 @@ void shouldProjectCountResultWithoutId() {
@Test // GH-469
void shouldExistsByCriteria() {

MockRowMetadata metadata = MockRowMetadata.builder()
.columnMetadata(MockColumnMetadata.builder().name("name").type(R2dbcType.VARCHAR).build()).build();
MockResult result = MockResult.builder().row(MockRow.builder().identified(0, Long.class, 1L).build()).build();

recorder.addStubbing(s -> s.startsWith("SELECT"), result);
Expand Down Expand Up @@ -654,6 +652,24 @@ void projectDtoShouldReadPropertiesOnce() {
}).verifyComplete();
}

@Test // GH-1652
void shouldConsiderFilterFunction() {

MockResult result = MockResult.builder().row(MockRow.builder().identified(0, Long.class, 1L).build()).build();

recorder.addStubbing(s -> s.startsWith("SELECT"), result);

entityTemplate.setStatementFilterFunction(statement -> statement.fetchSize(10));
entityTemplate.count(Query.empty(), Person.class) //
.as(StepVerifier::create) //
.expectNext(1L) //
.verifyComplete();

StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("SELECT"));

assertThat(statement.getFetchSize()).isEqualTo(10);
}

@ReadingConverter
static class PkConverter implements Converter<ByteBuffer, DoubleHolder> {

Expand Down

0 comments on commit 4834d08

Please sign in to comment.