Skip to content

Instantly share code, notes, and snippets.

@astojilj
Created December 5, 2018 15:10
Show Gist options
  • Select an option

  • Save astojilj/15b2e02089e733f5512463aeff90a2e4 to your computer and use it in GitHub Desktop.

Select an option

Save astojilj/15b2e02089e733f5512463aeff90a2e4 to your computer and use it in GitHub Desktop.
diff --git a/src/ops/batchnorm_test.ts b/src/ops/batchnorm_test.ts
index c8ac335..6bd30e1 100644
--- a/src/ops/batchnorm_test.ts
+++ b/src/ops/batchnorm_test.ts
@@ -16,8 +16,10 @@
*/
import * as tf from '../index';
-import {describeWithFlags} from '../jasmine_util';
+import {describeWithFlags, TEST_ENVS} from '../jasmine_util';
import {ALL_ENVS, expectArraysClose, WEBGL_ENVS} from '../test_util';
+import {MathBackendWebGL} from '../kernels/backend_webgl';
+import {ENV} from '../environment';
describeWithFlags('packed batchNormalization', WEBGL_ENVS, () => {
const webglPackedBatchNormalizationSavedFlag =
@@ -161,6 +163,18 @@ describeWithFlags('packed batchNormalization', WEBGL_ENVS, () => {
});
});
+if (!ENV.get('IS_NODE')) {
+ TEST_ENVS.push(
+ {
+ name: 'webgl2-packedBinary',
+ factory: () => new MathBackendWebGL(),
+ features: {'WEBGL_VERSION': ENV.get('WEBGL_VERSION'),
+ 'WEBGL_CPU_FORWARD': false,
+ 'WEBGL_PACK': true}
+ }
+ );
+ }
+
describeWithFlags('batchNormalization4D', ALL_ENVS, () => {
it('simple batchnorm4D, no offset or scale, 2x1x1x2', () => {
const x = tf.tensor4d([2, 4, 9, 23], [2, 1, 1, 2]);
@@ -358,6 +372,12 @@ describeWithFlags('batchNormalization4D', ALL_ENVS, () => {
});
});
+if (!ENV.get('IS_NODE')) {
+ if (TEST_ENVS.pop().features['WEBGL_PACK'] !== true) {
+ throw new Error('Error with WEBGL_PACK setup.');
+ }
+ }
+
describeWithFlags('batchNormalization3D', ALL_ENVS, () => {
it('simple batchnorm3D, no offset or scale, 2x1x2', () => {
const x = tf.tensor3d([2, 4, 9, 23], [2, 1, 2]);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment