Created
February 18, 2019 17:58
-
-
Save astojilj/b1185e799f1a8f952d820a85957b7b93 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/integration_tests/benchmarks/benchmark_test.ts b/integration_tests/benchmarks/benchmark_test.ts | |
| index 8148ea22..7a38c6da 100644 | |
| --- a/integration_tests/benchmarks/benchmark_test.ts | |
| +++ b/integration_tests/benchmarks/benchmark_test.ts | |
| @@ -16,8 +16,8 @@ | |
| */ | |
| import {ConvGPUBenchmark, RegularConvParams} from './conv_benchmarks'; | |
| -import {MatmulGPUBenchmark} from './matmul_benchmarks'; | |
| import {MobileNetV1GPUBenchmark} from './mobilenet_benchmarks'; | |
| +import {ReductionOpsGPUBenchmark} from './reduction_ops_benchmark'; | |
| import * as test_util from './test_util'; | |
| const BENCHMARK_RUNS = 100; | |
| @@ -27,14 +27,13 @@ describe('benchmarks', () => { | |
| jasmine.DEFAULT_TIMEOUT_INTERVAL = 600000; | |
| }); | |
| - it('matmul', async done => { | |
| - const sizes = [1, 100, 400, 1000]; | |
| - | |
| - const benchmark = new MatmulGPUBenchmark(); | |
| + it('reduction', async done => { | |
| + const sizes = [257, 513, 769]; | |
| + const benchmark = new ReductionOpsGPUBenchmark(); | |
| await test_util.benchmarkAndLog( | |
| - 'matmul', size => benchmark.run(size), sizes, size => `N=${size}`, | |
| - BENCHMARK_RUNS); | |
| + 'reduction', size => benchmark.run(size, 'argMax'), sizes, | |
| + size => `N=${size}`, BENCHMARK_RUNS); | |
| done(); | |
| }); | |
| diff --git a/integration_tests/benchmarks/reduction_ops_benchmark.ts b/integration_tests/benchmarks/reduction_ops_benchmark.ts | |
| index 6df93425..46b69a54 100644 | |
| --- a/integration_tests/benchmarks/reduction_ops_benchmark.ts | |
| +++ b/integration_tests/benchmarks/reduction_ops_benchmark.ts | |
| @@ -26,7 +26,7 @@ function getReductionOp(option: string): (x: tf.Tensor) => tf.Scalar { | |
| case 'min': | |
| return x => x.min(); | |
| case 'argMax': | |
| - return x => x.argMax(); | |
| + return x => x.argMax(-1); | |
| case 'argMin': | |
| return x => x.argMin(); | |
| case 'sum': | |
| @@ -49,7 +49,7 @@ export class ReductionOpsCPUBenchmark implements BenchmarkTest { | |
| const start = performance.now(); | |
| tf.tidy(() => { | |
| - op(input).get(); | |
| + op(input).arraySync(); | |
| }); | |
| const end = performance.now(); | |
| @@ -60,10 +60,9 @@ export class ReductionOpsCPUBenchmark implements BenchmarkTest { | |
| export class ReductionOpsGPUBenchmark implements BenchmarkTest { | |
| async run(size: number, option: string) { | |
| tf.setBackend('webgl'); | |
| + const input: tf.Tensor4D = | |
| + tf.randomUniform([1, size, size, 21], -1, 1).add(1); | |
| - // Square the provided size to make these 1D benchmarks comparable to the | |
| - // other 2D ones. | |
| - const input: tf.Tensor1D = tf.randomUniform([size * size], -1, 1); | |
| const op = getReductionOp(option); | |
| const benchmark = () => op(input); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment