|
import json |
|
import random |
|
import logging |
|
import re |
|
|
|
from azure.cosmosdb.table.tableservice import TableService |
|
from azure.storage.queue import QueueService, QueueMessageFormat |
|
|
|
class JobManager(object): |
|
def __init__(self, accound_name: str, account_key: str, job_group: str, job_id: str = None): |
|
self.account_key = account_key |
|
self.accound_name = accound_name |
|
self.queue_service = QueueService(account_name=account_name, account_key=account_key) |
|
self.queue_service.encode_function = QueueMessageFormat.text_base64encode |
|
self.table_service = TableService(account_name=account_name, account_key=account_key) |
|
self.job_group = job_group |
|
|
|
self.table_service.create_table('JobStatus') |
|
|
|
def decode_job_description(self, job_description: str = "{}"): |
|
json_jd = json.loads(job_description) |
|
return (json_jd['jid'], json_jd['workload']) |
|
|
|
def create_job(self, job_id: str = None): |
|
if job_id is None: |
|
job_id = str(random.getrandbits(128)) |
|
|
|
self.table_service.insert_entity('JobStatus', { |
|
'PartitionKey': self.job_group, |
|
'RowKey': job_id, |
|
'Status': "running" |
|
}) |
|
|
|
return NewJob(self.accound_name, self.account_key, self.job_group, job_id) |
|
|
|
def update_job(self, job_id: str): |
|
return ExistingJob(self.accound_name, self.account_key, self.job_group, job_id) |
|
|
|
def abort_old_jobs(self, reason: str = "Aborted running jobs"): |
|
for entity in self.table_service.query_entities('JobStatus', filter="PartitionKey eq '%s' and Status eq 'running'" % self.job_group): |
|
self.table_service.merge_entity('JobStatus', { |
|
'PartitionKey': entity.PartitionKey, |
|
'RowKey': entity.RowKey, |
|
'reason': reason, |
|
'Status': "aborted" |
|
}) |
|
|
|
class __Job(object): |
|
def __init__(self, accound_name: str, account_key: str, job_group: str, job_id: str): |
|
self.job_group = job_group |
|
self.job_id = job_id |
|
self.queue_service = QueueService(account_name=account_name, account_key=account_key) |
|
self.queue_service.encode_function = QueueMessageFormat.text_base64encode |
|
self.table_service = TableService(account_name=account_name, account_key=account_key) |
|
|
|
def __exit__(self, type, value, traceback): |
|
pass |
|
|
|
def entity(self): |
|
return self.table_service.get_entity('JobStatus', self.job_group, self.job_id) |
|
|
|
def __update_workload(self, workload: str): |
|
if ' ' in workload: |
|
logging.warn("Workload cannot contain spaces. Trimming...") |
|
workload = workload.replace(' ', '') |
|
|
|
if hasattr(self, 'next_stage'): |
|
self.queue_service.put_message(self.next_stage, json.dumps({ |
|
'workload': workload, |
|
'jid': self.job_id |
|
})) |
|
|
|
entity = self.entity() |
|
self.table_service.update_entity('JobStatus', self._Job__transition_entity(entity, workload)) |
|
|
|
def __transition_entity(self, entity, workload): |
|
if hasattr(self, 'previous_stage'): |
|
sanitized_previous_stage = re.sub(r'[^a-zA-Z0-9]', '', self.previous_stage) |
|
# remove the workload and afterwards clear out the spare whitespaces |
|
entity[sanitized_previous_stage] = " ".join(entity.get(sanitized_previous_stage, "").replace(workload, '').split()) |
|
if entity[sanitized_previous_stage] == '': |
|
del entity[sanitized_previous_stage] |
|
# add the workload to the workloads already present in 'next_stage'. |
|
# handles initialization of 'next_stage' |
|
if hasattr(self, 'next_stage'): |
|
sanitized_next_stage = re.sub(r'[^a-zA-Z0-9]', '', self.next_stage) |
|
entity[sanitized_next_stage] = " ".join(set([*entity.get(sanitized_next_stage, "").split(), workload])) |
|
|
|
return entity |
|
|
|
def __set_next_stage(self, stage): |
|
self.queue_service.create_queue(stage) |
|
self.next_stage = stage |
|
return self |
|
|
|
class NewJob(__Job): |
|
def __enter__(self): |
|
if not hasattr(self, 'next_stage'): |
|
raise Exception("You forgot to specify the initial stage. Specify it using .for_stage(my_stage)") |
|
return self |
|
|
|
def create_workload(self, workload: str): |
|
self._Job__update_workload(workload) |
|
|
|
def for_stage(self, stage): |
|
self._Job__set_next_stage(stage) |
|
return self |
|
|
|
class ExistingJob(__Job): |
|
def __enter__(self): |
|
if not hasattr(self, 'previous_stage'): |
|
raise Exception("You forgot to specify the previous stage. Specify it using .from_stage(my_stage)") |
|
return self |
|
|
|
def update_workload(self, workload: str): |
|
if not hasattr(self, 'next_stage'): |
|
raise Exception("You forgot to specify the next stage. Specify it using .to_stage(my_stage)") |
|
self._Job__update_workload(workload) |
|
|
|
def finish_workload(self, workload: str): |
|
self._Job__update_workload(workload) |
|
self.__update_status_if_done() |
|
|
|
def from_stage(self, stage): |
|
self.previous_stage = stage |
|
return self |
|
|
|
def to_stage(self, stage): |
|
self._Job__set_next_stage(stage) |
|
return self |
|
|
|
def __update_status_if_done(self): |
|
entity = self.entity() |
|
entity_attributes = list(entity.keys()) |
|
mandatory_table_columns = {'PartitionKey','RowKey','Timestamp','Status','reason','etag'} |
|
if len([attribute for attribute in entity_attributes if attribute not in mandatory_table_columns]) == 0: |
|
entity.Status = "Done" |
|
self.table_service.merge_entity('JobStatus', entity) |