-
-
Save helxsz/74b927c2d2733f79bba0739ff05af171 to your computer and use it in GitHub Desktop.
Python script for automatically generating HED(https://github.com/s9xie/hed) network, compatitable with newest caffe(https://github.com/bvlc/caffe)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# HED batch processing script; modified from https://github.com/s9xie/hed/blob/master/examples/hed/HED-tutorial.ipynb | |
# Step 1: download the hed repo: https://github.com/s9xie/hed | |
# Step 2: download the models and protoxt, and put them under {caffe_root}/examples/hed/ | |
# Step 3: put this script under {caffe_root}/examples/hed/ | |
# Step 4: run the following script: | |
# python batch_hed.py --images_dir=/data/to/path/photos/ --hed_mat_dir=/data/to/path/hed_mat_files/ | |
# The code sometimes crashes after computation is done. Error looks like "Check failed: ... driver shutting down". You can just kill the job. | |
# For large images, it will produce gpu memory issue. Therefore, you better resize the images before running this script. | |
# Step 5: run the MATLAB post-processing script "PostprocessHED.m" | |
# https://github.com/phillipi/pix2pix/blob/master/scripts/edges/batch_hed.py | |
import numpy as np | |
import scipy.misc | |
import Image | |
import scipy.io | |
import os | |
import cv2 | |
import argparse | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='batch proccesing: photos->edges') | |
parser.add_argument('--caffe_root', dest='caffe_root', help='caffe root', default='../../', type=str) | |
parser.add_argument('--caffemodel', dest='caffemodel', help='caffemodel', default='./hed_pretrained_bsds.caffemodel', type=str) | |
parser.add_argument('--prototxt', dest='prototxt', help='caffe prototxt file', default='./deploy.prototxt', type=str) | |
parser.add_argument('--images_dir', dest='images_dir', help='directory to store input photos', type=str) | |
parser.add_argument('--hed_mat_dir', dest='hed_mat_dir', help='directory to store output hed edges in mat file', type=str) | |
parser.add_argument('--border', dest='border', help='padding border', type=int, default=128) | |
parser.add_argument('--gpu_id', dest='gpu_id', help='gpu id', type=int, default=1) | |
args = parser.parse_args() | |
return args | |
args = parse_args() | |
for arg in vars(args): | |
print('[%s] =' % arg, getattr(args, arg)) | |
# Make sure that caffe is on the python path: | |
caffe_root = args.caffe_root # this file is expected to be in {caffe_root}/examples/hed/ | |
import sys | |
sys.path.insert(0, caffe_root + 'python') | |
import caffe | |
import scipy.io as sio | |
if not os.path.exists(args.hed_mat_dir): | |
print('create output directory %s' % args.hed_mat_dir) | |
os.makedirs(args.hed_mat_dir) | |
imgList = os.listdir(args.images_dir) | |
nImgs = len(imgList) | |
print('#images = %d' % nImgs) | |
caffe.set_mode_gpu() | |
caffe.set_device(args.gpu_id) | |
# load net | |
net = caffe.Net(args.prototxt, args.caffemodel, caffe.TEST) | |
# pad border | |
border = args.border | |
for i in range(nImgs): | |
if i % 500 == 0: | |
print('processing image %d/%d' % (i, nImgs)) | |
im = Image.open(os.path.join(args.images_dir, imgList[i])) | |
in_ = np.array(im, dtype=np.float32) | |
in_ = np.pad(in_,((border, border),(border,border),(0,0)),'reflect') | |
in_ = in_[:,:,::-1] | |
in_ -= np.array((104.00698793,116.66876762,122.67891434)) | |
in_ = in_.transpose((2, 0, 1)) | |
# remove the following two lines if testing with cpu | |
# shape for input (data blob is N x C x H x W), set data | |
net.blobs['data'].reshape(1, *in_.shape) | |
net.blobs['data'].data[...] = in_ | |
# run net and take argmax for prediction | |
net.forward() | |
fuse = net.blobs['sigmoid-fuse'].data[0][0, :, :] | |
# get rid of the border | |
fuse = fuse[border:-border, border:-border] | |
# save hed file to the disk | |
name, ext = os.path.splitext(imgList[i]) | |
sio.savemat(os.path.join(args.hed_mat_dir, name + '.mat'), {'predict':fuse}) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: "FCN" | |
input: "data" | |
input_dim: 1 | |
input_dim: 3 | |
input_dim: 500 | |
input_dim: 500 | |
layer { bottom: 'data' top: 'conv1_1' name: 'conv1_1' type: "Convolution" | |
param { lr_mult: 1 decay_mult: 1 } param { lr_mult: 2 decay_mult: 0} | |
convolution_param { engine: CAFFE num_output: 64 pad: 35 kernel_size: 3 } } | |
layer { bottom: 'conv1_1' top: 'conv1_1' name: 'relu1_1' type: "ReLU" } | |
layer { bottom: 'conv1_1' top: 'conv1_2' name: 'conv1_2' type: "Convolution" | |
param { lr_mult: 1 decay_mult: 1 } param { lr_mult: 2 decay_mult: 0} | |
convolution_param { engine: CAFFE num_output: 64 pad: 1 kernel_size: 3 } } | |
layer { bottom: 'conv1_2' top: 'conv1_2' name: 'relu1_2' type: "ReLU" } | |
layer { name: 'pool1' bottom: 'conv1_2' top: 'pool1' type: "Pooling" | |
pooling_param { pool: MAX kernel_size: 2 stride: 2 } } | |
layer { name: 'conv2_1' bottom: 'pool1' top: 'conv2_1' type: "Convolution" | |
param { lr_mult: 1 decay_mult: 1 } param { lr_mult: 2 decay_mult: 0} | |
convolution_param { engine: CAFFE num_output: 128 pad: 1 kernel_size: 3 } } | |
layer { bottom: 'conv2_1' top: 'conv2_1' name: 'relu2_1' type: "ReLU" } | |
layer { bottom: 'conv2_1' top: 'conv2_2' name: 'conv2_2' type: "Convolution" | |
param { lr_mult: 1 decay_mult: 1 } param { lr_mult: 2 decay_mult: 0} | |
convolution_param { engine: CAFFE num_output: 128 pad: 1 kernel_size: 3 } } | |
layer { bottom: 'conv2_2' top: 'conv2_2' name: 'relu2_2' type: "ReLU" } | |
layer { bottom: 'conv2_2' top: 'pool2' name: 'pool2' type: "Pooling" | |
pooling_param { pool: MAX kernel_size: 2 stride: 2 } } | |
layer { bottom: 'pool2' top: 'conv3_1' name: 'conv3_1' type: "Convolution" | |
param { lr_mult: 1 decay_mult: 1 } param { lr_mult: 2 decay_mult: 0} | |
convolution_param { engine: CAFFE num_output: 256 pad: 1 kernel_size: 3 } } | |
layer { bottom: 'conv3_1' top: 'conv3_1' name: 'relu3_1' type: "ReLU" } | |
layer { bottom: 'conv3_1' top: 'conv3_2' name: 'conv3_2' type: "Convolution" | |
param { lr_mult: 1 decay_mult: 1 } param { lr_mult: 2 decay_mult: 0} | |
convolution_param { engine: CAFFE num_output: 256 pad: 1 kernel_size: 3 } } | |
layer { bottom: 'conv3_2' top: 'conv3_2' name: 'relu3_2' type: "ReLU" } | |
layer { bottom: 'conv3_2' top: 'conv3_3' name: 'conv3_3' type: "Convolution" | |
param { lr_mult: 1 decay_mult: 1 } param { lr_mult: 2 decay_mult: 0} | |
convolution_param { engine: CAFFE num_output: 256 pad: 1 kernel_size: 3 } } | |
layer { bottom: 'conv3_3' top: 'conv3_3' name: 'relu3_3' type: "ReLU" } | |
layer { bottom: 'conv3_3' top: 'pool3' name: 'pool3' type: "Pooling" | |
pooling_param { pool: MAX kernel_size: 2 stride: 2 } } | |
layer { bottom: 'pool3' top: 'conv4_1' name: 'conv4_1' type: "Convolution" | |
param { lr_mult: 1 decay_mult: 1 } param { lr_mult: 2 decay_mult: 0} | |
convolution_param { engine: CAFFE num_output: 512 pad: 1 kernel_size: 3 } } | |
layer { bottom: 'conv4_1' top: 'conv4_1' name: 'relu4_1' type: "ReLU" } | |
layer { bottom: 'conv4_1' top: 'conv4_2' name: 'conv4_2' type: "Convolution" | |
param { lr_mult: 1 decay_mult: 1 } param { lr_mult: 2 decay_mult: 0} | |
convolution_param { engine: CAFFE num_output: 512 pad: 1 kernel_size: 3 } } | |
layer { bottom: 'conv4_2' top: 'conv4_2' name: 'relu4_2' type: "ReLU" } | |
layer { bottom: 'conv4_2' top: 'conv4_3' name: 'conv4_3' type: "Convolution" | |
param { lr_mult: 1 decay_mult: 1 } param { lr_mult: 2 decay_mult: 0} | |
convolution_param { engine: CAFFE num_output: 512 pad: 1 kernel_size: 3 } } | |
layer { bottom: 'conv4_3' top: 'conv4_3' name: 'relu4_3' type: "ReLU" } | |
layer { bottom: 'conv4_3' top: 'pool4' name: 'pool4' type: "Pooling" | |
pooling_param { pool: MAX kernel_size: 2 stride: 2 } } | |
layer { bottom: 'pool4' top: 'conv5_1' name: 'conv5_1' type: "Convolution" | |
param { lr_mult: 100 decay_mult: 1 } param { lr_mult: 200 decay_mult: 0} | |
convolution_param { engine: CAFFE num_output: 512 pad: 1 kernel_size: 3 } } | |
layer { bottom: 'conv5_1' top: 'conv5_1' name: 'relu5_1' type: "ReLU" } | |
layer { bottom: 'conv5_1' top: 'conv5_2' name: 'conv5_2' type: "Convolution" | |
param { lr_mult: 100 decay_mult: 1 } param { lr_mult: 200 decay_mult: 0} | |
convolution_param { engine: CAFFE num_output: 512 pad: 1 kernel_size: 3 } } | |
layer { bottom: 'conv5_2' top: 'conv5_2' name: 'relu5_2' type: "ReLU" } | |
layer { bottom: 'conv5_2' top: 'conv5_3' name: 'conv5_3' type: "Convolution" | |
param { lr_mult: 100 decay_mult: 1 } param { lr_mult: 200 decay_mult: 0} | |
convolution_param { engine: CAFFE num_output: 512 pad: 1 kernel_size: 3 } } | |
layer { bottom: 'conv5_3' top: 'conv5_3' name: 'relu5_3' type: "ReLU" } | |
## DSN conv 1 ### | |
layer { name: 'score-dsn1' type: "Convolution" bottom: 'conv1_2' top: 'score-dsn1-up' | |
param { lr_mult: 0.01 decay_mult: 1 } param { lr_mult: 0.02 decay_mult: 0} | |
convolution_param { engine: CAFFE num_output: 1 kernel_size: 1 } } | |
layer { type: "Crop" name: 'crop' bottom: 'score-dsn1-up' bottom: 'data' top: 'upscore-dsn1' } | |
layer { type: "Sigmoid" name: "sigmoid-dsn1" bottom: "upscore-dsn1" top:"sigmoid-dsn1"} | |
### DSN conv 2 ### | |
layer { name: 'score-dsn2' type: "Convolution" bottom: 'conv2_2' top: 'score-dsn2' | |
param { lr_mult: 0.01 decay_mult: 1 } param { lr_mult: 0.02 decay_mult: 0} | |
convolution_param { engine: CAFFE num_output: 1 kernel_size: 1 } } | |
layer { type: "Deconvolution" name: 'upsample_2' bottom: 'score-dsn2' top: 'score-dsn2-up' | |
param { lr_mult: 0 decay_mult: 1 } param { lr_mult: 0 decay_mult: 0} | |
convolution_param { kernel_size: 4 stride: 2 num_output: 1 } } | |
layer { type: "Crop" name: 'crop' bottom: 'score-dsn2-up' bottom: 'data' top: 'upscore-dsn2' } | |
layer { type: "Sigmoid" name: "sigmoid-dsn2" bottom: "upscore-dsn2" top:"sigmoid-dsn2"} | |
### DSN conv 3 ### | |
layer { name: 'score-dsn3' type: "Convolution" bottom: 'conv3_3' top: 'score-dsn3' | |
param { lr_mult: 0.01 decay_mult: 1 } param { lr_mult: 0.02 decay_mult: 0} | |
convolution_param { engine: CAFFE num_output: 1 kernel_size: 1 } } | |
layer { type: "Deconvolution" name: 'upsample_4' bottom: 'score-dsn3' top: 'score-dsn3-up' | |
param { lr_mult: 0 decay_mult: 1 } param { lr_mult: 0 decay_mult: 0} | |
convolution_param { kernel_size: 8 stride: 4 num_output: 1 } } | |
layer { type: "Crop" name: 'crop' bottom: 'score-dsn3-up' bottom: 'data' top: 'upscore-dsn3' } | |
layer { type: "Sigmoid" name: "sigmoid-dsn3" bottom: "upscore-dsn3" top:"sigmoid-dsn3"} | |
###DSN conv 4### | |
layer { name: 'score-dsn4' type: "Convolution" bottom: 'conv4_3' top: 'score-dsn4' | |
param { lr_mult: 0.01 decay_mult: 1 } param { lr_mult: 0.02 decay_mult: 0} | |
convolution_param { engine: CAFFE num_output: 1 kernel_size: 1 } } | |
layer { type: "Deconvolution" name: 'upsample_8' bottom: 'score-dsn4' top: 'score-dsn4-up' | |
param { lr_mult: 0 decay_mult: 1 } param { lr_mult: 0 decay_mult: 0} | |
convolution_param { kernel_size: 16 stride: 8 num_output: 1 } } | |
layer { type: "Crop" name: 'crop' bottom: 'score-dsn4-up' bottom: 'data' top: 'upscore-dsn4' } | |
layer { type: "Sigmoid" name: "sigmoid-dsn4" bottom: "upscore-dsn4" top:"sigmoid-dsn4"} | |
###DSN conv 5### | |
layer { name: 'score-dsn5' type: "Convolution" bottom: 'conv5_3' top: 'score-dsn5' | |
param { lr_mult: 0.01 decay_mult: 1 } param { lr_mult: 0.02 decay_mult: 0} | |
convolution_param { engine: CAFFE num_output: 1 kernel_size: 1 } } | |
layer { type: "Deconvolution" name: 'upsample_16' bottom: 'score-dsn5' top: 'score-dsn5-up' | |
param { lr_mult: 0 decay_mult: 1 } param { lr_mult: 0 decay_mult: 0} | |
convolution_param { kernel_size: 32 stride: 16 num_output: 1 } } | |
layer { type: "Crop" name: 'crop' bottom: 'score-dsn5-up' bottom: 'data' top: 'upscore-dsn5' } | |
layer { type: "Sigmoid" name: "sigmoid-dsn5" bottom: "upscore-dsn5" top:"sigmoid-dsn5"} | |
### Concat and multiscale weight layer ### | |
layer { name: "concat" bottom: "upscore-dsn1" bottom: "upscore-dsn2" bottom: "upscore-dsn3" | |
bottom: "upscore-dsn4" bottom: "upscore-dsn5" top: "concat-upscore" type: "Concat" | |
concat_param { concat_dim: 1} } | |
layer { name: 'new-score-weighting' type: "Convolution" bottom: 'concat-upscore' top: 'upscore-fuse' | |
param { lr_mult: 0.01 decay_mult: 1 } param { lr_mult: 0.02 decay_mult: 0} | |
convolution_param { engine: CAFFE num_output: 1 kernel_size: 1 weight_filler {type: "constant" value: 0.2} } } | |
layer { type: "Sigmoid" name: "sigmoid-fuse" bottom: "upscore-fuse" top:"sigmoid-fuse"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys, os | |
sys.path.insert(0, 'path/to/caffe/python') | |
import caffe | |
from caffe import layers as L, params as P | |
from caffe.coord_map import crop | |
import numpy as np | |
def conv_relu(bottom, nout, ks=3, stride=1, pad=1, mult=[1,1,2,0]): | |
conv = L.Convolution(bottom, kernel_size=ks, stride=stride, | |
num_output=nout, pad=pad, weight_filler=dict(type='xavier'), | |
param=[dict(lr_mult=mult[0], decay_mult=mult[1]), dict(lr_mult=mult[2], decay_mult=mult[3])]) | |
return conv, L.ReLU(conv, in_place=True) | |
def max_pool(bottom, ks=2, stride=2): | |
return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride) | |
def full_conv(bottom, name, lr): | |
return L.Convolution(bottom, name=name, kernel_size=1,num_output=1,# weight_filler=dict(type='xavier'), | |
param=[dict(lr_mult=0.01*lr, decay_mult=1), dict(lr_mult=0.02*lr, decay_mult=0)]) | |
def fcn(split): | |
n = caffe.NetSpec() | |
n.data = L.Input(name = 'data', input_param=dict(shape=dict(dim=[1,3,500,500]))) | |
if split=='train': | |
n.label = L.Input(name='label', input_param=dict(shape=dict(dim=[1,1,500,500]))) | |
n.conv1_1, n.relu1_1 = conv_relu(n.data, 64, pad=100) | |
n.conv1_2, n.relu1_2 = conv_relu(n.relu1_1, 64) | |
n.pool1 = max_pool(n.relu1_2) | |
n.conv2_1, n.relu2_1 = conv_relu(n.pool1, 128) | |
n.conv2_2, n.relu2_2 = conv_relu(n.relu2_1, 128) | |
n.pool2 = max_pool(n.relu2_2) | |
n.conv3_1, n.relu3_1 = conv_relu(n.pool2, 256) | |
n.conv3_2, n.relu3_2 = conv_relu(n.relu3_1, 256) | |
n.conv3_3, n.relu3_3 = conv_relu(n.relu3_2, 256) | |
n.pool3 = max_pool(n.relu3_3) | |
n.conv4_1, n.relu4_1 = conv_relu(n.pool3, 512) | |
n.conv4_2, n.relu4_2 = conv_relu(n.relu4_1, 512) | |
n.conv4_3, n.relu4_3 = conv_relu(n.relu4_2, 512) | |
n.pool4 = max_pool(n.relu4_3) | |
n.conv5_1, n.relu5_1 = conv_relu(n.pool4, 512, mult=[100,1,200,0]) | |
n.conv5_2, n.relu5_2 = conv_relu(n.relu5_1, 512, mult=[100,1,200,0]) | |
n.conv5_3, n.relu5_3 = conv_relu(n.relu5_2, 512, mult=[100,1,200,0]) | |
# DSN1 | |
n.score_dsn1=full_conv(n.conv1_2, 'score-dsn1', lr=1) | |
n.upscore_dsn1 = crop(n.score_dsn1, n.data) | |
if split=='train': | |
n.loss1 = L.SigmoidCrossentropyLoss(n.upscore_dsn1, n.label) | |
if split=='test': | |
n.sigmoid_dsn1 = L.Sigmoid(n.upscore_dsn1) | |
# n.sigmoid_dsn1 = L.Sigmoid(n.upscore_dsn1) | |
# DSN2 | |
n.score_dsn2=full_conv(n.conv2_2, 'score-dsn2', lr=1) | |
n.score_dsn2_up = L.Deconvolution(n.score_dsn2, name='upsample_2', | |
convolution_param=dict(num_output=1, kernel_size=4, stride=2), | |
param=[dict(lr_mult=0, decay_mult=1), dict(lr_mult=0, decay_mult=0)]) | |
n.upscore_dsn2 = crop(n.score_dsn2_up, n.data) | |
if split=='train': | |
n.loss2 = L.SigmoidCrossentropyLoss(n.upscore_dsn2, n.label) | |
if split=='test': | |
n.sigmoid_dsn2 = L.Sigmoid(n.upscore_dsn2) | |
# n.sigmoid_dsn2 = L.Sigmoid(n.upscore_dsn2) | |
# DSN3 | |
n.score_dsn3=full_conv(n.conv3_3, 'score-dsn3', lr=1) | |
n.score_dsn3_up = L.Deconvolution(n.score_dsn3, name='upsample_4', | |
convolution_param=dict(num_output=1, kernel_size=8, stride=4), | |
param=[dict(lr_mult=0, decay_mult=1), dict(lr_mult=0, decay_mult=0)]) | |
n.upscore_dsn3 = crop(n.score_dsn3_up, n.data) | |
if split=='train': | |
n.loss3 = L.SigmoidCrossentropyLoss(n.upscore_dsn3, n.label) | |
if split=='test': | |
n.sigmoid_dsn3 = L.Sigmoid(n.upscore_dsn3) | |
# n.sigmoid_dsn3 = L.Sigmoid(n.upscore_dsn3) | |
# DSN4 | |
n.score_dsn4=full_conv(n.conv4_3, 'score-dsn4', lr=1) | |
n.score_dsn4_up = L.Deconvolution(n.score_dsn4, name='upsample_8', | |
convolution_param=dict(num_output=1, kernel_size=16, stride=8), | |
param=[dict(lr_mult=0, decay_mult=1), dict(lr_mult=0, decay_mult=0)]) | |
n.upscore_dsn4 = crop(n.score_dsn4_up, n.data) | |
if split=='train': | |
n.loss4 = L.SigmoidCrossentropyLoss(n.upscore_dsn4, n.label) | |
if split=='test': | |
n.sigmoid_dsn4 = L.Sigmoid(n.upscore_dsn4) | |
# n.sigmoid_dsn4 = L.Sigmoid(n.upscore_dsn4) | |
# DSN5 | |
n.score_dsn5=full_conv(n.conv5_3, 'score-dsn5', lr=1) | |
n.score_dsn5_up = L.Deconvolution(n.score_dsn5, name='upsample_16', | |
convolution_param=dict(num_output=1, kernel_size=32, stride=16), | |
param=[dict(lr_mult=0, decay_mult=1), dict(lr_mult=0, decay_mult=0)]) | |
n.upscore_dsn5 = crop(n.score_dsn5_up, n.data) | |
if split=='train': | |
n.loss5 = L.SigmoidCrossentropyLoss(n.upscore_dsn5, n.label) | |
if split=='test': | |
n.sigmoid_dsn5 = L.Sigmoid(n.upscore_dsn5) | |
# n.sigmoid_dsn5 = L.Sigmoid(n.upscore_dsn5) | |
# concat and fuse | |
n.concat_upscore = L.Concat(n.upscore_dsn1, | |
n.upscore_dsn2, | |
n.upscore_dsn3, | |
n.upscore_dsn4, | |
n.upscore_dsn5, | |
name='concat', concat_param=dict({'concat_dim':1})) | |
n.upscore_fuse = L.Convolution(n.concat_upscore, name='new-score-weighting', | |
num_output=1, kernel_size=1, | |
param=[dict(lr_mult=0.001, decay_mult=1), dict(lr_mult=0.002, decay_mult=0)], | |
weight_filler=dict(type='constant', value=0.2)) | |
if split=='test': | |
n.sigmoid_fuse = L.Sigmoid(n.upscore_fuse) | |
if split=='train': | |
n.loss_fuse = L.SigmoidCrossentropyLoss(n.upscore_fuse, n.label) | |
return n.to_proto() | |
def make_net(): | |
with open('hed_train.pt', 'w') as f: | |
f.writelines(os.linesep+'force_backward: true'+os.linesep) | |
f.write(str(fcn('train'))) | |
with open('hed_test.pt', 'w') as f: | |
f.write(str(fcn('test'))) | |
def make_solver(): | |
sp = {} | |
sp['net'] = '"train.pt"' | |
sp['base_lr'] = '0.001' | |
sp['lr_policy'] = '"step"' | |
sp['momentum'] = '0.9' | |
sp['weight_decay'] = '0.0002' | |
sp['iter_size'] = '10' | |
sp['stepsize'] = '1000' | |
sp['display'] = '20' | |
sp['snapshot'] = '100000' | |
sp['snapshot_prefix'] = '"net"' | |
sp['gamma'] = '0.1' | |
sp['max_iter'] = '100000' | |
sp['solver_mode'] = 'CPU' | |
f = open('solver.pt', 'w') | |
for k, v in sorted(sp.items()): | |
if not(type(v) is str): | |
raise TypeError('All solver parameters must be strings') | |
f.write('%s: %s\n'%(k, v)) | |
f.close() | |
if __name__ == '__main__': | |
make_net() | |
make_solver() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment