Created
October 31, 2024 07:50
-
-
Save pleabargain/ca5ef9d71e7e41ffccbc8a77e3cba7cf to your computer and use it in GitHub Desktop.
streamit call replicate API and make a video supply your own API key
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import streamlit as st | |
import requests | |
from requests.adapters import HTTPAdapter | |
from urllib3.util.retry import Retry | |
import json | |
from time import sleep | |
import os | |
from datetime import datetime | |
import traceback | |
st.set_page_config( | |
page_title="AI Video Generator", | |
page_icon="🎬", | |
layout="centered" | |
) | |
def create_requests_session(): | |
"""Create a requests session with retry strategy""" | |
session = requests.Session() | |
# Define retry strategy | |
retries = Retry( | |
total=5, # number of total retries | |
backoff_factor=1, # wait 1, 2, 4, 8, 16 seconds between retries | |
status_forcelist=[408, 429, 500, 502, 503, 504], # retry on these status codes | |
allowed_methods=["HEAD", "GET", "POST"] # allowed methods for retry | |
) | |
# Add retry adapter to session | |
adapter = HTTPAdapter(max_retries=retries) | |
session.mount('http://', adapter) | |
session.mount('https://', adapter) | |
return session | |
def safe_request(method, url, session=None, max_retries=3, **kwargs): | |
"""Make a safe request with retries and error handling""" | |
if session is None: | |
session = create_requests_session() | |
for attempt in range(max_retries): | |
try: | |
response = session.request(method, url, timeout=30, **kwargs) | |
response.raise_for_status() | |
return response | |
except requests.exceptions.RequestException as e: | |
if attempt == max_retries - 1: # Last attempt | |
raise e | |
sleep_time = (attempt + 1) * 2 | |
st.warning(f"Request failed, retrying in {sleep_time} seconds... (Attempt {attempt + 1}/{max_retries})") | |
sleep(sleep_time) | |
return None | |
def download_video(url, file_path): | |
"""Download video from URL and save locally""" | |
try: | |
if isinstance(url, list): | |
st.info(f"Received {len(url)} video segments. Downloading the first segment...") | |
url = url[0] | |
st.info(f"Downloading video from: {url}") | |
session = create_requests_session() | |
headers = { | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' | |
} | |
response = safe_request('GET', url, session=session, headers=headers, stream=True) | |
if response is None: | |
return False | |
total_size = int(response.headers.get('content-length', 0)) | |
progress_text = st.empty() | |
progress_bar = st.progress(0) | |
with open(file_path, 'wb') as f: | |
downloaded_size = 0 | |
chunk_size = 1024 * 1024 # 1MB chunks | |
for chunk in response.iter_content(chunk_size=chunk_size): | |
if chunk: | |
f.write(chunk) | |
downloaded_size += len(chunk) | |
if total_size: | |
progress = (downloaded_size / total_size) | |
progress_bar.progress(progress) | |
progress_text.text(f"Downloading: {downloaded_size}/{total_size} bytes ({progress:.1%})") | |
progress_text.text("Download completed successfully!") | |
sleep(1) | |
progress_text.empty() | |
progress_bar.empty() | |
return True | |
except Exception as e: | |
st.error(f"Error downloading video: {str(e)}") | |
st.error(f"Detailed error: {traceback.format_exc()}") | |
return False | |
def generate_video(api_key, prompt, fps): | |
"""Generate video using Replicate API""" | |
st.info("Sending request to Replicate API...") | |
headers = { | |
"Authorization": f"Bearer {api_key}", | |
"Content-Type": "application/json", | |
} | |
payload = { | |
"version": "9f747673945c62801b13b84701c783929c0ee784e4748ec062204894dda1a351", | |
"input": { | |
"fps": fps, | |
"width": 1024, | |
"height": 576, | |
"prompt": prompt, | |
"guidance_scale": 17.5, | |
"negative_prompt": "very blue, dust, noisy, washed out, ugly, distorted, broken" | |
} | |
} | |
try: | |
session = create_requests_session() | |
response = safe_request('POST', "https://api.replicate.com/v1/predictions", | |
session=session, headers=headers, json=payload) | |
if response is None: | |
return None | |
return response.json() | |
except Exception as e: | |
st.error(f"Error making prediction request: {str(e)}") | |
st.error(f"Detailed error: {traceback.format_exc()}") | |
return None | |
def check_prediction_status(api_key, prediction_id): | |
"""Check the status of a prediction with retries""" | |
headers = { | |
"Authorization": f"Bearer {api_key}", | |
} | |
try: | |
session = create_requests_session() | |
response = safe_request('GET', f"https://api.replicate.com/v1/predictions/{prediction_id}", | |
session=session, headers=headers) | |
if response is None: | |
return {'status': 'error', 'error': 'Failed to check prediction status'} | |
return response.json() | |
except Exception as e: | |
st.error(f"Error checking prediction status: {str(e)}") | |
st.error(f"Detailed error: {traceback.format_exc()}") | |
return {'status': 'error', 'error': str(e)} | |
def main(): | |
st.title("🎬 AI Video Generator") | |
videos_dir = 'generated_videos' | |
try: | |
if not os.path.exists(videos_dir): | |
os.makedirs(videos_dir) | |
st.info(f"Created directory: {videos_dir}") | |
except Exception as e: | |
st.error(f"Error creating videos directory: {str(e)}") | |
return | |
with st.container(): | |
st.markdown(""" | |
Welcome to the AI Video Generator! This tool helps you create amazing AI-generated videos | |
using state-of-the-art technology. | |
""") | |
api_key = st.text_input("Enter your Replicate API Key", type="password") | |
# Create containers for video and download button | |
video_container = st.empty() | |
download_container = st.empty() | |
with st.form("video_generation_form"): | |
prompt = st.text_area( | |
"Describe the video you want to make today", | |
placeholder="Example: A majestic eagle soaring through snow-capped mountains at sunset", | |
help="Be descriptive and specific about what you want to see in your video" | |
) | |
col1, col2 = st.columns(2) | |
with col1: | |
fps = st.slider("Frames Per Second (FPS)", | |
min_value=1, | |
max_value=60, | |
value=24, | |
help="Higher FPS means smoother video but longer generation time") | |
submit_button = st.form_submit_button("Generate Video") | |
if submit_button: | |
if not api_key: | |
st.error("Please enter your Replicate API key") | |
elif not prompt: | |
st.error("Please enter a description for your video") | |
else: | |
with st.spinner("Initiating video generation..."): | |
result = generate_video(api_key, prompt, fps) | |
if result and 'id' in result: | |
prediction_id = result['id'] | |
st.info(f"Prediction ID: {prediction_id}") | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
max_attempts = 60 # Maximum number of status check attempts | |
attempt = 0 | |
while attempt < max_attempts: | |
status = check_prediction_status(api_key, prediction_id) | |
if not status or status.get('status') == 'error': | |
sleep(5) # Wait before retry | |
attempt += 1 | |
continue | |
if status.get('status') == 'succeeded': | |
progress_bar.progress(100) | |
status_text.success("Video generation complete!") | |
if 'output' in status and status['output']: | |
video_url = status['output'] | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
video_path = os.path.join(videos_dir, f'generated_video_{timestamp}.mp4') | |
if download_video(video_url, video_path): | |
try: | |
with open(video_path, 'rb') as video_file: | |
video_bytes = video_file.read() | |
with video_container: | |
st.success("Video ready for playback!") | |
st.video(video_bytes) | |
with download_container: | |
st.download_button( | |
label="💾 Download Video", | |
data=video_bytes, | |
file_name=f"generated_video_{timestamp}.mp4", | |
mime="video/mp4", | |
key=f"download_{timestamp}" | |
) | |
st.success(f"Video saved locally at: {video_path}") | |
except Exception as e: | |
st.error(f"Error processing video file: {str(e)}") | |
st.error(f"Detailed error: {traceback.format_exc()}") | |
else: | |
st.error("Failed to download the video. Please try again.") | |
break | |
elif status.get('status') == 'failed': | |
progress_bar.empty() | |
status_text.error(f"Video generation failed: {status.get('error', 'Unknown error')}") | |
break | |
else: | |
progress_value = { | |
'starting': 10, | |
'processing': 50, | |
'pushing': 80 | |
}.get(status.get('status'), 0) | |
progress_bar.progress(progress_value) | |
status_message = f"Status: {status.get('status', 'Unknown').title()}" | |
if 'logs' in status: | |
status_message += f"\nLogs: {status['logs']}" | |
status_text.info(status_message) | |
sleep(2) | |
attempt += 1 | |
if attempt >= max_attempts: | |
st.error("Timeout waiting for video generation. Please try again.") | |
st.divider() | |
with st.expander("Tips for better results"): | |
st.markdown(""" | |
- Be specific about the scene, including details about lighting, atmosphere, and movement | |
- Mention camera angles or movements if you have specific preferences | |
- Include artistic style references (e.g., cinematic, anime, photorealistic) | |
- Specify time of day or weather conditions for more accurate results | |
""") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment