From 6991786cc5fee60552891e4e7367f7785094b098 Mon Sep 17 00:00:00 2001 From: Ross McFarland Date: Fri, 26 Apr 2024 16:01:42 -0700 Subject: [PATCH] Fix for EnsureTrailingDots reverting value types back to strings --- CHANGELOG.md | 2 + octodns/processor/trailing_dots.py | 12 ++- tests/test_octodns_processor_trailing_dots.py | 97 +++++++++++++++++-- 3 files changed, 99 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b8f99c1..cbc69c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ where lots of changes are expected frequently to live along side zones where little or no churn is expected. * AutoArpa gained support for prioritizing values +* Fix for EnsureTrailingDots reverting value types back to strings which then + failed when rr methods were used on them (e.g. w/octodns-bind) ## v1.6.1 - 2024-03-17 - Didn't we do this already diff --git a/octodns/processor/trailing_dots.py b/octodns/processor/trailing_dots.py index 31ac740..d50775e 100644 --- a/octodns/processor/trailing_dots.py +++ b/octodns/processor/trailing_dots.py @@ -14,7 +14,9 @@ def _ensure_trailing_dots(record, prop): for value in new.values: val = getattr(value, prop) if val[-1] != '.': - setattr(value, prop, f'{val}.') + # these will generally be str, but just in case we'll use the + # constructor + setattr(value, prop, val.__class__(f'{val}.')) return new @@ -24,14 +26,18 @@ class EnsureTrailingDots(BaseProcessor): _type = record._type if _type in ('ALIAS', 'CNAME', 'DNAME') and record.value[-1] != '.': new = record.copy() - new.value = f'{new.value}.' + # we need to preserve the value type (class) here and there's no + # way to change a strings value, these all inherit from string, + # so we need to create a new one of the same type + new.value = new.value.__class__(f'{new.value}.') desired.add_record(new, replace=True) elif _type in ('NS', 'PTR') and any( v[-1] != '.' for v in record.values ): new = record.copy() + klass = new.values[0].__class__ new.values = [ - v if v[-1] == '.' else f'{v}.' for v in record.values + v if v[-1] == '.' else klass(f'{v}.') for v in record.values ] desired.add_record(new, replace=True) elif _type == 'MX' and _no_trailing_dot(record, 'exchange'): diff --git a/tests/test_octodns_processor_trailing_dots.py b/tests/test_octodns_processor_trailing_dots.py index 0118b4b..937f534 100644 --- a/tests/test_octodns_processor_trailing_dots.py +++ b/tests/test_octodns_processor_trailing_dots.py @@ -10,6 +10,11 @@ from octodns.processor.trailing_dots import ( _no_trailing_dot, ) from octodns.record import Record +from octodns.record.alias import AliasValue +from octodns.record.cname import CnameValue +from octodns.record.dname import DnameValue +from octodns.record.ns import NsValue +from octodns.record.ptr import PtrValue from octodns.zone import Zone @@ -37,22 +42,61 @@ class EnsureTrailingDotsTest(TestCase): zone.add_record(missing) got = etd.process_source_zone(zone, None) + self.assertEqual('absolute.target.', _find(got, 'has').value) self.assertEqual('relative.target.', _find(got, 'missing').value) + # ensure types were preserved + self.assertIsInstance(_find(got, 'has').value, CnameValue) + self.assertIsInstance(_find(got, 'missing').value, CnameValue) + + def test_alias(self): + etd = EnsureTrailingDots('test') + + zone = Zone('unit.tests.', []) + has = Record.new( + zone, + 'has', + {'type': 'ALIAS', 'ttl': 42, 'value': 'absolute.target.'}, + lenient=True, + ) + zone.add_record(has) + missing = Record.new( + zone, + 'missing', + {'type': 'ALIAS', 'ttl': 42, 'value': 'relative.target'}, + lenient=True, + ) + zone.add_record(missing) - # HACK: this should never be done to records outside of specific testing - # situations like this - has._type = 'ALIAS' - missing._type = 'ALIAS' got = etd.process_source_zone(zone, None) self.assertEqual('absolute.target.', _find(got, 'has').value) self.assertEqual('relative.target.', _find(got, 'missing').value) + self.assertIsInstance(_find(got, 'has').value, AliasValue) + self.assertIsInstance(_find(got, 'missing').value, AliasValue) + + def test_dname(self): + etd = EnsureTrailingDots('test') + + zone = Zone('unit.tests.', []) + has = Record.new( + zone, + 'has', + {'type': 'DNAME', 'ttl': 42, 'value': 'absolute.target.'}, + ) + zone.add_record(has) + missing = Record.new( + zone, + 'missing', + {'type': 'DNAME', 'ttl': 42, 'value': 'relative.target'}, + lenient=True, + ) + zone.add_record(missing) - has._type = 'DNAME' - missing._type = 'DNAME' got = etd.process_source_zone(zone, None) self.assertEqual('absolute.target.', _find(got, 'has').value) self.assertEqual('relative.target.', _find(got, 'missing').value) + self.assertIsInstance(_find(got, 'has').value, DnameValue) + self.assertIsInstance(_find(got, 'missing').value, DnameValue) def test_mx(self): etd = EnsureTrailingDots('test') @@ -114,13 +158,48 @@ class EnsureTrailingDotsTest(TestCase): got = etd.process_source_zone(zone, None) got = next(iter(got.records)) self.assertEqual(['absolute.target.', 'relative.target.'], got.values) + self.assertIsInstance(got.values[0], NsValue) + self.assertIsInstance(got.values[1], NsValue) + + # again, but this time nothing to fix so that we fully use up the + # generator + zone = Zone('unit.tests.', []) + record = Record.new( + zone, + 'record', + { + 'type': 'NS', + 'ttl': 42, + 'values': ['absolute.target.', 'another.target.'], + }, + ) + zone.add_record(record) + + got = etd.process_source_zone(zone, None) + got = next(iter(got.records)) + self.assertEqual(['absolute.target.', 'another.target.'], got.values) + + def test_ptr(self): + etd = EnsureTrailingDots('test') + + zone = Zone('unit.tests.', []) + record = Record.new( + zone, + 'record', + { + 'type': 'PTR', + 'ttl': 42, + 'values': ['absolute.target.', 'relative.target'], + }, + lenient=True, + ) + zone.add_record(record) - # HACK: this should never be done to records outside of specific testing - # situations like this - record._type = 'PTR' got = etd.process_source_zone(zone, None) got = next(iter(got.records)) self.assertEqual(['absolute.target.', 'relative.target.'], got.values) + self.assertIsInstance(got.values[0], PtrValue) + self.assertIsInstance(got.values[1], PtrValue) def test_srv(self): etd = EnsureTrailingDots('test')