Skip to content

Instantly share code, notes, and snippets.

@caisq
Created September 11, 2018 03:46
Show Gist options
  • Save caisq/d59bb49f436e0974503468a6f7652a9d to your computer and use it in GitHub Desktop.
Save caisq/d59bb49f436e0974503468a6f7652a9d to your computer and use it in GitHub Desktop.
const tf = require('@tensorflow/tfjs');
require('@tensorflow/tfjs-node-gpu');
(async function() {
const inputShape = [43, 232, 1];
const numClasses = 10;
const model = tf.sequential();
model.add(tf.layers.conv2d({
filters: 8, kernelSize: [2, 8],
activation: 'relu', inputShape}));
model.add(tf.layers.maxPooling2d({
poolSize: [2, 2], strides: [2, 2]}));
model.add(tf.layers.conv2d({
filters: 32, kernelSize: [2, 4],
activation: 'relu'}));
model.add(tf.layers .maxPooling2d({
poolSize: [2, 2], strides: [2, 2]}));
model.add(tf.layers.conv2d({
filters: 32, kernelSize: [2, 4],
activation: 'relu'}));
model.add(tf.layers.maxPooling2d({
poolSize: [2, 2], strides: [2, 2]}));
model.add(tf.layers.conv2d({
filters: 32, kernelSize: [2, 4],
activation: 'relu'}));
model.add(tf.layers.maxPooling2d({
poolSize: [2, 2], strides: [1, 2]}));
model.add(tf.layers.flatten({}));
model.add(tf.layers.dense({
units: 2000, activation: 'relu'}));
model.add(tf.layers.dense({
units: numClasses, activation: 'softmax'}));
const optimizer = tf.train.sgd(0.001);
const batchSize = 48;
const xs = tf.randomNormal([batchSize].concat(inputShape));
const ys = tf.randomUniform([batchSize, numClasses]);
console.log('Calling minimize');
const t0 = tf.util.now();
for (let i = 0; i < 100; ++i) {
optimizer.minimize(() => tf.losses.softmaxCrossEntropy(ys, model.apply(xs)));
}
const t1 = tf.util.now();
console.log(`DONE Calling minimize: ${t1 - t0} ms`);
// model.apply(xs).print();
})();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment