From dd4da8a30c1913da6995157291606eb350a2553b Mon Sep 17 00:00:00 2001 From: Jens Schauder Date: Wed, 25 Sep 2024 09:07:04 +0200 Subject: [PATCH] Repository and template now return Lists. Closes #1623 --- .../data/jdbc/core/JdbcAggregateTemplate.java | 24 ++++++++----- .../support/SimpleJdbcRepository.java | 36 ++++++++++++------- 2 files changed, 39 insertions(+), 21 deletions(-) diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java index ad690dc273..dbef6d1e1a 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java @@ -52,6 +52,7 @@ import org.springframework.data.relational.core.mapping.event.*; import org.springframework.data.relational.core.query.Query; import org.springframework.data.support.PageableExecutionUtils; +import org.springframework.data.util.Streamable; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; @@ -171,7 +172,7 @@ public T save(T instance) { } @Override - public Iterable saveAll(Iterable instances) { + public List saveAll(Iterable instances) { Assert.notNull(instances, "Aggregate instances must not be null"); @@ -204,7 +205,7 @@ public T insert(T instance) { } @Override - public Iterable insertAll(Iterable instances) { + public List insertAll(Iterable instances) { Assert.notNull(instances, "Aggregate instances must not be null"); @@ -239,7 +240,7 @@ public T update(T instance) { } @Override - public Iterable updateAll(Iterable instances) { + public List updateAll(Iterable instances) { Assert.notNull(instances, "Aggregate instances must not be null"); @@ -298,7 +299,7 @@ public T findById(Object id, Class domainType) { } @Override - public Iterable findAll(Class domainType, Sort sort) { + public List findAll(Class domainType, Sort sort) { Assert.notNull(domainType, "Domain type must not be null"); @@ -323,8 +324,13 @@ public Optional findOne(Query query, Class domainType) { } @Override - public Iterable findAll(Query query, Class domainType) { - return accessStrategy.findAll(query, domainType); + public List findAll(Query query, Class domainType) { + + Iterable all = accessStrategy.findAll(query, domainType); + if (all instanceof List list) { + return list; + } + return Streamable.of(all).toList(); } @Override @@ -337,7 +343,7 @@ public Page findAll(Query query, Class domainType, Pageable pageable) } @Override - public Iterable findAll(Class domainType) { + public List findAll(Class domainType) { Assert.notNull(domainType, "Domain type must not be null"); @@ -346,7 +352,7 @@ public Iterable findAll(Class domainType) { } @Override - public Iterable findAllById(Iterable ids, Class domainType) { + public List findAllById(Iterable ids, Class domainType) { Assert.notNull(ids, "Ids must not be null"); Assert.notNull(domainType, "Domain type must not be null"); @@ -607,7 +613,7 @@ private MutableAggregateChange createDeletingChange(Class domainType) { return aggregateChange; } - private Iterable triggerAfterConvert(Iterable all) { + private List triggerAfterConvert(Iterable all) { List result = new ArrayList<>(); diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/support/SimpleJdbcRepository.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/support/SimpleJdbcRepository.java index 628c37afbf..8a55f9e668 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/support/SimpleJdbcRepository.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/support/SimpleJdbcRepository.java @@ -15,6 +15,7 @@ */ package org.springframework.data.jdbc.repository.support; +import java.util.List; import java.util.Optional; import java.util.function.Function; @@ -30,6 +31,7 @@ import org.springframework.data.repository.PagingAndSortingRepository; import org.springframework.data.repository.query.FluentQuery; import org.springframework.data.repository.query.QueryByExampleExecutor; +import org.springframework.data.util.Streamable; import org.springframework.transaction.annotation.Transactional; import org.springframework.util.Assert; @@ -70,8 +72,8 @@ public S save(S instance) { @Transactional @Override - public Iterable saveAll(Iterable entities) { - return entityOperations.saveAll(entities); + public List saveAll(Iterable entities) { + return asList(entityOperations.saveAll(entities)); } @Override @@ -85,13 +87,13 @@ public boolean existsById(ID id) { } @Override - public Iterable findAll() { - return entityOperations.findAll(entity.getType()); + public List findAll() { + return asList(entityOperations.findAll(entity.getType())); } @Override - public Iterable findAllById(Iterable ids) { - return entityOperations.findAllById(ids, entity.getType()); + public List findAllById(Iterable ids) { + return asList(entityOperations.findAllById(ids, entity.getType())); } @Override @@ -130,8 +132,8 @@ public void deleteAll() { } @Override - public Iterable findAll(Sort sort) { - return entityOperations.findAll(entity.getType(), sort); + public List findAll(Sort sort) { + return asList(entityOperations.findAll(entity.getType(), sort)); } @Override @@ -148,7 +150,7 @@ public Optional findOne(Example example) { } @Override - public Iterable findAll(Example example) { + public List findAll(Example example) { Assert.notNull(example, "Example must not be null"); @@ -156,13 +158,13 @@ public Iterable findAll(Example example) { } @Override - public Iterable findAll(Example example, Sort sort) { + public List findAll(Example example, Sort sort) { Assert.notNull(example, "Example must not be null"); Assert.notNull(sort, "Sort must not be null"); - return this.entityOperations.findAll(this.exampleMapper.getMappedExample(example).sort(sort), - example.getProbeType()); + return asList(this.entityOperations.findAll(this.exampleMapper.getMappedExample(example).sort(sort), + example.getProbeType())); } @Override @@ -200,4 +202,14 @@ public R findBy(Example example, Function List asList(Iterable iterable) { + + if (iterable instanceof List list) { + return list; + } + return Streamable.of(iterable).stream().toList(); + } + }