1
0
mirror of https://github.com/github/octodns.git synced 2024-05-11 05:55:00 +00:00

Major refactoring of record validation to better support (planned) complex/dynamic record types

This commit is contained in:
Ross McFarland
2018-11-29 15:15:12 -08:00
parent c41824c3e9
commit 2829862ea5
3 changed files with 349 additions and 224 deletions

View File

@@ -253,6 +253,7 @@ class _ValuesMixin(object):
@classmethod
def validate(cls, name, data):
reasons = super(_ValuesMixin, cls).validate(name, data)
values = []
try:
values = data['values']
@@ -279,12 +280,11 @@ class _ValuesMixin(object):
reasons.append('empty value')
values = []
else:
values = [value]
values = value
except KeyError:
reasons.append('missing value(s)')
for value in values:
reasons.extend(cls._validate_value(value))
reasons.extend(cls._value_type.validate(values))
return reasons
@@ -339,8 +339,7 @@ class _GeoMixin(_ValuesMixin):
# TODO: validate legal codes
for code, values in geo.items():
reasons.extend(GeoValue._validate_geo(code))
for value in values:
reasons.extend(cls._validate_value(value))
reasons.extend(cls._value_type.validate(values))
except KeyError:
pass
return reasons
@@ -353,6 +352,8 @@ class _GeoMixin(_ValuesMixin):
super(_GeoMixin, self).__init__(zone, name, data, *args, **kwargs)
try:
self.geo = dict(data['geo'])
self.log.warn("'geo' support has been deprecated, "
"transition %s to use 'dynamic'", name)
except KeyError:
self.geo = {}
for code, values in self.geo.items():
@@ -397,7 +398,7 @@ class _ValueMixin(object):
except KeyError:
reasons.append('missing value')
if value:
reasons.extend(cls._validate_value(value))
reasons.extend(cls._value_type.validate(value, cls))
return reasons
def __init__(self, zone, name, data, source=None):
@@ -421,59 +422,97 @@ class _ValueMixin(object):
self.fqdn, self.value)
class _DynamicBaseMixin(object):
pass
class _DynamicValuesMixin(_DynamicBaseMixin, _GeoMixin):
pass
class _DynamicValueMixin(_DynamicBaseMixin, _ValueMixin):
pass
class ARecord(_DynamicValuesMixin, Record):
_type = 'A'
class _DynamicMixin(object):
@classmethod
def _validate_value(self, value):
reasons = []
def validate(cls, name, data):
reasons = super(_DynamicMixin, cls).validate(name, data)
try:
IPv4Address(unicode(value))
except Exception:
reasons.append('invalid ip address "{}"'.format(value))
pools = data['dynamic']['pools']
except KeyError:
pools = {}
for pool in sorted(pools.values()):
reasons.extend(cls._value_type.validate(pool))
return reasons
def __init__(self, zone, name, data, *args, **kwargs):
super(_DynamicMixin, self).__init__(zone, name, data, *args,
**kwargs)
try:
self.dynamic = dict(data['dynamic'])
except:
self.dynamic = {}
# TODO:
class Ipv4List(object):
@classmethod
def validate(cls, data):
if not isinstance(data, (list, tuple)):
data = (data,)
reasons = []
for value in data:
try:
IPv4Address(unicode(value))
except Exception:
reasons.append('invalid IPv4 address "{}"'.format(value))
return reasons
class Ipv6List(object):
@classmethod
def validate(cls, data):
if not isinstance(data, (list, tuple)):
data = (data,)
reasons = []
for value in data:
try:
IPv6Address(unicode(value))
except Exception:
reasons.append('invalid IPv6 address "{}"'.format(value))
return reasons
class _TargetValue(object):
@classmethod
def validate(cls, data, record_cls):
reasons = []
if not data.endswith('.'):
reasons.append('{} value "{}" missing trailing .'
.format(record_cls._type, data))
return reasons
class CnameValue(_TargetValue):
pass
class ARecord(_DynamicMixin, _GeoMixin, Record):
_type = 'A'
_value_type = Ipv4List
def _process_values(self, values):
return values
class AaaaRecord(_GeoMixin, Record):
_type = 'AAAA'
@classmethod
def _validate_value(self, value):
reasons = []
try:
IPv6Address(unicode(value))
except Exception:
reasons.append('invalid ip address "{}"'.format(value))
return reasons
_value_type = Ipv6List
def _process_values(self, values):
return values
class AliasValue(_TargetValue):
pass
class AliasRecord(_ValueMixin, Record):
_type = 'ALIAS'
@classmethod
def _validate_value(self, value):
reasons = []
if not value.endswith('.'):
reasons.append('missing trailing .')
return reasons
_value_type = AliasValue
def _process_value(self, value):
return value
@@ -483,20 +522,22 @@ class CaaValue(object):
# https://tools.ietf.org/html/rfc6844#page-5
@classmethod
def _validate_value(cls, value):
def validate(cls, data):
if not isinstance(data, (list, tuple)):
data = (data,)
reasons = []
try:
flags = int(value.get('flags', 0))
if flags < 0 or flags > 255:
reasons.append('invalid flags "{}"'.format(flags))
except ValueError:
reasons.append('invalid flags "{}"'.format(value['flags']))
if 'tag' not in value:
reasons.append('missing tag')
if 'value' not in value:
reasons.append('missing value')
for value in data:
try:
flags = int(value.get('flags', 0))
if flags < 0 or flags > 255:
reasons.append('invalid flags "{}"'.format(flags))
except ValueError:
reasons.append('invalid flags "{}"'.format(value['flags']))
if 'tag' not in value:
reasons.append('missing tag')
if 'value' not in value:
reasons.append('missing value')
return reasons
def __init__(self, value):
@@ -525,10 +566,7 @@ class CaaValue(object):
class CaaRecord(_ValuesMixin, Record):
_type = 'CAA'
@classmethod
def _validate_value(cls, value):
return CaaValue._validate_value(value)
_value_type = CaaValue
def _process_values(self, values):
return [CaaValue(v) for v in values]
@@ -536,6 +574,7 @@ class CaaRecord(_ValuesMixin, Record):
class CnameRecord(_ValueMixin, Record):
_type = 'CNAME'
_value_type = CnameValue
@classmethod
def validate(cls, name, data):
@@ -545,13 +584,6 @@ class CnameRecord(_ValueMixin, Record):
reasons.extend(super(CnameRecord, cls).validate(name, data))
return reasons
@classmethod
def _validate_value(cls, value):
reasons = []
if not value.endswith('.'):
reasons.append('missing trailing .')
return reasons
def _process_value(self, value):
return value
@@ -559,25 +591,29 @@ class CnameRecord(_ValueMixin, Record):
class MxValue(object):
@classmethod
def _validate_value(cls, value):
def validate(cls, data):
if not isinstance(data, (list, tuple)):
data = (data,)
reasons = []
try:
for value in data:
try:
int(value['preference'])
try:
int(value['preference'])
except KeyError:
int(value['priority'])
except KeyError:
int(value['priority'])
except KeyError:
reasons.append('missing preference')
except ValueError:
reasons.append('invalid preference "{}"'
.format(value['preference']))
exchange = None
try:
exchange = value.get('exchange', None) or value['value']
if not exchange.endswith('.'):
reasons.append('missing trailing .')
except KeyError:
reasons.append('missing exchange')
reasons.append('missing preference')
except ValueError:
reasons.append('invalid preference "{}"'
.format(value['preference']))
exchange = None
try:
exchange = value.get('exchange', None) or value['value']
if not exchange.endswith('.'):
reasons.append('MX value "{}" missing trailing .'
.format(exchange))
except KeyError:
reasons.append('missing exchange')
return reasons
def __init__(self, value):
@@ -612,10 +648,7 @@ class MxValue(object):
class MxRecord(_ValuesMixin, Record):
_type = 'MX'
@classmethod
def _validate_value(cls, value):
return MxValue._validate_value(value)
_value_type = MxValue
def _process_values(self, values):
return [MxValue(v) for v in values]
@@ -625,32 +658,36 @@ class NaptrValue(object):
VALID_FLAGS = ('S', 'A', 'U', 'P')
@classmethod
def _validate_value(cls, data):
def validate(cls, data):
if not isinstance(data, (list, tuple)):
data = (data,)
reasons = []
try:
int(data['order'])
except KeyError:
reasons.append('missing order')
except ValueError:
reasons.append('invalid order "{}"'.format(data['order']))
try:
int(data['preference'])
except KeyError:
reasons.append('missing preference')
except ValueError:
reasons.append('invalid preference "{}"'
.format(data['preference']))
try:
flags = data['flags']
if flags not in cls.VALID_FLAGS:
reasons.append('unrecognized flags "{}"'.format(flags))
except KeyError:
reasons.append('missing flags')
for value in data:
try:
int(value['order'])
except KeyError:
reasons.append('missing order')
except ValueError:
reasons.append('invalid order "{}"'.format(value['order']))
try:
int(value['preference'])
except KeyError:
reasons.append('missing preference')
except ValueError:
reasons.append('invalid preference "{}"'
.format(value['preference']))
try:
flags = value['flags']
if flags not in cls.VALID_FLAGS:
reasons.append('unrecognized flags "{}"'.format(flags))
except KeyError:
reasons.append('missing flags')
# TODO: validate these... they're non-trivial
for k in ('service', 'regexp', 'replacement'):
if k not in value:
reasons.append('missing {}'.format(k))
# TODO: validate these... they're non-trivial
for k in ('service', 'regexp', 'replacement'):
if k not in data:
reasons.append('missing {}'.format(k))
return reasons
def __init__(self, value):
@@ -696,38 +733,41 @@ class NaptrValue(object):
class NaptrRecord(_ValuesMixin, Record):
_type = 'NAPTR'
@classmethod
def _validate_value(cls, value):
return NaptrValue._validate_value(value)
_value_type = NaptrValue
def _process_values(self, values):
return [NaptrValue(v) for v in values]
class NsRecord(_ValuesMixin, Record):
_type = 'NS'
class _NsValue(object):
@classmethod
def _validate_value(cls, value):
def validate(cls, data):
if not isinstance(data, (list, tuple)):
data = (data,)
reasons = []
if not value.endswith('.'):
reasons.append('missing trailing .')
for value in data:
if not value.endswith('.'):
reasons.append('NS value "{}" missing trailing .'
.format(value))
return reasons
class NsRecord(_ValuesMixin, Record):
_type = 'NS'
_value_type = _NsValue
def _process_values(self, values):
return values
class PtrValue(_TargetValue):
pass
class PtrRecord(_ValueMixin, Record):
_type = 'PTR'
@classmethod
def _validate_value(cls, value):
reasons = []
if not value.endswith('.'):
reasons.append('missing trailing .')
return reasons
_value_type = PtrValue
def _process_value(self, value):
return value
@@ -738,28 +778,33 @@ class SshfpValue(object):
VALID_FINGERPRINT_TYPES = (1, 2)
@classmethod
def _validate_value(cls, value):
def validate(cls, data):
if not isinstance(data, (list, tuple)):
data = (data,)
reasons = []
try:
algorithm = int(value['algorithm'])
if algorithm not in cls.VALID_ALGORITHMS:
reasons.append('unrecognized algorithm "{}"'.format(algorithm))
except KeyError:
reasons.append('missing algorithm')
except ValueError:
reasons.append('invalid algorithm "{}"'.format(value['algorithm']))
try:
fingerprint_type = int(value['fingerprint_type'])
if fingerprint_type not in cls.VALID_FINGERPRINT_TYPES:
reasons.append('unrecognized fingerprint_type "{}"'
.format(fingerprint_type))
except KeyError:
reasons.append('missing fingerprint_type')
except ValueError:
reasons.append('invalid fingerprint_type "{}"'
.format(value['fingerprint_type']))
if 'fingerprint' not in value:
reasons.append('missing fingerprint')
for value in data:
try:
algorithm = int(value['algorithm'])
if algorithm not in cls.VALID_ALGORITHMS:
reasons.append('unrecognized algorithm "{}"'
.format(algorithm))
except KeyError:
reasons.append('missing algorithm')
except ValueError:
reasons.append('invalid algorithm "{}"'
.format(value['algorithm']))
try:
fingerprint_type = int(value['fingerprint_type'])
if fingerprint_type not in cls.VALID_FINGERPRINT_TYPES:
reasons.append('unrecognized fingerprint_type "{}"'
.format(fingerprint_type))
except KeyError:
reasons.append('missing fingerprint_type')
except ValueError:
reasons.append('invalid fingerprint_type "{}"'
.format(value['fingerprint_type']))
if 'fingerprint' not in value:
reasons.append('missing fingerprint')
return reasons
def __init__(self, value):
@@ -789,26 +834,15 @@ class SshfpValue(object):
class SshfpRecord(_ValuesMixin, Record):
_type = 'SSHFP'
@classmethod
def _validate_value(cls, value):
return SshfpValue._validate_value(value)
_value_type = SshfpValue
def _process_values(self, values):
return [SshfpValue(v) for v in values]
_unescaped_semicolon_re = re.compile(r'\w;')
class _ChunkedValuesMixin(_ValuesMixin):
CHUNK_SIZE = 255
@classmethod
def _validate_value(cls, value):
if _unescaped_semicolon_re.search(value):
return ['unescaped ;']
return []
_unescaped_semicolon_re = re.compile(r'\w;')
def _process_values(self, values):
ret = []
@@ -830,39 +864,59 @@ class _ChunkedValuesMixin(_ValuesMixin):
return values
class _ChunkedValue(object):
_unescaped_semicolon_re = re.compile(r'\w;')
@classmethod
def validate(cls, data):
if not isinstance(data, (list, tuple)):
data = (data,)
reasons = []
for value in data:
if cls._unescaped_semicolon_re.search(value):
reasons.append('unescaped ; in "{}"'.format(value))
return reasons
class SpfRecord(_ChunkedValuesMixin, Record):
_type = 'SPF'
_value_type = _ChunkedValue
class SrvValue(object):
@classmethod
def _validate_value(self, value):
def validate(cls, data):
if not isinstance(data, (list, tuple)):
data = (data,)
reasons = []
# TODO: validate algorithm and fingerprint_type values
try:
int(value['priority'])
except KeyError:
reasons.append('missing priority')
except ValueError:
reasons.append('invalid priority "{}"'.format(value['priority']))
try:
int(value['weight'])
except KeyError:
reasons.append('missing weight')
except ValueError:
reasons.append('invalid weight "{}"'.format(value['weight']))
try:
int(value['port'])
except KeyError:
reasons.append('missing port')
except ValueError:
reasons.append('invalid port "{}"'.format(value['port']))
try:
if not value['target'].endswith('.'):
reasons.append('missing trailing .')
except KeyError:
reasons.append('missing target')
for value in data:
# TODO: validate algorithm and fingerprint_type values
try:
int(value['priority'])
except KeyError:
reasons.append('missing priority')
except ValueError:
reasons.append('invalid priority "{}"'
.format(value['priority']))
try:
int(value['weight'])
except KeyError:
reasons.append('missing weight')
except ValueError:
reasons.append('invalid weight "{}"'.format(value['weight']))
try:
int(value['port'])
except KeyError:
reasons.append('missing port')
except ValueError:
reasons.append('invalid port "{}"'.format(value['port']))
try:
if not value['target'].endswith('.'):
reasons.append('SRV value "{}" missing trailing .'
.format(value['target']))
except KeyError:
reasons.append('missing target')
return reasons
def __init__(self, value):
@@ -896,6 +950,7 @@ class SrvValue(object):
class SrvRecord(_ValuesMixin, Record):
_type = 'SRV'
_value_type = SrvValue
_name_re = re.compile(r'^_[^\.]+\.[^\.]+')
@classmethod
@@ -906,13 +961,14 @@ class SrvRecord(_ValuesMixin, Record):
reasons.extend(super(SrvRecord, cls).validate(name, data))
return reasons
@classmethod
def _validate_value(cls, value):
return SrvValue._validate_value(value)
def _process_values(self, values):
return [SrvValue(v) for v in values]
class _TxtValue(_ChunkedValue):
pass
class TxtRecord(_ChunkedValuesMixin, Record):
_type = 'TXT'
_value_type = _TxtValue

View File

@@ -2,17 +2,12 @@
a:
dynamic:
pools:
ams:
values:
- 1.1.1.1
ams: 1.1.1.1
iad:
values:
- 2.2.2.2
- 3.3.3.3
lax:
value: 4.4.4.4
sea:
value: 5.5.5.5
lax: 4.4.4.4
sea: 5.5.5.5
rules:
- geo: EU-UK
pools:
@@ -28,8 +23,7 @@ a:
pools:
25: iad
75: sea
- default:
pool: iad
- pool: iad
type: A
values:
- 2.2.2.2
@@ -60,7 +54,7 @@ cname:
pools:
12: sea
250: iad
- default:
- pools:
1: sea
4: iad
type: CNAME
@@ -69,12 +63,12 @@ simple-weighted:
dynamic:
pools:
one:
one.unit.tests.
value: one.unit.tests.
two:
two.unit.tests.
value: two.unit.tests.
rules:
- default:
100: one
200: two
- pools:
100: one
200: two
type: CNAME
value: default.unit.tests.

View File

@@ -941,7 +941,7 @@ class TestRecordValidation(TestCase):
'ttl': 600,
'value': 'hello'
})
self.assertEquals(['invalid ip address "hello"'],
self.assertEquals(['invalid IPv4 address "hello"'],
ctx.exception.reasons)
# invalid ip addresses
@@ -952,8 +952,8 @@ class TestRecordValidation(TestCase):
'values': ['hello', 'goodbye']
})
self.assertEquals([
'invalid ip address "hello"',
'invalid ip address "goodbye"'
'invalid IPv4 address "hello"',
'invalid IPv4 address "goodbye"'
], ctx.exception.reasons)
# invalid & valid ip addresses, no ttl
@@ -964,7 +964,7 @@ class TestRecordValidation(TestCase):
})
self.assertEquals([
'missing ttl',
'invalid ip address "hello"',
'invalid IPv4 address "hello"',
], ctx.exception.reasons)
def test_geo(self):
@@ -989,7 +989,7 @@ class TestRecordValidation(TestCase):
'ttl': 600,
'value': '1.2.3.4',
})
self.assertEquals(['invalid ip address "hello"'],
self.assertEquals(['invalid IPv4 address "hello"'],
ctx.exception.reasons)
# invalid geo code
@@ -1016,8 +1016,8 @@ class TestRecordValidation(TestCase):
'value': '1.2.3.4',
})
self.assertEquals([
'invalid ip address "hello"',
'invalid ip address "goodbye"'
'invalid IPv4 address "hello"',
'invalid IPv4 address "goodbye"'
], ctx.exception.reasons)
# invalid healthcheck protocol
@@ -1062,16 +1062,21 @@ class TestRecordValidation(TestCase):
'ttl': 600,
'value': 'hello'
})
self.assertEquals(['invalid ip address "hello"'],
self.assertEquals(['invalid IPv6 address "hello"'],
ctx.exception.reasons)
with self.assertRaises(ValidationError) as ctx:
Record.new(self.zone, '', {
'type': 'AAAA',
'ttl': 600,
'value': '1.2.3.4'
'values': [
'1.2.3.4',
'2.3.4.5',
],
})
self.assertEquals(['invalid ip address "1.2.3.4"'],
ctx.exception.reasons)
self.assertEquals([
'invalid IPv6 address "1.2.3.4"',
'invalid IPv6 address "2.3.4.5"',
], ctx.exception.reasons)
# invalid ip addresses
with self.assertRaises(ValidationError) as ctx:
@@ -1081,8 +1086,8 @@ class TestRecordValidation(TestCase):
'values': ['hello', 'goodbye']
})
self.assertEquals([
'invalid ip address "hello"',
'invalid ip address "goodbye"'
'invalid IPv6 address "hello"',
'invalid IPv6 address "goodbye"'
], ctx.exception.reasons)
def test_ALIAS_and_value_mixin(self):
@@ -1126,7 +1131,8 @@ class TestRecordValidation(TestCase):
'ttl': 600,
'value': 'foo.bar.com',
})
self.assertEquals(['missing trailing .'], ctx.exception.reasons)
self.assertEquals(['ALIAS value "foo.bar.com" missing trailing .'],
ctx.exception.reasons)
def test_CAA(self):
# doesn't blow up
@@ -1221,7 +1227,8 @@ class TestRecordValidation(TestCase):
'ttl': 600,
'value': 'foo.bar.com',
})
self.assertEquals(['missing trailing .'], ctx.exception.reasons)
self.assertEquals(['CNAME value "foo.bar.com" missing trailing .'],
ctx.exception.reasons)
def test_MX(self):
# doesn't blow up
@@ -1278,7 +1285,8 @@ class TestRecordValidation(TestCase):
'exchange': 'foo.bar.com'
}
})
self.assertEquals(['missing trailing .'], ctx.exception.reasons)
self.assertEquals(['MX value "foo.bar.com" missing trailing .'],
ctx.exception.reasons)
def test_NXPTR(self):
# doesn't blow up
@@ -1375,7 +1383,8 @@ class TestRecordValidation(TestCase):
'ttl': 600,
'value': 'foo.bar',
})
self.assertEquals(['missing trailing .'], ctx.exception.reasons)
self.assertEquals(['NS value "foo.bar" missing trailing .'],
ctx.exception.reasons)
def test_PTR(self):
# doesn't blow up (name & zone here don't make any sense, but not
@@ -1401,7 +1410,8 @@ class TestRecordValidation(TestCase):
'ttl': 600,
'value': 'foo.bar',
})
self.assertEquals(['missing trailing .'], ctx.exception.reasons)
self.assertEquals(['PTR value "foo.bar" missing trailing .'],
ctx.exception.reasons)
def test_SSHFP(self):
# doesn't blow up
@@ -1534,7 +1544,8 @@ class TestRecordValidation(TestCase):
'ttl': 600,
'value': 'this has some; semi-colons\\; in it',
})
self.assertEquals(['unescaped ;'], ctx.exception.reasons)
self.assertEquals(['unescaped ; in "this has some; '
'semi-colons\\; in it"'], ctx.exception.reasons)
def test_SRV(self):
# doesn't blow up
@@ -1666,7 +1677,7 @@ class TestRecordValidation(TestCase):
'target': 'foo.bar.baz'
}
})
self.assertEquals(['missing trailing .'],
self.assertEquals(['SRV value "foo.bar.baz" missing trailing .'],
ctx.exception.reasons)
def test_TXT(self):
@@ -1696,7 +1707,8 @@ class TestRecordValidation(TestCase):
'ttl': 600,
'value': 'this has some; semi-colons\\; in it',
})
self.assertEquals(['unescaped ;'], ctx.exception.reasons)
self.assertEquals(['unescaped ; in "this has some; semi-colons\\; '
'in it"'], ctx.exception.reasons)
def test_TXT_long_value_chunking(self):
expected = '"Lorem ipsum dolor sit amet, consectetur adipiscing ' \
@@ -1757,3 +1769,66 @@ class TestRecordValidation(TestCase):
self.assertEquals(single.values, chunked.values)
# should be chunked values, with quoting
self.assertEquals(single.chunked_values, chunked.chunked_values)
class TestDynamicRecords(TestCase):
zone = Zone('unit.tests.', [])
def test_simple_a_weighted(self):
a_data = {
'dynamic': {
'pools': {
'one': '3.3.3.3',
'two': [
'4.4.4.4',
'5.5.5.5',
],
},
'rules': [{
'pools': {
100: 'one',
200: 'two',
}
}],
},
'ttl': 60,
'values': [
'1.1.1.1',
'2.2.2.2',
],
}
a = ARecord(self.zone, 'weighted', a_data)
self.assertEquals('A', a._type)
self.assertEquals(a_data['ttl'], a.ttl)
self.assertEquals(a_data['values'], a.values)
self.assertEquals(a_data['dynamic'], a.dynamic)
def test_a_validation(self):
a_data = {
'dynamic': {
'pools': {
'one': 'this-aint-right',
'two': [
'4.4.4.4',
'nor-is-this',
],
},
'rules': [{
'pools': {
100: '5.5.5.5',
200: '6.6.6.6',
}
}],
},
'ttl': 60,
'type': 'A',
'values': [
'1.1.1.1',
'2.2.2.2',
],
}
with self.assertRaises(ValidationError) as ctx:
Record.new(self.zone, 'bad', a_data)
self.assertEquals(['invalid IPv4 address "nor-is-this"',
'invalid IPv4 address "this-aint-right"'],
ctx.exception.reasons)