Last active
September 18, 2023 05:43
-
-
Save rominirani/aaca4b4f8ebee01902ee5ee1260d8bac to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
import os | |
from google.cloud import logging | |
from typing import Mapping | |
import google.auth.transport.requests | |
import google.oauth2.id_token | |
import requests | |
import flask | |
from vertex_llm import predict_large_language_model | |
from utils import coerce_datetime_zulu | |
_FUNCTIONS_VERTEX_EVENT_LOGGER = "my-llm-usecase" | |
_PROJECT_ID = os.environ["PROJECT_ID"] | |
_OUTPUT_BUCKET = os.environ["OUTPUT_BUCKET"] | |
_LOCATION = os.environ["LOCATION"] | |
def default_marshaller(o: object) -> str: | |
if isinstance(o, (datetime.date, datetime.datetime)): | |
return o.isoformat() | |
return str(o) | |
def redirect_and_reply(previous_data): | |
endpoint = f'https://{_LOCATION}-{_PROJECT_ID}.cloudfunctions.net/{os.environ["K_SERVICE"]}' | |
logging_client = logging.Client() | |
logger = logging_client.logger(_FUNCTIONS_VERTEX_EVENT_LOGGER) | |
auth_req = google.auth.transport.requests.Request() | |
id_token = google.oauth2.id_token.fetch_id_token(auth_req, endpoint) | |
data = { | |
"name": previous_data["name"], | |
"id": previous_data["id"], | |
"bucket": previous_data["bucket"], | |
"timeCreated": previous_data["timeCreated"], | |
} | |
headers = {} | |
headers["Authorization"] = f"Bearer {id_token}" | |
logger.log(f"TRIGGERING JOB FLOW: {endpoint}") | |
try: | |
requests.post( | |
endpoint, | |
json=data, | |
timeout=1, | |
headers=headers, | |
) | |
except requests.exceptions.Timeout: | |
return flask.Response(status=200) | |
except Exception: | |
return flask.Response(status=500) | |
return flask.Response(status=200) | |
def entrypoint(request: object) -> Mapping[str, str]: | |
data = request.get_json() | |
if data.get("kind", None) == "storage#object": | |
# Entrypoint called by Pub-Sub (Eventarc) | |
return redirect_and_reply(data) | |
if "bucket" in data: | |
# Entrypoint called by REST (possibly by redirect_and_replay) | |
return cloud_event_entrypoint( | |
name=data["name"], | |
event_id=data["id"], | |
bucket=data["bucket"], | |
time_created=coerce_datetime_zulu(data["timeCreated"]), | |
) | |
return flask.Response(status=500) | |
def cloud_event_entrypoint(event_id, bucket, name, time_created): | |
orig_file_uri = f"gs://{bucket}/{name}" | |
logging_client = logging.Client() | |
logger = logging_client.logger(_FUNCTIONS_VERTEX_EVENT_LOGGER) | |
logger.log(f"cloud_event_id({event_id}): UPLOAD {orig_file_uri}", severity="INFO") | |
# extracted_text = call_some_function(bucket, name, output_bucket=_OUTPUT_BUCKET) | |
logger.log(f"cloud_event_id({event_id}): OCR gs://{bucket}/{name}", severity="INFO") | |
return {"response": "ok"} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment