-
-
Save azamsharp/1b37e37bf99b8c3306583ad9f84877c1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import requests | |
import os | |
import shutil | |
import numpy as np | |
from numpy import save | |
import turicreate as tc | |
random_state = np.random.RandomState(100) | |
current_directory = os.getcwd() | |
quickdraw_directory = current_directory + '/quickdraw' | |
bitmap_directory = quickdraw_directory + '/bitmap' | |
sframes_directory = quickdraw_directory + '/sframes' | |
npy_ext = '.npy' | |
num_examples_per_class = 100 | |
classes = ["square", "triangle"] | |
num_classes = len(classes) | |
categories = ["cat", "cake"] | |
# remove directories | |
def remove_directory(path): | |
try: | |
shutil.rmtree(path) | |
except: | |
print('Error deleting directory!') | |
# get bitmap SFrame | |
def get_bitmap_sframe(): | |
labels, drawings = [], [] | |
for category in categories: | |
data = np.load( | |
bitmap_directory + '/' + category + '.npy', | |
allow_pickle = True | |
) | |
# shuffle the data | |
random_state.shuffle(data) | |
sampled_data = data[:training_samples] | |
transformed_data = sampled_data.reshape( | |
sampled_data.shape[0], 28, 28, 1 | |
) | |
for pixel_data in transformed_data: | |
image = tc.Image(_image_data=np.invert(pixel_data).tobytes(), | |
_width=pixel_data.shape[1], | |
_height=pixel_data.shape[0], | |
_channels=pixel_data.shape[2], | |
_format_enum=2, | |
_image_data_size=pixel_data.size | |
) | |
drawings.append(image) | |
labels.append(category) | |
return tc.SFrame({'drawing': drawings, 'label': labels}) | |
# create directories | |
def make_directory(path): | |
try: | |
os.makedirs(path) | |
except: | |
print('Exception happened!') | |
# fetch data | |
def fetchData(): | |
url = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap' | |
total_categories = len(categories) | |
for index, category in enumerate(categories): | |
bitmap_filename = '/' + category + '.npy' | |
response = requests.get(url + bitmap_filename) | |
save(bitmap_directory + bitmap_filename,response.content) | |
print(f'Downloaded: {category} - Total Categories: {index + 1}/{total_categories}') | |
print(categories) | |
# setup directories | |
#remove_directory(quickdraw_directory) | |
#remove_directory(bitmap_directory) | |
#make_directory(quickdraw_directory) | |
#make_directory(bitmap_directory) | |
#fetchData() | |
bitmap_sframe = get_bitmap_sframe() | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Traceback (most recent call last): | |
File "train_drawing_classifier.py", line 123, in <module> | |
bitmap_sframe = get_bitmap_sframe() | |
File "train_drawing_classifier.py", line 68, in get_bitmap_sframe | |
random_state.shuffle(data) | |
File "mtrand.pyx", line 4417, in numpy.random.mtrand.RandomState.shuffle | |
TypeError: len() of unsized object |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment