Last active
March 18, 2021 20:27
-
-
Save jamalex/66f9e16aa076e8ba73acd69bcbd9dd6b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import logging | |
import math | |
from contextlib import contextmanager | |
from django.core.management import call_command | |
from django.core.management.base import CommandError | |
from django.db import transaction | |
from morango.models import Filter | |
from morango.models import InstanceIDModel | |
from morango.models import ScopeDefinition | |
from morango.sync.controller import MorangoProfileController | |
from ..utils import create_superuser_and_provision_device | |
from ..utils import get_baseurl | |
from ..utils import get_client_and_server_certs | |
from ..utils import get_dataset_id | |
from kolibri.core.auth.constants.morango_sync import PROFILE_FACILITY_DATA | |
from kolibri.core.auth.constants.morango_sync import ScopeDefinitions | |
from kolibri.core.auth.constants.morango_sync import State | |
from kolibri.core.auth.management.utils import get_facility | |
from kolibri.core.auth.management.utils import run_once | |
from kolibri.core.auth.models import dataset_cache | |
from kolibri.core.logger.utils.data import bytes_for_humans | |
from kolibri.core.tasks.exceptions import UserCancelledError | |
from kolibri.core.tasks.management.commands.base import AsyncCommand | |
from kolibri.core.tasks.utils import db_task_write_lock | |
from kolibri.utils import conf | |
DATA_PORTAL_SYNCING_BASE_URL = conf.OPTIONS["Urls"]["DATA_PORTAL_SYNCING_BASE_URL"] | |
TRANSFER_MESSAGE = "{records_transferred}/{records_total}, {transfer_total}" | |
logger = logging.getLogger(__name__) | |
class Command(AsyncCommand): | |
help = "Allow the syncing of facility data with Kolibri Data Portal or another Kolibri device." | |
def add_arguments(self, parser): | |
parser.add_argument( | |
"--facility", action="store", type=str, help="ID of facility to sync" | |
) | |
parser.add_argument( | |
"--baseurl", type=str, default=DATA_PORTAL_SYNCING_BASE_URL, dest="baseurl" | |
) | |
parser.add_argument("--noninteractive", action="store_true") | |
parser.add_argument( | |
"--chunk-size", | |
type=int, | |
default=500, | |
help="Chunk size of records to send/retrieve per request", | |
) | |
parser.add_argument( | |
"--no-push", action="store_true", help="Do not push data to the server" | |
) | |
parser.add_argument( | |
"--no-pull", action="store_true", help="Do not pull data from the server" | |
) | |
parser.add_argument( | |
"--username", | |
type=str, | |
help="username of superuser on server we are syncing with", | |
) | |
parser.add_argument( | |
"--password", | |
type=str, | |
help="password of superuser on server we are syncing with", | |
) | |
parser.add_argument( | |
"--no-provision", | |
action="store_true", | |
help="do not create a facility and temporary superuser", | |
) | |
# parser.add_argument("--scope-id", type=str, default=FULL_FACILITY) | |
def handle_async(self, *args, **options): # noqa C901 | |
( | |
baseurl, | |
facility_id, | |
chunk_size, | |
username, | |
password, | |
no_push, | |
no_pull, | |
noninteractive, | |
no_provision, | |
) = ( | |
options["baseurl"], | |
options["facility"], | |
options["chunk_size"], | |
options["username"], | |
options["password"], | |
options["no_push"], | |
options["no_pull"], | |
options["noninteractive"], | |
options["no_provision"], | |
) | |
PORTAL_SYNC = baseurl == DATA_PORTAL_SYNCING_BASE_URL | |
# validate url that is passed in | |
if not PORTAL_SYNC: | |
baseurl = get_baseurl(baseurl) | |
# call this in case user directly syncs without migrating database | |
if not ScopeDefinition.objects.filter(): | |
call_command("loaddata", "scopedefinitions") | |
dataset_cache.clear() | |
# try to connect to server | |
controller = MorangoProfileController(PROFILE_FACILITY_DATA) | |
network_connection = controller.create_network_connection(baseurl) | |
# if instance_ids are equal, this means device is trying to sync with itself, which we don't allow | |
if ( | |
InstanceIDModel.get_or_create_current_instance()[0].id | |
== network_connection.server_info["instance_id"] | |
): | |
raise CommandError( | |
"Device can not sync with itself. Please recheck base URL and try again." | |
) | |
if PORTAL_SYNC: # do portal sync setup | |
facility = get_facility( | |
facility_id=facility_id, noninteractive=noninteractive | |
) | |
# check for the certs we own for the specific facility | |
client_cert = ( | |
facility.dataset.get_owned_certificates() | |
.filter(scope_definition_id=ScopeDefinitions.FULL_FACILITY) | |
.first() | |
) | |
if not client_cert: | |
raise CommandError( | |
"This device does not own a certificate for Facility: {}".format( | |
facility.name | |
) | |
) | |
# get primary partition | |
scope_params = json.loads(client_cert.scope_params) | |
dataset_id = scope_params["dataset_id"] | |
# check if the server already has a cert for this facility | |
server_certs = network_connection.get_remote_certificates( | |
dataset_id, scope_def_id=ScopeDefinitions.FULL_FACILITY | |
) | |
# if necessary, push a cert up to the server | |
server_cert = ( | |
server_certs[0] | |
if server_certs | |
else network_connection.push_signed_client_certificate_chain( | |
local_parent_cert=client_cert, | |
scope_definition_id=ScopeDefinitions.FULL_FACILITY, | |
scope_params=scope_params, | |
) | |
) | |
else: # do P2P setup | |
dataset_id = get_dataset_id( | |
baseurl, identifier=facility_id, noninteractive=noninteractive | |
) | |
client_cert, server_cert, username = get_client_and_server_certs( | |
username, | |
password, | |
dataset_id, | |
network_connection, | |
noninteractive=noninteractive, | |
) | |
logger.info("Syncing has been initiated (this may take a while)...") | |
sync_session_client = network_connection.create_sync_session( | |
client_cert, server_cert, chunk_size=chunk_size | |
) | |
try: | |
# pull from server | |
if not no_pull: | |
self._handle_pull(sync_session_client, noninteractive, dataset_id) | |
# and push our own data to server | |
if not no_push: | |
self._handle_push(sync_session_client, noninteractive, dataset_id) | |
if not no_provision: | |
with self._lock(): | |
create_superuser_and_provision_device( | |
username, dataset_id, noninteractive=noninteractive | |
) | |
except UserCancelledError: | |
if self.job: | |
self.job.extra_metadata.update(sync_state=State.CANCELLED) | |
self.job.save_meta() | |
logger.info("Syncing has been cancelled.") | |
return | |
network_connection.close() | |
if self.job: | |
self.job.extra_metadata.update(sync_state=State.COMPLETED) | |
self.job.save_meta() | |
logger.info("Syncing has been completed.") | |
@contextmanager | |
def _lock(self): | |
cancellable = False | |
# job can't be cancelled while locked | |
if self.job: | |
cancellable = self.job.cancellable | |
self.job.save_as_cancellable(cancellable=False) | |
while True: | |
try: | |
with transaction.atomic(): | |
InstanceIDModel.objects.all().select_for_update() | |
yield | |
break | |
except: | |
pass | |
if self.job: | |
self.job.save_as_cancellable(cancellable=cancellable) | |
def _raise_cancel(self, *args, **kwargs): | |
if self.is_cancelled() and (not self.job or self.job.cancellable): | |
raise UserCancelledError() | |
def _handle_pull(self, sync_session_client, noninteractive, dataset_id): | |
""" | |
:type sync_session_client: morango.sync.syncsession.SyncSessionClient | |
:type noninteractive: bool | |
:type dataset_id: str | |
""" | |
sync_client = sync_session_client.get_pull_client() | |
sync_client.signals.queuing.connect(self._raise_cancel) | |
sync_client.signals.transferring.connect(self._raise_cancel) | |
self._queueing_tracker_adapter( | |
sync_client.signals.queuing, | |
"Remotely preparing data", | |
State.REMOTE_QUEUING, | |
noninteractive, | |
) | |
self._transfer_tracker_adapter( | |
sync_client.signals.transferring, | |
"Receiving data ({})".format(TRANSFER_MESSAGE), | |
State.PULLING, | |
noninteractive, | |
) | |
self._queueing_tracker_adapter( | |
sync_client.signals.dequeuing, | |
"Locally integrating received data", | |
State.LOCAL_DEQUEUING, | |
noninteractive, | |
) | |
self._session_tracker_adapter( | |
sync_client.signals.session, | |
"Creating pull transfer session", | |
"Completed pull transfer session", | |
) | |
sync_client.initialize(Filter(dataset_id)) | |
sync_client.run() | |
with self._lock(): | |
sync_client.finalize() | |
def _handle_push(self, sync_session_client, noninteractive, dataset_id): | |
""" | |
:type sync_session_client: morango.sync.syncsession.SyncSessionClient | |
:type noninteractive: bool | |
:type dataset_id: str | |
""" | |
sync_client = sync_session_client.get_push_client() | |
sync_client.signals.transferring.connect(self._raise_cancel) | |
self._queueing_tracker_adapter( | |
sync_client.signals.queuing, | |
"Locally preparing data to send", | |
State.LOCAL_QUEUING, | |
noninteractive, | |
) | |
self._transfer_tracker_adapter( | |
sync_client.signals.transferring, | |
"Sending data ({})".format(TRANSFER_MESSAGE), | |
State.PUSHING, | |
noninteractive, | |
) | |
self._queueing_tracker_adapter( | |
sync_client.signals.dequeuing, | |
"Remotely integrating data", | |
State.REMOTE_DEQUEUING, | |
noninteractive, | |
) | |
self._session_tracker_adapter( | |
sync_client.signals.session, | |
"Creating push transfer session", | |
"Completed push transfer session", | |
) | |
with self._lock(): | |
sync_client.initialize(Filter(dataset_id)) | |
sync_client.run() | |
# we can't cancel remotely integrating data | |
if self.job: | |
self.job.save_as_cancellable(cancellable=False) | |
# allow server timeout since remotely integrating data can take a while and the request | |
# could timeout. In that case, we'll assume everything is good. | |
sync_client.finalize(allow_server_timeout=True) | |
def _update_all_progress(self, progress_fraction, progress): | |
""" | |
Override parent progress update callback to report from the progress tracker we're sent | |
""" | |
if self.job: | |
self.job.update_progress(progress_fraction, 1.0) | |
self.job.extra_metadata.update(progress.extra_data) | |
self.job.save_meta() | |
def _session_tracker_adapter(self, signal_group, started_msg, completed_msg): | |
""" | |
Attaches a signal handler to session creation signals | |
:type signal_group: morango.sync.syncsession.SyncSignalGroup | |
:type started_msg: str | |
:type completed_msg: str | |
""" | |
@run_once | |
def session_creation(transfer_session): | |
""" | |
A session is created individually for pushing and pulling | |
""" | |
logger.info(started_msg) | |
if self.job: | |
self.job.extra_metadata.update(sync_state=State.SESSION_CREATION) | |
@run_once | |
def session_destruction(transfer_session): | |
if transfer_session.records_total == 0: | |
logger.info("There are no records to transfer") | |
logger.info(completed_msg) | |
signal_group.started.connect(session_creation) | |
signal_group.completed.connect(session_destruction) | |
def _transfer_tracker_adapter( | |
self, signal_group, message, sync_state, noninteractive | |
): | |
""" | |
Attaches a signal handler to pushing/pulling signals | |
:type signal_group: morango.sync.syncsession.SyncSignalGroup | |
:type message: str | |
:type sync_state: str | |
:type noninteractive: bool | |
""" | |
tracker = self.start_progress(total=100) | |
def stats_msg(transfer_session): | |
transfer_total = ( | |
transfer_session.bytes_sent + transfer_session.bytes_received | |
) | |
return message.format( | |
records_transferred=transfer_session.records_transferred, | |
records_total=transfer_session.records_total, | |
transfer_total=bytes_for_humans(transfer_total), | |
) | |
def stats(transfer_session): | |
logger.info(stats_msg(transfer_session)) | |
def handler(transfer_session): | |
""" | |
:type transfer_session: morango.models.core.TransferSession | |
""" | |
progress = ( | |
100 | |
* transfer_session.records_transferred | |
/ float(transfer_session.records_total) | |
) | |
tracker.update_progress( | |
increment=math.ceil(progress - tracker.progress), | |
message=stats_msg(transfer_session), | |
extra_data=dict( | |
bytes_sent=transfer_session.bytes_sent, | |
bytes_received=transfer_session.bytes_received, | |
sync_state=sync_state, | |
), | |
) | |
if noninteractive or tracker.progressbar is None: | |
signal_group.started.connect(stats) | |
signal_group.in_progress.connect(stats) | |
signal_group.connect(handler) | |
# log one more time at end to capture in logging output | |
signal_group.completed.connect(stats) | |
def _queueing_tracker_adapter( | |
self, signal_group, message, sync_state, noninteractive | |
): | |
""" | |
Attaches a signal handler to queuing/dequeuing signals | |
:type signal_group: morango.sync.syncsession.SyncSignalGroup | |
:type message: str | |
:type sync_state: str | |
:type noninteractive: bool | |
""" | |
tracker = self.start_progress(total=2) | |
def started(transfer_session): | |
dataset_cache.clear() | |
if noninteractive or tracker.progressbar is None: | |
logger.info(message) | |
def handler(transfer_session): | |
tracker.update_progress( | |
message=message, extra_data=dict(sync_state=sync_state) | |
) | |
if noninteractive or tracker.progressbar is None: | |
signal_group.started.connect(started) | |
signal_group.started.connect(started) | |
signal_group.started.connect(handler) | |
signal_group.completed.connect(handler) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment