Skip to content

Instantly share code, notes, and snippets.

View Hanrui-Wang's full-sized avatar
🎯
Focusing

Ryan Hanrui Wang Hanrui-Wang

🎯
Focusing
View GitHub Profile
@Hanrui-Wang
Hanrui-Wang / net.py
Created July 16, 2019 23:10
Basic Pytorch net definition
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# 1 input image channel, 6 output channels, 3x3 square convolution
@Hanrui-Wang
Hanrui-Wang / main.py
Created July 16, 2019 23:12
main feedforward net
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyper-parameters
@Hanrui-Wang
Hanrui-Wang / matplot.py
Last active July 17, 2019 18:10
basic plot using matplotlib
import matplotlib.pyplot as plt
import matplotlib
plt.figure(i, figsize=(5.5, 5))
plt.grid(True, linewidth=0.2)
plt.plot(step, reward, linestyle=line_style[methods[i][ii]], linewidth=line_width[methods[i][ii]], color=colors[methods[i][ii]])
plt.ylim(top=y_top[i], bottom=y_bottom[i])
plt.xlim(left=0, right=10000)
plt.rcParams.update({'font.size': 12})
isinstance(output_size, (int, tuple))
@Hanrui-Wang
Hanrui-Wang / rand.py
Last active July 17, 2019 18:10
random
import numpy as np
np.random.randint(0, h - new_h)
np.random.randn(in_size, out_size)
@Hanrui-Wang
Hanrui-Wang / transform.py
Created July 17, 2019 00:01
transform_pytorch
class Rescale(object):
"""Rescale the image in a sample to a given size.
Args:
output_size (tuple or int): Desired output size. If tuple, output is
matched to output_size. If int, smaller of image edges is matched
to output_size keeping aspect ratio the same.
"""
def __init__(self, output_size):
@Hanrui-Wang
Hanrui-Wang / dataset.py
Created July 17, 2019 00:08
how to create a dataset in the pytorch
class FaceLandmarksDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, csv_file, root_dir, transform=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
@Hanrui-Wang
Hanrui-Wang / plot_image_batch.py
Created July 17, 2019 00:40
Plot image batch from pytorch tensor
grid = utils.make_grid(images_batch)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
@Hanrui-Wang
Hanrui-Wang / dataloader.py
Created July 17, 2019 00:43
dataloader in pytorch
dataloader = DataLoader(transformed_dataset, batch_size=4,
shuffle=True, num_workers=4)
@Hanrui-Wang
Hanrui-Wang / image_folder.py
Created July 17, 2019 00:47
useage of ImageFolder
import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])