Created
July 3, 2023 08:44
-
-
Save feliche93/0c928a9ca2ee8bc9b907173a007b3868 to your computer and use it in GitHub Desktop.
Discord Midjourney Image Automation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import os | |
from getpass import getpass | |
from pathlib import Path | |
from typing import Dict, List, Optional | |
import boto3 | |
import requests | |
from dotenv import load_dotenv | |
from playwright.async_api import Page, async_playwright | |
from sqlalchemy import create_engine, text | |
from sqlalchemy.engine.base import Engine | |
import time | |
from sqlalchemy.exc import OperationalError | |
from sqlalchemy.exc import OperationalError | |
load_dotenv(override=True) | |
def download_image(image_url: str, image_path: str, timeout: int = 5) -> str: | |
""" | |
Downloads an image from a provided URL and saves it to a local path. | |
Args: | |
image_url (str): URL of the image to download. | |
image_path (str): Local path where the image will be saved, including the image file name. | |
timeout (int): Maximum time, in seconds, to wait for the server's response. Default is 5 seconds. | |
Raises: | |
HTTPError: If there was an unsuccessful HTTP response. | |
Timeout: If the request times out. | |
Returns: | |
str: Local path where the image has been saved. | |
""" | |
response = requests.get(image_url, timeout=timeout) | |
response.raise_for_status() # Raise exception if invalid response. | |
with open(image_path, "wb") as f: | |
f.write(response.content) | |
return image_path | |
def upload_to_s3( | |
image_path: str, | |
bucket: str, | |
s3_image_name: str, | |
aws_access_key_id: str, | |
aws_secret_access_key: str, | |
region_name: str, | |
) -> str: | |
""" | |
Uploads an image file to an S3 bucket and returns the URL of the uploaded file. | |
Args: | |
image_path (str): Path to the image file to upload. | |
bucket (str): Name of the S3 bucket to upload to. | |
s3_image_name (str): Name to give to the file once it's uploaded. | |
aws_access_key_id (str): AWS access key ID. | |
aws_secret_access_key (str): AWS secret access key. | |
region_name (str): The name of the AWS region where the S3 bucket is located. | |
Returns: | |
str: URL of the uploaded image in the S3 bucket. | |
Raises: | |
ClientError: If there was an error uploading the file to S3. | |
""" | |
s3 = boto3.client( | |
"s3", aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key | |
) | |
with open(image_path, "rb") as f: | |
s3_path = "blog_post_covers/" + s3_image_name # prepend the S3 'folder' name | |
s3.upload_fileobj(f, bucket, s3_path) | |
# remove the image from the local filesystem | |
os.remove(image_path) | |
url = f"https://{bucket}.s3.{region_name}.amazonaws.com/{s3_path}" | |
return url | |
async def login_to_discord( | |
page: Page, | |
server_id: str, | |
channel_id: str, | |
email: Optional[str] = None, | |
password: Optional[str] = None, | |
auth_code: Optional[str] = None, | |
) -> None: | |
""" | |
Log in to Discord via a Playwright browser page. | |
Args: | |
page (Page): Playwright browser page instance. | |
server_id (str): Discord server ID to navigate to after login. | |
channel_id (str): Discord channel ID to navigate to after login. | |
email (Optional[str], optional): Email to use for logging in to Discord. Defaults to None. | |
password (Optional[str], optional): Password to use for logging in to Discord. Defaults to None. | |
auth_code (Optional[str], optional): Authentication code to use for logging in to Discord. Defaults to None. | |
Raises: | |
TimeoutError: If any of the page actions do not complete within the default timeout period. | |
""" | |
discord_channel_url = f"https://discord.com/channels/{server_id}/{channel_id}" | |
await page.goto(discord_channel_url) | |
await page.get_by_role("button", name="Continue in browser").click() | |
await page.get_by_label("Email or Phone Number*").click() | |
if not email: | |
email = input("Please enter your email: ") | |
await page.get_by_label("Email or Phone Number*").fill(email) | |
await page.get_by_label("Email or Phone Number*").press("Tab") | |
if not password: | |
password = getpass("Please enter your password: ") | |
await page.get_by_label("Password*").fill(password) | |
await page.get_by_role("button", name="Log In").click() | |
if not auth_code: | |
auth_code = input("Please enter your authentication code: ") | |
await page.get_by_placeholder("6-digit authentication code/8-digit backup code").fill(auth_code) | |
await page.get_by_role("button", name="Log In").click() | |
async def post_prompt(page: Page, prompt: str) -> None: | |
""" | |
Post a prompt message in Discord via a Playwright browser page. | |
Args: | |
page (Page): Playwright browser page instance. | |
prompt (str): The prompt to be posted in the message box. | |
Raises: | |
TimeoutError: If any of the page actions do not complete within the default timeout period. | |
""" | |
message_text_boy = page.get_by_role("textbox", name="Message #general").nth(0) | |
await message_text_boy.fill("/imagine ") | |
prompt_input = page.locator(".optionPillValue-2uxsMp").nth(0) | |
await prompt_input.fill(prompt, timeout=2000) | |
await message_text_boy.press("Enter", timeout=2000) | |
async def upscale_image(page: Page) -> None: | |
""" | |
Upscale an image on a Discord channel using the U1 button. | |
Args: | |
page (Page): Playwright browser page instance. | |
Raises: | |
TimeoutError: If any of the page actions do not complete within the default timeout period. | |
""" | |
last_message = page.locator(selector="li").last | |
upscale_1 = last_message.locator("button", has_text="U1") | |
# Wait for the upscale button to be visible | |
while not await upscale_1.is_visible(): | |
print("Upscale button is not yet available, waiting...") | |
await asyncio.sleep(5) # wait for 5 seconds | |
print("Upscale button is now available, clicking...") | |
await upscale_1.click(timeout=1000) | |
async def get_image_url( | |
page: Page, timeout: int = 1000, check_interval: int = 5, max_wait: int = 30 | |
) -> str: | |
""" | |
Get the href attribute of the last image link on the page, retrying until it exists and the 'Vary (Strong)' button is visible. | |
Args: | |
page (Page): Playwright browser page instance. | |
timeout (int): Maximum time, in milliseconds, to wait for the image link. Default is 1000 milliseconds. | |
check_interval (int): Time, in seconds, to wait between checks for the button and image link. Default is 5 seconds. | |
max_wait (int): Maximum time, in seconds, to wait before giving up. Default is 30 seconds. | |
Returns: | |
str: The href attribute of the last image link. | |
Raises: | |
TimeoutError: If the image link does not appear within the maximum wait time. | |
""" | |
last_message = page.locator(selector="li").last | |
vary_strong = last_message.locator("button", has_text="Vary (Strong)") | |
image_links = last_message.locator("xpath=//a[starts-with(@class, 'originalLink-')]") | |
start_time = time.time() | |
# Wait for the 'Vary (Strong)' button and an image link to appear | |
while True: | |
if await vary_strong.is_visible() and await image_links.count() > 0: | |
last_image_link = await image_links.last.get_attribute("href", timeout=timeout) | |
print("Image link is present, returning it.") | |
return last_image_link | |
print("Waiting for 'Vary (Strong)' button to appear and for image link to appear...") | |
# If the maximum wait time has been reached, raise an exception | |
if time.time() - start_time > max_wait: | |
raise TimeoutError( | |
"Waited for 30 seconds but 'Vary (Strong)' button did not appear and image link did not appear." | |
) | |
await asyncio.sleep(check_interval) # wait for 5 seconds | |
def update_db_record( | |
engine: Engine, s3_path: str, keyword_value: str, max_retries: int = 5, retry_wait: int = 2 | |
) -> None: | |
""" | |
Update a database record's blog_post_cover_image_url field with an S3 URL. | |
Args: | |
engine (Engine): SQLAlchemy Engine instance. | |
s3_path (str): S3 URL to be added to the blog_post_cover_image_url field. | |
keyword_value (str): Keyword value to identify the specific record to be updated. | |
max_retries (int): Maximum number of retries in case of failure. Default is 5. | |
retry_wait (int): Time, in seconds, to wait between retries. Default is 2 seconds. | |
Raises: | |
SQLAlchemyError: If any SQLAlchemy error occurs while updating the record. | |
""" | |
retries = 0 | |
while retries < max_retries: | |
try: | |
with engine.connect() as connection: | |
query = text( | |
"UPDATE keywords SET blog_post_cover_image_url = :s3_path WHERE slug = :keyword_value" | |
) | |
connection.execute(query, s3_path=s3_path, keyword_value=keyword_value) | |
break # break the loop if the operation is successful | |
except OperationalError: | |
retries += 1 | |
print(f"OperationalError occurred. Retry {retries} of {max_retries}.") | |
time.sleep(retry_wait) | |
else: # If we've exhausted all retries, re-raise the last exception | |
raise | |
def get_records_with_null_cover_image(engine: Engine) -> List[Dict[str, str]]: | |
""" | |
Retrieve records from the database where blog_post_cover_image_url is NULL. | |
Args: | |
engine (Engine): SQLAlchemy Engine instance. | |
Returns: | |
List[Dict[str, str]]: A list of dictionaries where each dictionary represents a record | |
with 'slug' and 'blog_post_cover_prompt' as keys. | |
Raises: | |
SQLAlchemyError: If any SQLAlchemy error occurs while retrieving the records. | |
""" | |
with engine.connect() as connection: | |
query = text( | |
"SELECT slug, blog_post_cover_prompt FROM keywords WHERE blog_post_cover_image_url IS NULL" | |
) | |
result = connection.execute(query) | |
records = [{"slug": row[0], "blog_post_cover_prompt": row[1]} for row in result] | |
return records | |
S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME") | |
S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID") | |
S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY") | |
S3_REGION_NAME = os.environ.get("S3_REGION_NAME") | |
DATABASE_URL = os.environ.get("DATABASE_URL") | |
DISCORD_SERVER_ID = "1124815914815201481" | |
DISCORD_CHANEL_ID = "1124815915297542217" | |
IMAGE_PATH = Path(__file__).parent / "temp_images" | |
async def main() -> None: | |
async with async_playwright() as playwright: | |
# playwright = await async_playwright().start() | |
engine = create_engine(DATABASE_URL) | |
browser = await playwright.chromium.launch(headless=False) | |
context = await browser.new_context() | |
page = await context.new_page() | |
records = get_records_with_null_cover_image(engine) | |
await login_to_discord( | |
page=page, | |
server_id=DISCORD_SERVER_ID, | |
channel_id=DISCORD_CHANEL_ID, | |
) | |
for record in records[181:]: | |
slug = record["slug"] | |
prompt = record["blog_post_cover_prompt"] | |
await post_prompt( | |
page=page, | |
prompt=prompt, | |
) | |
await upscale_image(page=page) | |
image_url = await get_image_url(page=page) | |
local_image_path = IMAGE_PATH / f"{slug}.png" | |
image_path = download_image(image_url=image_url, image_path=local_image_path) | |
s3_path = upload_to_s3( | |
image_path=image_path, | |
aws_access_key_id=S3_ACCESS_KEY_ID, | |
aws_secret_access_key=S3_SECRET_ACCESS_KEY, | |
bucket=S3_BUCKET_NAME, | |
region_name=S3_REGION_NAME, | |
s3_image_name=f"{slug}.png", | |
) | |
update_db_record( | |
engine=engine, | |
s3_path=s3_path, | |
keyword_value=slug, | |
) | |
await context.close() | |
await browser.close() | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
OK so I solved this problem by inspecting my own HTML elements and find a new locator value for optionPillValue. You need to change it to your own value, e.g. optionPillValue_a8sfx6