Skip to content

Instantly share code, notes, and snippets.

@tcollins
Last active September 5, 2024 13:26
Show Gist options
  • Save tcollins/0ebd1dfa78028ecdef0b to your computer and use it in GitHub Desktop.
Save tcollins/0ebd1dfa78028ecdef0b to your computer and use it in GitHub Desktop.
Spring Data JPA - Limit results when using Specifications without an unnecessary count query being executed
If you use the findAll(Specification, Pageable) method, a count query is first executed and then the
data query is executed if the count returns a value greater than the offset.
For what I was doing I did not need pageable, but simply wanted to limit my results. This is easy
to do with static named queries and methodNameMagicGoodness queries, but from my research (googling
for a few hours) I couldn't find a way to do it with dynamic criteria queries using Specifications.
During my search I found two things that helped me to figure out how to just do it myself.
1.) A stackoverflow question.
How to disable count when Specification and Pageable are used together?
http://stackoverflow.com/questions/26738199/how-to-disable-count-when-specification-and-pageable-are-used-together
(where I will add a link to this gist)
2.) Spring documentation - Adding custom behavior to all repositories
http://docs.spring.io/spring-data/data-jpa/docs/current/reference/html/#repositories.custom-behaviour-for-all-repositories
I followed the Spring documentation pretty closely and got this all working pretty quickly without
any real problems.
@NoRepositoryBean
public interface BaseRepository<T, ID extends Serializable> extends JpaRepository<T, ID> {
List<T> findAll(Specification<T> spec, int offset, int maxResults, Sort sort);
List<T> findAll(Specification<T> spec, int offset, int maxResults);
}
public class BaseRepositoryFactoryBean<R extends JpaRepository<T, I>, T, I extends Serializable> extends JpaRepositoryFactoryBean<R, T, I> {
@SuppressWarnings("rawtypes")
@Override
protected RepositoryFactorySupport createRepositoryFactory(EntityManager entityManager) {
return new BaseRepositoryFactory(entityManager);
}
private static class BaseRepositoryFactory<T, I extends Serializable> extends JpaRepositoryFactory {
private final EntityManager em;
public BaseRepositoryFactory(EntityManager em) {
super(em);
this.em = em;
}
@SuppressWarnings({ "unchecked", "rawtypes", "hiding" })
protected <T, ID extends Serializable> SimpleJpaRepository<?, ?> getTargetRepository(RepositoryMetadata metadata, EntityManager entityManager) {
SimpleJpaRepository<?, ?> repo = new BaseRepositoryImpl(metadata.getDomainType(), entityManager);
return repo;
}
protected Class<?> getRepositoryBaseClass(RepositoryMetadata metadata) {
return BaseRepositoryImpl.class;
}
}
}
public class BaseRepositoryImpl<T, ID extends Serializable> extends SimpleJpaRepository<T, ID> implements BaseRepository<T, ID> {
private final EntityManager entityManager;
public BaseRepositoryImpl(Class<T> domainClass, EntityManager entityManager) {
super(domainClass, entityManager);
this.entityManager = entityManager;
}
public List<T> findAll(Specification<T> spec, int offset, int maxResults) {
return findAll(spec, offset, maxResults, null);
}
public List<T> findAll(Specification<T> spec, int offset, int maxResults, Sort sort) {
TypedQuery<T> query = getQuery(spec, sort);
if (offset < 0) {
throw new IllegalArgumentException("Offset must not be less than zero!");
}
if (maxResults < 1) {
throw new IllegalArgumentException("Max results must not be less than one!");
}
query.setFirstResult(offset);
query.setMaxResults(maxResults);
return query.getResultList();
}
}
@SpringBootApplication
@EnableJpaRepositories(repositoryFactoryBeanClass = BaseRepositoryFactoryBean.class)
public class MySpringBootApplication {
public static void main(String[] args) {
SpringApplication app = new SpringApplication(MySpringBootApplication.class);
app.run(args);
}
}
// This is just to show an example of a repo
public interface UserRepository extends BaseRepository<User, Long>, JpaSpecificationExecutor<User> {
}
@josergdev
Copy link

josergdev commented Sep 4, 2024

Slice is simply a specialization of Window, the same goal can be achieved using the
<S extends T, R> R findBy(Specification<T> spec, Function<FluentQuery.FetchableFluentQuery<S>, R> queryFunction);
method of JpaSpecificationExecutor

https://gist.github.com/josergdev/06c82891a719eca4834410339885ad23

package dev.joserg.jpa;

import static org.springframework.data.domain.ScrollPosition.offset;

import java.util.function.Function;

import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.ScrollPosition;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.SliceImpl;
import org.springframework.data.domain.Sort;
import org.springframework.data.domain.Window;
import org.springframework.data.jpa.domain.Specification;
import org.springframework.data.jpa.repository.JpaSpecificationExecutor;
import org.springframework.data.repository.query.FluentQuery.FetchableFluentQuery;

public interface SliceSpecificationExecutor<T> extends JpaSpecificationExecutor<T> {

  default Window<T> findAllWindowed(Specification<T> spec, Sort sort, int limit, ScrollPosition scrollPosition) {
    return this.findBy(spec, toWindow(sort, limit, scrollPosition));
  }

  default Window<T> findAllWindowed(Specification<T> spec, Sort sort, ScrollPosition scrollPosition) {
    return this.findBy(spec, toWindow(sort, scrollPosition));
  }

  default Window<T> findAllWindowed(Specification<T> spec, ScrollPosition scrollPosition) {
    return this.findAllWindowed(spec, Sort.unsorted(), scrollPosition);
  }

  default Slice<T> findAllSliced(Specification<T> spec, Pageable pageable) {
    final var window = pageable.isUnpaged()
        ? this.findAllWindowed(spec, pageable.getSort(), offset())
        : this.findAllWindowed(spec, pageable.getSort(), pageable.getPageSize(), offset(pageable.getOffset()));
    return new SliceImpl<>(window.getContent(), pageable, window.hasNext());
  }

  private static <T> Function<FetchableFluentQuery<T>, Window<T>> toWindow(Sort sort, int limit, ScrollPosition scrollPosition) {
    return fetchableFluentQuery -> fetchableFluentQuery.sortBy(sort).limit(limit).scroll(scrollPosition);
  }

  private static <T> Function<FetchableFluentQuery<T>, Window<T>> toWindow(Sort sort, ScrollPosition scrollPosition) {
    return fetchableFluentQuery -> fetchableFluentQuery.sortBy(sort).scroll(scrollPosition);
  }

}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment