Created
August 30, 2016 22:38
-
-
Save dineshj1/839e08576d441944fd6f36ca6896453b to your computer and use it in GitHub Desktop.
Training on Pycaffe
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import time | |
start_time=time.time(); | |
################## Argument Parsing ##################################### | |
parser=argparse.ArgumentParser(); | |
parser.add_argument('-s','--solver', default='', type=str); # if empty, solver is created, else read | |
parser.add_argument('-res', '--resume_from', default='', type=str); #if not empty, resumes training from given file | |
parser.add_argument('-ft', '--finetune_from', default='', type=str); | |
#parser.add_argument('-r','--rng_seed', default=242351, type=int); # not implemented | |
parser.add_argument('-d','--debug_mode', default=False, type=bool); | |
parser.add_argument('-p','--prefix', default='../clust_runs/', type=str); | |
parser.add_argument('--showfigs', default=False, type=bool); | |
parser.add_argument('--log_interval', default=2000, type=bool); | |
# TODO not implemented (pending net read/write) | |
#parser.add_argument('-t','--trn_data', default='./panocon_cls_trn.h5', type=str); | |
#parser.add_argument('-v','--val_data', default='./panocon_cls_val.h5', type=str); | |
#parser.add_argument('-v','--test_data', default='./panocon_cls_val.h5', type=str); | |
#parser.add_argument('-s','--split_partition', default=1, type=int); | |
#parser.add_argument('-lw','--loss_weight', default=1, type=float); | |
# Solver parameters taken over within Pycaffe | |
parser.add_argument('--solver_max_iter', default=10000, type=int); | |
parser.add_argument('--solver_display', default=20, type=int); | |
parser.add_argument('--solver_test_iter', default=20, type=int); | |
parser.add_argument('--solver_test_interval', default=200, type=int); | |
parser.add_argument('--sys_cmd', default='', type=str); # can be used to rsync etc. | |
# TODO Solver functions not implemented completely (pending solver read/write) | |
parser.add_argument('--solver_snapshot', default=1000, type=int); | |
parser.add_argument('--solver_snapshot_prefix', default='../clust_runs/caffe_snapshots/snap', type=str); | |
parser.add_argument('--solver_mode', default='GPU', type=str); | |
parser.add_argument('--solver_net', default='"../SUN360/clsnet_32_net.prototxt"', type=str); | |
parser.add_argument('--solver_base_lr', default=0.0001, type=float); | |
parser.add_argument('--solver_momentum', default=0.9, type=float); | |
parser.add_argument('--solver_weight_decay', default=5e-4, type=float); | |
parser.add_argument('--solver_lr_policy', default='"fixed"', type=str); | |
parser.add_argument('--caffe_pythonpath', default='/vision/vision_users/dineshj/caffe_vis/python/', type=str); | |
# early termination | |
parser.add_argument('-time','--max_time', default=float('inf'), type=float); # in minutes | |
parser.add_argument('-ET_w','--saturation_wait', default=0, type=int); | |
parser.add_argument('-ET_tar','--target_output', default='', type=str); | |
parser.add_argument('-ET_drn','--target_drn', default='h', type=str); # indicating that higher or lower is better | |
parser.add_argument('-ET_ov', '--overfit_margin', default=0, type=float); # indicating how close to target | |
parser.add_argument('-ET_perf', '--target_perfect', default=1.0, type=float); # indicating how close to target | |
args=parser.parse_args(); | |
print(args) | |
#np.random.seed(args.rng_seed); #doesn't affect in any way at the moment (Caffe uses a different random seed) | |
######################################################################### | |
print "Importing necessary libraries" | |
import sys | |
import matplotlib as mpl | |
if not args.showfigs: | |
mpl.use('Agg'); | |
import matplotlib.pyplot as plt | |
#caffe_pythonpath='/vision/vision_users/dineshj/caffe_vis/python/'; | |
sys.path.insert(0, args.caffe_pythonpath) | |
import caffe | |
#import lmdb | |
#from pylab import * | |
#import pylab | |
from matplotlib.figure import Figure | |
import numpy as np | |
import cPickle as pk | |
import scipy.io as sio | |
import pdb | |
import os | |
print("--- Runtime: %.2f secs ---" % (time.time()-start_time)) | |
def update_logs(all_op_names, train_op_names, test_op_names, train_ops, test_ops, args, best_test_iter, best_test_score, target_op): | |
plt.close('all'); | |
print "Creating and saving plots" | |
num_all_outputs=len(all_op_names); | |
for opno in range(num_all_outputs): | |
fig, ax_array=plt.subplots(nrows=2, sharex=True); | |
if all_op_names[opno] in train_op_names: | |
curr_op=train_ops[opno]; | |
ax_array[0].plot((np.arange(len(curr_op)))*args.solver_display, curr_op); ax_array[0].set_title('train %s' % all_op_names[opno]); | |
ax_array[0].axvline(x=best_test_iter, color='r'); | |
if all_op_names[opno] in test_op_names: | |
curr_op=test_ops[opno]; | |
ax_array[1].plot((np.arange(len(curr_op)))*args.solver_test_interval, curr_op); ax_array[1].set_title('test %s' % all_op_names[opno]); | |
ax_array[1].axvline(x=best_test_iter, color='r'); | |
if opno==target_op: | |
y_range=ax_array[1].get_ylim(); | |
y_ht=y_range[1]-y_range[0]; | |
scale=np.round(np.log10(y_ht))<=1; | |
if scale: | |
prec_str="%%%df"%(scale+3) | |
print prec_str | |
else: | |
prec_str="%d"; | |
ax_array[1].text(best_test_iter, best_test_score+y_ht*0.05, (prec_str%best_test_score).replace("-0", "-").lstrip("0") , color='r'); | |
if args.showfigs: | |
print "Trying to show plots" | |
try: | |
plt.ion(); | |
plt.show(); | |
except Exception as e: | |
print e.__doc__ | |
print e.message | |
print "Skipping showing figure. Saving directly." | |
fig_name_root="%s_%s" % (args.prefix, all_op_names[opno]); | |
print "Storing fig to %s(.png/.pkfig)" % fig_name_root | |
plt.savefig(fig_name_root+'.png'); | |
pk.dump(fig, file(fig_name_root + '.pkfig', 'w')); | |
matfilename="%s.mat" % args.prefix; | |
print "Saving records to %s" % matfilename; | |
sys.stdout.flush(); | |
sio.savemat(matfilename, | |
{ | |
'train_ops':train_ops, | |
'test_ops':test_ops, | |
'train_op_names':train_op_names, | |
'test_op_names':test_op_names, | |
'all_op_names':all_op_names, | |
'best_test_score': best_test_score, | |
'best_test_iter': best_test_iter | |
} | |
); | |
if args.sys_cmd: | |
print "Running sys cmd: %s" % args.sys_cmd; | |
try: | |
os.system(args.sys_cmd); | |
except Exception as e: | |
print e.__doc__ | |
print e.message | |
print "Cmd did not work." | |
######################################################################### | |
#print "Setting up network" | |
#from caffe import layers as L | |
#from caffe import params as P | |
#print "Beginning net definition" | |
#def net(lmdbname, batch_size): | |
# n = caffe.NetSpec() | |
# n.data, n.label = L.Data( | |
# batch_size=batch_size, | |
# backend=P.Data.LMDB, | |
# source=lmdbname, | |
# transform_param=dict( | |
# #mirror=True, | |
# #crop_size=227, | |
# #mean_file='/scratch/vision/dineshj/caffe2/data/ilsvrc12/imagenet_mean.binaryproto' | |
# ), | |
# ntop=2, | |
# ) | |
# n.data | |
# return n.to_proto() | |
#with open('auto_train.prototxt', 'w') as f: | |
# f.write(str(mini_net(lmdbname, 64))) | |
#print("--- Runtime: %.2f secs ---" % (time.time()-start_time)) | |
## Writing a solver | |
solver_file=args.solver; | |
if not solver_file: | |
print "Setting up solver" | |
solver_file="%s_solver.prototxt" % args.prefix; | |
print "Writing a solver" | |
solver_dict={}; | |
if not args.solver_net[0]=='"': | |
args.solver_net='"'+ args.solver_net + '"'; | |
if not args.solver_lr_policy[0]=='"': | |
args.solver_lr_policy='"'+ args.solver_lr_policy + '"'; | |
if not args.solver_snapshot_prefix[0]=='"': | |
solver_dict['snapshot_prefix']=str('"'+args.solver_snapshot_prefix + '"'); | |
solver_dict['net']=args.solver_net; | |
solver_dict['test_iter']=str(0); | |
solver_dict['test_interval']=str(int(args.solver_max_iter)*2); | |
solver_dict['base_lr']=str(args.solver_base_lr); | |
solver_dict['momentum']=str(args.solver_momentum); | |
solver_dict['weight_decay']=str(args.solver_weight_decay); | |
solver_dict['lr_policy']=str(args.solver_lr_policy); | |
solver_dict['display']=str(0); | |
solver_dict['max_iter']=str(args.solver_max_iter); | |
solver_dict['snapshot']=str(args.solver_snapshot); | |
solver_dict['solver_mode']=args.solver_mode; | |
with file(solver_file, 'w') as f: | |
for key in solver_dict: | |
f.write(key+':'+solver_dict[key]+'\n'); | |
print("--- Runtime: %.2f secs ---" % (time.time()-start_time)) | |
print "Loading solver" | |
#caffe.set_mode_cpu(); | |
solver=caffe.SGDSolver(solver_file); | |
if args.resume_from: | |
print "Resuming from %s" %(args.resume_from) | |
solver.restore(args.resume_from); | |
elif args.finetune_from: | |
print "Finetuning %s" %(args.finetune_from) | |
solver.net.copy_from(args.finetune_from); | |
print("--- Runtime: %.2f secs ---" % (time.time()-start_time)) | |
test_batch_sz=solver.test_nets[0].blobs.items()[0][1].shape[0]; | |
it=0; | |
# get outputs automatically from solver.net.outputs | |
train_op_names=solver.net.outputs; | |
#num_outputs=len(train_op_names); | |
# get test outputs automatically from solver.test_nets[0].outputs | |
test_op_names=solver.test_nets[0].outputs; | |
#num_test_outputs=len(test_op_names); | |
all_op_names= list(set(train_op_names) | set(test_op_names)); | |
num_all_outputs=len(all_op_names); | |
train_ops=np.zeros((num_all_outputs, max(args.solver_max_iter/args.solver_display, 1))); | |
test_ops=np.zeros((num_all_outputs, max(args.solver_max_iter/args.solver_test_interval,1))); | |
# Automatically determine the target variable to determine early termination based on, etc. | |
if args.target_output: | |
try: | |
target_op=all_op_names.index(args.target_output); | |
except Exception as e: | |
print e.__doc__ | |
print e.message | |
print "Could not find output %s. Setting to empty." % (args.target_output); | |
args.target_output=''; | |
if not args.target_output: | |
# TODO include a regex search for outputs starting with "target_" | |
if 'accuracy' in test_op_names: | |
target_op=all_op_names.index('accuracy'); | |
args.target_drn='h'; | |
elif 'cls_accuracy' in test_op_names: | |
target_op=all_op_names.index('cls_accuracy'); | |
args.target_drn='h'; | |
elif 'acc' in test_op_names: | |
target_op=all_op_names.index('acc'); | |
args.target_drn='h'; | |
elif 'cls_acc' in test_op_names: | |
target_op=all_op_names.index('cls_acc'); | |
args.target_drn='h'; | |
else: | |
target_op=0; # setting randomly to the first output | |
print "Setting target output to #%d:%s (dir: %s)" % (target_op, all_op_names[target_op], args.target_drn); | |
ET_on=False | |
if args.saturation_wait>0 or args.overfit_margin>0: | |
ET_on=True; | |
print "EARLY TERMINATION ON"; | |
args.solver_test_interval=np.round(args.solver_test_interval/args.solver_display)*args.solver_display; | |
args.log_interval=np.round(args.log_interval/args.solver_display)*args.solver_display; | |
best_test_score=np.NaN; | |
best_test_iter=0; | |
timeout_flag=False; | |
termination_flag=False; | |
if args.overfit_margin>0: | |
overfit_clear=False; | |
else: | |
overfit_clear=True; | |
if args.saturation_wait>0: | |
saturate_clear=False; | |
else: | |
saturate_clear=True; | |
overfit_window=5; | |
solver.net.forward(); | |
solver.test_nets[0].forward(); | |
if args.debug_mode: | |
pdb.set_trace() | |
print "Beginning iterations" | |
while it < args.solver_max_iter and not termination_flag: | |
sys.stdout.flush() | |
for opno in range(num_all_outputs): | |
if all_op_names[opno] in train_op_names: | |
train_ops[opno, it/args.solver_display]=solver.net.blobs[all_op_names[opno]].data; | |
print 'Py-iteration', it, 'training outputs ...' | |
print train_ops[:, it/args.solver_display] | |
if not overfit_clear: | |
if it/args.solver_display>=overfit_window: | |
train_score_running_avg = np.mean(train_ops[target_op, it/args.solver_display:it/args.solver_display-overfit_window:-1]); | |
if np.abs(train_score_running_avg-args.target_perfect) < args.overfit_margin: | |
overfit_clear=True; | |
print "Running avg: %f" % train_score_running_avg; | |
print "Overfit target achieved"; | |
runtime = (time.time()-start_time)/60; | |
if runtime>args.max_time: | |
print "Ran out of time. Finishing up early." | |
timeout_flag=True; | |
if it % args.solver_test_interval==0: | |
print '=====================================' | |
print("--- runtime: %.2f / %.2f mins ---" % (runtime, args.max_time)); | |
print 'Py-iteration', it, 'testing outputs ...' | |
correct=0; | |
loss_sum=0; | |
test_op_sum=np.zeros(num_all_outputs); | |
for test_it in range(args.solver_test_iter): | |
solver.test_nets[0].forward() | |
for opno in range(num_all_outputs): | |
if all_op_names[opno] in test_op_names: | |
test_op_sum[opno]+=solver.test_nets[0].blobs[all_op_names[opno]].data; | |
#correct+=sum(solver.test_nets[0].blobs['cls_ip2'].data.argmax(1) == solver.test_nets[0].blobs['cls_label'].data); | |
test_ops[:, it/args.solver_test_interval]=test_op_sum[:]/args.solver_test_iter; | |
#test_acc[it/args.solver_test_interval] = correct/(args.solver_test_iter*test_batch_sz); | |
print test_ops[:, it/args.solver_test_interval] | |
# implement early termination | |
if args.debug_mode: | |
pdb.set_trace() | |
if args.target_drn=='h': | |
if not best_test_score > test_ops[target_op, it/args.solver_test_interval]: # i.e. current score is highest | |
best_test_score=test_ops[target_op, it/args.solver_test_interval]; | |
best_test_iter=it; | |
print "%s (op #%d/%d) improved!" % (all_op_names[target_op], target_op+1, len(all_op_names)); | |
print(args.solver_snapshot_prefix + '_bestweights') | |
solver.net.save(args.solver_snapshot_prefix + '_bestweights'); # update saved best model | |
else: | |
print "best %s (op #%d/%d) so far: %f (at iter %d)" %(all_op_names[target_op], target_op+1, len(all_op_names), best_test_score, best_test_iter); | |
elif args.target_drn=='l': | |
if not best_test_score < test_ops[target_op, it/args.solver_test_interval]: # i.e. current score is highest | |
best_test_score=test_ops[target_op, it/args.solver_test_interval]; | |
best_test_iter=it; | |
print "%s (op #%d/%d) improved!" % (all_op_names[target_op], target_op+1, len(all_op_names)); | |
solver.net.save(args.solver_snapshot_prefix + '_bestweights'); # update saved best model | |
else: | |
print "best %s (op #%d/%d) so far: %f (at iter %d)" %(all_op_names[target_op], target_op+1, len(all_op_names), best_test_score, best_test_iter); | |
else: | |
raise NameError('Unknown target direction (target_drn) %s' % args.target_drn); | |
if overfit_clear and not saturate_clear: | |
if it - best_test_iter > args.saturation_wait: # time to quit! | |
saturate_clear=True; | |
print "Test performance saturation target cleared" | |
print '=====================================' | |
if it % args.log_interval==0: | |
update_logs(all_op_names, train_op_names, test_op_names, train_ops, test_ops, args, best_test_iter, best_test_score, target_op); | |
termination_flag = timeout_flag or (ET_on and overfit_clear and saturate_clear); | |
if not termination_flag: | |
solver.step(args.solver_display); | |
it=it+args.solver_display; | |
else: | |
print "Triggering early termination"; | |
if termination_flag: | |
train_ops=train_ops[:,:it/args.solver_display+1]; | |
test_ops=test_ops[:,:it/args.solver_test_interval+1]; | |
update_logs(all_op_names, train_op_names, test_op_names, train_ops, test_ops, args, best_test_iter, best_test_score, target_op); | |
elapsed_s=time.time()-start_time; | |
print("--- %.2f secs (%.2f mins) ---" % (elapsed_s, elapsed_s/60)) | |
sys.stdout.flush(); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment