Skip to content

Instantly share code, notes, and snippets.

@cesar1000
Last active April 17, 2024 05:23
Show Gist options
  • Save cesar1000/b610c1c95169ca6e83c173d888321632 to your computer and use it in GitHub Desktop.
Save cesar1000/b610c1c95169ca6e83c173d888321632 to your computer and use it in GitHub Desktop.
Ordered unit tests in Gradle
implementation-class=com.twitter.gradle.plugin.orderedtest.OrderedTestPlugin
package com.twitter.gradle.plugin.orderedtest;
import org.gradle.api.file.EmptyFileVisitor;
import org.gradle.api.file.FileTree;
import org.gradle.api.file.FileVisitDetails;
import org.gradle.api.internal.file.RelativeFile;
import org.gradle.api.internal.tasks.testing.DefaultTestClassRunInfo;
import org.gradle.api.internal.tasks.testing.TestClassProcessor;
import org.gradle.api.internal.tasks.testing.TestClassRunInfo;
import org.gradle.api.internal.tasks.testing.detection.TestFrameworkDetector;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Random;
/**
* Based on {@link org.gradle.api.internal.tasks.testing.detection.DefaultTestClassScanner} from gradle REL_5.1
* Depending on the availability of a test framework detector, a detection or filename scan is
* performed to find test classes. Implements stable run ordering for the tests.
*/
@SuppressWarnings("all")
public class OrderedTestClassScanner implements Runnable {
private final FileTree mCandidateClassFiles;
private final TestFrameworkDetector mTestFrameworkDetector;
private final TestClassProcessor mTestClassProcessor;
private final long mTestSeed;
private final int mShardCount;
private final int mShardIndex;
public OrderedTestClassScanner(@NotNull FileTree candidateClassFiles,
TestFrameworkDetector testFrameworkDetector,
@NotNull TestClassProcessor testClassProcessor,
long testSeed, int shardCount, int shardIndex) {
mCandidateClassFiles = candidateClassFiles;
mTestFrameworkDetector = testFrameworkDetector;
mTestClassProcessor = testClassProcessor;
mTestSeed = testSeed;
mShardCount = shardCount;
mShardIndex = shardIndex;
}
@Override
public void run() {
if (mTestFrameworkDetector == null) {
final List<TestClassRunInfo> scanned = filenameScan();
final List<TestClassRunInfo> sorted = sort(scanned, mTestSeed,
(o1, o2) -> o1.getTestClassName().compareTo(o2.getTestClassName()));
final List<TestClassRunInfo> filtered = filter(sorted, mShardCount, mShardIndex);
filtered.forEach(s -> mTestClassProcessor.processTestClass(s));
} else {
final List<RelativeFile> scanned = detectionScan();
final List<RelativeFile> sorted = sort(scanned, mTestSeed,
new Comparator<RelativeFile>() {
@Override
public int compare(RelativeFile o1, RelativeFile o2) {
return o1.getFile().compareTo(o2.getFile());
}
});
final List<RelativeFile> filtered = filter(sorted, mShardCount, mShardIndex);
filtered.forEach(s -> mTestFrameworkDetector.processTestClass(s));
}
}
private List<RelativeFile> detectionScan() {
final List<RelativeFile> scanned = new ArrayList<>();
mTestFrameworkDetector.startDetection(mTestClassProcessor);
mCandidateClassFiles.visit(new ClassFileVisitor() {
@Override
public void visitClassFile(FileVisitDetails fileDetails) {
scanned.add(new RelativeFile(fileDetails.getFile(), fileDetails.getRelativePath()));
}
});
return scanned;
}
private List<TestClassRunInfo> filenameScan() {
final List<TestClassRunInfo> scanned = new ArrayList<>();
mCandidateClassFiles.visit(new ClassFileVisitor() {
@Override
public void visitClassFile(FileVisitDetails fileDetails) {
final String className = fileDetails.getRelativePath().getPathString().replaceAll("\\.class", "")
.replace('/', '.');
final TestClassRunInfo testClass = new DefaultTestClassRunInfo(className);
scanned.add(testClass);
}
});
return scanned;
}
private abstract class ClassFileVisitor extends EmptyFileVisitor {
@Override
public void visitFile(FileVisitDetails fileDetails) {
final File file = fileDetails.getFile();
if (file.getAbsolutePath().endsWith(".class")) {
visitClassFile(fileDetails);
}
}
public abstract void visitClassFile(FileVisitDetails fileDetails);
}
private static <T> List<T> sort(@NotNull List<T> items, long testSeed, @Nullable Comparator<? super T> cmp) {
if (testSeed == 0) {
// special case natural order (lex for strings), still 'random' enough in the big picture
Collections.sort(items, cmp);
} else {
// start from ascending natural order, then random shuffle with known seed, so we get a stable random sort
Collections.sort(items, null);
Collections.shuffle(items, new Random(testSeed));
}
return items;
}
private static <T> List<T> filter(@NotNull List<T> items, int shardCount, int shardIndex) {
if (shardCount != 0) {
final List<T> filtered = new ArrayList<>(items.size() / shardCount + 1);
for (int i = shardIndex; i < items.size(); i += shardCount) {
filtered.add(items.get(i));
}
return filtered;
} else {
return items;
}
}
}
package com.twitter.gradle.plugin.orderedtest;
import com.google.common.collect.ImmutableSet;
import org.gradle.api.file.FileTree;
import org.gradle.api.internal.DocumentationRegistry;
import org.gradle.api.internal.classpath.ModuleRegistry;
import org.gradle.api.internal.tasks.testing.JvmTestExecutionSpec;
import org.gradle.api.internal.tasks.testing.TestClassProcessor;
import org.gradle.api.internal.tasks.testing.TestExecuter;
import org.gradle.api.internal.tasks.testing.TestFramework;
import org.gradle.api.internal.tasks.testing.TestResultProcessor;
import org.gradle.api.internal.tasks.testing.WorkerTestClassProcessorFactory;
import org.gradle.api.internal.tasks.testing.detection.TestFrameworkDetector;
import org.gradle.api.internal.tasks.testing.processors.MaxNParallelTestClassProcessor;
import org.gradle.api.internal.tasks.testing.processors.RestartEveryNTestClassProcessor;
import org.gradle.api.internal.tasks.testing.processors.TestMainAction;
import org.gradle.api.internal.tasks.testing.worker.ForkingTestClassProcessor;
import org.gradle.api.logging.Logger;
import org.gradle.api.logging.Logging;
import org.gradle.internal.Factory;
import org.gradle.internal.actor.ActorFactory;
import org.gradle.internal.operations.BuildOperationExecutor;
import org.gradle.internal.time.Clock;
import org.gradle.internal.work.WorkerLeaseRegistry;
import org.gradle.process.internal.worker.WorkerProcessFactory;
import java.io.File;
import java.util.Set;
/**
* Based on DefaultTestExecuter from gradle REL_4.4.0.
*/
public class OrderedTestExecuter implements TestExecuter<JvmTestExecutionSpec> {
private static final Logger LOGGER = Logging.getLogger(OrderedTestExecuter.class);
private final WorkerProcessFactory mWorkerFactory;
private final ActorFactory mActorFactory;
private final ModuleRegistry mModuleRegistry;
private final WorkerLeaseRegistry mWorkerLeaseRegistry;
private final DocumentationRegistry mDocumentationRegistry;
private final BuildOperationExecutor mBuildOperationExecutor;
private final int mMaxWorkerCount;
private final Clock mClock;
private final long mTestSeed;
private final int mShardCount;
private final int mShardIndex;
public OrderedTestExecuter(WorkerProcessFactory workerFactory, ActorFactory actorFactory,
ModuleRegistry moduleRegistry, WorkerLeaseRegistry workerLeaseRegistry,
DocumentationRegistry documentationRegistry,
BuildOperationExecutor buildOperationExecutor, int maxWorkerCount, Clock clock,
long testSeed, int shardCount, int shardIndex) {
mWorkerFactory = workerFactory;
mActorFactory = actorFactory;
mModuleRegistry = moduleRegistry;
mWorkerLeaseRegistry = workerLeaseRegistry;
mDocumentationRegistry = documentationRegistry;
mBuildOperationExecutor = buildOperationExecutor;
mMaxWorkerCount = maxWorkerCount;
mClock = clock;
mTestSeed = testSeed;
mShardCount = shardCount;
mShardIndex = shardIndex;
}
@Override
public void execute(final JvmTestExecutionSpec testExecutionSpec, final TestResultProcessor testResultProcessor) {
final TestFramework testFramework = testExecutionSpec.getTestFramework();
final WorkerTestClassProcessorFactory testInstanceFactory = testFramework.getProcessorFactory();
final WorkerLeaseRegistry.WorkerLease currentWorkerLease = mWorkerLeaseRegistry.getCurrentWorkerLease();
final Set<File> classpath = ImmutableSet.copyOf(testExecutionSpec.getClasspath());
final Factory<TestClassProcessor> forkingProcessorFactory =
() -> new ForkingTestClassProcessor(currentWorkerLease, mWorkerFactory, testInstanceFactory,
testExecutionSpec.getJavaForkOptions(), classpath, testFramework.getWorkerConfigurationAction(),
mModuleRegistry, mDocumentationRegistry);
final Factory<TestClassProcessor> reforkingProcessorFactory =
() -> new RestartEveryNTestClassProcessor(forkingProcessorFactory, testExecutionSpec.getForkEvery());
final TestClassProcessor processor = new MaxNParallelTestClassProcessor(getMaxParallelForks(testExecutionSpec),
reforkingProcessorFactory, mActorFactory);
final FileTree testClassFiles = testExecutionSpec.getCandidateClassFiles();
final Runnable detector;
if (testExecutionSpec.isScanForTestClasses()) {
final TestFrameworkDetector testFrameworkDetector = testFramework.getDetector();
testFrameworkDetector.setTestClasses(testExecutionSpec.getTestClassesDirs().getFiles());
testFrameworkDetector.setTestClasspath(classpath);
detector = new OrderedTestClassScanner(testClassFiles, testFrameworkDetector, processor, mTestSeed,
mShardCount, mShardIndex);
} else {
detector = new OrderedTestClassScanner(testClassFiles, null, processor, mTestSeed, mShardCount,
mShardIndex);
}
final Object testTaskOperationId = mBuildOperationExecutor.getCurrentOperation().getParentId();
new TestMainAction(detector, processor, testResultProcessor, mClock, testTaskOperationId,
testExecutionSpec.getPath(), "Gradle Test Run " + testExecutionSpec.getIdentityPath()).run();
}
@Override
public void stopNow() {
throw new UnsupportedOperationException("stopNow() unsupported by OrderedTestExecuter");
}
private int getMaxParallelForks(JvmTestExecutionSpec testExecutionSpec) {
int maxParallelForks = testExecutionSpec.getMaxParallelForks();
if (maxParallelForks > mMaxWorkerCount) {
LOGGER.info("{}.maxParallelForks ({}) is larger than max-workers ({}), forcing it to {}",
testExecutionSpec.getPath(), maxParallelForks, mMaxWorkerCount, mMaxWorkerCount);
maxParallelForks = mMaxWorkerCount;
}
return maxParallelForks;
}
}
package com.twitter.gradle.plugin.orderedtest
import groovy.transform.CompileStatic
@CompileStatic
class OrderedTestExtension {
String seed
int shardCount
int shardIndex
String getSeed() {
return seed
}
void setSeed(String seed) {
this.seed = seed
}
int getShardCount() {
return shardCount
}
void setShardCount(int shardCount) {
this.shardCount = shardCount
}
int getShardIndex() {
return shardIndex
}
void setShardIndex(int shardIndex) {
this.shardIndex = shardIndex
}
}
package com.twitter.gradle.plugin.orderedtest
import org.gradle.StartParameter
import org.gradle.api.InvalidUserDataException
import org.gradle.api.Plugin
import org.gradle.api.Project
import org.gradle.api.internal.DocumentationRegistry
import org.gradle.api.tasks.testing.Test
import org.gradle.internal.operations.BuildOperationExecutor
import org.gradle.internal.time.Clock
import org.gradle.internal.work.WorkerLeaseRegistry
class OrderedTestPlugin implements Plugin<Project> {
@Override
void apply(Project project) {
def orderedTestExt = project.extensions.create('orderedTest', OrderedTestExtension)
// Update test tasks after project evaluation, since that will allow the orderedTest extension to be configured.
project.afterEvaluate {
project.tasks.withType(Test).configureEach {
long seed = parseSeed(orderedTestExt.seed)
int shardCount = orderedTestExt.shardCount
if (shardCount < 0) {
throw new IllegalArgumentException("Invalid shard count ${shardCount}")
}
int shardIndex = orderedTestExt.shardIndex
if (shardIndex < 0 || shardCount > 0 && shardIndex >= shardCount) {
throw new IllegalArgumentException(
"Invalid shard index ${shardIndex} with shard count ${shardCount}")
}
OrderedTestExecuter executer = new OrderedTestExecuter(
getProcessBuilderFactory(),
getActorFactory(), getModuleRegistry(),
getServices().get(WorkerLeaseRegistry.class),
getServices().get(DocumentationRegistry.class),
getServices().get(BuildOperationExecutor.class),
getServices().get(StartParameter.class).getMaxWorkerCount(),
getServices().get(Clock.class),
seed, shardCount, shardIndex)
setTestExecuter(executer)
inputs.property('seed', seed)
inputs.property('shardCount', shardCount)
inputs.property('shardIndex', shardIndex)
testLogging {
afterSuite { desc, result ->
if (!desc.parent && result.testCount != 0 && seed != 0) {
logger.lifecycle("To re-run tests in the same order use this with your gradle task:" +
" -Dtest.seed=${Long.toHexString(seed)}")
}
}
}
}
}
}
// implement our documented user interface/policy for this plugin
// 1. null or 'lex' or 0 (the default) results in a lex sort of the test classes
// the 0 is really a magic value used further down the stack
// 2. 'random' results in a random run order, of use for CI jobs which aim to shake out dependencies
// the seed is published on the test console output to enable reproduction of the same order
// 3. hex long value allows a specific seed to be used allowing reproduction of 'random' orders
// in a perfect world we'd default to 'random', in a world where we minimize developer pain we default to 'lex'
static long parseSeed(String seed) {
if (seed == null || seed.equals("lex")) {
return 0
} else if (seed.equals("random")) {
return new Random().nextLong()
} else {
return parseHexToLong(seed)
}
}
// because Long.parseLong does not grok 2's complement
static long parseHexToLong(String seed) {
def len = seed.length()
if (len < 16) {
seed = ('0' * (16 - len)) + seed
}
def lsb = seed.substring(8)
def msb = seed.substring(0, 8)
try {
return (Long.parseLong(msb, 16)<<32) | Long.parseLong(lsb, 16)
} catch (NumberFormatException e) {
throw new InvalidUserDataException("expected -Dtest.seed=N where N must be a hex long integer")
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment