Last active
April 17, 2024 05:23
-
-
Save cesar1000/b610c1c95169ca6e83c173d888321632 to your computer and use it in GitHub Desktop.
Ordered unit tests in Gradle
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
implementation-class=com.twitter.gradle.plugin.orderedtest.OrderedTestPlugin |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.twitter.gradle.plugin.orderedtest; | |
import org.gradle.api.file.EmptyFileVisitor; | |
import org.gradle.api.file.FileTree; | |
import org.gradle.api.file.FileVisitDetails; | |
import org.gradle.api.internal.file.RelativeFile; | |
import org.gradle.api.internal.tasks.testing.DefaultTestClassRunInfo; | |
import org.gradle.api.internal.tasks.testing.TestClassProcessor; | |
import org.gradle.api.internal.tasks.testing.TestClassRunInfo; | |
import org.gradle.api.internal.tasks.testing.detection.TestFrameworkDetector; | |
import org.jetbrains.annotations.NotNull; | |
import org.jetbrains.annotations.Nullable; | |
import java.io.File; | |
import java.util.ArrayList; | |
import java.util.Collections; | |
import java.util.Comparator; | |
import java.util.List; | |
import java.util.Random; | |
/** | |
* Based on {@link org.gradle.api.internal.tasks.testing.detection.DefaultTestClassScanner} from gradle REL_5.1 | |
* Depending on the availability of a test framework detector, a detection or filename scan is | |
* performed to find test classes. Implements stable run ordering for the tests. | |
*/ | |
@SuppressWarnings("all") | |
public class OrderedTestClassScanner implements Runnable { | |
private final FileTree mCandidateClassFiles; | |
private final TestFrameworkDetector mTestFrameworkDetector; | |
private final TestClassProcessor mTestClassProcessor; | |
private final long mTestSeed; | |
private final int mShardCount; | |
private final int mShardIndex; | |
public OrderedTestClassScanner(@NotNull FileTree candidateClassFiles, | |
TestFrameworkDetector testFrameworkDetector, | |
@NotNull TestClassProcessor testClassProcessor, | |
long testSeed, int shardCount, int shardIndex) { | |
mCandidateClassFiles = candidateClassFiles; | |
mTestFrameworkDetector = testFrameworkDetector; | |
mTestClassProcessor = testClassProcessor; | |
mTestSeed = testSeed; | |
mShardCount = shardCount; | |
mShardIndex = shardIndex; | |
} | |
@Override | |
public void run() { | |
if (mTestFrameworkDetector == null) { | |
final List<TestClassRunInfo> scanned = filenameScan(); | |
final List<TestClassRunInfo> sorted = sort(scanned, mTestSeed, | |
(o1, o2) -> o1.getTestClassName().compareTo(o2.getTestClassName())); | |
final List<TestClassRunInfo> filtered = filter(sorted, mShardCount, mShardIndex); | |
filtered.forEach(s -> mTestClassProcessor.processTestClass(s)); | |
} else { | |
final List<RelativeFile> scanned = detectionScan(); | |
final List<RelativeFile> sorted = sort(scanned, mTestSeed, | |
new Comparator<RelativeFile>() { | |
@Override | |
public int compare(RelativeFile o1, RelativeFile o2) { | |
return o1.getFile().compareTo(o2.getFile()); | |
} | |
}); | |
final List<RelativeFile> filtered = filter(sorted, mShardCount, mShardIndex); | |
filtered.forEach(s -> mTestFrameworkDetector.processTestClass(s)); | |
} | |
} | |
private List<RelativeFile> detectionScan() { | |
final List<RelativeFile> scanned = new ArrayList<>(); | |
mTestFrameworkDetector.startDetection(mTestClassProcessor); | |
mCandidateClassFiles.visit(new ClassFileVisitor() { | |
@Override | |
public void visitClassFile(FileVisitDetails fileDetails) { | |
scanned.add(new RelativeFile(fileDetails.getFile(), fileDetails.getRelativePath())); | |
} | |
}); | |
return scanned; | |
} | |
private List<TestClassRunInfo> filenameScan() { | |
final List<TestClassRunInfo> scanned = new ArrayList<>(); | |
mCandidateClassFiles.visit(new ClassFileVisitor() { | |
@Override | |
public void visitClassFile(FileVisitDetails fileDetails) { | |
final String className = fileDetails.getRelativePath().getPathString().replaceAll("\\.class", "") | |
.replace('/', '.'); | |
final TestClassRunInfo testClass = new DefaultTestClassRunInfo(className); | |
scanned.add(testClass); | |
} | |
}); | |
return scanned; | |
} | |
private abstract class ClassFileVisitor extends EmptyFileVisitor { | |
@Override | |
public void visitFile(FileVisitDetails fileDetails) { | |
final File file = fileDetails.getFile(); | |
if (file.getAbsolutePath().endsWith(".class")) { | |
visitClassFile(fileDetails); | |
} | |
} | |
public abstract void visitClassFile(FileVisitDetails fileDetails); | |
} | |
private static <T> List<T> sort(@NotNull List<T> items, long testSeed, @Nullable Comparator<? super T> cmp) { | |
if (testSeed == 0) { | |
// special case natural order (lex for strings), still 'random' enough in the big picture | |
Collections.sort(items, cmp); | |
} else { | |
// start from ascending natural order, then random shuffle with known seed, so we get a stable random sort | |
Collections.sort(items, null); | |
Collections.shuffle(items, new Random(testSeed)); | |
} | |
return items; | |
} | |
private static <T> List<T> filter(@NotNull List<T> items, int shardCount, int shardIndex) { | |
if (shardCount != 0) { | |
final List<T> filtered = new ArrayList<>(items.size() / shardCount + 1); | |
for (int i = shardIndex; i < items.size(); i += shardCount) { | |
filtered.add(items.get(i)); | |
} | |
return filtered; | |
} else { | |
return items; | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.twitter.gradle.plugin.orderedtest; | |
import com.google.common.collect.ImmutableSet; | |
import org.gradle.api.file.FileTree; | |
import org.gradle.api.internal.DocumentationRegistry; | |
import org.gradle.api.internal.classpath.ModuleRegistry; | |
import org.gradle.api.internal.tasks.testing.JvmTestExecutionSpec; | |
import org.gradle.api.internal.tasks.testing.TestClassProcessor; | |
import org.gradle.api.internal.tasks.testing.TestExecuter; | |
import org.gradle.api.internal.tasks.testing.TestFramework; | |
import org.gradle.api.internal.tasks.testing.TestResultProcessor; | |
import org.gradle.api.internal.tasks.testing.WorkerTestClassProcessorFactory; | |
import org.gradle.api.internal.tasks.testing.detection.TestFrameworkDetector; | |
import org.gradle.api.internal.tasks.testing.processors.MaxNParallelTestClassProcessor; | |
import org.gradle.api.internal.tasks.testing.processors.RestartEveryNTestClassProcessor; | |
import org.gradle.api.internal.tasks.testing.processors.TestMainAction; | |
import org.gradle.api.internal.tasks.testing.worker.ForkingTestClassProcessor; | |
import org.gradle.api.logging.Logger; | |
import org.gradle.api.logging.Logging; | |
import org.gradle.internal.Factory; | |
import org.gradle.internal.actor.ActorFactory; | |
import org.gradle.internal.operations.BuildOperationExecutor; | |
import org.gradle.internal.time.Clock; | |
import org.gradle.internal.work.WorkerLeaseRegistry; | |
import org.gradle.process.internal.worker.WorkerProcessFactory; | |
import java.io.File; | |
import java.util.Set; | |
/** | |
* Based on DefaultTestExecuter from gradle REL_4.4.0. | |
*/ | |
public class OrderedTestExecuter implements TestExecuter<JvmTestExecutionSpec> { | |
private static final Logger LOGGER = Logging.getLogger(OrderedTestExecuter.class); | |
private final WorkerProcessFactory mWorkerFactory; | |
private final ActorFactory mActorFactory; | |
private final ModuleRegistry mModuleRegistry; | |
private final WorkerLeaseRegistry mWorkerLeaseRegistry; | |
private final DocumentationRegistry mDocumentationRegistry; | |
private final BuildOperationExecutor mBuildOperationExecutor; | |
private final int mMaxWorkerCount; | |
private final Clock mClock; | |
private final long mTestSeed; | |
private final int mShardCount; | |
private final int mShardIndex; | |
public OrderedTestExecuter(WorkerProcessFactory workerFactory, ActorFactory actorFactory, | |
ModuleRegistry moduleRegistry, WorkerLeaseRegistry workerLeaseRegistry, | |
DocumentationRegistry documentationRegistry, | |
BuildOperationExecutor buildOperationExecutor, int maxWorkerCount, Clock clock, | |
long testSeed, int shardCount, int shardIndex) { | |
mWorkerFactory = workerFactory; | |
mActorFactory = actorFactory; | |
mModuleRegistry = moduleRegistry; | |
mWorkerLeaseRegistry = workerLeaseRegistry; | |
mDocumentationRegistry = documentationRegistry; | |
mBuildOperationExecutor = buildOperationExecutor; | |
mMaxWorkerCount = maxWorkerCount; | |
mClock = clock; | |
mTestSeed = testSeed; | |
mShardCount = shardCount; | |
mShardIndex = shardIndex; | |
} | |
@Override | |
public void execute(final JvmTestExecutionSpec testExecutionSpec, final TestResultProcessor testResultProcessor) { | |
final TestFramework testFramework = testExecutionSpec.getTestFramework(); | |
final WorkerTestClassProcessorFactory testInstanceFactory = testFramework.getProcessorFactory(); | |
final WorkerLeaseRegistry.WorkerLease currentWorkerLease = mWorkerLeaseRegistry.getCurrentWorkerLease(); | |
final Set<File> classpath = ImmutableSet.copyOf(testExecutionSpec.getClasspath()); | |
final Factory<TestClassProcessor> forkingProcessorFactory = | |
() -> new ForkingTestClassProcessor(currentWorkerLease, mWorkerFactory, testInstanceFactory, | |
testExecutionSpec.getJavaForkOptions(), classpath, testFramework.getWorkerConfigurationAction(), | |
mModuleRegistry, mDocumentationRegistry); | |
final Factory<TestClassProcessor> reforkingProcessorFactory = | |
() -> new RestartEveryNTestClassProcessor(forkingProcessorFactory, testExecutionSpec.getForkEvery()); | |
final TestClassProcessor processor = new MaxNParallelTestClassProcessor(getMaxParallelForks(testExecutionSpec), | |
reforkingProcessorFactory, mActorFactory); | |
final FileTree testClassFiles = testExecutionSpec.getCandidateClassFiles(); | |
final Runnable detector; | |
if (testExecutionSpec.isScanForTestClasses()) { | |
final TestFrameworkDetector testFrameworkDetector = testFramework.getDetector(); | |
testFrameworkDetector.setTestClasses(testExecutionSpec.getTestClassesDirs().getFiles()); | |
testFrameworkDetector.setTestClasspath(classpath); | |
detector = new OrderedTestClassScanner(testClassFiles, testFrameworkDetector, processor, mTestSeed, | |
mShardCount, mShardIndex); | |
} else { | |
detector = new OrderedTestClassScanner(testClassFiles, null, processor, mTestSeed, mShardCount, | |
mShardIndex); | |
} | |
final Object testTaskOperationId = mBuildOperationExecutor.getCurrentOperation().getParentId(); | |
new TestMainAction(detector, processor, testResultProcessor, mClock, testTaskOperationId, | |
testExecutionSpec.getPath(), "Gradle Test Run " + testExecutionSpec.getIdentityPath()).run(); | |
} | |
@Override | |
public void stopNow() { | |
throw new UnsupportedOperationException("stopNow() unsupported by OrderedTestExecuter"); | |
} | |
private int getMaxParallelForks(JvmTestExecutionSpec testExecutionSpec) { | |
int maxParallelForks = testExecutionSpec.getMaxParallelForks(); | |
if (maxParallelForks > mMaxWorkerCount) { | |
LOGGER.info("{}.maxParallelForks ({}) is larger than max-workers ({}), forcing it to {}", | |
testExecutionSpec.getPath(), maxParallelForks, mMaxWorkerCount, mMaxWorkerCount); | |
maxParallelForks = mMaxWorkerCount; | |
} | |
return maxParallelForks; | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.twitter.gradle.plugin.orderedtest | |
import groovy.transform.CompileStatic | |
@CompileStatic | |
class OrderedTestExtension { | |
String seed | |
int shardCount | |
int shardIndex | |
String getSeed() { | |
return seed | |
} | |
void setSeed(String seed) { | |
this.seed = seed | |
} | |
int getShardCount() { | |
return shardCount | |
} | |
void setShardCount(int shardCount) { | |
this.shardCount = shardCount | |
} | |
int getShardIndex() { | |
return shardIndex | |
} | |
void setShardIndex(int shardIndex) { | |
this.shardIndex = shardIndex | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.twitter.gradle.plugin.orderedtest | |
import org.gradle.StartParameter | |
import org.gradle.api.InvalidUserDataException | |
import org.gradle.api.Plugin | |
import org.gradle.api.Project | |
import org.gradle.api.internal.DocumentationRegistry | |
import org.gradle.api.tasks.testing.Test | |
import org.gradle.internal.operations.BuildOperationExecutor | |
import org.gradle.internal.time.Clock | |
import org.gradle.internal.work.WorkerLeaseRegistry | |
class OrderedTestPlugin implements Plugin<Project> { | |
@Override | |
void apply(Project project) { | |
def orderedTestExt = project.extensions.create('orderedTest', OrderedTestExtension) | |
// Update test tasks after project evaluation, since that will allow the orderedTest extension to be configured. | |
project.afterEvaluate { | |
project.tasks.withType(Test).configureEach { | |
long seed = parseSeed(orderedTestExt.seed) | |
int shardCount = orderedTestExt.shardCount | |
if (shardCount < 0) { | |
throw new IllegalArgumentException("Invalid shard count ${shardCount}") | |
} | |
int shardIndex = orderedTestExt.shardIndex | |
if (shardIndex < 0 || shardCount > 0 && shardIndex >= shardCount) { | |
throw new IllegalArgumentException( | |
"Invalid shard index ${shardIndex} with shard count ${shardCount}") | |
} | |
OrderedTestExecuter executer = new OrderedTestExecuter( | |
getProcessBuilderFactory(), | |
getActorFactory(), getModuleRegistry(), | |
getServices().get(WorkerLeaseRegistry.class), | |
getServices().get(DocumentationRegistry.class), | |
getServices().get(BuildOperationExecutor.class), | |
getServices().get(StartParameter.class).getMaxWorkerCount(), | |
getServices().get(Clock.class), | |
seed, shardCount, shardIndex) | |
setTestExecuter(executer) | |
inputs.property('seed', seed) | |
inputs.property('shardCount', shardCount) | |
inputs.property('shardIndex', shardIndex) | |
testLogging { | |
afterSuite { desc, result -> | |
if (!desc.parent && result.testCount != 0 && seed != 0) { | |
logger.lifecycle("To re-run tests in the same order use this with your gradle task:" + | |
" -Dtest.seed=${Long.toHexString(seed)}") | |
} | |
} | |
} | |
} | |
} | |
} | |
// implement our documented user interface/policy for this plugin | |
// 1. null or 'lex' or 0 (the default) results in a lex sort of the test classes | |
// the 0 is really a magic value used further down the stack | |
// 2. 'random' results in a random run order, of use for CI jobs which aim to shake out dependencies | |
// the seed is published on the test console output to enable reproduction of the same order | |
// 3. hex long value allows a specific seed to be used allowing reproduction of 'random' orders | |
// in a perfect world we'd default to 'random', in a world where we minimize developer pain we default to 'lex' | |
static long parseSeed(String seed) { | |
if (seed == null || seed.equals("lex")) { | |
return 0 | |
} else if (seed.equals("random")) { | |
return new Random().nextLong() | |
} else { | |
return parseHexToLong(seed) | |
} | |
} | |
// because Long.parseLong does not grok 2's complement | |
static long parseHexToLong(String seed) { | |
def len = seed.length() | |
if (len < 16) { | |
seed = ('0' * (16 - len)) + seed | |
} | |
def lsb = seed.substring(8) | |
def msb = seed.substring(0, 8) | |
try { | |
return (Long.parseLong(msb, 16)<<32) | Long.parseLong(lsb, 16) | |
} catch (NumberFormatException e) { | |
throw new InvalidUserDataException("expected -Dtest.seed=N where N must be a hex long integer") | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment