diff --git a/netbox_branching/models/branches.py b/netbox_branching/models/branches.py index 882e6e6..363a452 100644 --- a/netbox_branching/models/branches.py +++ b/netbox_branching/models/branches.py @@ -14,6 +14,7 @@ from django.urls import reverse from django.utils import timezone from django.utils.translation import gettext_lazy as _ +from mptt.models import MPTTModel from core.models import ObjectChange as ObjectChange_ from netbox.config import get_config @@ -281,12 +282,18 @@ def sync(self, user, commit=True): try: with activate_branch(self): with transaction.atomic(using=self.connection_name): + models = set() + # Apply each change from the main schema for change in changes: + models.add(change.changed_object_type.model_class()) change.apply(using=self.connection_name, logger=logger) if not commit: raise AbortTransaction() + # Perform cleanup tasks + self._cleanup(models) + except Exception as e: if err_message := str(e): logger.error(err_message) @@ -345,8 +352,11 @@ def merge(self, user, commit=True): try: with transaction.atomic(): + models = set() + # Apply each change from the Branch for change in changes: + models.add(change.changed_object_type.model_class()) with event_tracking(request): request.id = change.request_id request.user = change.user @@ -354,6 +364,9 @@ def merge(self, user, commit=True): if not commit: raise AbortTransaction() + # Perform cleanup tasks + self._cleanup(models) + except Exception as e: if err_message := str(e): logger.error(err_message) @@ -417,8 +430,11 @@ def revert(self, user, commit=True): try: with transaction.atomic(): + models = set() + # Undo each change from the Branch for change in changes: + models.add(change.changed_object_type.model_class()) with event_tracking(request): request.id = change.request_id request.user = change.user @@ -426,6 +442,9 @@ def revert(self, user, commit=True): if not commit: raise AbortTransaction() + # Perform cleanup tasks + self._cleanup(models) + except Exception as e: if err_message := str(e): logger.error(err_message) @@ -455,6 +474,19 @@ def revert(self, user, commit=True): revert.alters_data = True + def _cleanup(self, models): + """ + Called after syncing, merging, or reverting a branch. + """ + logger = logging.getLogger('netbox_branching.branch') + + for model in models: + + # Recalculate MPTT as needed + if issubclass(model, MPTTModel): + logger.debug(f"Recalculating MPTT for model {model}") + model.objects.rebuild() + def provision(self, user): """ Create the schema & replicate main tables. diff --git a/netbox_branching/models/changes.py b/netbox_branching/models/changes.py index 72cbc9d..08d3c2c 100644 --- a/netbox_branching/models/changes.py +++ b/netbox_branching/models/changes.py @@ -5,13 +5,12 @@ from django.contrib.postgres.fields import ArrayField from django.db import DEFAULT_DB_ALIAS, models from django.utils.translation import gettext_lazy as _ -from mptt.models import MPTTModel from core.choices import ObjectChangeActionChoices from core.models import ObjectChange as ObjectChange_ +from netbox_branching.utilities import update_object from utilities.querysets import RestrictedQuerySet from utilities.serialization import deserialize_object -from netbox_branching.utilities import update_object __all__ = ( 'AppliedChange', @@ -56,10 +55,6 @@ def apply(self, using=DEFAULT_DB_ALIAS, logger=None): except model.DoesNotExist: logger.debug(f'{model._meta.verbose_name} ID {self.changed_object_id} already deleted; skipping') - # Rebuild the MPTT tree where applicable - if issubclass(model, MPTTModel): - model.objects.rebuild() - apply.alters_data = True def undo(self, using=DEFAULT_DB_ALIAS, logger=None): @@ -91,10 +86,6 @@ def undo(self, using=DEFAULT_DB_ALIAS, logger=None): instance.object.full_clean() instance.save(using=using) - # Rebuild the MPTT tree where applicable - if issubclass(model, MPTTModel): - model.objects.rebuild() - undo.alters_data = True