Created
March 25, 2020 11:14
-
-
Save Ehsan1997/dce2cbc529f9b3a9b82a70c8e6eb3bdd to your computer and use it in GitHub Desktop.
Script to Generate Image Classification Dataset from Google Images
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from selenium import webdriver | |
from selenium.webdriver.firefox.options import Options | |
import time | |
import urllib.request | |
from PIL import Image | |
import os | |
def cr_folder(folder_name): | |
if not os.path.exists(folder_name): | |
os.makedirs(folder_name) | |
# --Some important variables-- | |
# Number of seconds to wait after scroll ('2' - Works for my system) | |
sleep_time = 2 | |
# Number of minimum images needed for extraction (Mostly it's a multiple of 100) | |
# This is a very buggy feature, hit and trial for now. | |
n_min = 500 | |
# Button Class (Some sites, need clicking a button to further load the images) | |
btn_class = "mye4qd" | |
# Image tile class, in order to get images from the site | |
img_class = 'rg_i Q4LuWd tx8vtf' | |
# Classes | |
classes = ['Torterra', 'Chimchar', 'Monferno', 'Greninja', 'Bidoof'] | |
# Name of the parent folder to create | |
parent_folder = "Pokemon_5_Dataset" | |
# Size of the images to save | |
im_size = (80,80) | |
cr_folder(parent_folder) | |
options = Options() | |
options.headless = True | |
url_dict = dict() | |
for c in classes: | |
print(f'Processing for class: {c}') | |
class_folder = parent_folder + '/' + c | |
cr_folder(class_folder) | |
browser = webdriver.Firefox(options=options) | |
browser.get(f'https://www.google.com/search?tbm=isch&q={c}') | |
ads = browser.find_elements_by_xpath(f"//img[@class='{img_class}']") | |
n = len(ads) | |
while n < n_min: | |
print('Scrolling Down!!') | |
browser.execute_script("window.scrollTo(0, document.body.scrollHeight);") | |
time.sleep(sleep_time) | |
ads = browser.find_elements_by_xpath(f"//img[@class='{img_class}']") | |
n_last = n | |
n = len(ads) | |
print(n) | |
if n == n_last: | |
btn = browser.find_elements_by_xpath(f"//input[@class='{btn_class}']")[0] | |
browser.execute_script("arguments[0].click();",btn) | |
time.sleep(sleep_time) | |
ads = browser.find_elements_by_xpath(f"//img[@class='{img_class}']") | |
n = len(ads) | |
if n == n_last: | |
break | |
failed_count = 0 | |
for i, element in enumerate(ads): | |
print(element.get_attribute('alt')) | |
url = element.get_attribute('src') | |
if url == None: | |
url = element.get_attribute('data-src') | |
print(url) | |
if url != None: | |
try: | |
image = Image.open(urllib.request.urlopen(url)) | |
image = image.resize(im_size) | |
image.save(class_folder + '/' + f"{i-failed_count}.jpg") | |
except: | |
failed_count += 1 | |
print(image.size) | |
print('_________________') | |
print(len(ads)) | |
print("Failed attempts: ", failed_count) | |
browser.quit() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment