Last active
December 25, 2021 21:02
-
-
Save gallir/6cabc4580865da71ef4c0725e8fb1b13 to your computer and use it in GitHub Desktop.
A basic pipeline for AWS Forecast
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from botocore.hooks import _PrefixTrie | |
import util | |
import boto3 | |
from .s3utils import upload_csv | |
import time | |
import pprint | |
class Pipeline: | |
def __init__(self, name, target, freq, horizon, s3_bucket, related=None, domain="RETAIL", | |
s3_prefix="forecast", role_name="ForecastExecution", aws_region="eu-west-1", | |
algorithm="arn:aws:forecast:::algorithm/CNN-QR", autoML=False): | |
self.name = name | |
self.target = target | |
self.freq = freq | |
self.horizon = horizon | |
self.s3_bucket = s3_bucket | |
self.related = related | |
self.domain = domain | |
self.s3_prefix = s3_prefix | |
self.role_name = role_name | |
self.aws_region = aws_region | |
self.algorithm = algorithm | |
self.autoML = autoML | |
self.key_target = f"{s3_prefix}/{name}/target/{target.name}.csv" | |
if related: | |
self.key_related = f"{s3_prefix}/{name}/related/{related.name}.csv" | |
else: | |
self.key_related = None | |
self.session = boto3.Session(region_name=self.aws_region) | |
self.forecast = self.session.client(service_name='forecast') | |
self.s3 = self.session.client('s3') | |
self.role_arn = None | |
self.ds_arns = [] | |
self.dg_arn = None | |
self.ds_import_arns = [] | |
self.predictor_arn = None | |
self.forecast_arn = None | |
self.export_arn = None | |
def run(self, create=True, dont_upload=False, delete_after=False, delete_all=False): | |
if delete_all: | |
create = False | |
self.set_role(create=create) | |
if not dont_upload and not delete_all: | |
self.upload_data() | |
self.set_datasets(create=create) | |
self.set_dataset_group(create=create) | |
self.do_import(create=create) | |
self.set_predictor(create=create) | |
self.do_forecast(create=create) | |
if delete_after or delete_all: | |
self.delete_all() | |
def set_role(self, create=True): | |
if not create: | |
return | |
self.role_arn = util.get_or_create_iam_role(role_name=self.role_name) | |
def upload_data(self): | |
print("Uploading data to", self.s3_bucket) | |
upload_csv(self.target.get_df(), self.s3_bucket, self.key_target, s3_client=self.s3) | |
if self.related: | |
upload_csv(self.related.get_df(), self.s3_bucket, self.key_related, s3_client=self.s3) | |
def set_datasets(self, create=True): | |
# Create Datasets | |
ds_target = f"{self.name}_target_ds" | |
print("Processing target dataset", ds_target) | |
ds_dict = dict() | |
resp = self.forecast.list_datasets() | |
for ds in resp['Datasets']: | |
ds_dict[ds['DatasetName']] = ds['DatasetArn'] | |
if ds_target in ds_dict: | |
self.ds_arns.append(ds_dict[ds_target]) | |
elif create: | |
response = self.forecast.create_dataset( | |
Domain=self.domain, | |
DatasetType='TARGET_TIME_SERIES', | |
DatasetName=ds_target, | |
DataFrequency=self.freq, | |
Schema=self.target.get_schema() | |
) | |
self.ds_arns.append(response['DatasetArn']) | |
print("Created target dataset", self.ds_arns[0]) | |
if not self.related: | |
return | |
ds_related = f"{self.name}_related_ds" | |
print("Processing related dataset", ds_related) | |
if ds_related in ds_dict: | |
self.ds_arns.append(ds_dict[ds_related]) | |
elif create: | |
response = self.forecast.create_dataset( | |
Domain=self.domain, | |
DatasetType='RELATED_TIME_SERIES', | |
DatasetName=ds_related, | |
DataFrequency=self.freq, | |
Schema=self.related.get_schema() | |
) | |
self.ds_arns.append(response['DatasetArn']) | |
print("Created related dataset", self.ds_arns[1]) | |
def set_dataset_group(self, create=True): | |
# Create Dataset group | |
dg_name = f"{self.name}_dg" | |
print("Processing Dataset Group", dg_name) | |
dg_dict = dict() | |
resp = self.forecast.list_dataset_groups() | |
for ds in resp['DatasetGroups']: | |
dg_dict[ds['DatasetGroupName']] = ds['DatasetGroupArn'] | |
if dg_name in dg_dict: | |
self.dg_arn = dg_dict[dg_name] | |
elif create: | |
response = self.forecast.create_dataset_group(DatasetGroupName=dg_name, | |
Domain=self.domain, | |
DatasetArns=self.ds_arns | |
) | |
self.dg_arn = response['DatasetGroupArn'] | |
print("Created Dataset Group", self.dg_arn) # forecast.describe_dataset_group(DatasetGroupArn=dg_arn)) | |
def do_import(self, create=True): | |
# Import JOBS | |
ds_target_import_name = f"target_{self.name}" | |
print("Importing target data", ds_target_import_name) | |
job_arn = '' | |
jobs = self.forecast.list_dataset_import_jobs()['DatasetImportJobs'] | |
# Target import | |
for j in jobs: | |
if j['DatasetImportJobName'] == ds_target_import_name: | |
job_arn = j['DatasetImportJobArn'] | |
print(f"Skipping existing target import: {job_arn}") | |
break | |
if not job_arn and create: | |
resp = self.forecast.create_dataset_import_job( | |
DatasetImportJobName=ds_target_import_name, | |
DatasetArn=self.ds_arns[0], | |
DataSource={ | |
"S3Config": { | |
"Path": f"s3://{self.s3_bucket}/{self.key_target}", | |
"RoleArn": self.role_arn | |
} | |
}, | |
TimestampFormat=self.target.timestamp_format | |
) | |
job_arn = resp['DatasetImportJobArn'] | |
print(f"Created target import: {job_arn}") | |
if job_arn: | |
self.ds_import_arns.append(job_arn) | |
# Related import | |
job_arn = '' | |
ds_related_import_name = f"related_{self.name}" | |
print("Importing related data", ds_related_import_name) | |
for j in jobs: | |
if j['DatasetImportJobName'] == ds_related_import_name: | |
job_arn = j['DatasetImportJobArn'] | |
print(f"Skipping existing related import: {job_arn}") | |
break | |
if self.related and not job_arn and create: | |
resp = self.forecast.create_dataset_import_job( | |
DatasetImportJobName=ds_related_import_name, | |
DatasetArn=self.ds_arns[1], | |
DataSource={ | |
"S3Config": { | |
"Path": f"s3://{self.s3_bucket}/{self.key_related}", | |
"RoleArn": self.role_arn | |
} | |
}, | |
TimestampFormat=self.related.timestamp_format | |
) | |
job_arn = resp['DatasetImportJobArn'] | |
print(f"Created related import: {job_arn}") | |
if job_arn: | |
self.ds_import_arns.append(job_arn) | |
statuses = ['', ''] | |
wait_start = time.time() | |
while create: | |
for i, arn in enumerate(self.ds_import_arns): | |
statuses[i] = self.forecast.describe_dataset_import_job(DatasetImportJobArn=arn)['Status'] | |
if statuses[0] in ('ACTIVE', 'CREATE_FAILED') and statuses[1] in ('ACTIVE', 'CREATE_FAILED'): | |
break | |
print(f"Waiting for import jobs, {int(time.time()-wait_start)} secs: {statuses}") | |
time.sleep(10) | |
def set_predictor(self, create=True): | |
# Start predictor | |
predictor_name = self.name + '_predictor' | |
print("Creating predictor", predictor_name) | |
predictors = self.forecast.list_predictors()['Predictors'] | |
for e in predictors: | |
if e['PredictorName'] == predictor_name: | |
self.predictor_arn = e['PredictorArn'] | |
print(f"Skipping existing predictor creation: {self.predictor_arn}") | |
break | |
if not self.predictor_arn and create: | |
resp = self.forecast.create_predictor( | |
PredictorName=predictor_name, | |
AlgorithmArn=self.algorithm, | |
ForecastHorizon=self.horizon, | |
PerformAutoML=self.autoML, | |
PerformHPO=False, | |
EvaluationParameters={ | |
"NumberOfBacktestWindows": 1, | |
"BackTestWindowOffset": self.horizon | |
}, | |
InputDataConfig={"DatasetGroupArn": self.dg_arn}, | |
FeaturizationConfig={ | |
"ForecastFrequency": self.freq, | |
# "Featurizations": [target.FEATURES, covid.FEATURES] | |
}, | |
TrainingParameters={'use_related_data': 'ALL'} | |
) | |
self.predictor_arn = resp['PredictorArn'] | |
wait_start = time.time() | |
pred_info = {} | |
while create: | |
pred_info = self.forecast.describe_predictor(PredictorArn=self.predictor_arn) | |
if pred_info['Status'] in ('ACTIVE', 'CREATE_FAILED'): | |
if 'AlgorithmArn' in pred_info: | |
alg = pred_info['AlgorithmArn'] | |
elif 'AutoMLAlgorithmArns' in pred_info: | |
alg = pred_info['AutoMLAlgorithmArns'] | |
else: | |
alg = "unknown" | |
print(f"Ready, algorithm: {alg}, forecast types: {pred_info['ForecastTypes']}") | |
break | |
print(f"Waiting for predictor, {int(time.time()-wait_start)} secs: {pred_info['Status']}") | |
time.sleep(10) | |
if pred_info: | |
pprint.pprint(pred_info) | |
def do_forecast(self, create=True): | |
# Forecast | |
forecast_name = self.name + '_forecast' | |
print("Executing forecast", forecast_name) | |
forecasts = self.forecast.list_forecasts()['Forecasts'] | |
for e in forecasts: | |
if e['ForecastName'] == forecast_name: | |
self.forecast_arn = e['ForecastArn'] | |
print(f"Skipping existing forecast creation: {self.forecast_arn}") | |
break | |
if not self.forecast_arn and create: | |
resp = self.forecast.create_forecast(ForecastName=forecast_name, | |
PredictorArn=self.predictor_arn) | |
self.forecast_arn = resp['ForecastArn'] | |
wait_start = time.time() | |
while create: | |
status = self.forecast.describe_forecast(ForecastArn=self.forecast_arn)['Status'] | |
if status in ('ACTIVE', 'CREATE_FAILED'): | |
break | |
print(f"Waiting for forecast, {int(time.time()-wait_start)} secs: {status}") | |
time.sleep(10) | |
def export(self, create=True): | |
# Export Job | |
export_name = self.name + '_export' | |
print("Processing export", export_name) | |
export_prefix = f"{self.s3_bucket}/{self.name}/output" | |
export_path = f"s3://{self.s3_bucket}/{export_prefix}/" | |
exports = self.forecast.list_forecast_export_jobs()['ForecastExportJobs'] | |
for e in exports: | |
if e['ForecastExportJobName'] == export_name: | |
self.export_arn = e['ForecastExportJobArn'] | |
print(f"Skipping existing export creation: {self.export_arn}") | |
break | |
if not self.export_arn and create: | |
resp = self.forecast.create_forecast_export_job( | |
ForecastExportJobName=export_name, | |
ForecastArn=self.forecast_arn, | |
Destination={ | |
"S3Config": { | |
"Path": export_path, | |
"RoleArn": self.role_arn | |
} | |
} | |
) | |
export_arn = resp['ForecastExportJobArn'] | |
wait_start = time.time() | |
while create: | |
status = self.forecast.describe_forecast_export_job(ForecastExportJobArn=export_arn)['Status'] | |
if status in ('ACTIVE', 'CREATE_FAILED'): | |
break | |
print(f"Waiting for export, {int(time.time()-wait_start)} secs: {status}") | |
time.sleep(10) | |
def delete_all(self): | |
# Delete forecast export for both algorithms | |
if self.export_arn: | |
print("Deleting", self.export_arn) | |
util.wait_till_delete(lambda: self.forecast.delete_forecast_export_job( | |
ForecastExportJobArn=self.export_arn)) | |
# Delete forecast | |
if self.forecast_arn: | |
print("Deleting", self.forecast_arn) | |
util.wait_till_delete(lambda: self.forecast.delete_forecast(ForecastArn=self.forecast_arn)) | |
# Delete predictor | |
if self.predictor_arn: | |
print("Deleting", self.predictor_arn) | |
util.wait_till_delete(lambda: self.forecast.delete_predictor(PredictorArn=self.predictor_arn)) | |
# Delete Import | |
if self.ds_import_arns: | |
for arn in self.ds_import_arns: | |
print("Deleting", arn) | |
util.wait_till_delete(lambda: self.forecast.delete_dataset_import_job(DatasetImportJobArn=arn)) | |
# Delete the datasets | |
if self.ds_arns: | |
for arn in self.ds_arns: | |
print("Deleting", arn) | |
util.wait_till_delete(lambda: self.forecast.delete_dataset(DatasetArn=arn)) | |
# Delete Dataset Group | |
if self.dg_arn: | |
print("Deleting", self.dg_arn) | |
util.wait_till_delete(lambda: self.forecast.delete_dataset_group(DatasetGroupArn=self.dg_arn)) | |
# Delete IAM role | |
# util.delete_iam_role(role_name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment