Created
December 5, 2018 15:10
-
-
Save astojilj/15b2e02089e733f5512463aeff90a2e4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/src/ops/batchnorm_test.ts b/src/ops/batchnorm_test.ts | |
| index c8ac335..6bd30e1 100644 | |
| --- a/src/ops/batchnorm_test.ts | |
| +++ b/src/ops/batchnorm_test.ts | |
| @@ -16,8 +16,10 @@ | |
| */ | |
| import * as tf from '../index'; | |
| -import {describeWithFlags} from '../jasmine_util'; | |
| +import {describeWithFlags, TEST_ENVS} from '../jasmine_util'; | |
| import {ALL_ENVS, expectArraysClose, WEBGL_ENVS} from '../test_util'; | |
| +import {MathBackendWebGL} from '../kernels/backend_webgl'; | |
| +import {ENV} from '../environment'; | |
| describeWithFlags('packed batchNormalization', WEBGL_ENVS, () => { | |
| const webglPackedBatchNormalizationSavedFlag = | |
| @@ -161,6 +163,18 @@ describeWithFlags('packed batchNormalization', WEBGL_ENVS, () => { | |
| }); | |
| }); | |
| +if (!ENV.get('IS_NODE')) { | |
| + TEST_ENVS.push( | |
| + { | |
| + name: 'webgl2-packedBinary', | |
| + factory: () => new MathBackendWebGL(), | |
| + features: {'WEBGL_VERSION': ENV.get('WEBGL_VERSION'), | |
| + 'WEBGL_CPU_FORWARD': false, | |
| + 'WEBGL_PACK': true} | |
| + } | |
| + ); | |
| + } | |
| + | |
| describeWithFlags('batchNormalization4D', ALL_ENVS, () => { | |
| it('simple batchnorm4D, no offset or scale, 2x1x1x2', () => { | |
| const x = tf.tensor4d([2, 4, 9, 23], [2, 1, 1, 2]); | |
| @@ -358,6 +372,12 @@ describeWithFlags('batchNormalization4D', ALL_ENVS, () => { | |
| }); | |
| }); | |
| +if (!ENV.get('IS_NODE')) { | |
| + if (TEST_ENVS.pop().features['WEBGL_PACK'] !== true) { | |
| + throw new Error('Error with WEBGL_PACK setup.'); | |
| + } | |
| + } | |
| + | |
| describeWithFlags('batchNormalization3D', ALL_ENVS, () => { | |
| it('simple batchnorm3D, no offset or scale, 2x1x2', () => { | |
| const x = tf.tensor3d([2, 4, 9, 23], [2, 1, 2]); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment