1
0
mirror of https://github.com/netbox-community/netbox.git synced 2024-05-10 07:54:54 +00:00

6347 Cache the number of each component type assigned to devices/VMs (#12632)

---------

Co-authored-by: Jeremy Stretch <jstretch@netboxlabs.com>
This commit is contained in:
Arthur Hanson
2023-07-25 20:39:05 +07:00
committed by GitHub
parent a4acb50edd
commit 149a496011
23 changed files with 623 additions and 35 deletions

View File

@@ -0,0 +1,93 @@
from django.apps import apps
from django.db.models import F
from django.db.models.signals import post_delete, post_save
from netbox.registry import registry
from .fields import CounterCacheField
def get_counters_for_model(model):
"""
Return field mappings for all counters registered to the given model.
"""
return registry['counter_fields'][model].items()
def update_counter(model, pk, counter_name, value):
"""
Increment or decrement a counter field on an object identified by its model and primary key (PK). Positive values
will increment; negative values will decrement.
"""
model.objects.filter(pk=pk).update(
**{counter_name: F(counter_name) + value}
)
#
# Signal handlers
#
def post_save_receiver(sender, instance, **kwargs):
"""
Update counter fields on related objects when a TrackingModelMixin subclass is created or modified.
"""
for field_name, counter_name in get_counters_for_model(sender):
parent_model = sender._meta.get_field(field_name).related_model
new_pk = getattr(instance, field_name, None)
old_pk = instance.tracker.get(field_name) if field_name in instance.tracker else None
# Update the counters on the old and/or new parents as needed
if old_pk is not None:
update_counter(parent_model, old_pk, counter_name, -1)
if new_pk is not None:
update_counter(parent_model, new_pk, counter_name, 1)
def post_delete_receiver(sender, instance, **kwargs):
"""
Update counter fields on related objects when a TrackingModelMixin subclass is deleted.
"""
for field_name, counter_name in get_counters_for_model(sender):
parent_model = sender._meta.get_field(field_name).related_model
parent_pk = getattr(instance, field_name, None)
# Decrement the parent's counter by one
if parent_pk is not None:
update_counter(parent_model, parent_pk, counter_name, -1)
#
# Registration
#
def connect_counters(*models):
"""
Register counter fields and connect post_save & post_delete signal handlers for the affected models.
"""
for model in models:
# Find all CounterCacheFields on the model
counter_fields = [
field for field in model._meta.get_fields() if type(field) is CounterCacheField
]
for field in counter_fields:
to_model = apps.get_model(field.to_model_name)
# Register the counter in the registry
change_tracking_fields = registry['counter_fields'][to_model]
change_tracking_fields[f"{field.to_field_name}_id"] = field.name
# Connect the post_save and post_delete handlers
post_save.connect(
post_save_receiver,
sender=to_model,
weak=False,
dispatch_uid=f'{model._meta.label}.{field.name}'
)
post_delete.connect(
post_delete_receiver,
sender=to_model,
weak=False,
dispatch_uid=f'{model._meta.label}.{field.name}'
)

View File

@@ -2,6 +2,7 @@ from collections import defaultdict
from django.contrib.contenttypes.fields import GenericForeignKey
from django.db import models
from django.utils.translation import gettext_lazy as _
from utilities.ordering import naturalize
from .forms.widgets import ColorSelect
@@ -9,6 +10,7 @@ from .validators import ColorValidator
__all__ = (
'ColorField',
'CounterCacheField',
'NaturalOrderingField',
'NullableCharField',
'RestrictedGenericForeignKey',
@@ -143,3 +145,43 @@ class RestrictedGenericForeignKey(GenericForeignKey):
self.name,
False,
)
class CounterCacheField(models.BigIntegerField):
"""
Counter field to keep track of related model counts.
"""
def __init__(self, to_model, to_field, *args, **kwargs):
if not isinstance(to_model, str):
raise TypeError(
_("%s(%r) is invalid. to_model parameter to CounterCacheField must be "
"a string in the format 'app.model'")
% (
self.__class__.__name__,
to_model,
)
)
if not isinstance(to_field, str):
raise TypeError(
_("%s(%r) is invalid. to_field parameter to CounterCacheField must be "
"a string in the format 'field'")
% (
self.__class__.__name__,
to_field,
)
)
self.to_model_name = to_model
self.to_field_name = to_field
kwargs['default'] = kwargs.get('default', 0)
kwargs['editable'] = False
super().__init__(*args, **kwargs)
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
kwargs["to_model"] = self.to_model_name
kwargs["to_field"] = self.to_field_name
return name, path, args, kwargs

View File

View File

@@ -0,0 +1,52 @@
from collections import defaultdict
from django.core.management.base import BaseCommand
from django.db.models import Count, OuterRef, Subquery
from netbox.registry import registry
class Command(BaseCommand):
help = "Force a recalculation of all cached counter fields"
@staticmethod
def collect_models():
"""
Query the registry to find all models which have one or more counter fields. Return a mapping of counter fields
to related query names for each model.
"""
models = defaultdict(dict)
for model, field_mappings in registry['counter_fields'].items():
for field_name, counter_name in field_mappings.items():
fk_field = model._meta.get_field(field_name) # Interface.device
parent_model = fk_field.related_model # Device
related_query_name = fk_field.related_query_name() # 'interfaces'
models[parent_model][counter_name] = related_query_name
return models
def update_counts(self, model, field_name, related_query):
"""
Perform a bulk update for the given model and counter field. For example,
update_counts(Device, '_interface_count', 'interfaces')
will effectively set
Device.objects.update(_interface_count=Count('interfaces'))
"""
self.stdout.write(f'Updating {model.__name__} {field_name}...')
subquery = Subquery(
model.objects.filter(pk=OuterRef('pk')).annotate(_count=Count(related_query)).values('_count')
)
return model.objects.update(**{
field_name: subquery
})
def handle(self, *model_names, **options):
for model, mappings in self.collect_models().items():
for field_name, related_query in mappings.items():
self.update_counts(model, field_name, related_query)
self.stdout.write(self.style.SUCCESS('Finished.'))

View File

@@ -0,0 +1,69 @@
from django.test import TestCase
from dcim.models import *
from utilities.testing.utils import create_test_device
class CountersTest(TestCase):
"""
Validate the operation of dict_to_filter_params().
"""
@classmethod
def setUpTestData(cls):
# Create devices
device1 = create_test_device('Device 1')
device2 = create_test_device('Device 2')
# Create interfaces
Interface.objects.create(device=device1, name='Interface 1')
Interface.objects.create(device=device1, name='Interface 2')
Interface.objects.create(device=device2, name='Interface 3')
Interface.objects.create(device=device2, name='Interface 4')
def test_interface_count_creation(self):
"""
When a tracked object (Interface) is added the tracking counter should be updated.
"""
device1, device2 = Device.objects.all()
self.assertEqual(device1.interface_count, 2)
self.assertEqual(device2.interface_count, 2)
Interface.objects.create(device=device1, name='Interface 5')
Interface.objects.create(device=device2, name='Interface 6')
device1.refresh_from_db()
device2.refresh_from_db()
self.assertEqual(device1.interface_count, 3)
self.assertEqual(device2.interface_count, 3)
def test_interface_count_deletion(self):
"""
When a tracked object (Interface) is deleted the tracking counter should be updated.
"""
device1, device2 = Device.objects.all()
self.assertEqual(device1.interface_count, 2)
self.assertEqual(device2.interface_count, 2)
Interface.objects.get(name='Interface 1').delete()
Interface.objects.get(name='Interface 3').delete()
device1.refresh_from_db()
device2.refresh_from_db()
self.assertEqual(device1.interface_count, 1)
self.assertEqual(device2.interface_count, 1)
def test_interface_count_move(self):
"""
When a tracked object (Interface) is moved the tracking counter should be updated.
"""
device1, device2 = Device.objects.all()
self.assertEqual(device1.interface_count, 2)
self.assertEqual(device2.interface_count, 2)
interface1 = Interface.objects.get(name='Interface 1')
interface1.device = device2
interface1.save()
device1.refresh_from_db()
device2.refresh_from_db()
self.assertEqual(device1.interface_count, 1)
self.assertEqual(device2.interface_count, 3)

View File

@@ -0,0 +1,78 @@
from django.db.models.query_utils import DeferredAttribute
from netbox.registry import registry
class Tracker:
"""
An ephemeral instance employed to record which tracked fields on an instance have been modified.
"""
def __init__(self):
self._changed_fields = {}
def __contains__(self, item):
return item in self._changed_fields
def set(self, name, value):
"""
Mark an attribute as having been changed and record its original value.
"""
self._changed_fields[name] = value
def get(self, name):
"""
Return the original value of a changed field. Raises KeyError if name is not found.
"""
return self._changed_fields[name]
def clear(self, *names):
"""
Clear any fields that were recorded as having been changed.
"""
for name in names:
self._changed_fields.pop(name, None)
else:
self._changed_fields = {}
class TrackingModelMixin:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Mark the instance as initialized, to enable our custom __setattr__()
self._initialized = True
@property
def tracker(self):
"""
Return the Tracker instance for this instance, first creating it if necessary.
"""
if not hasattr(self._state, "_tracker"):
self._state._tracker = Tracker()
return self._state._tracker
def save(self, *args, **kwargs):
super().save(*args, **kwargs)
# Clear any tracked fields now that changes have been saved
update_fields = kwargs.get('update_fields', [])
self.tracker.clear(*update_fields)
def __setattr__(self, name, value):
if hasattr(self, "_initialized"):
# Record any changes to a tracked field
if name in registry['counter_fields'][self.__class__]:
if name not in self.tracker:
# The attribute has been created or changed
if name in self.__dict__:
old_value = getattr(self, name)
if value != old_value:
self.tracker.set(name, old_value)
else:
self.tracker.set(name, DeferredAttribute)
elif value == self.tracker.get(name):
# A previously changed attribute has been restored
self.tracker.clear(name)
super().__setattr__(name, value)