Last active
February 15, 2022 15:22
-
-
Save dsyme/ccb5d1e4ec4ba32405709fe9b2cda152 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env -S dotnet fsi | |
#r "nuget: DiffSharp-cuda, 1.0.6" | |
#r @"..\bin\Debug\net6.0\fwdsgd.dll" | |
#load "argparse.fsx" | |
open System.IO | |
open DiffSharp | |
open DiffSharp.Model | |
open DiffSharp.Util | |
open DiffSharp.Data | |
open Helpers | |
let parser = ArgumentParser() | |
parser.add_argument("--dir") | |
parser.add_argument("--lr", dflt="0.0002") | |
parser.add_argument("--lr_decay", dflt="0.0001") | |
parser.add_argument("--threshold", dflt="0.00001") | |
parser.add_argument("--n", dflt="25000") | |
parser.add_argument("--runs", dflt="1") | |
parser.add_argument("--batch_size", dflt="128") | |
parser.add_argument("--num_workers", dflt="8") | |
parser.add_argument("--valid_every", dflt="500") | |
parser.add_argument("--gc_every", help="GC every N iterations. Default is determined by model. Note DiffSharp needs regular GC for large models.") | |
parser.add_argument("--device", choices=["cpu";"cuda"], dflt="cpu") | |
parser.add_argument("--model", choices=["logreg"; "mlp"; "cnn"; "cnn2"; "cnn4"; "cnn4b"; "vgg16"; "resnet18"; "resnet50"], dflt="logreg") | |
parser.add_argument("--optimizer", choices=["sgd"; "sgdn"; "adam"], dflt="sgd") | |
parser.add_argument("--momentum", dflt="0.2") | |
parser.add_argument0("--skiprev") | |
parser.add_argument0("--skipfwd") | |
parser.parse_args() | |
let optimizer = parser.result("--optimizer") | |
let modelArg = parser.result("--model") | |
let dirArg = parser.result("--dir") | |
let lr = parser.resultFloat("--lr") | |
let lr_decay = parser.resultFloat("--lr_decay") | |
let n = parser.result("--n") |> int | |
let runs = parser.result("--runs") |> int | |
let batch_size = parser.resultInt("--batch_size") | |
let device = parser.result("--device") | |
let skiprev = parser.resultBool("--skiprev") | |
let skipfwd = parser.resultBool("--skipfwd") | |
let valid_every = parser.resultInt("--valid_every") | |
let momentum = parser.resultFloat("--momentum") | |
let threshold = parser.resultFloat("--threshold") | |
let gc_every_arg = parser.resultIntOption("--gc_every") | |
dsharp.config(backend=Backend.Torch, device= (if device = "cpu" then Device.CPU else Device.GPU)) | |
dsharp.seed(0) | |
let logreg_ps() = | |
[ "w1", Weight.kaiming(28*28, 10) | |
"b1", Weight.bias(10)] | |
|> parameters | |
let logreg_eval(ps: ParameterDict, x: Tensor) = | |
let x = x.view [-1; 28*28 ] | |
let x = x.matmul(ps["w1"]) | |
let x = x + ps["b1"] | |
// x = x.matmul(ps["w2"]) | |
// x = x + ps["b2"] | |
x | |
let logreg_loss(ps, x, target) = | |
let y = logreg_eval(ps, x) | |
let loss = dsharp.crossEntropyLoss(y, target) // TODO: check this is the same as cross_entropy in utils.py | |
let predicted = y.argmax(dim=1) | |
let num_correct = dsharp.eq(predicted, target).sum().toInt32() | |
loss, num_correct |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment