Skip to content

Commit

Permalink
Polishing.
Browse files Browse the repository at this point in the history
Replace code duplications with doWithBatch(…) method. Return most concrete type in DefaultDataAccessStrategy and MyBatisDataAccessStrategy.

See #1623
Original pull request: #1897
  • Loading branch information
mp911de committed Oct 1, 2024
1 parent c4f62e9 commit 7cf81ae
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
package org.springframework.data.jdbc.core;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
Expand Down Expand Up @@ -56,6 +58,7 @@
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;

/**
* {@link JdbcAggregateOperations} implementation, storing aggregates in and obtaining them from a JDBC data store.
Expand Down Expand Up @@ -173,19 +176,8 @@ public <T> T save(T instance) {

@Override
public <T> List<T> saveAll(Iterable<T> instances) {

Assert.notNull(instances, "Aggregate instances must not be null");

if (!instances.iterator().hasNext()) {
return Collections.emptyList();
}

List<EntityAndChangeCreator<T>> entityAndChangeCreators = new ArrayList<>();
for (T instance : instances) {
verifyIdProperty(instance);
entityAndChangeCreators.add(new EntityAndChangeCreator<>(instance, changeCreatorSelectorForSave(instance)));
}
return performSaveAll(entityAndChangeCreators);
return doWithBatch(instances, entity -> changeCreatorSelectorForSave(entity).apply(entity), this::verifyIdProperty,
this::performSaveAll);
}

/**
Expand All @@ -206,21 +198,7 @@ public <T> T insert(T instance) {

@Override
public <T> List<T> insertAll(Iterable<T> instances) {

Assert.notNull(instances, "Aggregate instances must not be null");

if (!instances.iterator().hasNext()) {
return Collections.emptyList();
}

List<EntityAndChangeCreator<T>> entityAndChangeCreators = new ArrayList<>();
for (T instance : instances) {

Function<T, RootAggregateChange<T>> changeCreator = entity -> createInsertChange(prepareVersionForInsert(entity));
EntityAndChangeCreator<T> entityChange = new EntityAndChangeCreator<>(instance, changeCreator);
entityAndChangeCreators.add(entityChange);
}
return performSaveAll(entityAndChangeCreators);
return doWithBatch(instances, entity -> createInsertChange(prepareVersionForInsert(entity)), this::performSaveAll);
}

/**
Expand All @@ -241,21 +219,35 @@ public <T> T update(T instance) {

@Override
public <T> List<T> updateAll(Iterable<T> instances) {
return doWithBatch(instances, entity -> createUpdateChange(prepareVersionForUpdate(entity)), this::performSaveAll);
}

private <T> List<T> doWithBatch(Iterable<T> iterable, Function<T, RootAggregateChange<T>> changeCreator,
Function<List<EntityAndChangeCreator<T>>, List<T>> performFunction) {
return doWithBatch(iterable, changeCreator, entity -> {}, performFunction);
}

Assert.notNull(instances, "Aggregate instances must not be null");
private <T> List<T> doWithBatch(Iterable<T> iterable, Function<T, RootAggregateChange<T>> changeCreator,
Consumer<T> beforeEntityChange, Function<List<EntityAndChangeCreator<T>>, List<T>> performFunction) {

if (!instances.iterator().hasNext()) {
Assert.notNull(iterable, "Aggregate instances must not be null");

if (ObjectUtils.isEmpty(iterable)) {
return Collections.emptyList();
}

List<EntityAndChangeCreator<T>> entityAndChangeCreators = new ArrayList<>();
for (T instance : instances) {
List<EntityAndChangeCreator<T>> entityAndChangeCreators = new ArrayList<>(
iterable instanceof Collection<?> c ? c.size() : 16);

for (T instance : iterable) {

beforeEntityChange.accept(instance);

Function<T, RootAggregateChange<T>> changeCreator = entity -> createUpdateChange(prepareVersionForUpdate(entity));
EntityAndChangeCreator<T> entityChange = new EntityAndChangeCreator<>(instance, changeCreator);
entityAndChangeCreators.add(entityChange);
}
return performSaveAll(entityAndChangeCreators);

return performFunction.apply(entityAndChangeCreators);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,12 @@ public <T> T findById(Object id, Class<T> domainType) {
}

@Override
public <T> Iterable<T> findAll(Class<T> domainType) {
public <T> List<T> findAll(Class<T> domainType) {
return operations.query(sql(domainType).getFindAll(), getEntityRowMapper(domainType));
}

@Override
public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {

if (!ids.iterator().hasNext()) {
return Collections.emptyList();
Expand All @@ -290,7 +290,7 @@ public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {

@Override
@SuppressWarnings("unchecked")
public Iterable<Object> findAllByPath(Identifier identifier,
public List<Object> findAllByPath(Identifier identifier,
PersistentPropertyPath<? extends RelationalPersistentProperty> propertyPath) {

Assert.notNull(identifier, "identifier must not be null");
Expand Down Expand Up @@ -338,12 +338,12 @@ public <T> boolean existsById(Object id, Class<T> domainType) {
}

@Override
public <T> Iterable<T> findAll(Class<T> domainType, Sort sort) {
public <T> List<T> findAll(Class<T> domainType, Sort sort) {
return operations.query(sql(domainType).getFindAll(sort), getEntityRowMapper(domainType));
}

@Override
public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {
public <T> List<T> findAll(Class<T> domainType, Pageable pageable) {
return operations.query(sql(domainType).getFindAll(pageable), getEntityRowMapper(domainType));
}

Expand All @@ -361,7 +361,7 @@ public <T> Optional<T> findOne(Query query, Class<T> domainType) {
}

@Override
public <T> Iterable<T> findAll(Query query, Class<T> domainType) {
public <T> List<T> findAll(Query query, Class<T> domainType) {

MapSqlParameterSource parameterSource = new MapSqlParameterSource();
String sqlQuery = sql(domainType).selectByQuery(query, parameterSource);
Expand All @@ -370,7 +370,7 @@ public <T> Iterable<T> findAll(Query query, Class<T> domainType) {
}

@Override
public <T> Iterable<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
public <T> List<T> findAll(Query query, Class<T> domainType, Pageable pageable) {

MapSqlParameterSource parameterSource = new MapSqlParameterSource();
String sqlQuery = sql(domainType).selectByQuery(query, parameterSource, pageable);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,21 +256,21 @@ public <T> T findById(Object id, Class<T> domainType) {
}

@Override
public <T> Iterable<T> findAll(Class<T> domainType) {
public <T> List<T> findAll(Class<T> domainType) {

String statement = namespace(domainType) + ".findAll";
MyBatisContext parameter = new MyBatisContext(null, null, domainType, Collections.emptyMap());
return sqlSession().selectList(statement, parameter);
}

@Override
public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
return sqlSession().selectList(namespace(domainType) + ".findAllById",
new MyBatisContext(ids, null, domainType, Collections.emptyMap()));
}

@Override
public Iterable<Object> findAllByPath(Identifier identifier,
public List<Object> findAllByPath(Identifier identifier,
PersistentPropertyPath<? extends RelationalPersistentProperty> path) {

String statementName = namespace(getOwnerTyp(path)) + ".findAllByPath-" + path.toDotPath();
Expand All @@ -288,7 +288,7 @@ public <T> boolean existsById(Object id, Class<T> domainType) {
}

@Override
public <T> Iterable<T> findAll(Class<T> domainType, Sort sort) {
public <T> List<T> findAll(Class<T> domainType, Sort sort) {

Map<String, Object> additionalContext = new HashMap<>();
additionalContext.put("sort", sort);
Expand All @@ -297,7 +297,7 @@ public <T> Iterable<T> findAll(Class<T> domainType, Sort sort) {
}

@Override
public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {
public <T> List<T> findAll(Class<T> domainType, Pageable pageable) {

Map<String, Object> additionalContext = new HashMap<>();
additionalContext.put("pageable", pageable);
Expand All @@ -311,12 +311,12 @@ public <T> Optional<T> findOne(Query query, Class<T> probeType) {
}

@Override
public <T> Iterable<T> findAll(Query query, Class<T> probeType) {
public <T> List<T> findAll(Query query, Class<T> probeType) {
throw new UnsupportedOperationException("Not implemented");
}

@Override
public <T> Iterable<T> findAll(Query query, Class<T> probeType, Pageable pageable) {
public <T> List<T> findAll(Query query, Class<T> probeType, Pageable pageable) {
throw new UnsupportedOperationException("Not implemented");
}

Expand Down

0 comments on commit 7cf81ae

Please sign in to comment.