1
0
mirror of https://github.com/github/octodns.git synced 2024-05-11 05:55:00 +00:00

Support dynamic records in Azure DNS

This commit is contained in:
Viranch Mehta
2021-05-11 15:39:00 -07:00
parent 056cec8935
commit d619025040
4 changed files with 1439 additions and 81 deletions

View File

@@ -192,7 +192,7 @@ The above command pulled the existing data out of Route53 and placed the results
| Provider | Requirements | Record Support | Dynamic | Notes | | Provider | Requirements | Record Support | Dynamic | Notes |
|--|--|--|--|--| |--|--|--|--|--|
| [AzureProvider](/octodns/provider/azuredns.py) | azure-identity, azure-mgmt-dns | A, AAAA, CAA, CNAME, MX, NS, PTR, SRV, TXT | No | | | [AzureProvider](/octodns/provider/azuredns.py) | azure-identity, azure-mgmt-dns, azure-mgmt-trafficmanager | A, AAAA, CAA, CNAME, MX, NS, PTR, SRV, TXT | Yes (CNAMEs only) | |
| [Akamai](/octodns/provider/edgedns.py) | edgegrid-python | A, AAAA, CNAME, MX, NAPTR, NS, PTR, SPF, SRV, SSHFP, TXT | No | | | [Akamai](/octodns/provider/edgedns.py) | edgegrid-python | A, AAAA, CNAME, MX, NAPTR, NS, PTR, SPF, SRV, SSHFP, TXT | No | |
| [CloudflareProvider](/octodns/provider/cloudflare.py) | | A, AAAA, ALIAS, CAA, CNAME, LOC, MX, NS, PTR, SPF, SRV, TXT | No | CAA tags restricted | | [CloudflareProvider](/octodns/provider/cloudflare.py) | | A, AAAA, ALIAS, CAA, CNAME, LOC, MX, NS, PTR, SPF, SRV, TXT | No | CAA tags restricted |
| [ConstellixProvider](/octodns/provider/constellix.py) | | A, AAAA, ALIAS (ANAME), CAA, CNAME, MX, NS, PTR, SPF, SRV, TXT | No | CAA tags restricted | | [ConstellixProvider](/octodns/provider/constellix.py) | | A, AAAA, ALIAS (ANAME), CAA, CNAME, MX, NS, PTR, SPF, SRV, TXT | No | CAA tags restricted |

View File

@@ -5,18 +5,28 @@
from __future__ import absolute_import, division, print_function, \ from __future__ import absolute_import, division, print_function, \
unicode_literals unicode_literals
from collections import defaultdict
from azure.identity import ClientSecretCredential from azure.identity import ClientSecretCredential
from azure.common.credentials import ServicePrincipalCredentials
from azure.mgmt.dns import DnsManagementClient from azure.mgmt.dns import DnsManagementClient
from azure.mgmt.trafficmanager import TrafficManagerManagementClient
from azure.mgmt.dns.models import ARecord, AaaaRecord, CaaRecord, \ from azure.mgmt.dns.models import ARecord, AaaaRecord, CaaRecord, \
CnameRecord, MxRecord, SrvRecord, NsRecord, PtrRecord, TxtRecord, Zone CnameRecord, MxRecord, SrvRecord, NsRecord, PtrRecord, TxtRecord, Zone
from azure.mgmt.trafficmanager.models import Profile, DnsConfig, \
MonitorConfig, Endpoint, MonitorConfigCustomHeadersItem
import logging import logging
from functools import reduce from functools import reduce
from ..record import Record from ..record import Record, Update, GeoCodes
from .base import BaseProvider from .base import BaseProvider
class AzureException(Exception):
pass
def escape_semicolon(s): def escape_semicolon(s):
assert s assert s
return s.replace(';', '\\;') return s.replace(';', '\\;')
@@ -67,7 +77,8 @@ class _AzureRecord(object):
'TXT': TxtRecord 'TXT': TxtRecord
} }
def __init__(self, resource_group, record, delete=False): def __init__(self, resource_group, record, delete=False,
traffic_manager=None):
'''Constructor for _AzureRecord. '''Constructor for _AzureRecord.
Notes on Azure records: An Azure record set has the form Notes on Azure records: An Azure record set has the form
@@ -94,9 +105,11 @@ class _AzureRecord(object):
self.log = logging.getLogger('AzureRecord') self.log = logging.getLogger('AzureRecord')
self.resource_group = resource_group self.resource_group = resource_group
self.zone_name = record.zone.name[:len(record.zone.name) - 1] self.zone_name = record.zone.name[:-1]
self.relative_record_set_name = record.name or '@' self.relative_record_set_name = record.name or '@'
self.record_type = record._type self.record_type = record._type
self._record = record
self.traffic_manager = traffic_manager
if delete: if delete:
return return
@@ -104,11 +117,11 @@ class _AzureRecord(object):
# Refer to function docstring for key_name and class_name. # Refer to function docstring for key_name and class_name.
key_name = '{}_records'.format(self.record_type).lower() key_name = '{}_records'.format(self.record_type).lower()
if record._type == 'CNAME': if record._type == 'CNAME':
key_name = key_name[:len(key_name) - 1] key_name = key_name[:-1]
azure_class = self.TYPE_MAP[self.record_type] azure_class = self.TYPE_MAP[self.record_type]
self.params = getattr(self, '_params_for_{}'.format(record._type)) params_for = getattr(self, '_params_for_{}'.format(record._type))
self.params = self.params(record.data, key_name, azure_class) self.params = params_for(record.data, key_name, azure_class)
self.params['ttl'] = record.ttl self.params['ttl'] = record.ttl
def _params_for_A(self, data, key_name, azure_class): def _params_for_A(self, data, key_name, azure_class):
@@ -139,6 +152,9 @@ class _AzureRecord(object):
return {key_name: params} return {key_name: params}
def _params_for_CNAME(self, data, key_name, azure_class): def _params_for_CNAME(self, data, key_name, azure_class):
if self._record.dynamic and self.traffic_manager:
return {'target_resource': self.traffic_manager}
return {key_name: azure_class(cname=data['value'])} return {key_name: azure_class(cname=data['value'])}
def _params_for_MX(self, data, key_name, azure_class): def _params_for_MX(self, data, key_name, azure_class):
@@ -227,25 +243,6 @@ class _AzureRecord(object):
(parse_dict(self.params) == parse_dict(b.params)) & \ (parse_dict(self.params) == parse_dict(b.params)) & \
(self.relative_record_set_name == b.relative_record_set_name) (self.relative_record_set_name == b.relative_record_set_name)
def __str__(self):
'''String representation of an _AzureRecord.
:type return: str
'''
string = 'Zone: {}; '.format(self.zone_name)
string += 'Name: {}; '.format(self.relative_record_set_name)
string += 'Type: {}; '.format(self.record_type)
if not hasattr(self, 'params'):
return string
string += 'Ttl: {}; '.format(self.params['ttl'])
for char in self.params:
if char != 'ttl':
try:
for rec in self.params[char]:
string += 'Record: {}; '.format(rec.__dict__)
except:
string += 'Record: {}; '.format(self.params[char].__dict__)
return string
def _check_endswith_dot(string): def _check_endswith_dot(string):
return string if string.endswith('.') else string + '.' return string if string.endswith('.') else string + '.'
@@ -259,15 +256,89 @@ def _parse_azure_type(string):
:type return: str :type return: str
''' '''
return string.split('/')[len(string.split('/')) - 1] return string.split('/')[-1]
def _check_for_alias(azrecord): def _traffic_manager_suffix(record):
if (azrecord.target_resource.id and not azrecord.a_records and not return record.fqdn[:-1].replace('.', '-')
azrecord.cname_record):
return True
def _get_monitor(record):
monitor = MonitorConfig(
protocol=record.healthcheck_protocol,
port=record.healthcheck_port,
path=record.healthcheck_path,
)
host = record.healthcheck_host
if host:
monitor.custom_headers = [MonitorConfigCustomHeadersItem(
name='Host', value=host
)]
return monitor
def _profile_is_match(have, desired):
if have is None or desired is None:
return False return False
# compare basic attributes
if have.name != desired.name or \
have.traffic_routing_method != desired.traffic_routing_method or \
have.dns_config.ttl != desired.dns_config.ttl or \
len(have.endpoints) != len(desired.endpoints):
return False
# compare monitoring configuration
monitor_have = have.monitor_config
monitor_desired = desired.monitor_config
if monitor_have.protocol != monitor_desired.protocol or \
monitor_have.port != monitor_desired.port or \
monitor_have.path != monitor_desired.path or \
monitor_have.custom_headers != monitor_desired.custom_headers:
return False
# compare endpoints
method = have.traffic_routing_method
if method == 'Priority':
have_endpoints = sorted(have.endpoints, key=lambda e: e.priority)
desired_endpoints = sorted(desired.endpoints,
key=lambda e: e.priority)
elif method == 'Weighted':
have_endpoints = sorted(have.endpoints, key=lambda e: e.target)
desired_endpoints = sorted(desired.endpoints, key=lambda e: e.target)
else:
have_endpoints = have.endpoints
desired_endpoints = desired.endpoints
endpoints = zip(have_endpoints, desired_endpoints)
for have_endpoint, desired_endpoint in endpoints:
if have_endpoint.name != desired_endpoint.name or \
have_endpoint.type != desired_endpoint.type:
return False
target_type = have_endpoint.type.split('/')[-1]
if target_type == 'externalEndpoints':
# compare value, weight, priority
if have_endpoint.target != desired_endpoint.target:
return False
if method == 'Weighted' and \
have_endpoint.weight != desired_endpoint.weight:
return False
elif target_type == 'nestedEndpoints':
# compare targets
if have_endpoint.target_resource_id != \
desired_endpoint.target_resource_id:
return False
# compare geos
if method == 'Geographic':
have_geos = sorted(have_endpoint.geo_mapping)
desired_geos = sorted(desired_endpoint.geo_mapping)
if have_geos != desired_geos:
return False
else:
# unexpected, give up
return False
return True
class AzureProvider(BaseProvider): class AzureProvider(BaseProvider):
''' '''
@@ -318,7 +389,7 @@ class AzureProvider(BaseProvider):
possible to also hard-code into the config file: eg, resource_group. possible to also hard-code into the config file: eg, resource_group.
''' '''
SUPPORTS_GEO = False SUPPORTS_GEO = False
SUPPORTS_DYNAMIC = False SUPPORTS_DYNAMIC = True
SUPPORTS = set(('A', 'AAAA', 'CAA', 'CNAME', 'MX', 'NS', 'PTR', 'SRV', SUPPORTS = set(('A', 'AAAA', 'CAA', 'CNAME', 'MX', 'NS', 'PTR', 'SRV',
'TXT')) 'TXT'))
@@ -336,24 +407,45 @@ class AzureProvider(BaseProvider):
self._dns_client_directory_id = directory_id self._dns_client_directory_id = directory_id
self._dns_client_subscription_id = sub_id self._dns_client_subscription_id = sub_id
self.__dns_client = None self.__dns_client = None
self.__tm_client = None
self._resource_group = resource_group self._resource_group = resource_group
self._azure_zones = set() self._azure_zones = set()
self._traffic_managers = dict()
@property @property
def _dns_client(self): def _dns_client(self):
if self.__dns_client is None: if self.__dns_client is None:
credential = ClientSecretCredential( # Azure's logger spits out a lot of debug messages at 'INFO'
# level, override it by re-assigning `info` method to `debug`
# (ugly hack until I find a better way)
logger_name = 'azure.core.pipeline.policies.http_logging_policy'
logger = logging.getLogger(logger_name)
logger.info = logger.debug
self.__dns_client = DnsManagementClient(
credential=ClientSecretCredential(
client_id=self._dns_client_client_id, client_id=self._dns_client_client_id,
client_secret=self._dns_client_key, client_secret=self._dns_client_key,
tenant_id=self._dns_client_directory_id tenant_id=self._dns_client_directory_id,
) logger=logger,
self.__dns_client = DnsManagementClient( ),
credential=credential, subscription_id=self._dns_client_subscription_id,
subscription_id=self._dns_client_subscription_id
) )
return self.__dns_client return self.__dns_client
@property
def _tm_client(self):
if self.__tm_client is None:
self.__tm_client = TrafficManagerManagementClient(
ServicePrincipalCredentials(
self._dns_client_client_id,
secret=self._dns_client_key,
tenant=self._dns_client_directory_id,
),
self._dns_client_subscription_id,
)
return self.__tm_client
def _populate_zones(self): def _populate_zones(self):
self.log.debug('azure_zones: loading') self.log.debug('azure_zones: loading')
list_zones = self._dns_client.zones.list_by_resource_group list_zones = self._dns_client.zones.list_by_resource_group
@@ -388,6 +480,42 @@ class AzureProvider(BaseProvider):
# Else return nothing (aka false) # Else return nothing (aka false)
return return
def _populate_traffic_managers(self):
self.log.debug('traffic managers: loading')
list_profiles = self._tm_client.profiles.list_by_resource_group
for profile in list_profiles(self._resource_group):
self._traffic_managers[profile.id] = profile
# link nested profiles in advance for convenience
for _, profile in self._traffic_managers.items():
self._populate_nested_profiles(profile)
def _populate_nested_profiles(self, profile):
for ep in profile.endpoints:
target_id = ep.target_resource_id
if target_id and target_id in self._traffic_managers:
target = self._traffic_managers[target_id]
ep.target_resource = self._populate_nested_profiles(target)
return profile
def _get_tm_profile_by_id(self, resource_id):
if not self._traffic_managers:
self._populate_traffic_managers()
return self._traffic_managers.get(resource_id)
def _profile_name_to_id(self, name):
return '/subscriptions/' + self._dns_client_subscription_id + \
'/resourceGroups/' + self._resource_group + \
'/providers/Microsoft.Network/trafficManagerProfiles/' + \
name
def _get_tm_profile_by_name(self, name):
profile_id = self._profile_name_to_id(name)
return self._get_tm_profile_by_id(profile_id)
def _get_tm_for_dynamic_record(self, record):
name = _traffic_manager_suffix(record)
return self._get_tm_profile_by_name(name)
def populate(self, zone, target=False, lenient=False): def populate(self, zone, target=False, lenient=False):
'''Required function of manager.py to collect records from zone. '''Required function of manager.py to collect records from zone.
@@ -417,40 +545,35 @@ class AzureProvider(BaseProvider):
exists = False exists = False
before = len(zone.records) before = len(zone.records)
zone_name = zone.name[:len(zone.name) - 1] zone_name = zone.name[:-1]
self._populate_zones() self._populate_zones()
self._check_zone(zone_name)
_records = []
records = self._dns_client.record_sets.list_by_dns_zone records = self._dns_client.record_sets.list_by_dns_zone
if self._check_zone(zone_name): if self._check_zone(zone_name):
exists = True exists = True
for azrecord in records(self._resource_group, zone_name): for azrecord in records(self._resource_group, zone_name):
if _parse_azure_type(azrecord.type) in self.SUPPORTS:
_records.append(azrecord)
for azrecord in _records:
record_name = azrecord.name if azrecord.name != '@' else ''
typ = _parse_azure_type(azrecord.type) typ = _parse_azure_type(azrecord.type)
if typ not in self.SUPPORTS:
continue
if typ in ['A', 'CNAME']: record = self._populate_record(zone, azrecord, lenient)
if _check_for_alias(azrecord):
self.log.debug(
'Skipping - ALIAS. zone=%s record=%s, type=%s',
zone_name, record_name, typ) # pragma: no cover
continue # pragma: no cover
data = getattr(self, '_data_for_{}'.format(typ))
data = data(azrecord)
data['type'] = typ
data['ttl'] = azrecord.ttl
record = Record.new(zone, record_name, data, source=self)
zone.add_record(record, lenient=lenient) zone.add_record(record, lenient=lenient)
self.log.info('populate: found %s records, exists=%s', self.log.info('populate: found %s records, exists=%s',
len(zone.records) - before, exists) len(zone.records) - before, exists)
return exists return exists
def _populate_record(self, zone, azrecord, lenient=False):
record_name = azrecord.name if azrecord.name != '@' else ''
typ = _parse_azure_type(azrecord.type)
data_for = getattr(self, '_data_for_{}'.format(typ))
data = data_for(azrecord)
data['type'] = typ
data['ttl'] = azrecord.ttl
return Record.new(zone, record_name, data, source=self,
lenient=lenient)
def _data_for_A(self, azrecord): def _data_for_A(self, azrecord):
return {'values': [ar.ipv4_address for ar in azrecord.a_records]} return {'values': [ar.ipv4_address for ar in azrecord.a_records]}
@@ -470,6 +593,9 @@ class AzureProvider(BaseProvider):
:type return: dict :type return: dict
''' '''
if azrecord.cname_record is None and azrecord.target_resource.id:
return self._data_for_dynamic(azrecord)
return {'value': _check_endswith_dot(azrecord.cname_record.cname)} return {'value': _check_endswith_dot(azrecord.cname_record.cname)}
def _data_for_MX(self, azrecord): def _data_for_MX(self, azrecord):
@@ -495,6 +621,322 @@ class AzureProvider(BaseProvider):
ar.value)) ar.value))
for ar in azrecord.txt_records]} for ar in azrecord.txt_records]}
def _data_for_dynamic(self, azrecord):
default = set()
pools = defaultdict(lambda: {'fallback': None, 'values': []})
rules = []
# top level geo profile
geo_profile = self._get_tm_profile_by_id(azrecord.target_resource.id)
for geo_ep in geo_profile.endpoints:
rule = {}
# resolve list of regions
geo_map = list(geo_ep.geo_mapping)
if geo_map != ['WORLD']:
if 'GEO-ME' in geo_map:
# Azure treats Middle East as a separate group, but
# its part of Asia in octoDNS, so we need to remove GEO-ME
# if GEO-AS is also in the list
# Throw exception otherwise, it should not happen if the
# profile was generated by octoDNS
if 'GEO-AS' not in geo_map:
msg = '_data_for_dynamic: Profile={}: '.format(
geo_profile.name)
msg += 'Middle East (GEO-ME) is not supported by ' + \
'octoDNS. It needs to be either paired ' + \
'with Asia (GEO-AS) or expanded into ' + \
'individual list of countries.'
raise AzureException(msg)
geo_map.remove('GEO-ME')
geos = rule.setdefault('geos', [])
for code in geo_map:
if code.startswith('GEO-'):
geos.append(code[len('GEO-'):])
elif '-' in code:
country, province = code.split('-', 1)
country = GeoCodes.country_to_code(country)
geos.append('{}-{}'.format(country, province))
else:
geos.append(GeoCodes.country_to_code(code))
# second level priority profile
pool = None
rule_endpoints = geo_ep.target_resource.endpoints
rule_endpoints = sorted(rule_endpoints, key=lambda e: e.priority)
for rule_ep in rule_endpoints:
pool_name = rule_ep.name
# third (and last) level weighted RR profile
# these should be leaf node profiles with no further nesting
pool_profile = rule_ep.target_resource
# last/default pool
if pool_name == '--default--':
for pool_ep in pool_profile.endpoints:
default.add(pool_ep.target)
# this should be the last one, so let's break here
break
# set first priority endpoint as the rule's primary pool
if 'pool' not in rule:
rule['pool'] = pool_name
if pool:
# set current pool as fallback of the previous pool
pool['fallback'] = pool_name
pool = pools[pool_name]
for pool_ep in pool_profile.endpoints:
val = pool_ep.target
value_dict = {
'value': _check_endswith_dot(val),
'weight': pool_ep.weight,
}
if value_dict not in pool['values']:
pool['values'].append(value_dict)
if 'pool' not in rule or not default:
# this will happen if the priority profile does not have
# enough endpoints
msg = 'Expected at least 2 endpoints in {}, got {}'.format(
geo_ep.target_resource.name, len(rule_endpoints)
)
raise AzureException(msg)
rules.append(rule)
# Order and convert to a list
default = sorted(default)
data = {
'dynamic': {
'pools': pools,
'rules': rules,
},
'value': _check_endswith_dot(default[0]),
}
return data
def _extra_changes(self, existing, desired, changes):
changed = set()
# Abort if there are non-CNAME dynamic records
for change in changes:
record = change.record
changed.add(record)
typ = record._type
dynamic = getattr(record, 'dynamic', False)
if dynamic and typ != 'CNAME':
msg = '{}: Dynamic records in Azure must be of type CNAME'
msg = msg.format(record.fqdn)
raise AzureException(msg)
log = self.log.info
extra = []
for record in desired.records:
if not getattr(record, 'dynamic', False):
# Already changed, or not dynamic, no need to check it
continue
# let's walk through and show what will be changed even if
# the record is already be in list of changes
added = (record in changed)
active = set()
profiles = self._generate_traffic_managers(record)
for profile in profiles:
name = profile.name
active.add(name)
existing_profile = self._get_tm_profile_by_name(name)
if not _profile_is_match(existing_profile, profile):
log('_extra_changes: Profile name=%s will be synced',
name)
if not added:
extra.append(Update(record, record))
added = True
existing_profiles = self._find_traffic_managers(record)
for name in existing_profiles - active:
log('_extra_changes: Profile name=%s will be destroyed', name)
if not added:
extra.append(Update(record, record))
added = True
return extra
def _generate_tm_profile(self, name, routing, endpoints, record):
# set appropriate endpoint types
endpoint_type_prefix = 'Microsoft.Network/trafficManagerProfiles/'
for ep in endpoints:
if ep.target_resource_id:
ep.type = endpoint_type_prefix + 'nestedEndpoints'
elif ep.target:
ep.type = endpoint_type_prefix + 'externalEndpoints'
else:
msg = ('_generate_tm_profile: Invalid endpoint {} ' +
'in profile {}, needs to have either target or ' +
'target_resource_id').format(ep.name, name)
raise AzureException(msg)
# build and return
return Profile(
id=self._profile_name_to_id(name),
name=name,
traffic_routing_method=routing,
dns_config=DnsConfig(
relative_name=name,
ttl=record.ttl,
),
monitor_config=_get_monitor(record),
endpoints=endpoints,
location='global',
)
def _generate_traffic_managers(self, record):
traffic_managers = []
pools = record.dynamic.pools
tm_suffix = _traffic_manager_suffix(record)
profile = self._generate_tm_profile
# construct the default pool that will be used at the end of
# all rules
target = record.value[:-1]
default_endpoints = [Endpoint(
name=target,
target=target,
weight=1,
)]
default_profile_name = 'default--{}'.format(tm_suffix)
default_profile = profile(default_profile_name, 'Weighted',
default_endpoints, record)
traffic_managers.append(default_profile)
geo_endpoints = []
for rule in record.dynamic.rules:
pool_name = rule.data['pool']
rule_endpoints = []
priority = 1
while pool_name:
# iterate until we reach end of fallback chain
pool = pools[pool_name].data
profile_name = 'pool-{}--{}'.format(pool_name, tm_suffix)
endpoints = []
for val in pool['values']:
target = val['value']
# strip trailing dot from CNAME value
target = target[:-1]
endpoints.append(Endpoint(
name=target,
target=target,
weight=val.get('weight', 1),
))
pool_profile = profile(profile_name, 'Weighted', endpoints,
record)
traffic_managers.append(pool_profile)
# append pool to endpoint list of fallback rule profile
rule_endpoints.append(Endpoint(
name=pool_name,
target_resource_id=pool_profile.id,
priority=priority,
))
priority += 1
pool_name = pool.get('fallback')
# append default profile to the end
rule_endpoints.append(Endpoint(
name='--default--',
target_resource_id=default_profile.id,
priority=priority,
))
# create rule profile with fallback chain
rule_profile_name = 'rule-{}--{}'.format(rule.data['pool'],
tm_suffix)
rule_profile = profile(rule_profile_name, 'Priority',
rule_endpoints, record)
traffic_managers.append(rule_profile)
# append rule profile to top-level geo profile
rule_geos = rule.data.get('geos', [])
geos = []
if len(rule_geos) > 0:
for geo in rule_geos:
if '-' in geo:
geos.append(geo.split('-', 1)[-1])
else:
geos.append('GEO-{}'.format(geo))
if geo == 'AS':
# Middle East is part of Asia in octoDNS, but
# Azure treats it as a separate "group", so let's
# add it in the list of geo mappings. We will drop
# it when we later parse the list of regions.
geos.append('GEO-ME')
else:
geos.append('WORLD')
geo_endpoints.append(Endpoint(
name='rule-{}'.format(rule.data['pool']),
target_resource_id=rule_profile.id,
geo_mapping=geos,
))
geo_profile = profile(tm_suffix, 'Geographic', geo_endpoints, record)
traffic_managers.append(geo_profile)
return traffic_managers
def _sync_traffic_managers(self, record):
desired_profiles = self._generate_traffic_managers(record)
seen = set()
tm_sync = self._tm_client.profiles.create_or_update
populate = self._populate_nested_profiles
for desired in desired_profiles:
name = desired.name
if name in seen:
continue
existing = self._get_tm_profile_by_name(name)
if not _profile_is_match(existing, desired):
self.log.info(
'_sync_traffic_managers: Syncing profile=%s', name)
profile = tm_sync(self._resource_group, name, desired)
self._traffic_managers[profile.id] = populate(profile)
else:
self.log.debug(
'_sync_traffic_managers: Skipping profile=%s: up to date',
name)
seen.add(name)
return seen
def _find_traffic_managers(self, record):
tm_suffix = _traffic_manager_suffix(record)
profiles = set()
for profile_id in self._traffic_managers:
# match existing profiles with record's suffix
name = profile_id.split('/')[-1]
if name == tm_suffix or \
name.endswith('--{}'.format(tm_suffix)):
profiles.add(name)
return profiles
def _traffic_managers_gc(self, record, active_profiles):
existing_profiles = self._find_traffic_managers(record)
# delete unused profiles
for profile_name in existing_profiles - active_profiles:
self.log.info('_traffic_managers_gc: Deleting profile=%s',
profile_name)
self._tm_client.profiles.delete(self._resource_group, profile_name)
def _apply_Create(self, change): def _apply_Create(self, change):
'''A record from change must be created. '''A record from change must be created.
@@ -503,7 +945,15 @@ class AzureProvider(BaseProvider):
:type return: void :type return: void
''' '''
ar = _AzureRecord(self._resource_group, change.new) record = change.new
dynamic = getattr(record, 'dynamic', False)
if dynamic:
self._sync_traffic_managers(record)
profile = self._get_tm_for_dynamic_record(record)
ar = _AzureRecord(self._resource_group, record,
traffic_manager=profile)
create = self._dns_client.record_sets.create_or_update create = self._dns_client.record_sets.create_or_update
create(resource_group_name=ar.resource_group, create(resource_group_name=ar.resource_group,
@@ -512,17 +962,71 @@ class AzureProvider(BaseProvider):
record_type=ar.record_type, record_type=ar.record_type,
parameters=ar.params) parameters=ar.params)
self.log.debug('* Success Create/Update: {}'.format(ar)) self.log.debug('* Success Create: {}'.format(record))
_apply_Update = _apply_Create def _apply_Update(self, change):
'''A record from change must be created.
:param change: a change object
:type change: octodns.record.Change
:type return: void
'''
existing = change.existing
new = change.new
existing_is_dynamic = getattr(existing, 'dynamic', False)
new_is_dynamic = getattr(new, 'dynamic', False)
update_record = True
if new_is_dynamic:
active = self._sync_traffic_managers(new)
# only TTL is configured in record, everything else goes inside
# traffic managers, so no need to update if TTL is unchanged
# and existing record is already aliased to its traffic manager
if existing.ttl == new.ttl and existing_is_dynamic:
update_record = False
if update_record:
profile = self._get_tm_for_dynamic_record(new)
ar = _AzureRecord(self._resource_group, new,
traffic_manager=profile)
update = self._dns_client.record_sets.create_or_update
update(resource_group_name=ar.resource_group,
zone_name=ar.zone_name,
relative_record_set_name=ar.relative_record_set_name,
record_type=ar.record_type,
parameters=ar.params)
if new_is_dynamic:
# let's cleanup unused traffic managers
self._traffic_managers_gc(new, active)
elif existing_is_dynamic:
# cleanup traffic managers when a dynamic record gets
# changed to a simple record
self._traffic_managers_gc(existing, set())
self.log.debug('* Success Update: {}'.format(new))
def _apply_Delete(self, change): def _apply_Delete(self, change):
ar = _AzureRecord(self._resource_group, change.existing, delete=True) '''A record from change must be deleted.
:param change: a change object
:type change: octodns.record.Change
:type return: void
'''
record = change.record
ar = _AzureRecord(self._resource_group, record, delete=True)
delete = self._dns_client.record_sets.delete delete = self._dns_client.record_sets.delete
delete(self._resource_group, ar.zone_name, ar.relative_record_set_name, delete(self._resource_group, ar.zone_name, ar.relative_record_set_name,
ar.record_type) ar.record_type)
if getattr(record, 'dynamic', False):
self._traffic_managers_gc(record, set())
self.log.debug('* Success Delete: {}'.format(ar)) self.log.debug('* Success Delete: {}'.format(ar))
def _apply(self, plan): def _apply(self, plan):

View File

@@ -2,6 +2,7 @@ PyYaml==5.4
azure-common==1.1.27 azure-common==1.1.27
azure-identity==1.5.0 azure-identity==1.5.0
azure-mgmt-dns==8.0.0 azure-mgmt-dns==8.0.0
azure-mgmt-trafficmanager==0.51.0
boto3==1.15.9 boto3==1.15.9
botocore==1.18.9 botocore==1.18.9
dnspython==1.16.0 dnspython==1.16.0

View File

@@ -5,19 +5,23 @@
from __future__ import absolute_import, division, print_function, \ from __future__ import absolute_import, division, print_function, \
unicode_literals unicode_literals
from octodns.record import Create, Delete, Record from octodns.record import Create, Update, Delete, Record
from octodns.provider.azuredns import _AzureRecord, AzureProvider, \ from octodns.provider.azuredns import _AzureRecord, AzureProvider, \
_check_endswith_dot, _parse_azure_type, _check_for_alias _check_endswith_dot, _parse_azure_type, _traffic_manager_suffix, \
_get_monitor, _profile_is_match, AzureException
from octodns.zone import Zone from octodns.zone import Zone
from octodns.provider.base import Plan from octodns.provider.base import Plan
from azure.mgmt.dns.models import ARecord, AaaaRecord, CaaRecord, \ from azure.mgmt.dns.models import ARecord, AaaaRecord, CaaRecord, \
CnameRecord, MxRecord, SrvRecord, NsRecord, PtrRecord, TxtRecord, \ CnameRecord, MxRecord, SrvRecord, NsRecord, PtrRecord, TxtRecord, \
RecordSet, SoaRecord, SubResource, Zone as AzureZone RecordSet, SoaRecord, SubResource, Zone as AzureZone
from azure.mgmt.trafficmanager.models import Profile, DnsConfig, \
MonitorConfig, Endpoint, MonitorConfigCustomHeadersItem
from msrestazure.azure_exceptions import CloudError from msrestazure.azure_exceptions import CloudError
from six import text_type
from unittest import TestCase from unittest import TestCase
from mock import Mock, patch from mock import Mock, patch, call
zone = Zone(name='unit.tests.', sub_zones=[]) zone = Zone(name='unit.tests.', sub_zones=[])
@@ -343,6 +347,43 @@ class Test_AzureRecord(TestCase):
assert(azure_records[i]._equals(octo)) assert(azure_records[i]._equals(octo))
class Test_DynamicAzureRecord(TestCase):
def test_azure_record(self):
tm_profile = Profile()
data = {
'ttl': 60,
'type': 'CNAME',
'value': 'default.unit.tests.',
'dynamic': {
'pools': {
'one': {
'values': [
{'value': 'one.unit.tests.', 'weight': 1}
],
'fallback': 'two',
},
'two': {
'values': [
{'value': 'two.unit.tests.', 'weight': 1}
],
},
},
'rules': [
{'geos': ['AF'], 'pool': 'one'},
{'pool': 'two'},
],
}
}
octo_record = Record.new(zone, 'foo', data)
azure_record = _AzureRecord('TestAzure', octo_record,
traffic_manager=tm_profile)
self.assertEqual(azure_record.zone_name, zone.name[:-1])
self.assertEqual(azure_record.relative_record_set_name, 'foo')
self.assertEqual(azure_record.record_type, 'CNAME')
self.assertEqual(azure_record.params['ttl'], 60)
self.assertEqual(azure_record.params['target_resource'], tm_profile)
class Test_ParseAzureType(TestCase): class Test_ParseAzureType(TestCase):
def test_parse_azure_type(self): def test_parse_azure_type(self):
for expected, test in [['A', 'Microsoft.Network/dnszones/A'], for expected, test in [['A', 'Microsoft.Network/dnszones/A'],
@@ -361,40 +402,369 @@ class Test_CheckEndswithDot(TestCase):
self.assertEquals(expected, _check_endswith_dot(test)) self.assertEquals(expected, _check_endswith_dot(test))
class Test_CheckAzureAlias(TestCase): class Test_TrafficManagerSuffix(TestCase):
def test_check_for_alias(self): def test_traffic_manager_suffix(self):
alias_record = type('C', (object,), {}) test = Record.new(zone, 'foo', data={
alias_record.target_resource = type('C', (object,), {}) 'ttl': 60, 'type': 'CNAME', 'value': 'default.unit.tests.',
alias_record.target_resource.id = "/subscriptions/x/resourceGroups/y/z" })
alias_record.a_records = None self.assertEqual(_traffic_manager_suffix(test), 'foo-unit-tests')
alias_record.cname_record = None
self.assertEquals(_check_for_alias(alias_record), True)
class Test_GetMonitor(TestCase):
def test_get_monitor(self):
record = Record.new(zone, 'foo', data={
'type': 'CNAME', 'ttl': 60, 'value': 'default.unit.tests.',
'octodns': {
'healthcheck': {
'path': '/_ping',
'port': 4443,
'protocol': 'HTTPS',
}
},
})
monitor = _get_monitor(record)
self.assertEqual(monitor.protocol, 'HTTPS')
self.assertEqual(monitor.port, 4443)
self.assertEqual(monitor.path, '/_ping')
headers = monitor.custom_headers
self.assertIsInstance(headers, list)
self.assertEquals(len(headers), 1)
headers = headers[0]
self.assertEqual(headers.name, 'Host')
self.assertEqual(headers.value, record.healthcheck_host)
# test TCP monitor
record._octodns['healthcheck']['protocol'] = 'TCP'
monitor = _get_monitor(record)
self.assertEqual(monitor.protocol, 'TCP')
self.assertIsNone(monitor.custom_headers)
class Test_ProfileIsMatch(TestCase):
def test_profile_is_match(self):
is_match = _profile_is_match
self.assertFalse(is_match(None, Profile()))
# Profile object builder with default property values that can be
# overridden for testing below
def profile(
name = 'foo-unit-tests',
ttl = 60,
method = 'Geographic',
monitor_proto = 'HTTPS',
monitor_port = 4443,
monitor_path = '/_ping',
endpoints = 1,
endpoint_name = 'name',
endpoint_type = 'profile/nestedEndpoints',
target = 'target.unit.tests',
target_id = 'resource/id',
geos = ['GEO-AF'],
weight = 1,
priority = 1,
):
dns = DnsConfig(ttl=ttl)
return Profile(
name=name, traffic_routing_method=method, dns_config=dns,
monitor_config=MonitorConfig(
protocol=monitor_proto,
port=monitor_port,
path=monitor_path,
),
endpoints=[Endpoint(
name=endpoint_name,
type=endpoint_type,
target=target,
target_resource_id=target_id,
geo_mapping=geos,
weight=weight,
priority=priority,
)] + [Endpoint()] * (endpoints - 1),
)
self.assertTrue(is_match(profile(), profile()))
self.assertFalse(is_match(profile(), profile(name='two')))
self.assertFalse(is_match(profile(), profile(endpoints=2)))
self.assertFalse(is_match(profile(), profile(monitor_proto='HTTP')))
self.assertFalse(is_match(profile(), profile(endpoint_name='a')))
self.assertFalse(is_match(profile(), profile(endpoint_type='b')))
self.assertFalse(
is_match(profile(endpoint_type='b'), profile(endpoint_type='b'))
)
self.assertFalse(is_match(profile(), profile(target_id='rsrc/id2')))
self.assertFalse(is_match(profile(), profile(geos=['IN'])))
def wprofile(**kwargs):
kwargs['method'] = 'Weighted'
kwargs['endpoint_type'] = 'profile/externalEndpoints'
return profile(**kwargs)
self.assertFalse(is_match(wprofile(), wprofile(target='bar.unit')))
self.assertFalse(is_match(wprofile(), wprofile(weight=3)))
class TestAzureDnsProvider(TestCase): class TestAzureDnsProvider(TestCase):
def _provider(self): def _provider(self):
return self._get_provider('mock_spc', 'mock_dns_client') return self._get_provider('mock_spc', 'mock_dns_client')
@patch('octodns.provider.azuredns.TrafficManagerManagementClient')
@patch('octodns.provider.azuredns.DnsManagementClient') @patch('octodns.provider.azuredns.DnsManagementClient')
@patch('octodns.provider.azuredns.ClientSecretCredential') @patch('octodns.provider.azuredns.ClientSecretCredential')
def _get_provider(self, mock_spc, mock_dns_client): @patch('octodns.provider.azuredns.ServicePrincipalCredentials')
def _get_provider(self, mock_spc, mock_css, mock_dns_client,
mock_tm_client):
'''Returns a mock AzureProvider object to use in testing. '''Returns a mock AzureProvider object to use in testing.
:param mock_spc: placeholder :param mock_spc: placeholder
:type mock_spc: str :type mock_spc: str
:param mock_dns_client: placeholder :param mock_dns_client: placeholder
:type mock_dns_client: str :type mock_dns_client: str
:param mock_tm_client: placeholder
:type mock_tm_client: str
:type return: AzureProvider :type return: AzureProvider
''' '''
provider = AzureProvider('mock_id', 'mock_client', 'mock_key', provider = AzureProvider('mock_id', 'mock_client', 'mock_key',
'mock_directory', 'mock_sub', 'mock_rg' 'mock_directory', 'mock_sub', 'mock_rg'
) )
# Fetch the client to force it to load the creds # Fetch the client to force it to load the creds
provider._dns_client provider._dns_client
# set critical functions to return properly
tm_list = provider._tm_client.profiles.list_by_resource_group
tm_list.return_value = []
tm_sync = provider._tm_client.profiles.create_or_update
def side_effect(rg, name, profile):
return profile
tm_sync.side_effect = side_effect
return provider return provider
def _get_dynamic_record(self, zone):
return Record.new(zone, 'foo', data={
'type': 'CNAME',
'ttl': 60,
'value': 'default.unit.tests.',
'dynamic': {
'pools': {
'one': {
'values': [
{'value': 'one.unit.tests.', 'weight': 11},
],
'fallback': 'two',
},
'two': {
'values': [
{'value': 'two1.unit.tests.', 'weight': 3},
{'value': 'two2.unit.tests.', 'weight': 4},
],
'fallback': 'three',
},
'three': {
'values': [
{'value': 'three.unit.tests.', 'weight': 13},
],
},
},
'rules': [
{'geos': ['AF', 'EU-DE', 'NA-US-CA'], 'pool': 'one'},
{'pool': 'three'},
],
},
'octodns': {
'healthcheck': {
'path': '/_ping',
'port': 4443,
'protocol': 'HTTPS',
}
},
})
def _get_tm_profiles(self, provider):
sub = provider._dns_client_subscription_id
rg = provider._resource_group
base_id = '/subscriptions/' + sub + \
'/resourceGroups/' + rg + \
'/providers/Microsoft.Network/trafficManagerProfiles/'
suffix = 'foo-unit-tests'
id_format = base_id + '{}--' + suffix
name_format = '{}--' + suffix
dns = DnsConfig(ttl=60)
header = MonitorConfigCustomHeadersItem(name='Host',
value='foo.unit.tests')
monitor = MonitorConfig(protocol='HTTPS', port=4443, path='/_ping',
custom_headers=[header])
external = 'Microsoft.Network/trafficManagerProfiles/externalEndpoints'
nested = 'Microsoft.Network/trafficManagerProfiles/nestedEndpoints'
return [
Profile(
id=id_format.format('default'),
name=name_format.format('default'),
traffic_routing_method='Weighted',
dns_config=dns,
monitor_config=monitor,
endpoints=[
Endpoint(
name='default.unit.tests',
type=external,
target='default.unit.tests',
weight=1,
),
],
),
Profile(
id=id_format.format('pool-one'),
name=name_format.format('pool-one'),
traffic_routing_method='Weighted',
dns_config=dns,
monitor_config=monitor,
endpoints=[
Endpoint(
name='one.unit.tests',
type=external,
target='one.unit.tests',
weight=11,
),
],
),
Profile(
id=id_format.format('pool-two'),
name=name_format.format('pool-two'),
traffic_routing_method='Weighted',
dns_config=dns,
monitor_config=monitor,
endpoints=[
Endpoint(
name='two1.unit.tests',
type=external,
target='two1.unit.tests',
weight=3,
),
Endpoint(
name='two2.unit.tests',
type=external,
target='two2.unit.tests',
weight=4,
),
],
),
Profile(
id=id_format.format('pool-three'),
name=name_format.format('pool-three'),
traffic_routing_method='Weighted',
dns_config=dns,
monitor_config=monitor,
endpoints=[
Endpoint(
name='three.unit.tests',
type=external,
target='three.unit.tests',
weight=13,
),
],
),
Profile(
id=id_format.format('rule-one'),
name=name_format.format('rule-one'),
traffic_routing_method='Priority',
dns_config=dns,
monitor_config=monitor,
endpoints=[
Endpoint(
name='one',
type=nested,
target_resource_id=id_format.format('pool-one'),
priority=1,
),
Endpoint(
name='two',
type=nested,
target_resource_id=id_format.format('pool-two'),
priority=2,
),
Endpoint(
name='three',
type=nested,
target_resource_id=id_format.format('pool-three'),
priority=3,
),
Endpoint(
name='--default--',
type=nested,
target_resource_id=id_format.format('default'),
priority=4,
),
],
),
Profile(
id=id_format.format('rule-three'),
name=name_format.format('rule-three'),
traffic_routing_method='Priority',
dns_config=dns,
monitor_config=monitor,
endpoints=[
Endpoint(
name='three',
type=nested,
target_resource_id=id_format.format('pool-three'),
priority=1,
),
Endpoint(
name='--default--',
type=nested,
target_resource_id=id_format.format('default'),
priority=2,
),
],
),
Profile(
id=base_id + suffix,
name=suffix,
traffic_routing_method='Geographic',
dns_config=dns,
monitor_config=monitor,
endpoints=[
Endpoint(
geo_mapping=['GEO-AF', 'DE', 'US-CA'],
name='rule-one',
type=nested,
target_resource_id=id_format.format('rule-one'),
),
Endpoint(
geo_mapping=['WORLD'],
name='rule-three',
type=nested,
target_resource_id=id_format.format('rule-three'),
),
],
),
]
def _get_dynamic_package(self):
'''Convenience function to setup a sample dynamic record.
'''
provider = self._get_provider()
# setup traffic manager profiles
tm_list = provider._tm_client.profiles.list_by_resource_group
tm_list.return_value = self._get_tm_profiles(provider)
# setup zone with dynamic record
zone = Zone(name='unit.tests.', sub_zones=[])
record = self._get_dynamic_record(zone)
zone.add_record(record)
# return everything
return provider, zone, record
def test_populate_records(self): def test_populate_records(self):
provider = self._get_provider() provider = self._get_provider()
@@ -510,6 +880,121 @@ class TestAzureDnsProvider(TestCase):
self.assertEquals(len(zone.records), 17) self.assertEquals(len(zone.records), 17)
self.assertTrue(exists) self.assertTrue(exists)
def test_populate_dynamic(self):
# Middle east without Asia raises exception
provider, zone, record = self._get_dynamic_package()
tm_suffix = _traffic_manager_suffix(record)
tm_id = provider._profile_name_to_id
tm_list = provider._tm_client.profiles.list_by_resource_group
rule_name = 'rule-one--{}'.format(tm_suffix)
nested = 'Microsoft.Network/trafficManagerProfiles/nestedEndpoints'
tm_list.return_value = [
Profile(
id=tm_id(tm_suffix),
name=tm_suffix,
traffic_routing_method='Geographic',
endpoints=[
Endpoint(
geo_mapping=['GEO-ME'],
),
],
),
]
azrecord = RecordSet(
ttl=60,
target_resource=SubResource(id=tm_id(tm_suffix)),
)
azrecord.name = record.name or '@'
azrecord.type = 'Microsoft.Network/dnszones/{}'.format(record._type)
with self.assertRaises(AzureException) as ctx:
provider._populate_record(zone, azrecord)
self.assertTrue(text_type(ctx).startswith(
'Middle East (GEO-ME) is not supported'
))
# empty priority profile raises exception
provider, zone, record = self._get_dynamic_package()
tm_list = provider._tm_client.profiles.list_by_resource_group
rule_name = 'rule-one--{}'.format(tm_suffix)
nested = 'Microsoft.Network/trafficManagerProfiles/nestedEndpoints'
tm_list.return_value = [
Profile(
id=tm_id(rule_name),
name=rule_name,
traffic_routing_method='Priority',
endpoints=[],
),
Profile(
id=tm_id(tm_suffix),
name=tm_suffix,
traffic_routing_method='Geographic',
endpoints=[
Endpoint(
geo_mapping=['WORLD'],
name='rule-one',
type=nested,
target_resource_id=tm_id(rule_name),
),
],
),
]
with self.assertRaises(AzureException) as ctx:
provider._populate_record(zone, azrecord)
self.assertTrue(text_type(ctx).startswith(
'Expected at least 2 endpoints'
))
# valid set of profiles produce expected dynamic record
provider, zone, record = self._get_dynamic_package()
root_profile_id = provider._profile_name_to_id(
_traffic_manager_suffix(record)
)
azrecord = RecordSet(
ttl=60,
target_resource=SubResource(id=root_profile_id),
)
azrecord.name = record.name or '@'
azrecord.type = 'Microsoft.Network/dnszones/{}'.format(record._type)
record = provider._populate_record(zone, azrecord)
self.assertEqual(record.name, 'foo')
self.assertEqual(record.ttl, 60)
self.assertEqual(record.value, 'default.unit.tests.')
self.assertEqual(record.dynamic._data(), {
'pools': {
'one': {
'values': [
{'value': 'one.unit.tests.', 'weight': 11},
],
'fallback': 'two',
},
'two': {
'values': [
{'value': 'two1.unit.tests.', 'weight': 3},
{'value': 'two2.unit.tests.', 'weight': 4},
],
'fallback': 'three',
},
'three': {
'values': [
{'value': 'three.unit.tests.', 'weight': 13},
],
'fallback': None,
},
},
'rules': [
{'geos': ['AF', 'EU-DE', 'NA-US-CA'], 'pool': 'one'},
{'pool': 'three'},
],
})
# valid profiles with Middle East test case
geo_profile = provider._get_tm_for_dynamic_record(record)
geo_profile.endpoints[0].geo_mapping.extend(['GEO-ME', 'GEO-AS'])
record = provider._populate_record(zone, azrecord)
self.assertIn('AS', record.dynamic.rules[0].data['geos'])
self.assertNotIn('ME', record.dynamic.rules[0].data['geos'])
def test_populate_zone(self): def test_populate_zone(self):
provider = self._get_provider() provider = self._get_provider()
@@ -541,20 +1026,388 @@ class TestAzureDnsProvider(TestCase):
None None
) )
def test_extra_changes(self):
provider, existing, record = self._get_dynamic_package()
# test simple records produce no extra changes
desired = Zone(name=existing.name, sub_zones=[])
desired.add_record(Record.new(desired, 'simple', data={
'type': record._type,
'ttl': record.ttl,
'value': record.value,
}))
extra = provider._extra_changes(desired, desired, [])
self.assertEqual(len(extra), 0)
# test an unchanged dynamic record produces no extra changes
desired.add_record(record)
extra = provider._extra_changes(existing, desired, [])
self.assertEqual(len(extra), 0)
# test unused TM produces the extra change for clean up
sample_profile = self._get_tm_profiles(provider)[0]
tm_id = provider._profile_name_to_id
root_profile_name = _traffic_manager_suffix(record)
extra_profile = Profile(
id=tm_id('random--{}'.format(root_profile_name)),
name='random--{}'.format(root_profile_name),
traffic_routing_method='Weighted',
dns_config=sample_profile.dns_config,
monitor_config=sample_profile.monitor_config,
endpoints=sample_profile.endpoints,
)
tm_list = provider._tm_client.profiles.list_by_resource_group
tm_list.return_value.append(extra_profile)
provider._populate_traffic_managers()
extra = provider._extra_changes(existing, desired, [])
self.assertEqual(len(extra), 1)
extra = extra[0]
self.assertIsInstance(extra, Update)
self.assertEqual(extra.new, record)
desired._remove_record(record)
tm_list.return_value.pop()
# test new dynamic record does not produce an extra change for it
new_dynamic = Record.new(desired, record.name + '2', data={
'type': record._type,
'ttl': record.ttl,
'value': record.value,
'dynamic': record.dynamic._data(),
'octodns': record._octodns,
})
# test change in healthcheck by using a different port number
update_dynamic = Record.new(desired, record.name, data={
'type': record._type,
'ttl': record.ttl,
'value': record.value,
'dynamic': record.dynamic._data(),
'octodns': {
'healthcheck': {
'path': '/_ping',
'port': 443,
'protocol': 'HTTPS',
},
},
})
desired.add_record(new_dynamic)
desired.add_record(update_dynamic)
changes = [Create(new_dynamic)]
extra = provider._extra_changes(existing, desired, changes)
# implicitly asserts that new_dynamic was not added to extra changes
# as it was already in the `changes` list
self.assertEqual(len(extra), 1)
extra = extra[0]
self.assertIsInstance(extra, Update)
self.assertEqual(extra.new, update_dynamic)
# test non-CNAME dynamic record throws exception
a_dynamic = Record.new(desired, record.name + '3', data={
'type': 'A',
'ttl': record.ttl,
'values': ['1.1.1.1'],
'dynamic': {
'pools': {
'one': {'values': [{'value': '2.2.2.2'}]},
},
'rules': [
{'pool': 'one'},
],
},
})
desired.add_record(a_dynamic)
changes.append(Create(a_dynamic))
with self.assertRaises(AzureException):
provider._extra_changes(existing, desired, changes)
def test_generate_tm_profile(self):
provider, zone, record = self._get_dynamic_package()
profile_gen = provider._generate_tm_profile
name = 'foobar'
routing = 'Priority'
endpoints = [
Endpoint(target='one.unit.tests'),
Endpoint(target_resource_id='/s/1/rg/foo/tm/foobar2'),
Endpoint(name='invalid'),
]
# invalid endpoint raises exception
with self.assertRaises(AzureException):
profile_gen(name, routing, endpoints, record)
# regular test
endpoints.pop()
profile = profile_gen(name, routing, endpoints, record)
# implicitly tests _profile_name_to_id
sub = provider._dns_client_subscription_id
rg = provider._resource_group
expected_id = '/subscriptions/' + sub + \
'/resourceGroups/' + rg + \
'/providers/Microsoft.Network/trafficManagerProfiles/' + name
self.assertEqual(profile.id, expected_id)
self.assertEqual(profile.name, name)
self.assertEqual(profile.traffic_routing_method, routing)
self.assertEqual(profile.dns_config.ttl, record.ttl)
self.assertEqual(len(profile.endpoints), len(endpoints))
self.assertEqual(
profile.endpoints[0].type,
'Microsoft.Network/trafficManagerProfiles/externalEndpoints'
)
self.assertEqual(
profile.endpoints[1].type,
'Microsoft.Network/trafficManagerProfiles/nestedEndpoints'
)
def test_generate_traffic_managers(self):
provider, zone, record = self._get_dynamic_package()
profiles = provider._generate_traffic_managers(record)
deduped = []
seen = set()
for profile in profiles:
if profile.name not in seen:
deduped.append(profile)
seen.add(profile.name)
# check that every profile is a match with what we expect
expected_profiles = self._get_tm_profiles(provider)
self.assertEqual(len(expected_profiles), len(deduped))
for have, expected in zip(deduped, expected_profiles):
self.assertTrue(_profile_is_match(have, expected))
# check Asia/Middle East test case
record.dynamic._data()['rules'][0]['geos'].append('AS')
profiles = provider._generate_traffic_managers(record)
geo_profile_name = _traffic_manager_suffix(record)
geo_profile = next(
profile
for profile in profiles
if profile.name == geo_profile_name
)
self.assertIn('GEO-ME', geo_profile.endpoints[0].geo_mapping)
self.assertIn('GEO-AS', geo_profile.endpoints[0].geo_mapping)
def test_sync_traffic_managers(self):
provider, zone, record = self._get_dynamic_package()
provider._populate_traffic_managers()
tm_sync = provider._tm_client.profiles.create_or_update
suffix = 'foo-unit-tests'
expected_seen = {
suffix, 'default--{}'.format(suffix),
'rule-one--{}'.format(suffix), 'rule-three--{}'.format(suffix),
'pool-one--{}'.format(suffix), 'pool-two--{}'.format(suffix),
'pool-three--{}'.format(suffix),
}
# test no change
seen = provider._sync_traffic_managers(record)
self.assertEqual(seen, expected_seen)
tm_sync.assert_not_called()
# test that changing weight causes update API call
dynamic = record.dynamic._data()
dynamic['pools']['one']['values'][0]['weight'] = 14
data = {
'type': 'CNAME',
'ttl': record.ttl,
'value': record.value,
'dynamic': dynamic,
'octodns': record._octodns,
}
new_record = Record.new(zone, record.name, data)
tm_sync.reset_mock()
seen2 = provider._sync_traffic_managers(new_record)
self.assertEqual(seen2, expected_seen)
tm_sync.assert_called_once()
# test that new profile was successfully inserted in cache
new_profile = provider._get_tm_profile_by_name(
'pool-one--{}'.format(suffix)
)
self.assertEqual(new_profile.endpoints[0].weight, 14)
def test_find_traffic_managers(self):
provider, zone, record = self._get_dynamic_package()
# insert a non-matching profile
sample_profile = self._get_tm_profiles(provider)[0]
# dummy record for generating suffix
record2 = Record.new(zone, record.name + '2', data={
'type': record._type,
'ttl': record.ttl,
'value': record.value,
})
suffix2 = _traffic_manager_suffix(record2)
tm_id = provider._profile_name_to_id
extra_profile = Profile(
id=tm_id('random--{}'.format(suffix2)),
name='random--{}'.format(suffix2),
traffic_routing_method='Weighted',
dns_config=sample_profile.dns_config,
monitor_config=sample_profile.monitor_config,
endpoints=sample_profile.endpoints,
)
tm_list = provider._tm_client.profiles.list_by_resource_group
tm_list.return_value.append(extra_profile)
provider._populate_traffic_managers()
# implicitly asserts that non-matching profile is not included
suffix = _traffic_manager_suffix(record)
self.assertEqual(provider._find_traffic_managers(record), {
suffix, 'default--{}'.format(suffix),
'rule-one--{}'.format(suffix), 'rule-three--{}'.format(suffix),
'pool-one--{}'.format(suffix), 'pool-two--{}'.format(suffix),
'pool-three--{}'.format(suffix),
})
def test_traffic_manager_gc(self):
provider, zone, record = self._get_dynamic_package()
provider._populate_traffic_managers()
profiles = provider._find_traffic_managers(record)
profile_delete_mock = provider._tm_client.profiles.delete
provider._traffic_managers_gc(record, profiles)
profile_delete_mock.assert_not_called()
profile_delete_mock.reset_mock()
remove = list(profiles)[3]
profiles.discard(remove)
provider._traffic_managers_gc(record, profiles)
profile_delete_mock.assert_has_calls(
[call(provider._resource_group, remove)]
)
def test_apply(self): def test_apply(self):
provider = self._get_provider() provider = self._get_provider()
changes = [] half = int(len(octo_records) / 2)
deletes = [] changes = [Create(r) for r in octo_records[:half]] + \
for i in octo_records: [Update(r, r) for r in octo_records[half:]]
changes.append(Create(i)) deletes = [Delete(r) for r in octo_records]
deletes.append(Delete(i))
self.assertEquals(19, provider.apply(Plan(None, zone, self.assertEquals(19, provider.apply(Plan(None, zone,
changes, True))) changes, True)))
self.assertEquals(19, provider.apply(Plan(zone, zone, self.assertEquals(19, provider.apply(Plan(zone, zone,
deletes, True))) deletes, True)))
def test_apply_create_dynamic(self):
provider = self._get_provider()
tm_list = provider._tm_client.profiles.list_by_resource_group
tm_list.return_value = []
tm_sync = provider._tm_client.profiles.create_or_update
zone = Zone(name='unit.tests.', sub_zones=[])
record = self._get_dynamic_record(zone)
profiles = self._get_tm_profiles(provider)
provider._apply_Create(Create(record))
# create was called as many times as number of profiles required for
# the dynamic record
self.assertEqual(tm_sync.call_count, len(profiles))
create = provider._dns_client.record_sets.create_or_update
create.assert_called_once()
def test_apply_update_dynamic(self):
# existing is simple, new is dynamic
provider = self._get_provider()
tm_list = provider._tm_client.profiles.list_by_resource_group
tm_list.return_value = []
profiles = self._get_tm_profiles(provider)
dynamic_record = self._get_dynamic_record(zone)
simple_record = Record.new(zone, dynamic_record.name, data={
'type': 'CNAME',
'ttl': 3600,
'value': 'cname.unit.tests.',
})
change = Update(simple_record, dynamic_record)
provider._apply_Update(change)
tm_sync, dns_update, tm_delete = (
provider._tm_client.profiles.create_or_update,
provider._dns_client.record_sets.create_or_update,
provider._tm_client.profiles.delete
)
self.assertEqual(tm_sync.call_count, len(profiles))
dns_update.assert_called_once()
tm_delete.assert_not_called()
# existing is dynamic, new is simple
provider, existing, dynamic_record = self._get_dynamic_package()
profiles = self._get_tm_profiles(provider)
change = Update(dynamic_record, simple_record)
provider._apply_Update(change)
tm_sync, dns_update, tm_delete = (
provider._tm_client.profiles.create_or_update,
provider._dns_client.record_sets.create_or_update,
provider._tm_client.profiles.delete
)
tm_sync.assert_not_called()
dns_update.assert_called_once()
self.assertEqual(tm_delete.call_count, len(profiles))
# both are dynamic, healthcheck port is changed
provider, existing, dynamic_record = self._get_dynamic_package()
profiles = self._get_tm_profiles(provider)
dynamic_record2 = self._get_dynamic_record(existing)
dynamic_record2._octodns['healthcheck']['port'] += 1
change = Update(dynamic_record, dynamic_record2)
provider._apply_Update(change)
tm_sync, dns_update, tm_delete = (
provider._tm_client.profiles.create_or_update,
provider._dns_client.record_sets.create_or_update,
provider._tm_client.profiles.delete
)
self.assertEqual(tm_sync.call_count, len(profiles))
dns_update.assert_not_called()
tm_delete.assert_not_called()
# both are dynamic, extra profile should be deleted
provider, existing, dynamic_record = self._get_dynamic_package()
sample_profile = self._get_tm_profiles(provider)[0]
tm_id = provider._profile_name_to_id
root_profile_name = _traffic_manager_suffix(dynamic_record)
extra_profile = Profile(
id=tm_id('random--{}'.format(root_profile_name)),
name='random--{}'.format(root_profile_name),
traffic_routing_method='Weighted',
dns_config=sample_profile.dns_config,
monitor_config=sample_profile.monitor_config,
endpoints=sample_profile.endpoints,
)
tm_list = provider._tm_client.profiles.list_by_resource_group
tm_list.return_value.append(extra_profile)
change = Update(dynamic_record, dynamic_record)
provider._apply_Update(change)
tm_sync, dns_update, tm_delete = (
provider._tm_client.profiles.create_or_update,
provider._dns_client.record_sets.create_or_update,
provider._tm_client.profiles.delete
)
tm_sync.assert_not_called()
dns_update.assert_not_called()
tm_delete.assert_called_once()
def test_apply_delete_dynamic(self):
provider, existing, record = self._get_dynamic_package()
provider._populate_traffic_managers()
profiles = self._get_tm_profiles(provider)
change = Delete(record)
provider._apply_Delete(change)
dns_delete, tm_delete = (
provider._dns_client.record_sets.delete,
provider._tm_client.profiles.delete
)
dns_delete.assert_called_once()
self.assertEqual(tm_delete.call_count, len(profiles))
def test_create_zone(self): def test_create_zone(self):
provider = self._get_provider() provider = self._get_provider()