diff --git a/octodns/equality.py b/octodns/equality.py new file mode 100644 index 0000000..965b8a0 --- /dev/null +++ b/octodns/equality.py @@ -0,0 +1,30 @@ +# +# +# + +from __future__ import absolute_import, division, print_function, \ + unicode_literals + + +class EqualityTupleMixin: + + def _equality_tuple(self): + raise NotImplementedError('_equality_tuple method not implemented') + + def __eq__(self, other): + return self._equality_tuple() == other._equality_tuple() + + def __ne__(self, other): + return self._equality_tuple() != other._equality_tuple() + + def __lt__(self, other): + return self._equality_tuple() < other._equality_tuple() + + def __le__(self, other): + return self._equality_tuple() <= other._equality_tuple() + + def __gt__(self, other): + return self._equality_tuple() > other._equality_tuple() + + def __ge__(self, other): + return self._equality_tuple() >= other._equality_tuple() diff --git a/octodns/provider/route53.py b/octodns/provider/route53.py index 968d8b8..66da6b5 100644 --- a/octodns/provider/route53.py +++ b/octodns/provider/route53.py @@ -16,6 +16,7 @@ import re from six import text_type +from ..equality import EqualityTupleMixin from ..record import Record, Update from ..record.geo import GeoCodes from .base import BaseProvider @@ -29,7 +30,7 @@ def _octal_replace(s): return octal_re.sub(lambda m: chr(int(m.group(1), 8)), s) -class _Route53Record(object): +class _Route53Record(EqualityTupleMixin): @classmethod def _new_dynamic(cls, provider, record, hosted_zone_id, creating): @@ -157,29 +158,10 @@ class _Route53Record(object): return '{}:{}'.format(self.fqdn, self._type).__hash__() def _equality_tuple(self): + '''Sub-classes should call up to this and return its value and add + any additional fields they need to hav considered.''' return (self.__class__.__name__, self.fqdn, self._type) - def __eq__(self, other): - '''Sub-classes should call up to this and return its value if true. - When it's false they should compute their own __eq__, same for other - ordering methods.''' - return self._equality_tuple() == other._equality_tuple() - - def __ne__(self, other): - return self._equality_tuple() != other._equality_tuple() - - def __lt__(self, other): - return self._equality_tuple() < other._equality_tuple() - - def __le__(self, other): - return self._equality_tuple() <= other._equality_tuple() - - def __gt__(self, other): - return self._equality_tuple() > other._equality_tuple() - - def __ge__(self, other): - return self._equality_tuple() >= other._equality_tuple() - def __repr__(self): return '_Route53Record<{} {} {} {}>'.format(self.fqdn, self._type, self.ttl, self.values) diff --git a/octodns/record/__init__.py b/octodns/record/__init__.py index c0f6482..0847018 100644 --- a/octodns/record/__init__.py +++ b/octodns/record/__init__.py @@ -11,6 +11,7 @@ import re from six import string_types, text_type +from ..equality import EqualityTupleMixin from .geo import GeoCodes @@ -76,7 +77,7 @@ class ValidationError(Exception): self.reasons = reasons -class Record(object): +class Record(EqualityTupleMixin): log = getLogger('Record') @classmethod @@ -209,30 +210,15 @@ class Record(object): def __hash__(self): return '{}:{}'.format(self.name, self._type).__hash__() - def __eq__(self, other): - return ((self.name, self._type) == (other.name, other._type)) - - def __ne__(self, other): - return ((self.name, self._type) != (other.name, other._type)) - - def __lt__(self, other): - return ((self.name, self._type) < (other.name, other._type)) - - def __le__(self, other): - return ((self.name, self._type) <= (other.name, other._type)) - - def __gt__(self, other): - return ((self.name, self._type) > (other.name, other._type)) - - def __ge__(self, other): - return ((self.name, self._type) >= (other.name, other._type)) + def _equality_tuple(self): + return (self.name, self._type) def __repr__(self): # Make sure this is always overridden raise NotImplementedError('Abstract base class, __repr__ required') -class GeoValue(object): +class GeoValue(EqualityTupleMixin): geo_re = re.compile(r'^(?P\w\w)(-(?P\w\w)' r'(-(?P\w\w))?)?$') @@ -259,35 +245,9 @@ class GeoValue(object): yield '-'.join(bits) bits.pop() - def __eq__(self, other): - return ((self.continent_code, self.country_code, self.subdivision_code, - self.values) == (other.continent_code, other.country_code, - other.subdivision_code, other.values)) - - def __ne__(self, other): - return ((self.continent_code, self.country_code, self.subdivision_code, - self.values) != (other.continent_code, other.country_code, - other.subdivision_code, other.values)) - - def __lt__(self, other): - return ((self.continent_code, self.country_code, self.subdivision_code, - self.values) < (other.continent_code, other.country_code, - other.subdivision_code, other.values)) - - def __le__(self, other): - return ((self.continent_code, self.country_code, self.subdivision_code, - self.values) <= (other.continent_code, other.country_code, - other.subdivision_code, other.values)) - - def __gt__(self, other): - return ((self.continent_code, self.country_code, self.subdivision_code, - self.values) > (other.continent_code, other.country_code, - other.subdivision_code, other.values)) - - def __ge__(self, other): - return ((self.continent_code, self.country_code, self.subdivision_code, - self.values) >= (other.continent_code, other.country_code, - other.subdivision_code, other.values)) + def _equality_tuple(self): + return (self.continent_code, self.country_code, self.subdivision_code, + self.values) def __repr__(self): return "'Geo {} {} {} {}'".format(self.continent_code, @@ -787,7 +747,7 @@ class AliasRecord(_ValueMixin, Record): _value_type = AliasValue -class CaaValue(object): +class CaaValue(EqualityTupleMixin): # https://tools.ietf.org/html/rfc6844#page-5 @classmethod @@ -826,29 +786,8 @@ class CaaValue(object): 'value': self.value, } - def __eq__(self, other): - return ((self.flags, self.tag, self.value) == - (other.flags, other.tag, other.value)) - - def __ne__(self, other): - return ((self.flags, self.tag, self.value) != - (other.flags, other.tag, other.value)) - - def __lt__(self, other): - return ((self.flags, self.tag, self.value) < - (other.flags, other.tag, other.value)) - - def __le__(self, other): - return ((self.flags, self.tag, self.value) <= - (other.flags, other.tag, other.value)) - - def __gt__(self, other): - return ((self.flags, self.tag, self.value) > - (other.flags, other.tag, other.value)) - - def __ge__(self, other): - return ((self.flags, self.tag, self.value) >= - (other.flags, other.tag, other.value)) + def _equality_tuple(self): + return (self.flags, self.tag, self.value) def __repr__(self): return '{} {} "{}"'.format(self.flags, self.tag, self.value) @@ -872,7 +811,7 @@ class CnameRecord(_DynamicMixin, _ValueMixin, Record): return reasons -class MxValue(object): +class MxValue(EqualityTupleMixin): @classmethod def validate(cls, data, _type): @@ -928,29 +867,8 @@ class MxValue(object): def __hash__(self): return hash((self.preference, self.exchange)) - def __eq__(self, other): - return ((self.preference, self.exchange) == - (other.preference, other.exchange)) - - def __ne__(self, other): - return ((self.preference, self.exchange) != - (other.preference, other.exchange)) - - def __lt__(self, other): - return ((self.preference, self.exchange) < - (other.preference, other.exchange)) - - def __le__(self, other): - return ((self.preference, self.exchange) <= - (other.preference, other.exchange)) - - def __gt__(self, other): - return ((self.preference, self.exchange) > - (other.preference, other.exchange)) - - def __ge__(self, other): - return ((self.preference, self.exchange) >= - (other.preference, other.exchange)) + def _equality_tuple(self): + return (self.preference, self.exchange) def __repr__(self): return "'{} {}'".format(self.preference, self.exchange) @@ -961,7 +879,7 @@ class MxRecord(_ValuesMixin, Record): _value_type = MxValue -class NaptrValue(object): +class NaptrValue(EqualityTupleMixin): VALID_FLAGS = ('S', 'A', 'U', 'P') @classmethod @@ -1023,41 +941,9 @@ class NaptrValue(object): def __hash__(self): return hash(self.__repr__()) - def __eq__(self, other): - return ((self.order, self.preference, self.flags, self.service, - self.regexp, self.replacement) == - (other.order, other.preference, other.flags, other.service, - other.regexp, other.replacement)) - - def __ne__(self, other): - return ((self.order, self.preference, self.flags, self.service, - self.regexp, self.replacement) != - (other.order, other.preference, other.flags, other.service, - other.regexp, other.replacement)) - - def __lt__(self, other): - return ((self.order, self.preference, self.flags, self.service, - self.regexp, self.replacement) < - (other.order, other.preference, other.flags, other.service, - other.regexp, other.replacement)) - - def __le__(self, other): - return ((self.order, self.preference, self.flags, self.service, - self.regexp, self.replacement) <= - (other.order, other.preference, other.flags, other.service, - other.regexp, other.replacement)) - - def __gt__(self, other): - return ((self.order, self.preference, self.flags, self.service, - self.regexp, self.replacement) > - (other.order, other.preference, other.flags, other.service, - other.regexp, other.replacement)) - - def __ge__(self, other): - return ((self.order, self.preference, self.flags, self.service, - self.regexp, self.replacement) >= - (other.order, other.preference, other.flags, other.service, - other.regexp, other.replacement)) + def _equality_tuple(self): + return (self.order, self.preference, self.flags, self.service, + self.regexp, self.replacement) def __repr__(self): flags = self.flags if self.flags is not None else '' @@ -1107,7 +993,7 @@ class PtrRecord(_ValueMixin, Record): _value_type = PtrValue -class SshfpValue(object): +class SshfpValue(EqualityTupleMixin): VALID_ALGORITHMS = (1, 2, 3, 4) VALID_FINGERPRINT_TYPES = (1, 2) @@ -1161,29 +1047,8 @@ class SshfpValue(object): def __hash__(self): return hash(self.__repr__()) - def __eq__(self, other): - return ((self.algorithm, self.fingerprint_type, self.fingerprint) == - (other.algorithm, other.fingerprint_type, other.fingerprint)) - - def __ne__(self, other): - return ((self.algorithm, self.fingerprint_type, self.fingerprint) != - (other.algorithm, other.fingerprint_type, other.fingerprint)) - - def __lt__(self, other): - return ((self.algorithm, self.fingerprint_type, self.fingerprint) < - (other.algorithm, other.fingerprint_type, other.fingerprint)) - - def __le__(self, other): - return ((self.algorithm, self.fingerprint_type, self.fingerprint) <= - (other.algorithm, other.fingerprint_type, other.fingerprint)) - - def __gt__(self, other): - return ((self.algorithm, self.fingerprint_type, self.fingerprint) > - (other.algorithm, other.fingerprint_type, other.fingerprint)) - - def __ge__(self, other): - return ((self.algorithm, self.fingerprint_type, self.fingerprint) >= - (other.algorithm, other.fingerprint_type, other.fingerprint)) + def _equality_tuple(self): + return (self.algorithm, self.fingerprint_type, self.fingerprint) def __repr__(self): return "'{} {} {}'".format(self.algorithm, self.fingerprint_type, @@ -1244,7 +1109,7 @@ class SpfRecord(_ChunkedValuesMixin, Record): _value_type = _ChunkedValue -class SrvValue(object): +class SrvValue(EqualityTupleMixin): @classmethod def validate(cls, data, _type): @@ -1302,29 +1167,8 @@ class SrvValue(object): def __hash__(self): return hash(self.__repr__()) - def __eq__(self, other): - return ((self.priority, self.weight, self.port, self.target) == - (other.priority, other.weight, other.port, other.target)) - - def __ne__(self, other): - return ((self.priority, self.weight, self.port, self.target) != - (other.priority, other.weight, other.port, other.target)) - - def __lt__(self, other): - return ((self.priority, self.weight, self.port, self.target) < - (other.priority, other.weight, other.port, other.target)) - - def __le__(self, other): - return ((self.priority, self.weight, self.port, self.target) <= - (other.priority, other.weight, other.port, other.target)) - - def __gt__(self, other): - return ((self.priority, self.weight, self.port, self.target) > - (other.priority, other.weight, other.port, other.target)) - - def __ge__(self, other): - return ((self.priority, self.weight, self.port, self.target) >= - (other.priority, other.weight, other.port, other.target)) + def _equality_tuple(self): + return (self.priority, self.weight, self.port, self.target) def __repr__(self): return "'{} {} {} {}'".format(self.priority, self.weight, self.port, diff --git a/tests/test_octodns_equality.py b/tests/test_octodns_equality.py new file mode 100644 index 0000000..dcdc460 --- /dev/null +++ b/tests/test_octodns_equality.py @@ -0,0 +1,68 @@ +# +# +# + +from __future__ import absolute_import, division, print_function, \ + unicode_literals + +from unittest import TestCase + +from octodns.equality import EqualityTupleMixin + + +class TestEqualityTupleMixin(TestCase): + + def test_basics(self): + + class Simple(EqualityTupleMixin): + + def __init__(self, a, b, c): + self.a = a + self.b = b + self.c = c + + def _equality_tuple(self): + return (self.a, self.b) + + one = Simple(1, 2, 3) + same = Simple(1, 2, 3) + matches = Simple(1, 2, 'ignored') + doesnt = Simple(2, 3, 4) + + # equality + self.assertEquals(one, one) + self.assertEquals(one, same) + self.assertEquals(same, one) + # only a & c are considered + self.assertEquals(one, matches) + self.assertEquals(matches, one) + self.assertNotEquals(one, doesnt) + self.assertNotEquals(doesnt, one) + + # lt + self.assertTrue(one < doesnt) + self.assertFalse(doesnt < one) + self.assertFalse(one < same) + + # le + self.assertTrue(one <= doesnt) + self.assertFalse(doesnt <= one) + self.assertTrue(one <= same) + + # gt + self.assertFalse(one > doesnt) + self.assertTrue(doesnt > one) + self.assertFalse(one > same) + + # ge + self.assertFalse(one >= doesnt) + self.assertTrue(doesnt >= one) + self.assertTrue(one >= same) + + def test_not_implemented(self): + + class MissingMethod(EqualityTupleMixin): + pass + + with self.assertRaises(NotImplementedError): + MissingMethod() == MissingMethod()