Created
September 23, 2024 14:20
-
-
Save josherich/47141e71df555b97e6315732858a09fd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import html | |
import requests | |
import json | |
import base64 | |
import hashlib | |
import google.generativeai as genai | |
from loguru import logger | |
from fasthtml.common import * | |
# ========== Config ============== | |
google_api_key = '' | |
screenshot_auth_key = '' # auth token for screenshot service | |
generate_auth_key = '' # auth token to regenerate a url | |
regenerate_auth_key = '' # auth token to generate a url | |
advanced_mode_key = '' # auth token to get advanced mode (custom prompt, model selector) | |
example_url = 'https://www.alex-hattori.com/blog/stride-mujoco-sim-no-arms' | |
# need a screenshot service at http://localhost:8001/screenshot that | |
# takes {"url": url, "thumbnail": True} and return {"status": "ok", "result": "base64", "thumbnail": "base64" } | |
# ================================ | |
global_style = Style(""" | |
body, html { | |
margin: 0; | |
padding: 0; | |
width: 100%; | |
height: 100%; | |
} | |
iframe { | |
width: 100%; | |
height: 100%; | |
border: none; | |
} | |
main { | |
padding: 1rem; | |
} | |
main input { | |
display: block; | |
width: 80%; | |
border: 1px solid #ccc; | |
margin: 0.5rem 0; | |
line-height: 2rem; | |
outline: none; | |
padding: 0 0.5rem; | |
} | |
main button { | |
padding: 0.5rem 1rem; | |
} | |
#remastered { | |
height: 100%; | |
overflow: hidden; | |
} | |
#gen-list { | |
margin-top: 1rem; | |
} | |
""") | |
gridlink = Link(rel="stylesheet", href="https://cdnjs.cloudflare.com/ajax/libs/flexboxgrid/6.3.1/flexboxgrid.min.css", type="text/css") | |
logger.remove() | |
logger.add("db.log", format="{extra[url]} {extra[title]} {extra[model]} {extra[prompt]} {message}", serialize=True, rotation="100 MB") | |
app = FastHTML(hdrs=(picolink, global_style, gridlink)) | |
rt = app.route | |
genai.configure(api_key=google_api_key) | |
def upload_to_gemini(path, mime_type=None): | |
"""Uploads the given file to Gemini. | |
See https://ai.google.dev/gemini-api/docs/prompting_with_media | |
""" | |
file = genai.upload_file(path, mime_type=mime_type) | |
return file | |
# Create the model | |
generation_config = { | |
"temperature": 1, | |
"top_p": 0.95, | |
"top_k": 64, | |
"max_output_tokens": 8192, | |
"response_mime_type": "text/plain", | |
} | |
default_prompt = ''' | |
give the html (including inline CSS in <style></style> tags and inline JavaScript code in <script></script> tags) source of the website in the screenshot: | |
- avoid using external CSS or JavaScript files. | |
- improve colors and layout and text readability using better font family. | |
- try to keep the DOM structure simple but capable of expressing the layout accurately. | |
- use the original source code at {url} for reference. | |
- use the reference URL to find the links of images, video and other external files, try to make their style and layout correct using inline CSS in <style> tags. | |
- summarize or rephrase the long text and only show the first 1000 words, and append "..." at the end. | |
- do not use the original text in the screenshot or url directly. | |
- specify image size in img tag's inline CSS style, also add 'border: 1px solid;' to show image borders. | |
- use src= in img tags to replace original images. | |
''' | |
default_model = "gemini-1.5-flash" | |
model_gemini_15_flash = genai.GenerativeModel( | |
model_name="gemini-1.5-flash", | |
generation_config=generation_config, | |
# safety_settings = Adjust safety settings | |
# See https://ai.google.dev/gemini-api/docs/safety-settings | |
) | |
model_gemini_15_pro = genai.GenerativeModel( | |
model_name="gemini-1.5-pro", | |
generation_config=generation_config, | |
# safety_settings = Adjust safety settings | |
# See https://ai.google.dev/gemini-api/docs/safety-settings | |
) | |
def generate_html_from_screenshot(url, path, key, title="Remaster", model_name=default_model, prompt=default_prompt): | |
if key != generate_auth_key and key != regenerate_auth_key: | |
return "Invalid key" | |
files = [ | |
upload_to_gemini(path, mime_type="image/png"), | |
] | |
model = model_gemini_15_pro | |
if model_name == "gemini-1.5-flash": | |
model = model_gemini_15_flash | |
elif model_name == "gemini-1.5-pro": | |
model = model_gemini_15_pro | |
chat_session = model.start_chat( | |
history=[] | |
) | |
chat_prompt = [ | |
files[0], | |
prompt.format(url=url), | |
] | |
response = chat_session.send_message(chat_prompt) | |
escaped_html = html.escape(response.text) | |
# remove the markdown code block wrapper ``` | |
escaped_html = escaped_html.replace('```html', '') | |
escaped_html = escaped_html.replace('```', '') | |
hashed_url = hashlib.md5(url.encode()).hexdigest() | |
with open(hashed_url + '.html', 'w') as f: | |
f.write(escaped_html) | |
logger.info(hashed_url, url=url, title=title, model=model_name, prompt=prompt) | |
return escaped_html | |
def read_remastered_logs(): | |
with open("db.log", "r") as f: | |
records = [] | |
for line in f: | |
record = json.loads(line) | |
records.append(record) | |
return records | |
# ============ Routes ============ | |
# Show the image (if available) and prompt for a generation | |
def generation_preview(g): | |
grid_cls = "box col-xs-12 col-sm-6 col-md-4 col-lg-3" | |
image_path = f"{g['hash']}_thumbnail.jpeg" | |
if os.path.exists(image_path): | |
return Div(Card( | |
Img(src=image_path, alt="Card image", cls="card-img-top"), | |
Div( | |
A('Original: ' + g['title'], href=g['url'], cls="card-link"), | |
Div(A('Remastered', href=f'/generate/{g["hash"]}', cls="card-link")), | |
cls="card-body"), | |
), id=f"gen-{g['hash']}", cls=grid_cls) | |
return Div(f"No image found") | |
def get_records_from_logs(): | |
logs = read_remastered_logs() | |
records = map(lambda log: { | |
'url': log['record']['extra']['url'], | |
'title': log['record']['extra']['title'], | |
'hash': log['record']['message'], | |
'model': log['record']['extra']['model'] if 'model' in log['record']['extra'] else '', | |
}, | |
logs[:10]) | |
records = { e['hash']: e for e in records }.values() | |
return records | |
@rt("/") | |
def get(req): | |
key = req.query_params.get('key') | |
records = get_records_from_logs() | |
gen_containers = [generation_preview(g) for g in records] | |
gen_list = Div(*reversed(gen_containers), id='gen-list', cls="row") | |
advanced_mode = Div('') | |
if key == advanced_mode_key: | |
advanced_mode = Div( | |
Textarea(default_prompt, id='prompt', name='prompt', placeholder='Prompt', rows=10), | |
Select( | |
Option('gemini-1.5-flash', value='gemini-1.5-flash'), | |
Option('gemini-1.5-pro', value='gemini-1.5-pro'), | |
name='model'), | |
) | |
return Title("Remaster"), Main( | |
H2(f"Remaster"), | |
Form(Input(type="text", id="url", value=example_url, placeholder="https://...", name="url"), | |
Input(type="text", id="key", placeholder="key to generate new remaster", name="key"), | |
advanced_mode, | |
Button("Submit"), | |
action="/generate", method="post"), | |
gen_list, | |
Div('', id="result"), | |
) | |
# For images, CSS, etc. | |
@rt("/{fname:path}.{ext:static}") | |
def static(fname:str, ext:str): | |
if ext == 'jpeg' or ext == 'css': | |
return FileResponse(f'{fname}.{ext}') | |
else: | |
return "You are naughty." | |
@rt("/generate/{hash}") | |
def get(hash:str): | |
hashed_url = hash | |
if os.path.exists(hashed_url + '.html'): | |
with open(hashed_url + '.html', 'r') as f: | |
result = f.read() | |
iframe_ele = NotStr(f'<iframe srcdoc="{result}"></iframe>') | |
return Div(iframe_ele, id="remastered") | |
@rt("/generate") | |
def post(url:str, key:str, prompt:str=default_prompt, model:str=default_model): | |
hashed_url = hashlib.md5(url.encode()).hexdigest() | |
if os.path.exists(hashed_url + '.html') and key != regenerate_auth_key: | |
with open(hashed_url + '.html', 'r') as f: | |
result = f.read() | |
iframe_ele = NotStr(f'<iframe srcdoc="{result}"></iframe>') | |
return Div(iframe_ele, id="remastered") | |
# get screenshot of url at POST localhost:8000/screenshot | |
screenshot = requests.post( | |
"http://localhost:8001/screenshot", | |
data={"url": url, "thumbnail": True}, | |
headers={'authorization': 'Bearer ' + screenshot_auth_key} | |
).json() | |
# save screenshot.result(base64 string) to screenshot.png | |
with open(f"{hashed_url}.png", "wb") as f: | |
f.write(base64.b64decode(screenshot['result'])) | |
with open(f"{hashed_url}_thumbnail.jpeg", "wb") as f: | |
f.write(base64.b64decode(screenshot['thumbnail'])) | |
result = generate_html_from_screenshot(url, f'{hashed_url}.png', key, screenshot['title'], model, prompt) | |
iframe_ele = NotStr(f'<iframe srcdoc="{result}"></iframe>') | |
return Div(iframe_ele, id="remastered") | |
serve() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment