diff --git a/netbox/extras/context_managers.py b/netbox/extras/context_managers.py new file mode 100644 index 000000000..4a33f28ef --- /dev/null +++ b/netbox/extras/context_managers.py @@ -0,0 +1,32 @@ +from contextlib import contextmanager + +from django.db.models.signals import m2m_changed, pre_delete, post_save + +from extras.signals import _handle_changed_object, _handle_deleted_object +from utilities.utils import curry + + +@contextmanager +def change_logging(request): + """ + Enable change logging by connecting the appropriate signals to their receivers before code is run, and + disconnecting them afterward. + + :param request: WSGIRequest object with a unique `id` set + """ + # Curry signals receivers to pass the current request + handle_changed_object = curry(_handle_changed_object, request) + handle_deleted_object = curry(_handle_deleted_object, request) + + # Connect our receivers to the post_save and post_delete signals. + post_save.connect(handle_changed_object, dispatch_uid='handle_changed_object') + m2m_changed.connect(handle_changed_object, dispatch_uid='handle_changed_object') + pre_delete.connect(handle_deleted_object, dispatch_uid='handle_deleted_object') + + yield + + # Disconnect change logging signals. This is necessary to avoid recording any errant + # changes during test cleanup. + post_save.disconnect(handle_changed_object, dispatch_uid='handle_changed_object') + m2m_changed.disconnect(handle_changed_object, dispatch_uid='handle_changed_object') + pre_delete.disconnect(handle_deleted_object, dispatch_uid='handle_deleted_object') diff --git a/netbox/extras/middleware.py b/netbox/extras/middleware.py index 71200faf3..f7be829cd 100644 --- a/netbox/extras/middleware.py +++ b/netbox/extras/middleware.py @@ -1,9 +1,6 @@ import uuid -from django.db.models.signals import m2m_changed, pre_delete, post_save - -from utilities.utils import curry -from .signals import _handle_changed_object, _handle_deleted_object +from .context_managers import change_logging class ObjectChangeMiddleware(object): @@ -24,27 +21,12 @@ class ObjectChangeMiddleware(object): self.get_response = get_response def __call__(self, request): - # Assign a random unique ID to the request. This will be used to associate multiple object changes made during # the same request. request.id = uuid.uuid4() - # Curry signals receivers to pass the current request - handle_changed_object = curry(_handle_changed_object, request) - handle_deleted_object = curry(_handle_deleted_object, request) - - # Connect our receivers to the post_save and post_delete signals. - post_save.connect(handle_changed_object, dispatch_uid='handle_changed_object') - m2m_changed.connect(handle_changed_object, dispatch_uid='handle_changed_object') - pre_delete.connect(handle_deleted_object, dispatch_uid='handle_deleted_object') - - # Process the request - response = self.get_response(request) - - # Disconnect change logging signals. This is necessary to avoid recording any errant - # changes during test cleanup. - post_save.disconnect(handle_changed_object, dispatch_uid='handle_changed_object') - m2m_changed.disconnect(handle_changed_object, dispatch_uid='handle_changed_object') - pre_delete.disconnect(handle_deleted_object, dispatch_uid='handle_deleted_object') + # Process the request with change logging enabled + with change_logging(request): + response = self.get_response(request) return response diff --git a/netbox/extras/scripts.py b/netbox/extras/scripts.py index 8ef742939..1d7229089 100644 --- a/netbox/extras/scripts.py +++ b/netbox/extras/scripts.py @@ -12,19 +12,17 @@ from django import forms from django.conf import settings from django.core.validators import RegexValidator from django.db import transaction -from django.db.models.signals import m2m_changed, pre_delete, post_save from django.utils.functional import classproperty from django_rq import job from extras.api.serializers import ScriptOutputSerializer from extras.choices import JobResultStatusChoices, LogLevelChoices from extras.models import JobResult -from extras.signals import _handle_changed_object, _handle_deleted_object from ipam.formfields import IPAddressFormField, IPNetworkFormField from ipam.validators import MaxPrefixLengthValidator, MinPrefixLengthValidator, prefix_validator from utilities.exceptions import AbortTransaction from utilities.forms import DynamicModelChoiceField, DynamicModelMultipleChoiceField -from utilities.utils import curry +from .context_managers import change_logging from .forms import ScriptForm __all__ = [ @@ -443,51 +441,39 @@ def run_script(data, request, commit=True, *args, **kwargs): f"with NetBox v2.10." ) - # Curry changelog signal receivers to pass the current request - handle_changed_object = curry(_handle_changed_object, request) - handle_deleted_object = curry(_handle_deleted_object, request) + with change_logging(request): - # Connect object modification signals to their respective receivers - post_save.connect(handle_changed_object) - m2m_changed.connect(handle_changed_object) - pre_delete.connect(handle_deleted_object) + try: + with transaction.atomic(): + script.output = script.run(**kwargs) - try: - with transaction.atomic(): - script.output = script.run(**kwargs) + if not commit: + raise AbortTransaction() + + except AbortTransaction: + pass + + except Exception as e: + stacktrace = traceback.format_exc() + script.log_failure( + "An exception occurred: `{}: {}`\n```\n{}\n```".format(type(e).__name__, e, stacktrace) + ) + logger.error(f"Exception raised during script execution: {e}") + commit = False + job_result.set_status(JobResultStatusChoices.STATUS_ERRORED) + + finally: + if job_result.status != JobResultStatusChoices.STATUS_ERRORED: + job_result.data = ScriptOutputSerializer(script).data + job_result.set_status(JobResultStatusChoices.STATUS_COMPLETED) if not commit: - raise AbortTransaction() + # Delete all pending changelog entries + script.log_info( + "Database changes have been reverted automatically." + ) - except AbortTransaction: - pass - - except Exception as e: - stacktrace = traceback.format_exc() - script.log_failure( - "An exception occurred: `{}: {}`\n```\n{}\n```".format(type(e).__name__, e, stacktrace) - ) - logger.error(f"Exception raised during script execution: {e}") - commit = False - job_result.set_status(JobResultStatusChoices.STATUS_ERRORED) - - finally: - if job_result.status != JobResultStatusChoices.STATUS_ERRORED: - job_result.data = ScriptOutputSerializer(script).data - job_result.set_status(JobResultStatusChoices.STATUS_COMPLETED) - - if not commit: - # Delete all pending changelog entries - script.log_info( - "Database changes have been reverted automatically." - ) - - logger.info(f"Script completed in {job_result.duration}") - - # Disconnect signals - post_save.disconnect(handle_changed_object) - m2m_changed.disconnect(handle_changed_object) - pre_delete.disconnect(handle_deleted_object) + logger.info(f"Script completed in {job_result.duration}") # Delete any previous terminal state results JobResult.objects.filter(