Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save Sandy4321/796b577b6858d35f1ebac88843178129 to your computer and use it in GitHub Desktop.
Save Sandy4321/796b577b6858d35f1ebac88843178129 to your computer and use it in GitHub Desktop.
import googleapiclient.discovery
import argparse
import pandas as pd
import json
import logging
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--csv-data-path',
dest='csv_data_path', required=True)
parser.add_argument('--model-name',
dest='model_name', required=True)
parser.add_argument('--project-id',
dest='project_id', required=True)
parser.add_argument('--version-name',
dest='version_name', required=True)
args = parser.parse_args()
df_raw = pd.read_csv(args.csv_data_path)
logger.info(df_raw.head())
instances = df_raw.head().to_dict(orient='lines')
service = googleapiclient.discovery.build('ml', 'v1', cache_discovery=False)
name = 'projects/{}/models/{}/versions/{}'.format(
args.project_id, args.model_name, args.version_name)
request_body = {'instances': instances}
response = service.projects().predict(
name=name,
body=request_body
).execute()
if 'error' in response:
raise RuntimeError(response['error'])
else:
logger.info(response['predictions'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment