diff --git a/CHANGELOG.md b/CHANGELOG.md index 8dde47b..b4cd240 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ * Include the octodns special section info in Record __repr__, makes it easier to debug things with providers that have special functionality configured there. +* Most processor.filter processors now support an include_target flag that can + be set to False to leave the target zone data untouched, thus remove any + existing filtered records. Default behavior is unchanged and filtered records + will be completely invisible to octoDNS ## v1.2.1 - 2023-09-29 - Now with fewer stale files diff --git a/octodns/processor/filter.py b/octodns/processor/filter.py index 2de3b23..eeab039 100644 --- a/octodns/processor/filter.py +++ b/octodns/processor/filter.py @@ -9,6 +9,20 @@ from ..record.exception import ValidationError from .base import BaseProcessor +class _FilterProcessor(BaseProcessor): + def __init__(self, name, include_target=True, **kwargs): + super().__init__(name, **kwargs) + self.include_target = include_target + + def process_source_zone(self, *args, **kwargs): + return self._process(*args, **kwargs) + + def process_target_zone(self, existing, *args, **kwargs): + if self.include_target: + return self._process(existing, *args, **kwargs) + return existing + + class AllowsMixin: def matches(self, zone, record): pass @@ -25,9 +39,9 @@ class RejectsMixin: pass -class _TypeBaseFilter(BaseProcessor): - def __init__(self, name, _list): - super().__init__(name) +class _TypeBaseFilter(_FilterProcessor): + def __init__(self, name, _list, **kwargs): + super().__init__(name, **kwargs) self._list = set(_list) def _process(self, zone, *args, **kwargs): @@ -39,9 +53,6 @@ class _TypeBaseFilter(BaseProcessor): return zone - process_source_zone = _process - process_target_zone = _process - class TypeAllowlistFilter(_TypeBaseFilter, AllowsMixin): '''Only manage records of the specified type(s). @@ -65,8 +76,8 @@ class TypeAllowlistFilter(_TypeBaseFilter, AllowsMixin): - ns1 ''' - def __init__(self, name, allowlist): - super().__init__(name, allowlist) + def __init__(self, name, allowlist, **kwargs): + super().__init__(name, allowlist, **kwargs) class TypeRejectlistFilter(_TypeBaseFilter, RejectsMixin): @@ -90,13 +101,13 @@ class TypeRejectlistFilter(_TypeBaseFilter, RejectsMixin): - route53 ''' - def __init__(self, name, rejectlist): - super().__init__(name, rejectlist) + def __init__(self, name, rejectlist, **kwargs): + super().__init__(name, rejectlist, **kwargs) -class _NameBaseFilter(BaseProcessor): - def __init__(self, name, _list): - super().__init__(name) +class _NameBaseFilter(_FilterProcessor): + def __init__(self, name, _list, **kwargs): + super().__init__(name, **kwargs) exact = set() regex = [] for pattern in _list: @@ -121,9 +132,6 @@ class _NameBaseFilter(BaseProcessor): return zone - process_source_zone = _process - process_target_zone = _process - class NameAllowlistFilter(_NameBaseFilter, AllowsMixin): '''Only manage records with names that match the provider patterns @@ -269,7 +277,7 @@ class ExcludeRootNsChanges(BaseProcessor): return plan -class ZoneNameFilter(BaseProcessor): +class ZoneNameFilter(_FilterProcessor): '''Filter or error on record names that contain the zone name Example usage: @@ -291,8 +299,8 @@ class ZoneNameFilter(BaseProcessor): - azure ''' - def __init__(self, name, error=True): - super().__init__(name) + def __init__(self, name, error=True, **kwargs): + super().__init__(name, **kwargs) self.error = error def _process(self, zone, *args, **kwargs): @@ -314,6 +322,3 @@ class ZoneNameFilter(BaseProcessor): zone.remove_record(record) return zone - - process_source_zone = _process - process_target_zone = _process diff --git a/tests/test_octodns_processor_filter.py b/tests/test_octodns_processor_filter.py index 2d9b881..d4c36ef 100644 --- a/tests/test_octodns_processor_filter.py +++ b/tests/test_octodns_processor_filter.py @@ -54,6 +54,22 @@ class TestTypeAllowListFilter(TestCase): ['a', 'a2', 'aaaa'], sorted([r.name for r in got.records]) ) + def test_include_target(self): + filter_txt = TypeAllowlistFilter( + 'only-txt', ['TXT'], include_target=False + ) + + # as a source we don't see them + got = filter_txt.process_source_zone(zone.copy()) + self.assertEqual(['txt', 'txt2'], sorted([r.name for r in got.records])) + + # but as a target we do b/c it's not included + got = filter_txt.process_target_zone(zone.copy()) + self.assertEqual( + ['a', 'a2', 'aaaa', 'txt', 'txt2'], + sorted([r.name for r in got.records]), + ) + class TestTypeRejectListFilter(TestCase): def test_basics(self):