Skip to content

Instantly share code, notes, and snippets.

@enijkamp
Created August 18, 2018 00:14
Show Gist options
  • Save enijkamp/55cef624a0fba8a4e2ca221b386b0d98 to your computer and use it in GitHub Desktop.
Save enijkamp/55cef624a0fba8a4e2ca221b386b0d98 to your computer and use it in GitHub Desktop.
acd fashion-mnist
# Map Cooperative Networks
import shutil
import csv
from map_code.coop_elm import *
import argparse
from acd.models import _netE, _netG, _netI
parser = argparse.ArgumentParser()
# mode (map, viz, or tree)
parser.add_argument('--mode', default='map', help='mode for code (map, viz, or tree)')
parser.add_argument('--delete_log', type=bool, default=True, help='delete previous ELM record')
# experiment and file info
parser.add_argument('--map_exp_id', type=int, default=1, help='ID of ELM experiment')
parser.add_argument('--category', default='cifar', help='Image category')
parser.add_argument('--net_dir', default='.', help='location of nets')
parser.add_argument('--gen_net_file', default='netG_epoch_199.pth', help='file for trained generator net')
parser.add_argument('--des_net_file', default='netE_epoch_199.pth', help='file for trained descriptor net')
parser.add_argument('--inf_net_file', default='netI_epoch_199.pth', help='file for trained descriptor net')
parser.add_argument('--load_state_dict', type=bool, default=True, help='load nets from state_dict or not')
parser.add_argument('--elm_out', default='./map_out3', help='height / width of the input image to network')
parser.add_argument('--elm_file_in', default='ELM', help='name of file for loading ELM')
parser.add_argument('--elm_file_out', default='ELM', help='name of file for ELM results')
parser.add_argument('--gpu_id', type=int, default=3, help='id for gpu (-1 for cpu, otherwise gpu id >=0)')
# mapping parameters
parser.add_argument('--num_steps', type=int, default=1000, help='number of mapping iterations')
parser.add_argument('--num_mins', type=int, default=1000, help='number of minima in record')
parser.add_argument('--continue_elm', type=bool, default=False, help='load previous ELM record')
parser.add_argument('--log_step', type=int, default=1, help='frequency for logging ELM record')
parser.add_argument('--en_min_limit', type=float, default=-40000, help='cutoff for min en (ignore degeneracy)')
# parameters for gaussian energy prior
parser.add_argument('--ref_normal', default='none', help='normal prior space (none, latent, image, both)')
parser.add_argument('--z_sigma_sq', type=float, default=1.0, help='variance of latent prior')
parser.add_argument('--im_sigma_sq', type=float, default=.10, help='variance of image prior')
#parser.add_argument('--vae_penalty', type=float, default=4e+4)
parser.add_argument('--vae_penalty', type=float, default=1e+6)
# size of image and latent space
parser.add_argument('--image_size', type=int, default=32, help='image height/width')
parser.add_argument('--num_channels', type=int, default=3, help='number of color channels (1 for gray, 3 for RGB')
parser.add_argument('--z_size', type=int, default=20, help='dimension of latent space')
# min search parameters
#parser.add_argument('--min_eps', type=float, default=5e-9, help='step size for min search')
parser.add_argument('--min_eps', type=float, default=1e-12, help='step size for min search')
parser.add_argument('--min_limit', type=int, default=5000, help='max number of steps for min search')
parser.add_argument('--min_window', type=int, default=250, help='energy improvement window for min search')
# AD parameters
parser.add_argument('--ad_eps', type=float, default=0.01, help='step size for AD trials')
parser.add_argument('--ad_alpha', type=float, default=100.0, help='magnetization strength for AD trials')
parser.add_argument('--ad_temp', type=float, default=150.0, help='temperature for AD trials')
parser.add_argument('--ad_noise', type=float, default=0.02, help='noise magnitude for ad langevin')
parser.add_argument('--dist_res', type=float, default=0.5, help='max separation for successful AD trial')
parser.add_argument('--ad_limit', type=int, default=50000, help='max number of steps for AD trial')
parser.add_argument('--ad_window', type=int, default=200, help='distance improvement window for AD trial')
parser.add_argument('--max_ad_pairs', type=int, default=10, help='max number of AD pairings in each iteration')
parser.add_argument('--ad_reps', type=int, default=0, help='number of AD trials between each pair')
parser.add_argument('--update_mins', type=bool, default=True, help='update basin representatives during training')
parser.add_argument('--print_ad_log', type=bool, default=True, help='print long log of AD trials')
parser.add_argument('--ad_type', default='mh', help='mode for AD trials (langevin or mh)')
# get config for mapping
config = parser.parse_args()
config.elm_out_dir = os.path.join(config.elm_out, config.category, 'exp{}/'.format(config.map_exp_id))
config.elm_out_ims = os.path.join(config.elm_out_dir, 'ims/')
config.elm_out_data = os.path.join(config.elm_out_dir, 'maps/')
def main():
# start new mapping or continue old mapping
if config.mode == 'map':
# delete old results
if config.delete_log and not config.continue_elm:
if os.path.exists(config.elm_out_dir):
shutil.rmtree(config.elm_out_dir)
# create directory for results
if not os.path.exists(config.elm_out):
os.mkdir(config.elm_out)
if not os.path.exists(os.path.join(config.elm_out, config.category)):
os.mkdir(os.path.join(config.elm_out, config.category))
if not os.path.exists(config.elm_out_dir):
os.mkdir(config.elm_out_dir)
if not os.path.exists(config.elm_out_ims):
os.mkdir(config.elm_out_ims)
if not os.path.exists(config.elm_out_data):
os.mkdir(config.elm_out_data)
# save config
with open(config.elm_out_dir + '/config' + str(config.map_exp_id) + '.csv', 'w') as f:
w = csv.writer(f)
for arg in vars(config):
w.writerow([arg, getattr(config, arg)])
with open(os.path.join(config.elm_out_dir, 'config.pkl'), 'wb') as f:
pickle.dump(config, f)
# load nets and coop explorer
if config.load_state_dict:
nz = 20
nc = 3
ndf = 64
nez = 1
ngf = 64
nif = 64
netG = _netG(nz, nc, ndf)
netE = _netE(nc, nez, ngf)
netI = _netI(nc, nz, nif)
netG.load_state_dict(torch.load(os.path.join(config.net_dir, config.gen_net_file)))
netE.load_state_dict(torch.load(os.path.join(config.net_dir, config.des_net_file)))
netI.load_state_dict(torch.load(os.path.join(config.net_dir, config.inf_net_file)))
netG.eval()
netE.eval()
netI.eval()
else:
raise Exception('not implemented')
explorer = CoopNetExplorer(config, netG, netE, netI)
# map nets
mapper = CoopNetMapper(explorer)
mapper.map()
# viz images from old mapping
elif config.mode == 'viz':
config.continue_elm = True
explorer = CoopNetExplorer(config)
CoopNetMapper(explorer)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment