Created
September 28, 2023 12:54
-
-
Save CodingFu/e38837301371d14a1712c9df35f404da to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from google.cloud import aiplatform | |
import os | |
import openai | |
import httpx | |
import tempfile | |
import shutil | |
import vertexai | |
import logging | |
from fastapi import FastAPI, Request, status | |
from fastapi.exceptions import RequestValidationError | |
from pydantic import BaseModel | |
from fastapi import FastAPI, File, HTTPException, UploadFile, Form | |
from fastapi.responses import JSONResponse, StreamingResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from vertexai.preview.vision_models import Image, ImageGenerationModel | |
# Grab values from the environment if available | |
from dotenv import load_dotenv | |
load_dotenv() | |
# Set the OpenAI API key and information accordingly | |
openai.api_key = os.getenv("OPENAI_API_KEY").rstrip() | |
openai.api_model = os.getenv("OPENAI_API_MODEL", "gpt-35-turbo") | |
# Set Google parameters | |
# google_scope = os.getenv("GOOGLE_SCOPE", "https://www.googleapis.com/auth/cloud-platform") | |
# google_project_id = os.getenv("GOOGLE_PROJECT_ID") | |
# google_region = os.getenv("GOOGLE_REGION", "us-central1") | |
# google_model = ImageGenerationModel.from_pretrained("imagegeneration@002") | |
# | |
# vertexai.init(project=google_project_id, location=google_region) | |
# Set the default prompt | |
DEFAULT_PROMPT = "Colgate toothpaste next to a toothbrush" | |
DEFAULT_PROVIDERS = ["google", "openai"] | |
# Create uploads folder if it doesn't exist | |
if not os.path.exists('uploads'): | |
os.makedirs('uploads') | |
class ImageRequest(BaseModel): | |
prompt: str = Form(DEFAULT_PROMPT) | |
n: int = Form(1) | |
size: str = Form("1024x1024") | |
class ImageVariationRequest(ImageRequest): | |
image_file: UploadFile = File(...) | |
def process_google_image(response): | |
# TODO: Fix: take only first image | |
response_image = response.images[0] | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp: | |
response_image.save(location=temp.name) | |
temp_path = temp.name # You can use this path later to serve the image | |
# Assume that temp_path is the path to the saved temporary image from the above example | |
try: | |
def iterfile(): # This generator will read the image in chunks | |
with open(temp_path, "rb") as f: | |
while True: | |
chunk = f.read(8192) # 8K chunks | |
if not chunk: | |
break | |
yield chunk | |
return StreamingResponse(iterfile(), media_type="image/png") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
finally: | |
# Once the image is streamed, you can safely delete the temporary file | |
shutil.rmtree(temp_path, ignore_errors=True) | |
# FastAPI based application | |
app = FastAPI( | |
title="Colgate GenAI Service", | |
description="TBD", | |
summary="TBD", | |
version="0.0.1", | |
terms_of_service="http://example.com/terms/", | |
license_info={ | |
"name": "Apache 2.0", | |
"url": "https://www.apache.org/licenses/LICENSE-2.0.html", | |
}, | |
) | |
origins = [ | |
"http://localhost.tiangolo.com", | |
"https://localhost.tiangolo.com", | |
"http://localhost", | |
"http://localhost:8080", | |
] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
@app.get("/") | |
def home(): | |
return "Hello, FastAPI!" | |
@app.post("/image/openai/edit") | |
def generative_image_openai_edit(prompt: str = Form(DEFAULT_PROMPT), | |
n: int = Form(1), | |
size: str = Form("1024x1024"), | |
image_file: UploadFile = File(...), | |
mask_file: UploadFile = File(...)): | |
image_filename = os.path.join('uploads', image_file.filename) | |
with open(image_filename, "wb") as buffer: | |
buffer.write(image_file.file.read()) | |
mask_filename = os.path.join('uploads', mask_file.filename) | |
with open(mask_filename, "wb") as buffer: | |
buffer.write(mask_file.file.read()) | |
response = openai.Image.create_edit(image=open(image_filename, "rb"), mask=open(mask_filename, "rb"), | |
prompt=prompt, n=n, size=size) | |
response_result = response['data'][0]['url'] | |
with httpx.Client() as client: | |
response = client.get(response_result) | |
response.stream = True | |
return StreamingResponse(response.iter_bytes(), media_type=response.headers["Content-Type"]) | |
@app.post("/image/openai/variation") | |
def generative_image_openai_variation(n: int = Form(1), | |
size: str = Form("1024x1024"), | |
image_file: UploadFile = File(...)): | |
image_filename = os.path.join('uploads', image_file.filename) | |
with open(image_filename, "wb") as buffer: | |
buffer.write(image_file.file.read()) | |
response = openai.Image.create_variation(image=open(image_filename, "rb"), | |
n=n, size=size) | |
response_result = response['data'][0]['url'] | |
with httpx.Client() as client: | |
response = client.get(response_result) | |
response.stream = True | |
return StreamingResponse(response.iter_bytes(), media_type=response.headers["Content-Type"]) | |
@app.exception_handler(RequestValidationError) | |
async def validation_exception_handler(request: Request, exc: RequestValidationError): | |
exc_str = f'{exc}'.replace('\n', ' ').replace(' ', ' ') | |
logging.error(f"{request}: {exc_str}") | |
content = {'status_code': 10422, 'message': exc_str, 'data': None} | |
return JSONResponse(content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) | |
@app.post("/image/google/edit") | |
def generative_image_google_edit(prompt: str = Form(DEFAULT_PROMPT), | |
n: int = Form(1), | |
image_file: UploadFile = File(...), | |
mask_file: UploadFile = File(...)): | |
image_filename = os.path.join('uploads', image_file.filename) | |
with open(image_filename, "wb") as buffer: | |
buffer.write(image_file.file.read()) | |
mask_filename = os.path.join('uploads', mask_file.filename) | |
with open(mask_filename, "wb") as buffer: | |
buffer.write(mask_file.file.read()) | |
response = google_model.edit_image( | |
base_image=Image.load_from_file(image_filename), | |
mask=Image.load_from_file(mask_filename), | |
prompt=prompt, | |
number_of_images=n | |
) | |
return process_google_image(response) | |
@app.post("/image/google/variation") | |
def generative_image_google_variation(prompt: str = Form(DEFAULT_PROMPT), | |
n: int = Form(1), | |
image_file: UploadFile = File(...)): | |
image_filename = os.path.join('uploads', image_file.filename) | |
with open(image_filename, "wb") as buffer: | |
buffer.write(image_file.file.read()) | |
response = google_model.edit_image( | |
base_image=Image.load_from_file(image_filename), | |
prompt=prompt, | |
number_of_images=n | |
) | |
return process_google_image(response) | |
### | |
# Generate an Image | |
### | |
def generative_image_openai(request: ImageRequest): | |
response = openai.Image.create(prompt=request.prompt, | |
n=request.n, | |
size=request.size) | |
response_result = response['data'][0]['url'] | |
with httpx.Client() as client: | |
response = client.get(response_result) | |
response.stream = True | |
return StreamingResponse(response.iter_bytes(), media_type=response.headers["Content-Type"]) | |
def generative_image_google(request: ImageRequest): | |
response = google_model.generate_images( | |
prompt=request.prompt, | |
number_of_images=request.n | |
) | |
return process_google_image(response) | |
@app.post("/image/{provider}") | |
def generative_image_generic(request: ImageRequest, | |
provider: str = "google"): | |
# Check that it is a valid provider | |
if provider not in DEFAULT_PROVIDERS: | |
return JSONResponse(content={"message": f"Supported service providers are currently {', '.join(DEFAULT_PROVIDERS)}"}, status_code=400) | |
generative_image_function = f"generative_image_{provider}" | |
return globals()[generative_image_function](request) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=6000) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment