Skip to content

Instantly share code, notes, and snippets.

@lfyzjck
Last active April 3, 2025 05:55
Show Gist options
  • Save lfyzjck/627dadd5b976a0e2fc15c804fd0b7d60 to your computer and use it in GitHub Desktop.
Save lfyzjck/627dadd5b976a0e2fc15c804fd0b7d60 to your computer and use it in GitHub Desktop.
celeborn_flush_benchmark.java
package main.java;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
public class FlushBenchmark {
// 对齐 celeborn.client.push.buffer.max.size
private static final int BUFFER_SIZE_MIN = 1 * 1024; // 1KB
private static final int BUFFER_SIZE_MAX = 64 * 1024; // 64KB
private static final long FLUSH_BUFFER_SIZE = 1 * 1024 * 1024; // 1MB
private static final int DEFAULT_BUFFER_COUNT = 1000; // 默认缓冲区数量
private static final int DEFAULT_ITERATIONS = 100; // 默认测试迭代次数
private static final String DEFAULT_TASK_TYPE = "both"; // 默认测试两种任务
abstract static class FlushTask {
protected final CompositeByteBuf buffer;
protected final FileChannel channel;
protected final String name;
public FlushTask(String name, CompositeByteBuf buffer, FileChannel channel) {
this.name = name;
this.buffer = buffer;
this.channel = channel;
}
public abstract void flush() throws IOException;
public String getName() {
return name;
}
}
static class LocalFlushTask extends FlushTask {
public LocalFlushTask(CompositeByteBuf buffer, FileChannel channel) {
super("LocalFlushTask", buffer, channel);
}
@Override
public void flush() throws IOException {
final ByteBuffer[] buffers = buffer.nioBuffers();
for (ByteBuffer buffer : buffers) {
while (buffer.hasRemaining()) {
channel.write(buffer);
}
}
}
}
static class LocalBatchFlushTask extends FlushTask {
public LocalBatchFlushTask(CompositeByteBuf buffer, FileChannel channel) {
super("LocalBatchFlushTask", buffer, channel);
}
@Override
public void flush() throws IOException {
final ByteBuffer[] buffers = buffer.nioBuffers();
final long readableBytes = buffer.readableBytes();
long written = 0L;
do {
written = channel.write(buffers) + written;
} while (written != readableBytes);
}
}
public static void main(String[] args) throws IOException {
// 解析命令行参数
BenchmarkConfig config = parseCommandLineArgs(args);
System.out.println("========= 性能测试配置 =========");
System.out.println("缓冲区数量: " + config.bufferCount);
System.out.println("测试迭代次数: " + config.iterations);
System.out.println("测试任务类型: " + config.taskType);
System.out.println("测试缓冲区大小: " + config.flushBufferSize);
if (config.outputPath != null) {
System.out.println("输出路径: " + config.outputPath);
}
System.out.println("保留测试文件: " + (config.keepFiles ? "是" : "否"));
// 获取测试缓冲区数据
CompositeByteBuf buffers = generateRandomBuffers(config.bufferCount, config.flushBufferSize);
// 根据任务类型运行相应的测试
if ("both".equals(config.taskType) || "batch".equals(config.taskType)) {
TestResult batchResult = runBenchmark(new LocalBatchFlushTask(buffers, null), config);
printStatistics(batchResult);
}
if ("both".equals(config.taskType) || "single".equals(config.taskType)) {
TestResult localResult = runBenchmark(new LocalFlushTask(buffers, null), config);
printStatistics(localResult);
}
}
/**
* 用于保存基准测试的配置参数
*/
private static class BenchmarkConfig {
final int bufferCount; // 缓冲区数量
final int iterations; // 测试迭代次数
final String outputPath; // 输出路径
final boolean keepFiles; // 是否保留测试文件
final String taskType; // 任务类型:batch, single, both
final long flushBufferSize; // 缓冲区大小
final boolean verbose; // 是否打印详细信息
BenchmarkConfig(int bufferCount, long flushBufferSize, int iterations, String outputPath, boolean keepFiles, String taskType, boolean verbose) {
this.bufferCount = bufferCount;
this.flushBufferSize = flushBufferSize;
this.iterations = iterations;
this.outputPath = outputPath;
this.keepFiles = keepFiles;
this.taskType = taskType;
this.verbose = verbose;
}
}
/**
* 解析命令行参数
*
* 支持的参数:
* -n <数量>: 缓冲区数量
* -k <次数>: 测试迭代次数
* -o <路径>: 输出文件路径
* -t <类型>: 测试任务类型 (batch|single|both)
* -keep: 保留测试文件
* -h, --help: 显示帮助
*/
private static BenchmarkConfig parseCommandLineArgs(String[] args) {
int bufferCount = DEFAULT_BUFFER_COUNT;
int iterations = DEFAULT_ITERATIONS;
String outputPath = null;
boolean keepFiles = false;
String taskType = DEFAULT_TASK_TYPE;
long flushBufferSize = FLUSH_BUFFER_SIZE;
boolean verbose = false;
if (args.length == 0) {
return new BenchmarkConfig(bufferCount, flushBufferSize, iterations, outputPath, keepFiles, taskType, verbose);
}
for (int i = 0; i < args.length; i++) {
switch (args[i]) {
case "-n":
if (i + 1 < args.length) {
try {
int value = Integer.parseInt(args[++i]);
if (value > 0) {
bufferCount = value;
} else {
System.err.println("缓冲区数量必须大于0,使用默认值: " + DEFAULT_BUFFER_COUNT);
}
} catch (NumberFormatException e) {
System.err.println("无效的缓冲区数量,使用默认值: " + DEFAULT_BUFFER_COUNT);
}
}
break;
// flus buffer size
case "-b":
if (i + 1 < args.length) {
int value = Integer.parseInt(args[++i]);
if (value > 0) {
flushBufferSize = value;
}
}
break;
case "-k":
if (i + 1 < args.length) {
try {
int value = Integer.parseInt(args[++i]);
if (value > 0) {
iterations = value;
} else {
System.err.println("迭代次数必须大于0,使用默认值: " + DEFAULT_ITERATIONS);
}
} catch (NumberFormatException e) {
System.err.println("无效的迭代次数,使用默认值: " + DEFAULT_ITERATIONS);
}
}
break;
case "-t":
if (i + 1 < args.length) {
String value = args[++i].toLowerCase();
if ("batch".equals(value) || "single".equals(value) || "both".equals(value)) {
taskType = value;
} else {
System.err.println("无效的任务类型,使用默认值: " + DEFAULT_TASK_TYPE);
}
}
break;
case "-o":
if (i + 1 < args.length) {
outputPath = args[++i];
// 确保目录存在
File dir = new File(outputPath);
if (!dir.exists()) {
if (dir.mkdirs()) {
System.out.println("已创建输出目录: " + outputPath);
} else {
System.err.println("无法创建输出目录: " + outputPath);
outputPath = null;
}
} else if (!dir.isDirectory()) {
System.err.println("指定的路径不是目录: " + outputPath);
outputPath = null;
}
}
break;
case "-keep":
keepFiles = true;
break;
case "-v":
verbose = true;
break;
case "-h":
case "--help":
printHelp();
System.exit(0);
break;
default:
System.err.println("未知参数: " + args[i]);
printHelp();
break;
}
}
return new BenchmarkConfig(bufferCount, flushBufferSize, iterations, outputPath, keepFiles, taskType, verbose);
}
private static void printHelp() {
System.out.println("用法: java FlushBenchmark [选项]");
System.out.println("选项:");
System.out.println(" -n <数量> 设置缓冲区数量 (默认: " + DEFAULT_BUFFER_COUNT + ")");
System.out.println(" -k <次数> 设置测试迭代次数 (默认: " + DEFAULT_ITERATIONS + ")");
System.out.println(" -t <类型> 测试任务类型 (batch|single|both) (默认: " + DEFAULT_TASK_TYPE + ")");
System.out.println(" batch: 仅测试批量写入 (LocalBatchFlushTask)");
System.out.println(" single: 仅测试单次写入 (LocalFlushTask)");
System.out.println(" both: 测试两种写入方式");
System.out.println(" -o <路径> 设置输出文件路径");
System.out.println(" -keep 保留测试文件 (默认删除)");
System.out.println(" -h, --help 显示帮助信息");
}
private static TestResult runBenchmark(FlushTask taskTemplate, BenchmarkConfig config) throws IOException {
List<File> tempFiles = new ArrayList<>();
long totalTime = 0;
System.out.println("\n========= 开始测试: " + taskTemplate.getName() + " =========");
for (int i = 0; i < config.iterations; i++) {
File tempFile;
if (config.outputPath != null) {
// 使用指定的输出路径
String fileName = "flush_test_" + taskTemplate.getName() + "_" + i + ".dat";
tempFile = new File(config.outputPath, fileName);
tempFile.createNewFile();
} else {
// 使用系统临时路径
tempFile = File.createTempFile("flush_test_" + taskTemplate.getName() + "_", ".dat");
}
tempFiles.add(tempFile);
if (config.verbose) {
System.out.println("创建测试文件: " + tempFile.getAbsolutePath());
}
CompositeByteBuf bufferCopy = generateRandomBuffers(config.bufferCount, FLUSH_BUFFER_SIZE);
try (FileChannel channel = createWritableFileChannel(tempFile.getAbsolutePath())) {
// 创建对应任务的实例
FlushTask task;
if (taskTemplate instanceof LocalFlushTask) {
task = new LocalFlushTask(bufferCopy, channel);
} else {
task = new LocalBatchFlushTask(bufferCopy, channel);
}
long startTime = System.nanoTime();
task.flush();
long endTime = System.nanoTime();
long iterationTime = endTime - startTime;
totalTime += iterationTime;
if (config.verbose) {
System.out.printf("迭代 %d: %.3f 毫秒\n", i + 1, iterationTime / 1e6);
}
} catch (IOException e) {
e.printStackTrace();
}
}
// 清理临时文件,除非要求保留
if (!config.keepFiles) {
cleanTempFiles(tempFiles);
} else {
System.out.println("保留测试文件: " + tempFiles.size() + "个");
for (File file : tempFiles) {
System.out.println(" " + file.getAbsolutePath());
}
}
return new TestResult(taskTemplate.getName(), config.bufferCount, config.iterations, totalTime);
}
private static CompositeByteBuf generateRandomBuffers(int n, long maxSize) {
CompositeByteBuf compositeBuf = Unpooled.compositeBuffer(n);
Random random = new Random();
long totalSize = 0;
for (int i = 0; i < n; i++) {
// 随机生成缓冲区大小
int bufferSize = BUFFER_SIZE_MIN + random.nextInt(BUFFER_SIZE_MAX - BUFFER_SIZE_MIN + 1);
byte[] data = new byte[bufferSize];
random.nextBytes(data);
ByteBuf buf = Unpooled.wrappedBuffer(data);
compositeBuf.addComponent(true, buf);
totalSize += bufferSize;
if (totalSize >= maxSize) {
// buffer pool if full, break
break;
}
}
return compositeBuf;
}
public static FileChannel createWritableFileChannel(String filePath) throws IOException {
return FileChannel.open(
Paths.get(filePath),
StandardOpenOption.CREATE,
StandardOpenOption.WRITE);
}
private static class TestResult {
final String taskName;
final int bufferCount;
final int iterations;
final long totalTimeNanos;
TestResult(String taskName, int bufferCount, int iterations, long totalTimeNanos) {
this.taskName = taskName;
this.bufferCount = bufferCount;
this.iterations = iterations;
this.totalTimeNanos = totalTimeNanos;
}
}
private static void printStatistics(TestResult result) {
long totalFlushes = result.iterations;
long totalSizeMB = (long) FLUSH_BUFFER_SIZE * result.iterations / (1024 * 1024);
double avgTimeMs = (result.totalTimeNanos / 1e6) / result.iterations;
System.out.println("\n----- " + result.taskName + " 测试结果 -----");
System.out.println("累计Flush次数: " + totalFlushes);
System.out.println("总写入数据量: " + totalSizeMB + " MB");
System.out.printf("平均Flush耗时: %.3f 毫秒\n", avgTimeMs);
System.out.printf("数据写入速度: %.2f MB/s\n",
totalSizeMB / (result.totalTimeNanos / 1e9));
}
private static void cleanTempFiles(List<File> files) throws IOException {
for (File file : files) {
Files.deleteIfExists(file.toPath());
}
System.out.println("已清理临时文件: " + files.size() + "个");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment