diff --git a/octodns/record/base.py b/octodns/record/base.py index ffea1b6..a1b26b9 100644 --- a/octodns/record/base.py +++ b/octodns/record/base.py @@ -5,6 +5,7 @@ from collections import defaultdict from logging import getLogger +from ..context import ContextDict from ..equality import EqualityTupleMixin from ..idna import IdnaError, idna_decode, idna_encode from .change import Update @@ -73,7 +74,7 @@ class Record(EqualityTupleMixin): ) else: raise ValidationError(fqdn, reasons, context) - return _class(zone, name, data, source=source) + return _class(zone, name, data, source=source, context=context) @classmethod def validate(cls, name, fqdn, data): @@ -136,7 +137,7 @@ class Record(EqualityTupleMixin): def parse_rdata_texts(cls, rdatas): return [cls._value_type.parse_rdata_text(r) for r in rdatas] - def __init__(self, zone, name, data, source=None): + def __init__(self, zone, name, data, source=None, context=None): self.zone = zone if name: # internally everything is idna @@ -152,11 +153,14 @@ class Record(EqualityTupleMixin): self.decoded_name, ) self.source = source + self.context = context self.ttl = int(data['ttl']) self._octodns = data.get('octodns', {}) def _data(self): + if self.context: + return ContextDict({'ttl': self.ttl}, context=self.context) return {'ttl': self.ttl} @property @@ -225,6 +229,7 @@ class Record(EqualityTupleMixin): return Update(self, other) def copy(self, zone=None): + # data, via _data(), will preserve context data = self.data data['type'] = self._type data['octodns'] = self._octodns @@ -271,8 +276,8 @@ class ValuesMixin(object): values = [cls._value_type.parse_rdata_text(rr.rdata) for rr in rrs] return {'ttl': rr.ttl, 'type': rr._type, 'values': values} - def __init__(self, zone, name, data, source=None): - super().__init__(zone, name, data, source=source) + def __init__(self, zone, name, data, source=None, context=None): + super().__init__(zone, name, data, source=source, context=context) try: values = data['values'] except KeyError: @@ -333,8 +338,8 @@ class ValueMixin(object): 'value': cls._value_type.parse_rdata_text(rr.rdata), } - def __init__(self, zone, name, data, source=None): - super().__init__(zone, name, data, source=source) + def __init__(self, zone, name, data, source=None, context=None): + super().__init__(zone, name, data, source=source, context=context) self.value = self._value_type.process(data['value']) def changes(self, other, target): diff --git a/octodns/zone.py b/octodns/zone.py index 8855986..a1ee51d 100644 --- a/octodns/zone.py +++ b/octodns/zone.py @@ -11,15 +11,45 @@ from .record import Create, Delete class SubzoneRecordException(Exception): - pass + def __init__(self, msg, record): + self.record = record + + if record.context: + msg += f', {record.context}' + + super().__init__(msg) class DuplicateRecordException(Exception): - pass + def __init__(self, msg, existing, new): + self.existing = existing + self.new = new + + if existing.context: + if new.context: + # both have context + msg += f'\n existing: {existing.context}\n new: {new.context}' + else: + # only existing has context + msg += ( + f'\n existing: {existing.context}\n new: [UNKNOWN]' + ) + elif new.context: + # only new has context + msg += f'\n existing: [UNKNOWN]\n new: {new.context}' + # else no one has context + + super().__init__(msg) class InvalidNodeException(Exception): - pass + def __init__(self, msg, record): + self.record = record + + if record.context: + msg += f', {record.context}' + + super().__init__(msg) class Zone(object): @@ -113,7 +143,8 @@ class Zone(object): if not record._type == 'NS': # and not a NS record, this should be in the sub raise SubzoneRecordException( - f'Record {record.fqdn} is a managed sub-zone and not of type NS' + f'Record {record.fqdn} is a managed sub-zone and not of type NS', + record, ) else: # It's not an exact match so there has to be a `.` before the @@ -122,7 +153,8 @@ class Zone(object): if name.endswith(f'.{sub_zone}'): # this should be in a sub raise SubzoneRecordException( - f'Record {record.fqdn} is under a managed subzone' + f'Record {record.fqdn} is under a managed subzone', + record, ) if replace: @@ -132,8 +164,11 @@ class Zone(object): node = self._records[name] if record in node: # We already have a record at this node of this type + existing = [c for c in node if c == record][0] raise DuplicateRecordException( - f'Duplicate record {record.fqdn}, ' f'type {record._type}' + f'Duplicate record {record.fqdn}, type {record._type}', + existing, + record, ) elif not lenient and ( (record._type == 'CNAME' and len(node) > 0) @@ -142,9 +177,8 @@ class Zone(object): # We're adding a CNAME to existing records or adding to an existing # CNAME raise InvalidNodeException( - 'Invalid state, CNAME at ' - f'{record.fqdn} cannot coexist with ' - 'other records' + f'Invalid state, CNAME at {record.fqdn} cannot coexist with other records', + record, ) if record._type == 'NS' and record.name == '': diff --git a/tests/test_octodns_provider_yaml.py b/tests/test_octodns_provider_yaml.py index 1cf017a..af89f78 100644 --- a/tests/test_octodns_provider_yaml.py +++ b/tests/test_octodns_provider_yaml.py @@ -260,10 +260,13 @@ xn--dj-kia8a: zone = Zone('unit.tests.', ['sub']) with self.assertRaises(SubzoneRecordException) as ctx: source.populate(zone) - self.assertEqual( - 'Record www.sub.unit.tests. is under a managed subzone', - str(ctx.exception), + msg = str(ctx.exception) + self.assertTrue( + msg.startswith( + 'Record www.sub.unit.tests. is under a managed subzone' + ) ) + self.assertTrue(msg.endswith('unit.tests.yaml, line 201, column 3')) def test_SUPPORTS(self): source = YamlProvider('test', join(dirname(__file__), 'config')) @@ -536,10 +539,13 @@ class TestSplitYamlProvider(TestCase): zone = Zone('unit.tests.', ['sub']) with self.assertRaises(SubzoneRecordException) as ctx: source.populate(zone) - self.assertEqual( - 'Record www.sub.unit.tests. is under a managed subzone', - str(ctx.exception), + msg = str(ctx.exception) + self.assertTrue( + msg.startswith( + 'Record www.sub.unit.tests. is under a managed subzone' + ) ) + self.assertTrue(msg.endswith('www.sub.yaml, line 3, column 3')) def test_copy(self): # going to put some sentinal values in here to ensure, these aren't diff --git a/tests/test_octodns_record.py b/tests/test_octodns_record.py index 29561c3..4aaa989 100644 --- a/tests/test_octodns_record.py +++ b/tests/test_octodns_record.py @@ -628,3 +628,13 @@ class TestRecordValidation(TestCase): ContextDict({'ttl': 42, 'value': '1.2.3.4'}, context='needle'), ) self.assertTrue('needle' in str(ctx.exception)) + + def test_context_copied_to_record(self): + record = Record.new( + self.zone, + 'www', + ContextDict( + {'ttl': 42, 'type': 'A', 'value': '1.2.3.4'}, context='needle' + ), + ) + self.assertEqual('needle', record.context) diff --git a/tests/test_octodns_zone.py b/tests/test_octodns_zone.py index 6563105..a01b6e9 100644 --- a/tests/test_octodns_zone.py +++ b/tests/test_octodns_zone.py @@ -6,6 +6,7 @@ from unittest import TestCase from helpers import SimpleProvider +from octodns.context import ContextDict from octodns.idna import idna_encode from octodns.record import ( AaaaRecord, @@ -106,6 +107,62 @@ class TestZone(TestCase): zone.add_record(b) self.assertEqual(zone.records, set([a, b])) + def test_duplicate_context_handling(self): + zone = Zone('unit.tests.', []) + + # these will be ==, but one has context and the other doesn't + no_context = ARecord(zone, 'a', {'ttl': 42, 'value': '1.1.1.1'}) + has_context = ARecord( + zone, 'a', {'ttl': 42, 'value': '1.1.1.1'}, context='hello world' + ) + + # both have context + zone.add_record(has_context) + with self.assertRaises(DuplicateRecordException) as ctx: + zone.add_record(has_context) + self.assertEqual(has_context, ctx.exception.existing) + self.assertEqual(has_context, ctx.exception.new) + zone.remove_record(has_context) + self.assertEqual( + [ + 'Duplicate record a.unit.tests., type A', + ' existing: hello world', + ' new: hello world', + ], + str(ctx.exception).split('\n'), + ) + + # new has context + zone.add_record(no_context) + with self.assertRaises(DuplicateRecordException) as ctx: + zone.add_record(has_context) + self.assertEqual(no_context, ctx.exception.existing) + self.assertEqual(has_context, ctx.exception.new) + zone.remove_record(no_context) + self.assertEqual( + [ + 'Duplicate record a.unit.tests., type A', + ' existing: [UNKNOWN]', + ' new: hello world', + ], + str(ctx.exception).split('\n'), + ) + + # existing has context + zone.add_record(has_context) + with self.assertRaises(DuplicateRecordException) as ctx: + zone.add_record(no_context) + self.assertEqual(has_context, ctx.exception.existing) + self.assertEqual(no_context, ctx.exception.new) + self.assertEqual( + [ + 'Duplicate record a.unit.tests., type A', + ' existing: hello world', + ' new: [UNKNOWN]', + ], + str(ctx.exception).split('\n'), + ) + def test_changes(self): before = Zone('unit.tests.', []) a = ARecord(before, 'a', {'ttl': 42, 'value': '1.1.1.1'}) @@ -242,9 +299,11 @@ class TestZone(TestCase): 'sub', {'ttl': 3600, 'type': 'A', 'values': ['1.2.3.4', '2.3.4.5']}, ) + record.context = 'added context' with self.assertRaises(SubzoneRecordException) as ctx: zone.add_record(record) self.assertTrue('not of type NS', str(ctx.exception)) + self.assertTrue(', added context' in str(ctx.exception)) # Can add it w/lenient zone.add_record(record, lenient=True) self.assertEqual(set([record]), zone.records) @@ -328,11 +387,13 @@ class TestZone(TestCase): cname = Record.new( zone, 'www', {'ttl': 60, 'type': 'CNAME', 'value': 'foo.bar.com.'} ) + cname.context = 'has some context' # add cname to a zone.add_record(a) - with self.assertRaises(InvalidNodeException): + with self.assertRaises(InvalidNodeException) as ctx: zone.add_record(cname) + self.assertTrue(', has some context' in str(ctx.exception)) self.assertEqual(set([a]), zone.records) zone.add_record(cname, lenient=True) self.assertEqual(set([a, cname]), zone.records) @@ -501,6 +562,22 @@ class TestZone(TestCase): # Doesn't the second self.assertFalse(copy.hydrate()) + def test_copy_context(self): + zone = Zone('unit.tests.', []) + + no_context = Record.new( + zone, 'a', {'ttl': 42, 'type': 'A', 'value': '1.1.1.1'} + ) + self.assertFalse(no_context.context) + self.assertFalse(no_context.copy().context) + + data = ContextDict( + {'ttl': 42, 'type': 'A', 'value': '1.1.1.1'}, context='hello world' + ) + has_context = Record.new(zone, 'a', data) + self.assertTrue(has_context.context) + self.assertTrue(has_context.copy().context) + def test_root_ns(self): zone = Zone('unit.tests.', [])