1
0
mirror of https://github.com/github/octodns.git synced 2024-05-11 05:55:00 +00:00
Files
github-octodns/octodns/manager.py
Ross McFarland 371138dbec Fix zone-level always-dry-run functionality
Thanks @offmindby!
2017-06-08 18:34:33 -07:00

360 lines
13 KiB
Python

#
#
#
from __future__ import absolute_import, division, print_function, \
unicode_literals
from StringIO import StringIO
from concurrent.futures import Future, ThreadPoolExecutor
from importlib import import_module
from os import environ
import logging
from .provider.base import BaseProvider
from .provider.yaml import YamlProvider
from .yaml import safe_load
from .zone import Zone
class _AggregateTarget(object):
id = 'aggregate'
def __init__(self, targets):
self.targets = targets
def supports(self, record):
for target in self.targets:
if not target.supports(record):
return False
return True
@property
def SUPPORTS_GEO(self):
for target in self.targets:
if not target.SUPPORTS_GEO:
return False
return True
class MainThreadExecutor(object):
'''
Dummy executor that runs things on the main thread during the involcation
of submit, but still returns a future object with the result. This allows
code to be written to handle async, even in the case where we don't want to
use multiple threads/workers and would prefer that things flow as if
traditionally written.
'''
def submit(self, func, *args, **kwargs):
future = Future()
try:
future.set_result(func(*args, **kwargs))
except Exception as e:
# TODO: get right stacktrace here
future.set_exception(e)
return future
class Manager(object):
log = logging.getLogger('Manager')
def __init__(self, config_file, max_workers=None):
self.log.info('__init__: config_file=%s', config_file)
# Read our config file
with open(config_file, 'r') as fh:
self.config = safe_load(fh, enforce_order=False)
manager_config = self.config.get('manager', {})
max_workers = manager_config.get('max_workers', 1) \
if max_workers is None else max_workers
if max_workers > 1:
self._executor = ThreadPoolExecutor(max_workers=max_workers)
else:
self._executor = MainThreadExecutor()
self.log.debug('__init__: configuring providers')
self.providers = {}
for provider_name, provider_config in self.config['providers'].items():
# Get our class and remove it from the provider_config
try:
_class = provider_config.pop('class')
except KeyError:
self.log.exception('Invalid provider class')
raise Exception('Provider {} is missing class'
.format(provider_name))
_class = self._get_provider_class(_class)
# Build up the arguments we need to pass to the provider
kwargs = {}
for k, v in provider_config.items():
try:
if v.startswith('env/'):
try:
env_var = v[4:]
v = environ[env_var]
except KeyError:
self.log.exception('Invalid provider config')
raise Exception('Incorrect provider config, '
'missing env var {}'
.format(env_var))
except AttributeError:
pass
kwargs[k] = v
try:
self.providers[provider_name] = _class(provider_name, **kwargs)
except TypeError:
self.log.exception('Invalid provider config')
raise Exception('Incorrect provider config for {}'
.format(provider_name))
zone_tree = {}
# sort by reversed strings so that parent zones always come first
for name in sorted(self.config['zones'].keys(), key=lambda s: s[::-1]):
# ignore trailing dots, and reverse
pieces = name[:-1].split('.')[::-1]
# where starts out at the top
where = zone_tree
# for all the pieces
for piece in pieces:
try:
where = where[piece]
# our current piece already exists, just point where at
# it's value
except KeyError:
# our current piece doesn't exist, create it
where[piece] = {}
# and then point where at it's newly created value
where = where[piece]
self.zone_tree = zone_tree
def _get_provider_class(self, _class):
try:
module_name, class_name = _class.rsplit('.', 1)
module = import_module(module_name)
except (ImportError, ValueError):
self.log.exception('_get_provider_class: Unable to import '
'module %s', _class)
raise Exception('Unknown provider class: {}'.format(_class))
try:
return getattr(module, class_name)
except AttributeError:
self.log.exception('_get_provider_class: Unable to get class %s '
'from module %s', class_name, module)
raise Exception('Unknown provider class: {}'.format(_class))
def configured_sub_zones(self, zone_name):
# Reversed pieces of the zone name
pieces = zone_name[:-1].split('.')[::-1]
# Point where at the root of the tree
where = self.zone_tree
# Until we've hit the bottom of this zone
try:
while pieces:
# Point where at the value of our current piece
where = where[pieces.pop(0)]
except KeyError:
self.log.debug('configured_sub_zones: unknown zone, %s, no subs',
zone_name)
return set()
# We're not pointed at the dict for our name, the keys of which will be
# any subzones
sub_zone_names = where.keys()
self.log.debug('configured_sub_zones: subs=%s', sub_zone_names)
return set(sub_zone_names)
def _populate_and_plan(self, zone_name, sources, targets):
self.log.debug('sync: populating, zone=%s', zone_name)
zone = Zone(zone_name,
sub_zones=self.configured_sub_zones(zone_name))
for source in sources:
source.populate(zone)
self.log.debug('sync: planning, zone=%s', zone_name)
plans = []
for target in targets:
plan = target.plan(zone)
if plan:
plans.append((target, plan))
return plans
def sync(self, eligible_zones=[], eligible_targets=[], dry_run=True,
force=False):
self.log.info('sync: eligible_zones=%s, eligible_targets=%s, '
'dry_run=%s, force=%s', eligible_zones, eligible_targets,
dry_run, force)
zones = self.config['zones'].items()
if eligible_zones:
zones = filter(lambda d: d[0] in eligible_zones, zones)
futures = []
for zone_name, config in zones:
self.log.info('sync: zone=%s', zone_name)
try:
sources = config['sources']
except KeyError:
raise Exception('Zone {} is missing sources'.format(zone_name))
try:
targets = config['targets']
except KeyError:
raise Exception('Zone {} is missing targets'.format(zone_name))
if eligible_targets:
targets = filter(lambda d: d in eligible_targets, targets)
self.log.info('sync: sources=%s -> targets=%s', sources, targets)
try:
sources = [self.providers[source] for source in sources]
except KeyError:
raise Exception('Zone {}, unknown source: {}'.format(zone_name,
source))
try:
trgs = []
for target in targets:
trg = self.providers[target]
if not isinstance(trg, BaseProvider):
raise Exception('{} - "{}" does not support targeting'
.format(trg, target))
trgs.append(trg)
targets = trgs
except KeyError:
raise Exception('Zone {}, unknown target: {}'.format(zone_name,
target))
futures.append(self._executor.submit(self._populate_and_plan,
zone_name, sources, targets))
# Wait on all results and unpack/flatten them in to a list of target &
# plan pairs.
plans = [p for f in futures for p in f.result()]
hr = '*************************************************************' \
'*******************\n'
buf = StringIO()
buf.write('\n')
if plans:
current_zone = None
for target, plan in plans:
if plan.desired.name != current_zone:
current_zone = plan.desired.name
buf.write(hr)
buf.write('* ')
buf.write(current_zone)
buf.write('\n')
buf.write(hr)
buf.write('* ')
buf.write(target.id)
buf.write(' (')
buf.write(target)
buf.write(')\n* ')
for change in plan.changes:
buf.write(change.__repr__(leader='* '))
buf.write('\n* ')
buf.write('Summary: ')
buf.write(plan)
buf.write('\n')
else:
buf.write(hr)
buf.write('No changes were planned\n')
buf.write(hr)
buf.write('\n')
self.log.info(buf.getvalue())
if not force:
self.log.debug('sync: checking safety')
for target, plan in plans:
plan.raise_if_unsafe()
if dry_run:
return 0
total_changes = 0
self.log.debug('sync: applying')
zones = self.config['zones']
for target, plan in plans:
zone_name = plan.existing.name
if zones[zone_name].get('always-dry-run', False):
self.log.info('sync: zone=%s skipping always-dry-run',
zone_name)
continue
total_changes += target.apply(plan)
self.log.info('sync: %d total changes', total_changes)
return total_changes
def compare(self, a, b, zone):
'''
Compare zone data between 2 sources.
Note: only things supported by both sources will be considered
'''
self.log.info('compare: a=%s, b=%s, zone=%s', a, b, zone)
try:
a = [self.providers[source] for source in a]
b = [self.providers[source] for source in b]
except KeyError as e:
raise Exception('Unknown source: {}'.format(e.args[0]))
sub_zones = self.configured_sub_zones(zone)
za = Zone(zone, sub_zones)
for source in a:
source.populate(za)
zb = Zone(zone, sub_zones)
for source in b:
source.populate(zb)
return zb.changes(za, _AggregateTarget(a + b))
def dump(self, zone, output_dir, source, *sources):
'''
Dump zone data from the specified source
'''
self.log.info('dump: zone=%s, sources=%s', zone, sources)
# We broke out source to force at least one to be passed, add it to any
# others we got.
sources = [source] + list(sources)
try:
sources = [self.providers[s] for s in sources]
except KeyError as e:
raise Exception('Unknown source: {}'.format(e.args[0]))
target = YamlProvider('dump', output_dir)
zone = Zone(zone, self.configured_sub_zones(zone))
for source in sources:
source.populate(zone)
plan = target.plan(zone)
target.apply(plan)
def validate_configs(self):
for zone_name, config in self.config['zones'].items():
zone = Zone(zone_name, self.configured_sub_zones(zone_name))
try:
sources = config['sources']
except KeyError:
raise Exception('Zone {} is missing sources'.format(zone_name))
try:
sources = [self.providers[source] for source in sources]
except KeyError:
raise Exception('Zone {}, unknown source: {}'.format(zone_name,
source))
for source in sources:
if isinstance(source, YamlProvider):
source.populate(zone)