This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Early stopping details | |
n_epochs_stop = 5 | |
min_val_loss = np.Inf | |
epochs_no_improve = 0 | |
# Main loop | |
for epoch in range(n_epochs): | |
# Initialize validation loss for epoch | |
val_loss = 0 | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
for epoch in range(n_epochs): | |
for data, targets in trainloader: | |
# Generate predictions | |
out = model(data) | |
# Calculate loss | |
loss = criterion(out, targets) | |
# Backpropagation | |
loss.backward() | |
# Update model parameters | |
optimizer.step() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torchvision import datasets | |
from torch.utils.data import DataLoader | |
# Datasets from folders | |
data = { | |
'train': | |
datasets.ImageFolder(root=traindir, transform=image_transforms['train']), | |
'valid': | |
datasets.ImageFolder(root=validdir, transform=image_transforms['valid']), | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torchvision import transforms | |
# Image transformations | |
image_transforms = { | |
# Train uses data augmentation | |
'train': | |
transforms.Compose([ | |
transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)), | |
transforms.RandomRotation(degrees=15), | |
transforms.ColorJitter(), |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<html> | |
<header> | |
<title>Random Starting Abstract | |
</title> | |
<link rel="stylesheet" href="/static/css/main.css"> | |
<link rel="shortcut icon" href="/static/images/lstm.ico"> | |
<ul> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.models import load_model | |
import tensorflow as tf | |
def load_keras_model(): | |
"""Load in the pre-trained model""" | |
global model | |
model = load_model('../models/train-embeddings-rnn.h5') | |
# Required for model to work | |
global graph | |
graph = tf.get_default_graph() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flask import request | |
# User defined utility functions | |
from utils import generate_random_start, generate_from_seed | |
# Home page | |
@app.route("/", methods=['GET', 'POST']) | |
def home(): | |
"""Home page of app with form""" | |
# Create form |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>RNN Patent Writing</title> | |
<link rel="stylesheet" href="/static/css/main.css"> | |
<link rel="shortcut icon" href="/static/images/lstm.ico"> | |
</head> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flask import render_template | |
# Home page | |
@app.route("/", methods=['GET', 'POST']) | |
def home(): | |
"""Home page of app with form""" | |
# Create form | |
form = ReusableForm(request.form) | |
# Send template information to index.html |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from wtforms import (Form, TextField, validators, SubmitField, | |
DecimalField, IntegerField) | |
class ReusableForm(Form): | |
"""User entry form for entering specifics for generation""" | |
# Starting seed | |
seed = TextField("Enter a seed string or 'random':", validators=[ | |
validators.InputRequired()]) | |
# Diversity of predictions | |
diversity = DecimalField('Enter diversity:', default=0.8, |