Skip to content

Instantly share code, notes, and snippets.

@Nucs
Last active September 29, 2019 12:29
Show Gist options
  • Select an option

  • Save Nucs/4cd1220cc676945447f663ee93114578 to your computer and use it in GitHub Desktop.

Select an option

Save Nucs/4cd1220cc676945447f663ee93114578 to your computer and use it in GitHub Desktop.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Drawing;
using System.Drawing.Imaging;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading.Tasks;
using NumSharp;
using NumSharp.Backends;
using NumSharp.Backends.Unmanaged;
using Tensorflow;
using static Tensorflow.Binding;
using Buffer = System.Buffer;
namespace ConsoleApp2
{
class Program
{
static void Main(string[] args)
{
var dncnn = new DnCNN();
var l = new DirectoryInfo("./github_dataset").GetFiles().Select(f => new Bitmap(f.FullName)).ToList();
dncnn.Train(l, l.Select(b=>(Bitmap)b.Clone()).ToList(), 100);
}
}
public class DnCNN
{
private const int batch_size = 128;
Tensor X, Y_, Y;
Tensor loss;
Operation optimizer;
Session sess;
public DnCNN()
{
X = tf.placeholder(tf.float32, shape: (-1, -1, -1, 3), name: "input_image");
Y_ = tf.placeholder(tf.float32, shape: (-1, -1, -1, 3), name: "clean_image");
Y = BuildModel(X);
loss = (1.0 / batch_size) * tf.nn.relu(Y_ - Y);
optimizer = tf.train.AdamOptimizer(0.001f, name: "AdamOptimizer").minimize(loss);
sess = new Session();
var init = tf.global_variables_initializer();
sess.run(init);
}
private Tensor BuildModel(Tensor input, bool is_training = true)
{
var output = tf.layers.conv2d(input, 64, new int[] {3, 3}, name: "conv1", padding: "same");
for (int i = 2; i < 20; i++)
{
output = tf.layers.conv2d(output, 64, new int[] {3, 3}, name: "conv" + i, padding: "same", use_bias: false);
}
output = tf.layers.conv2d(output, 3, new int[] {3, 3}, name: "conv20", padding: "same", use_bias: false);
return input - output;
}
public void Train(List<Bitmap> inputImages, List<Bitmap> outputImages, int epochs = 1)
{
var sw = new Stopwatch();
sw.Start();
NDArray x_train = GenerateDataset(inputImages);
NDArray y_train = GenerateDataset(outputImages);
var saver = new Saver();
print($"Dataset created in {sw.ElapsedMilliseconds}ms");
sw.Restart();
for (int i = 0; i < epochs; i++)
{
sess.run(optimizer, (X, (x_train)), (Y_, (y_train)));
// Calculate and display the batch loss and accuracy
var result = sess.run(new[] {loss}, new FeedItem(X, x_train), new FeedItem(Y_, y_train));
print($"iter {i.ToString("000")}: {sw.ElapsedMilliseconds}ms");
sw.Restart();
saver.save(sess, @"E:\Downloads\ciao.ckpt");
}
}
public Bitmap Evaluate(List<Bitmap> inputImages)
{
var sw = new Stopwatch();
sw.Start();
NDArray x_eval = GenerateDataset(inputImages);
var output = sess.run(Y, new FeedItem(X, x_eval));
sw.Stop();
print($"Inference done in {sw.ElapsedMilliseconds}ms");
output = 255 * output;
return image(output[0].astype(NPTypeCode.Byte));
}
public NDArray GenerateDataset(List<Bitmap> imgs)
{
return np.vstack(imgs.Select(img => img.ToNDArray(false, false)).ToArray());
}
public Bitmap image(NDArray nd)
{
return nd.ToBitmap(nd.shape[2], nd.shape[1]);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment