Created
August 18, 2018 00:14
-
-
Save enijkamp/55cef624a0fba8a4e2ca221b386b0d98 to your computer and use it in GitHub Desktop.
acd fashion-mnist
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Map Cooperative Networks | |
import shutil | |
import csv | |
from map_code.coop_elm import * | |
import argparse | |
from acd.models import _netE, _netG, _netI | |
parser = argparse.ArgumentParser() | |
# mode (map, viz, or tree) | |
parser.add_argument('--mode', default='map', help='mode for code (map, viz, or tree)') | |
parser.add_argument('--delete_log', type=bool, default=True, help='delete previous ELM record') | |
# experiment and file info | |
parser.add_argument('--map_exp_id', type=int, default=1, help='ID of ELM experiment') | |
parser.add_argument('--category', default='cifar', help='Image category') | |
parser.add_argument('--net_dir', default='.', help='location of nets') | |
parser.add_argument('--gen_net_file', default='netG_epoch_199.pth', help='file for trained generator net') | |
parser.add_argument('--des_net_file', default='netE_epoch_199.pth', help='file for trained descriptor net') | |
parser.add_argument('--inf_net_file', default='netI_epoch_199.pth', help='file for trained descriptor net') | |
parser.add_argument('--load_state_dict', type=bool, default=True, help='load nets from state_dict or not') | |
parser.add_argument('--elm_out', default='./map_out3', help='height / width of the input image to network') | |
parser.add_argument('--elm_file_in', default='ELM', help='name of file for loading ELM') | |
parser.add_argument('--elm_file_out', default='ELM', help='name of file for ELM results') | |
parser.add_argument('--gpu_id', type=int, default=3, help='id for gpu (-1 for cpu, otherwise gpu id >=0)') | |
# mapping parameters | |
parser.add_argument('--num_steps', type=int, default=1000, help='number of mapping iterations') | |
parser.add_argument('--num_mins', type=int, default=1000, help='number of minima in record') | |
parser.add_argument('--continue_elm', type=bool, default=False, help='load previous ELM record') | |
parser.add_argument('--log_step', type=int, default=1, help='frequency for logging ELM record') | |
parser.add_argument('--en_min_limit', type=float, default=-40000, help='cutoff for min en (ignore degeneracy)') | |
# parameters for gaussian energy prior | |
parser.add_argument('--ref_normal', default='none', help='normal prior space (none, latent, image, both)') | |
parser.add_argument('--z_sigma_sq', type=float, default=1.0, help='variance of latent prior') | |
parser.add_argument('--im_sigma_sq', type=float, default=.10, help='variance of image prior') | |
#parser.add_argument('--vae_penalty', type=float, default=4e+4) | |
parser.add_argument('--vae_penalty', type=float, default=1e+6) | |
# size of image and latent space | |
parser.add_argument('--image_size', type=int, default=32, help='image height/width') | |
parser.add_argument('--num_channels', type=int, default=3, help='number of color channels (1 for gray, 3 for RGB') | |
parser.add_argument('--z_size', type=int, default=20, help='dimension of latent space') | |
# min search parameters | |
#parser.add_argument('--min_eps', type=float, default=5e-9, help='step size for min search') | |
parser.add_argument('--min_eps', type=float, default=1e-12, help='step size for min search') | |
parser.add_argument('--min_limit', type=int, default=5000, help='max number of steps for min search') | |
parser.add_argument('--min_window', type=int, default=250, help='energy improvement window for min search') | |
# AD parameters | |
parser.add_argument('--ad_eps', type=float, default=0.01, help='step size for AD trials') | |
parser.add_argument('--ad_alpha', type=float, default=100.0, help='magnetization strength for AD trials') | |
parser.add_argument('--ad_temp', type=float, default=150.0, help='temperature for AD trials') | |
parser.add_argument('--ad_noise', type=float, default=0.02, help='noise magnitude for ad langevin') | |
parser.add_argument('--dist_res', type=float, default=0.5, help='max separation for successful AD trial') | |
parser.add_argument('--ad_limit', type=int, default=50000, help='max number of steps for AD trial') | |
parser.add_argument('--ad_window', type=int, default=200, help='distance improvement window for AD trial') | |
parser.add_argument('--max_ad_pairs', type=int, default=10, help='max number of AD pairings in each iteration') | |
parser.add_argument('--ad_reps', type=int, default=0, help='number of AD trials between each pair') | |
parser.add_argument('--update_mins', type=bool, default=True, help='update basin representatives during training') | |
parser.add_argument('--print_ad_log', type=bool, default=True, help='print long log of AD trials') | |
parser.add_argument('--ad_type', default='mh', help='mode for AD trials (langevin or mh)') | |
# get config for mapping | |
config = parser.parse_args() | |
config.elm_out_dir = os.path.join(config.elm_out, config.category, 'exp{}/'.format(config.map_exp_id)) | |
config.elm_out_ims = os.path.join(config.elm_out_dir, 'ims/') | |
config.elm_out_data = os.path.join(config.elm_out_dir, 'maps/') | |
def main(): | |
# start new mapping or continue old mapping | |
if config.mode == 'map': | |
# delete old results | |
if config.delete_log and not config.continue_elm: | |
if os.path.exists(config.elm_out_dir): | |
shutil.rmtree(config.elm_out_dir) | |
# create directory for results | |
if not os.path.exists(config.elm_out): | |
os.mkdir(config.elm_out) | |
if not os.path.exists(os.path.join(config.elm_out, config.category)): | |
os.mkdir(os.path.join(config.elm_out, config.category)) | |
if not os.path.exists(config.elm_out_dir): | |
os.mkdir(config.elm_out_dir) | |
if not os.path.exists(config.elm_out_ims): | |
os.mkdir(config.elm_out_ims) | |
if not os.path.exists(config.elm_out_data): | |
os.mkdir(config.elm_out_data) | |
# save config | |
with open(config.elm_out_dir + '/config' + str(config.map_exp_id) + '.csv', 'w') as f: | |
w = csv.writer(f) | |
for arg in vars(config): | |
w.writerow([arg, getattr(config, arg)]) | |
with open(os.path.join(config.elm_out_dir, 'config.pkl'), 'wb') as f: | |
pickle.dump(config, f) | |
# load nets and coop explorer | |
if config.load_state_dict: | |
nz = 20 | |
nc = 3 | |
ndf = 64 | |
nez = 1 | |
ngf = 64 | |
nif = 64 | |
netG = _netG(nz, nc, ndf) | |
netE = _netE(nc, nez, ngf) | |
netI = _netI(nc, nz, nif) | |
netG.load_state_dict(torch.load(os.path.join(config.net_dir, config.gen_net_file))) | |
netE.load_state_dict(torch.load(os.path.join(config.net_dir, config.des_net_file))) | |
netI.load_state_dict(torch.load(os.path.join(config.net_dir, config.inf_net_file))) | |
netG.eval() | |
netE.eval() | |
netI.eval() | |
else: | |
raise Exception('not implemented') | |
explorer = CoopNetExplorer(config, netG, netE, netI) | |
# map nets | |
mapper = CoopNetMapper(explorer) | |
mapper.map() | |
# viz images from old mapping | |
elif config.mode == 'viz': | |
config.continue_elm = True | |
explorer = CoopNetExplorer(config) | |
CoopNetMapper(explorer) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment