Skip to content

Instantly share code, notes, and snippets.

@pleabargain
Created October 31, 2024 07:50
Show Gist options
  • Save pleabargain/ca5ef9d71e7e41ffccbc8a77e3cba7cf to your computer and use it in GitHub Desktop.
Save pleabargain/ca5ef9d71e7e41ffccbc8a77e3cba7cf to your computer and use it in GitHub Desktop.
streamit call replicate API and make a video supply your own API key
import streamlit as st
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import json
from time import sleep
import os
from datetime import datetime
import traceback
st.set_page_config(
page_title="AI Video Generator",
page_icon="🎬",
layout="centered"
)
def create_requests_session():
"""Create a requests session with retry strategy"""
session = requests.Session()
# Define retry strategy
retries = Retry(
total=5, # number of total retries
backoff_factor=1, # wait 1, 2, 4, 8, 16 seconds between retries
status_forcelist=[408, 429, 500, 502, 503, 504], # retry on these status codes
allowed_methods=["HEAD", "GET", "POST"] # allowed methods for retry
)
# Add retry adapter to session
adapter = HTTPAdapter(max_retries=retries)
session.mount('http://', adapter)
session.mount('https://', adapter)
return session
def safe_request(method, url, session=None, max_retries=3, **kwargs):
"""Make a safe request with retries and error handling"""
if session is None:
session = create_requests_session()
for attempt in range(max_retries):
try:
response = session.request(method, url, timeout=30, **kwargs)
response.raise_for_status()
return response
except requests.exceptions.RequestException as e:
if attempt == max_retries - 1: # Last attempt
raise e
sleep_time = (attempt + 1) * 2
st.warning(f"Request failed, retrying in {sleep_time} seconds... (Attempt {attempt + 1}/{max_retries})")
sleep(sleep_time)
return None
def download_video(url, file_path):
"""Download video from URL and save locally"""
try:
if isinstance(url, list):
st.info(f"Received {len(url)} video segments. Downloading the first segment...")
url = url[0]
st.info(f"Downloading video from: {url}")
session = create_requests_session()
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
}
response = safe_request('GET', url, session=session, headers=headers, stream=True)
if response is None:
return False
total_size = int(response.headers.get('content-length', 0))
progress_text = st.empty()
progress_bar = st.progress(0)
with open(file_path, 'wb') as f:
downloaded_size = 0
chunk_size = 1024 * 1024 # 1MB chunks
for chunk in response.iter_content(chunk_size=chunk_size):
if chunk:
f.write(chunk)
downloaded_size += len(chunk)
if total_size:
progress = (downloaded_size / total_size)
progress_bar.progress(progress)
progress_text.text(f"Downloading: {downloaded_size}/{total_size} bytes ({progress:.1%})")
progress_text.text("Download completed successfully!")
sleep(1)
progress_text.empty()
progress_bar.empty()
return True
except Exception as e:
st.error(f"Error downloading video: {str(e)}")
st.error(f"Detailed error: {traceback.format_exc()}")
return False
def generate_video(api_key, prompt, fps):
"""Generate video using Replicate API"""
st.info("Sending request to Replicate API...")
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
payload = {
"version": "9f747673945c62801b13b84701c783929c0ee784e4748ec062204894dda1a351",
"input": {
"fps": fps,
"width": 1024,
"height": 576,
"prompt": prompt,
"guidance_scale": 17.5,
"negative_prompt": "very blue, dust, noisy, washed out, ugly, distorted, broken"
}
}
try:
session = create_requests_session()
response = safe_request('POST', "https://api.replicate.com/v1/predictions",
session=session, headers=headers, json=payload)
if response is None:
return None
return response.json()
except Exception as e:
st.error(f"Error making prediction request: {str(e)}")
st.error(f"Detailed error: {traceback.format_exc()}")
return None
def check_prediction_status(api_key, prediction_id):
"""Check the status of a prediction with retries"""
headers = {
"Authorization": f"Bearer {api_key}",
}
try:
session = create_requests_session()
response = safe_request('GET', f"https://api.replicate.com/v1/predictions/{prediction_id}",
session=session, headers=headers)
if response is None:
return {'status': 'error', 'error': 'Failed to check prediction status'}
return response.json()
except Exception as e:
st.error(f"Error checking prediction status: {str(e)}")
st.error(f"Detailed error: {traceback.format_exc()}")
return {'status': 'error', 'error': str(e)}
def main():
st.title("🎬 AI Video Generator")
videos_dir = 'generated_videos'
try:
if not os.path.exists(videos_dir):
os.makedirs(videos_dir)
st.info(f"Created directory: {videos_dir}")
except Exception as e:
st.error(f"Error creating videos directory: {str(e)}")
return
with st.container():
st.markdown("""
Welcome to the AI Video Generator! This tool helps you create amazing AI-generated videos
using state-of-the-art technology.
""")
api_key = st.text_input("Enter your Replicate API Key", type="password")
# Create containers for video and download button
video_container = st.empty()
download_container = st.empty()
with st.form("video_generation_form"):
prompt = st.text_area(
"Describe the video you want to make today",
placeholder="Example: A majestic eagle soaring through snow-capped mountains at sunset",
help="Be descriptive and specific about what you want to see in your video"
)
col1, col2 = st.columns(2)
with col1:
fps = st.slider("Frames Per Second (FPS)",
min_value=1,
max_value=60,
value=24,
help="Higher FPS means smoother video but longer generation time")
submit_button = st.form_submit_button("Generate Video")
if submit_button:
if not api_key:
st.error("Please enter your Replicate API key")
elif not prompt:
st.error("Please enter a description for your video")
else:
with st.spinner("Initiating video generation..."):
result = generate_video(api_key, prompt, fps)
if result and 'id' in result:
prediction_id = result['id']
st.info(f"Prediction ID: {prediction_id}")
progress_bar = st.progress(0)
status_text = st.empty()
max_attempts = 60 # Maximum number of status check attempts
attempt = 0
while attempt < max_attempts:
status = check_prediction_status(api_key, prediction_id)
if not status or status.get('status') == 'error':
sleep(5) # Wait before retry
attempt += 1
continue
if status.get('status') == 'succeeded':
progress_bar.progress(100)
status_text.success("Video generation complete!")
if 'output' in status and status['output']:
video_url = status['output']
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
video_path = os.path.join(videos_dir, f'generated_video_{timestamp}.mp4')
if download_video(video_url, video_path):
try:
with open(video_path, 'rb') as video_file:
video_bytes = video_file.read()
with video_container:
st.success("Video ready for playback!")
st.video(video_bytes)
with download_container:
st.download_button(
label="💾 Download Video",
data=video_bytes,
file_name=f"generated_video_{timestamp}.mp4",
mime="video/mp4",
key=f"download_{timestamp}"
)
st.success(f"Video saved locally at: {video_path}")
except Exception as e:
st.error(f"Error processing video file: {str(e)}")
st.error(f"Detailed error: {traceback.format_exc()}")
else:
st.error("Failed to download the video. Please try again.")
break
elif status.get('status') == 'failed':
progress_bar.empty()
status_text.error(f"Video generation failed: {status.get('error', 'Unknown error')}")
break
else:
progress_value = {
'starting': 10,
'processing': 50,
'pushing': 80
}.get(status.get('status'), 0)
progress_bar.progress(progress_value)
status_message = f"Status: {status.get('status', 'Unknown').title()}"
if 'logs' in status:
status_message += f"\nLogs: {status['logs']}"
status_text.info(status_message)
sleep(2)
attempt += 1
if attempt >= max_attempts:
st.error("Timeout waiting for video generation. Please try again.")
st.divider()
with st.expander("Tips for better results"):
st.markdown("""
- Be specific about the scene, including details about lighting, atmosphere, and movement
- Mention camera angles or movements if you have specific preferences
- Include artistic style references (e.g., cinematic, anime, photorealistic)
- Specify time of day or weather conditions for more accurate results
""")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment