diff --git a/octodns/provider/azuredns.py b/octodns/provider/azuredns.py index 6a2db41..cf0e2d5 100644 --- a/octodns/provider/azuredns.py +++ b/octodns/provider/azuredns.py @@ -128,6 +128,8 @@ class _AzureRecord(object): string = 'Zone: {}; '.format(self.zone_name) string += 'Name: {}; '.format(self.relative_record_set_name) string += 'Type: {}; '.format(self.record_type) + if not hasattr(self, 'params'): + return string string += 'Ttl: {}; '.format(self.params['ttl']) for char in self.params: if char != 'ttl': @@ -139,7 +141,7 @@ class _AzureRecord(object): return string -def period_validate(string): +def _validate_per(string): return string if string.endswith('.') else string + '.' @@ -209,20 +211,18 @@ class AzureProvider(BaseProvider): try: if name in self._azure_zones: return name - if self._dns_client.zones.get(self._resource_group, name): - self._azure_zones.add(name) - return name - except: # TODO: figure out what location should be + self._dns_client.zones.get(self._resource_group, name) + self._azure_zones.add(name) + return name + except: if create: - try: - self.log.debug('_check_zone:no matching zone; creating %s', - name) - create_zone = self._dns_client.zones.create_or_update - if create_zone(self._resource_group, name, Zone('global')): - return name - except: - raise - return None + self.log.debug('_check_zone:no matching zone; creating %s', + name) + create_zone = self._dns_client.zones.create_or_update + create_zone(self._resource_group, name, Zone('global')) + return name + else: + raise def populate(self, zone, target=False, lenient=False): ''' @@ -254,17 +254,21 @@ class AzureProvider(BaseProvider): before = len(zone.records) self._populate_zones() - if self._check_zone(zone_name): - for typ in self.SUPPORTS: - records = self._dns_client.record_sets.list_by_type - for azrecord in records(self._resource_group, zone_name, typ): - record_name = azrecord.name if azrecord.name != '@' else '' - data = getattr(self, '_data_for_{}'.format(typ))(azrecord) - data['type'] = typ - data['ttl'] = azrecord.ttl + self._check_zone(zone_name) - record = Record.new(zone, record_name, data, source=self) - zone.add_record(record) + _records = set() + records = self._dns_client.record_sets.list_by_dns_zone + for azrecord in records(self._resource_group, zone_name): + if azrecord.type in self.SUPPORTS: + _records.add(azrecord) + for azrecord in _records: + record_name = azrecord.name if azrecord.name != '@' else '' + data = getattr(self, '_data_for_{}'.format(azrecord.type)) + data = data(azrecord) + data['type'] = azrecord.type + data['ttl'] = azrecord.ttl + record = Record.new(zone, record_name, data, source=self) + zone.add_record(record) self.log.info('populate: found %s records', len(zone.records) - before) @@ -289,13 +293,13 @@ class AzureProvider(BaseProvider): records. Refer to population comment. ''' try: - return {'value': period_validate(azrecord.cname_record.cname)} + return {'value': _validate_per(azrecord.cname_record.cname)} except: return {'value': '.'} def _data_for_PTR(self, azrecord): try: - return {'value': period_validate(azrecord.ptr_records[0].ptdrname)} + return {'value': _validate_per(azrecord.ptr_records[0].ptdrname)} except: return {'value': '.'} @@ -311,7 +315,7 @@ class AzureProvider(BaseProvider): def _data_for_NS(self, azrecord): vals = [ar.nsdname for ar in azrecord.ns_records] - return {'values': [period_validate(val) for val in vals]} + return {'values': [_validate_per(val) for val in vals]} def _apply_Create(self, change): '''A record from change must be created. @@ -330,7 +334,7 @@ class AzureProvider(BaseProvider): record_type=ar.record_type, parameters=ar.params) - print('* Success Create/Update: {}'.format(ar), file=sys.stderr) + self.log.debug('* Success Create/Update: {}'.format(ar)) _apply_Update = _apply_Create @@ -341,7 +345,7 @@ class AzureProvider(BaseProvider): delete(self._resource_group, ar.zone_name, ar.relative_record_set_name, ar.record_type) - print('* Success Delete: {}'.format(ar), file=sys.stderr) + self.log.debug('* Success Delete: {}'.format(ar)) def _apply(self, plan): ''' diff --git a/tests/test_octodns_provider_azuredns.py b/tests/test_octodns_provider_azuredns.py index db4d293..edb2db2 100644 --- a/tests/test_octodns_provider_azuredns.py +++ b/tests/test_octodns_provider_azuredns.py @@ -5,167 +5,321 @@ from __future__ import absolute_import, division, print_function, \ unicode_literals -from octodns.record import Create, Delete, Record, Update -from octodns.provider.azuredns import _AzureRecord, AzureProvider +from octodns.record import Create, Delete, Record +from octodns.provider.azuredns import _AzureRecord, AzureProvider, \ + _validate_per from octodns.zone import Zone +from octodns.provider.base import Plan from azure.mgmt.dns.models import ARecord, AaaaRecord, CnameRecord, MxRecord, \ - SrvRecord, NsRecord, PtrRecord, TxtRecord, Zone as AzureZone + SrvRecord, NsRecord, PtrRecord, TxtRecord, RecordSet, SoaRecord, \ + Zone as AzureZone +from msrestazure.azure_exceptions import CloudError -from octodns.zone import Zone from unittest import TestCase -import sys +from mock import Mock, patch + + +zone = Zone(name='unit.tests.', sub_zones=[]) +octo_records = [] +octo_records.append(Record.new(zone, '', { + 'ttl': 0, + 'type': 'A', + 'values': ['1.2.3.4', '10.10.10.10']})) +octo_records.append(Record.new(zone, 'a', { + 'ttl': 1, + 'type': 'A', + 'values': ['1.2.3.4', '1.1.1.1']})) +octo_records.append(Record.new(zone, 'aa', { + 'ttl': 9001, + 'type': 'A', + 'values': ['1.2.4.3']})) +octo_records.append(Record.new(zone, 'aaa', { + 'ttl': 2, + 'type': 'A', + 'values': ['1.1.1.3']})) +octo_records.append(Record.new(zone, 'cname', { + 'ttl': 3, + 'type': 'CNAME', + 'value': 'a.unit.tests.'})) +octo_records.append(Record.new(zone, 'mx1', { + 'ttl': 3, + 'type': 'MX', + 'values': [{ + 'priority': 10, + 'value': 'mx1.unit.tests.', + }, { + 'priority': 20, + 'value': 'mx2.unit.tests.', + }]})) +octo_records.append(Record.new(zone, 'mx2', { + 'ttl': 3, + 'type': 'MX', + 'values': [{ + 'priority': 10, + 'value': 'mx1.unit.tests.', + }]})) +octo_records.append(Record.new(zone, '', { + 'ttl': 4, + 'type': 'NS', + 'values': ['ns1.unit.tests.', 'ns2.unit.tests.']})) +octo_records.append(Record.new(zone, 'foo', { + 'ttl': 5, + 'type': 'NS', + 'value': 'ns1.unit.tests.'})) +octo_records.append(Record.new(zone, '_srv._tcp', { + 'ttl': 6, + 'type': 'SRV', + 'values': [{ + 'priority': 10, + 'weight': 20, + 'port': 30, + 'target': 'foo-1.unit.tests.', + }, { + 'priority': 12, + 'weight': 30, + 'port': 30, + 'target': 'foo-2.unit.tests.', + }]})) +octo_records.append(Record.new(zone, '_srv2._tcp', { + 'ttl': 7, + 'type': 'SRV', + 'values': [{ + 'priority': 12, + 'weight': 17, + 'port': 1, + 'target': 'srvfoo.unit.tests.', + }]})) + +azure_records = [] +_base0 = _AzureRecord('TestAzure', octo_records[0]) +_base0.zone_name = 'unit.tests' +_base0.relative_record_set_name = '@' +_base0.record_type = 'A' +_base0.params['ttl'] = 0 +_base0.params['arecords'] = [ARecord('1.2.3.4'), ARecord('10.10.10.10')] +azure_records.append(_base0) + +_base1 = _AzureRecord('TestAzure', octo_records[1]) +_base1.zone_name = 'unit.tests' +_base1.relative_record_set_name = 'a' +_base1.record_type = 'A' +_base1.params['ttl'] = 1 +_base1.params['arecords'] = [ARecord('1.2.3.4'), ARecord('1.1.1.1')] +azure_records.append(_base1) + +_base2 = _AzureRecord('TestAzure', octo_records[2]) +_base2.zone_name = 'unit.tests' +_base2.relative_record_set_name = 'aa' +_base2.record_type = 'A' +_base2.params['ttl'] = 9001 +_base2.params['arecords'] = ARecord('1.2.4.3') +azure_records.append(_base2) + +_base3 = _AzureRecord('TestAzure', octo_records[3]) +_base3.zone_name = 'unit.tests' +_base3.relative_record_set_name = 'aaa' +_base3.record_type = 'A' +_base3.params['ttl'] = 2 +_base3.params['arecords'] = ARecord('1.1.1.3') +azure_records.append(_base3) + +_base4 = _AzureRecord('TestAzure', octo_records[4]) +_base4.zone_name = 'unit.tests' +_base4.relative_record_set_name = 'cname' +_base4.record_type = 'CNAME' +_base4.params['ttl'] = 3 +_base4.params['cname_record'] = CnameRecord('a.unit.tests.') +azure_records.append(_base4) + +_base5 = _AzureRecord('TestAzure', octo_records[5]) +_base5.zone_name = 'unit.tests' +_base5.relative_record_set_name = 'mx1' +_base5.record_type = 'MX' +_base5.params['ttl'] = 3 +_base5.params['mx_records'] = [MxRecord(10, 'mx1.unit.tests.'), + MxRecord(20, 'mx2.unit.tests.')] +azure_records.append(_base5) + +_base6 = _AzureRecord('TestAzure', octo_records[6]) +_base6.zone_name = 'unit.tests' +_base6.relative_record_set_name = 'mx2' +_base6.record_type = 'MX' +_base6.params['ttl'] = 3 +_base6.params['mx_records'] = [MxRecord(10, 'mx1.unit.tests.')] +azure_records.append(_base6) + +_base7 = _AzureRecord('TestAzure', octo_records[7]) +_base7.zone_name = 'unit.tests' +_base7.relative_record_set_name = '@' +_base7.record_type = 'NS' +_base7.params['ttl'] = 4 +_base7.params['ns_records'] = [NsRecord('ns1.unit.tests.'), + NsRecord('ns2.unit.tests.')] +azure_records.append(_base7) + +_base8 = _AzureRecord('TestAzure', octo_records[8]) +_base8.zone_name = 'unit.tests' +_base8.relative_record_set_name = 'foo' +_base8.record_type = 'NS' +_base8.params['ttl'] = 5 +_base8.params['ns_records'] = [NsRecord('ns1.unit.tests.')] +azure_records.append(_base8) + +_base9 = _AzureRecord('TestAzure', octo_records[9]) +_base9.zone_name = 'unit.tests' +_base9.relative_record_set_name = '_srv._tcp' +_base9.record_type = 'SRV' +_base9.params['ttl'] = 6 +_base9.params['srv_records'] = [SrvRecord(10, 20, 30, 'foo-1.unit.tests.'), + SrvRecord(12, 30, 30, 'foo-2.unit.tests.')] +azure_records.append(_base9) + +_base10 = _AzureRecord('TestAzure', octo_records[10]) +_base10.zone_name = 'unit.tests' +_base10.relative_record_set_name = '_srv2._tcp' +_base10.record_type = 'SRV' +_base10.params['ttl'] = 7 +_base10.params['srv_records'] = [SrvRecord(12, 17, 1, 'srvfoo.unit.tests.')] +azure_records.append(_base10) class Test_AzureRecord(TestCase): - zone = Zone(name='unit.tests.', sub_zones=[]) - octo_records = [] - octo_records.append(Record.new(zone, '', { - 'ttl': 0, - 'type': 'A', - 'values': ['1.2.3.4', '10.10.10.10'] - })) - octo_records.append(Record.new(zone, 'a', { - 'ttl': 1, - 'type': 'A', - 'values': ['1.2.3.4', '1.1.1.1'], - })) - octo_records.append(Record.new(zone, 'aa', { - 'ttl': 9001, - 'type': 'A', - 'values': ['1.2.4.3'] - })) - octo_records.append(Record.new(zone, 'aaa', { - 'ttl': 2, - 'type': 'A', - 'values': ['1.1.1.3'] - })) - octo_records.append(Record.new(zone, 'cname', { - 'ttl': 3, - 'type': 'CNAME', - 'value': 'a.unit.tests.', - })) - octo_records.append(Record.new(zone, '', { - 'ttl': 3, - 'type': 'MX', - 'values': [{ - 'priority': 10, - 'value': 'mx1.unit.tests.', - }, { - 'priority': 20, - 'value': 'mx2.unit.tests.', - }] - })) - octo_records.append(Record.new(zone, '', { - 'ttl': 4, - 'type': 'NS', - 'values': ['ns1.unit.tests.', 'ns2.unit.tests.'], - })) - octo_records.append(Record.new(zone, '', { - 'ttl': 5, - 'type': 'NS', - 'value': 'ns1.unit.tests.', - })) - octo_records.append(Record.new(zone, '_srv._tcp', { - 'ttl': 6, - 'type': 'SRV', - 'values': [{ - 'priority': 10, - 'weight': 20, - 'port': 30, - 'target': 'foo-1.unit.tests.', - }, { - 'priority': 12, - 'weight': 30, - 'port': 30, - 'target': 'foo-2.unit.tests.', - }] - })) - - azure_records = [] - _base0 = _AzureRecord('TestAzure', octo_records[0]) - _base0.zone_name = 'unit.tests' - _base0.relative_record_set_name = '@' - _base0.record_type = 'A' - _base0.params['ttl'] = 0 - _base0.params['arecords'] = [ARecord('1.2.3.4'), ARecord('10.10.10.10')] - azure_records.append(_base0) - - _base1 = _AzureRecord('TestAzure', octo_records[1]) - _base1.zone_name = 'unit.tests' - _base1.relative_record_set_name = 'a' - _base1.record_type = 'A' - _base1.params['ttl'] = 1 - _base1.params['arecords'] = [ARecord('1.2.3.4'), ARecord('1.1.1.1')] - azure_records.append(_base1) - - _base2 = _AzureRecord('TestAzure', octo_records[2]) - _base2.zone_name = 'unit.tests' - _base2.relative_record_set_name = 'aa' - _base2.record_type = 'A' - _base2.params['ttl'] = 9001 - _base2.params['arecords'] = ARecord('1.2.4.3') - azure_records.append(_base2) - - _base3 = _AzureRecord('TestAzure', octo_records[3]) - _base3.zone_name = 'unit.tests' - _base3.relative_record_set_name = 'aaa' - _base3.record_type = 'A' - _base3.params['ttl'] = 2 - _base3.params['arecords'] = ARecord('1.1.1.3') - azure_records.append(_base3) - - _base4 = _AzureRecord('TestAzure', octo_records[4]) - _base4.zone_name = 'unit.tests' - _base4.relative_record_set_name = 'cname' - _base4.record_type = 'CNAME' - _base4.params['ttl'] = 3 - _base4.params['cname_record'] = CnameRecord('a.unit.tests.') - azure_records.append(_base4) - - _base5 = _AzureRecord('TestAzure', octo_records[5]) - _base5.zone_name = 'unit.tests' - _base5.relative_record_set_name = '@' - _base5.record_type = 'MX' - _base5.params['ttl'] = 3 - _base5.params['mx_records'] = [MxRecord(10, 'mx1.unit.tests.'), - MxRecord(20, 'mx2.unit.tests.')] - azure_records.append(_base5) - - _base6 = _AzureRecord('TestAzure', octo_records[6]) - _base6.zone_name = 'unit.tests' - _base6.relative_record_set_name = '@' - _base6.record_type = 'NS' - _base6.params['ttl'] = 4 - _base6.params['ns_records'] = [NsRecord('ns1.unit.tests.'), - NsRecord('ns2.unit.tests.')] - azure_records.append(_base6) - - _base7 = _AzureRecord('TestAzure', octo_records[7]) - _base7.zone_name = 'unit.tests' - _base7.relative_record_set_name = '@' - _base7.record_type = 'NS' - _base7.params['ttl'] = 5 - _base7.params['ns_records'] = [NsRecord('ns1.unit.tests.')] - azure_records.append(_base7) - - _base8 = _AzureRecord('TestAzure', octo_records[8]) - _base8.zone_name = 'unit.tests' - _base8.relative_record_set_name = '_srv._tcp' - _base8.record_type = 'SRV' - _base8.params['ttl'] = 6 - _base8.params['srv_records'] = [SrvRecord(10, 20, 30, 'foo-1.unit.tests.'), - SrvRecord(12, 30, 30, 'foo-2.unit.tests.')] - azure_records.append(_base8) - def test_azure_record(self): - assert(len(self.azure_records) == len(self.octo_records)) - for i in range(len(self.azure_records)): - octo = _AzureRecord('TestAzure', self.octo_records[i]) - assert(self.azure_records[i]._equals(octo)) + assert(len(azure_records) == len(octo_records)) + for i in range(len(azure_records)): + octo = _AzureRecord('TestAzure', octo_records[i]) + assert(azure_records[i]._equals(octo)) + string = str(azure_records[i]) + assert(('Ttl: ' in string)) + + +class TestValidatePeriod(TestCase): + def test_validate_per(self): + for expected, test in [['a.', 'a'], + ['a.', 'a.'], + ['foo.bar.', 'foo.bar.'], + ['foo.bar.', 'foo.bar']]: + self.assertEquals(expected, _validate_per(test)) class TestAzureDnsProvider(TestCase): - def test_populate(self): - pass # placeholder + def _provider(self): + return self._get_provider('mock_spc', 'mock_dns_client') + + @patch('octodns.provider.azuredns.DnsManagementClient') + @patch('octodns.provider.azuredns.ServicePrincipalCredentials') + def _get_provider(self, mock_spc, mock_dns_client): + '''Returns a mock AzureProvider object to use in testing. + + :param mock_spc: placeholder + :type mock_spc: str + :param mock_dns_client: placeholder + :type mock_dns_client: str + + :type return: AzureProvider + ''' + return AzureProvider('mock_id', 'mock_client', 'mock_key', + 'mock_directory', 'mock_sub', 'mock_rg') + + def test_populate_records(self): + provider = self._get_provider() + + rs = [] + rs.append(RecordSet(name='a1', ttl=0, type='A', + arecords=[ARecord('1.1.1.1')])) + rs.append(RecordSet(name='a2', ttl=1, type='A', + arecords=[ARecord('1.1.1.1'), + ARecord('2.2.2.2')])) + rs.append(RecordSet(name='aaaa1', ttl=2, type='AAAA', + aaaa_records=[AaaaRecord('1:1ec:1::1')])) + rs.append(RecordSet(name='aaaa2', ttl=3, type='AAAA', + aaaa_records=[AaaaRecord('1:1ec:1::1'), + AaaaRecord('1:1ec:1::2')])) + rs.append(RecordSet(name='cname1', ttl=4, type='CNAME', + cname_record=CnameRecord('cname.unit.test.'))) + rs.append(RecordSet(name='cname2', ttl=5, type='CNAME', + cname_record=None)) + rs.append(RecordSet(name='mx1', ttl=6, type='MX', + mx_records=[MxRecord(10, 'mx1.unit.test.')])) + rs.append(RecordSet(name='mx2', ttl=7, type='MX', + mx_records=[MxRecord(10, 'mx1.unit.test.'), + MxRecord(11, 'mx2.unit.test.')])) + rs.append(RecordSet(name='ns1', ttl=8, type='NS', + ns_records=[NsRecord('ns1.unit.test.')])) + rs.append(RecordSet(name='ns2', ttl=9, type='NS', + ns_records=[NsRecord('ns1.unit.test.'), + NsRecord('ns2.unit.test.')])) + rs.append(RecordSet(name='ptr1', ttl=10, type='PTR', + ptr_records=[PtrRecord('ptr1.unit.test.')])) + rs.append(RecordSet(name='ptr2', ttl=11, type='PTR', + ptr_records=[PtrRecord('ptr1.unit.test.'), + PtrRecord('ptr2.unit.test.')])) + rs.append(RecordSet(name='_srv1._tcp', ttl=12, type='SRV', + srv_records=[SrvRecord(1, 2, 3, '1unit.tests.')])) + rs.append(RecordSet(name='_srv2._tcp', ttl=13, type='SRV', + srv_records=[SrvRecord(1, 2, 3, '1unit.tests.'), + SrvRecord(4, 5, 6, '2unit.tests.')])) + rs.append(RecordSet(name='txt1', ttl=14, type='TXT', + txt_records=[TxtRecord('sample text1')])) + rs.append(RecordSet(name='txt2', ttl=15, type='TXT', + txt_records=[TxtRecord('sample text1'), + TxtRecord('sample text2')])) + rs.append(RecordSet(name='', ttl=16, type='SOA', + soa_record=[SoaRecord()])) + + record_list = provider._dns_client.record_sets.list_by_dns_zone + record_list.return_value = rs + + provider.populate(zone) + + self.assertEquals(len(zone.records), 16) + + def test_populate_zone(self): + provider = self._get_provider() + + zone_list = provider._dns_client.zones.list_by_resource_group + zone_list.return_value = [AzureZone(location='global'), + AzureZone(location='global')] + + provider._populate_zones() + + self.assertEquals(len(provider._azure_zones), 1) + + def test_bad_zone_response(self): + provider = self._get_provider() + + _get = provider._dns_client.zones.get + _get.side_effect = CloudError(Mock(status=404), 'Azure Error') + trip = False + try: + provider._check_zone('unit.test', create=False) + except CloudError: + trip = True + self.assertEquals(trip, True) + + def test_apply(self): + provider = self._get_provider() + + changes = [] + deletes = [] + for i in octo_records: + changes.append(Create(i)) + deletes.append(Delete(i)) + + self.assertEquals(11, provider.apply(Plan(None, zone, changes))) + self.assertEquals(11, provider.apply(Plan(zone, zone, deletes))) + + def test_create_zone(self): + provider = self._get_provider() + + changes = [] + for i in octo_records: + changes.append(Create(i)) + desired = Zone('unit2.test.', []) + + _get = provider._dns_client.zones.get + _get.side_effect = CloudError(Mock(status=404), 'Azure Error') + + self.assertEquals(11, provider.apply(Plan(None, desired, changes)))