mirror of
				https://github.com/github/octodns.git
				synced 2024-05-11 05:55:00 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			378 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			378 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #
 | |
| #
 | |
| #
 | |
| from __future__ import absolute_import, division, print_function, \
 | |
|     unicode_literals
 | |
| 
 | |
| from requests import HTTPError, Session, post
 | |
| from collections import defaultdict
 | |
| import logging
 | |
| import string
 | |
| import time
 | |
| 
 | |
| from ..record import Record
 | |
| from .base import BaseProvider
 | |
| 
 | |
| 
 | |
| def add_trailing_dot(s):
 | |
|     assert s
 | |
|     assert s[-1] != '.'
 | |
|     return s + '.'
 | |
| 
 | |
| 
 | |
| def remove_trailing_dot(s):
 | |
|     assert s
 | |
|     assert s[-1] == '.'
 | |
|     return s[:-1]
 | |
| 
 | |
| 
 | |
| def escape_semicolon(s):
 | |
|     assert s
 | |
|     return string.replace(s, ';', '\\;')
 | |
| 
 | |
| 
 | |
| def unescape_semicolon(s):
 | |
|     assert s
 | |
|     return string.replace(s, '\\;', ';')
 | |
| 
 | |
| 
 | |
| class RackspaceProvider(BaseProvider):
 | |
|     SUPPORTS_GEO = False
 | |
|     SUPPORTS_DYNAMIC = False
 | |
|     SUPPORTS = set(('A', 'AAAA', 'ALIAS', 'CNAME', 'MX', 'NS', 'PTR', 'SPF',
 | |
|                     'TXT'))
 | |
|     TIMEOUT = 5
 | |
| 
 | |
|     def __init__(self, id, username, api_key, ratelimit_delay=0.0, *args,
 | |
|                  **kwargs):
 | |
|         '''
 | |
|         Rackspace API v1 Provider
 | |
| 
 | |
|         rackspace:
 | |
|             class: octodns.provider.rackspace.RackspaceProvider
 | |
|             # The the username to authenticate with (required)
 | |
|             username: username
 | |
|             # The api key that grants access for that user (required)
 | |
|             api_key: api-key
 | |
|         '''
 | |
|         self.log = logging.getLogger('RackspaceProvider[{}]'.format(id))
 | |
|         super(RackspaceProvider, self).__init__(id, *args, **kwargs)
 | |
| 
 | |
|         auth_token, dns_endpoint = self._get_auth_token(username, api_key)
 | |
|         self.dns_endpoint = dns_endpoint
 | |
| 
 | |
|         self.ratelimit_delay = float(ratelimit_delay)
 | |
| 
 | |
|         sess = Session()
 | |
|         sess.headers.update({'X-Auth-Token': auth_token})
 | |
|         self._sess = sess
 | |
| 
 | |
|         # Map record type, name, and data to an id when populating so that
 | |
|         # we can find the id for update and delete operations.
 | |
|         self._id_map = {}
 | |
| 
 | |
|     def _get_auth_token(self, username, api_key):
 | |
|         ret = post('https://identity.api.rackspacecloud.com/v2.0/tokens',
 | |
|                    json={"auth": {
 | |
|                        "RAX-KSKEY:apiKeyCredentials": {"username": username,
 | |
|                                                        "apiKey": api_key}}},
 | |
|                    )
 | |
|         cloud_dns_endpoint = \
 | |
|             [x for x in ret.json()['access']['serviceCatalog'] if
 | |
|              x['name'] == 'cloudDNS'][0]['endpoints'][0]['publicURL']
 | |
|         return ret.json()['access']['token']['id'], cloud_dns_endpoint
 | |
| 
 | |
|     def _get_zone_id_for(self, zone):
 | |
|         ret = self._request('GET', 'domains', pagination_key='domains')
 | |
|         return [x for x in ret if x['name'] == zone.name[:-1]][0]['id']
 | |
| 
 | |
|     def _request(self, method, path, data=None, pagination_key=None):
 | |
|         self.log.debug('_request: method=%s, path=%s', method, path)
 | |
|         url = '{}/{}'.format(self.dns_endpoint, path)
 | |
| 
 | |
|         if pagination_key:
 | |
|             resp = self._paginated_request_for_url(method, url, data,
 | |
|                                                    pagination_key)
 | |
|         else:
 | |
|             resp = self._request_for_url(method, url, data)
 | |
|         time.sleep(self.ratelimit_delay)
 | |
|         return resp
 | |
| 
 | |
|     def _request_for_url(self, method, url, data):
 | |
|         resp = self._sess.request(method, url, json=data, timeout=self.TIMEOUT)
 | |
|         self.log.debug('_request:   status=%d', resp.status_code)
 | |
|         resp.raise_for_status()
 | |
|         return resp
 | |
| 
 | |
|     def _paginated_request_for_url(self, method, url, data, pagination_key):
 | |
|         acc = []
 | |
| 
 | |
|         resp = self._sess.request(method, url, json=data, timeout=self.TIMEOUT)
 | |
|         self.log.debug('_request:   status=%d', resp.status_code)
 | |
|         resp.raise_for_status()
 | |
|         acc.extend(resp.json()[pagination_key])
 | |
| 
 | |
|         next_page = [x for x in resp.json().get('links', []) if
 | |
|                      x['rel'] == 'next']
 | |
|         if next_page:
 | |
|             url = next_page[0]['href']
 | |
|             acc.extend(self._paginated_request_for_url(method, url, data,
 | |
|                                                        pagination_key))
 | |
|             return acc
 | |
|         else:
 | |
|             return acc
 | |
| 
 | |
|     def _post(self, path, data=None):
 | |
|         return self._request('POST', path, data=data)
 | |
| 
 | |
|     def _put(self, path, data=None):
 | |
|         return self._request('PUT', path, data=data)
 | |
| 
 | |
|     def _delete(self, path, data=None):
 | |
|         return self._request('DELETE', path, data=data)
 | |
| 
 | |
|     @classmethod
 | |
|     def _key_for_record(cls, rs_record):
 | |
|         return rs_record['type'], rs_record['name'], rs_record['data']
 | |
| 
 | |
|     def _data_for_multiple(self, rrset):
 | |
|         return {
 | |
|             'type': rrset[0]['type'],
 | |
|             'values': [r['data'] for r in rrset],
 | |
|             'ttl': rrset[0]['ttl']
 | |
|         }
 | |
| 
 | |
|     _data_for_A = _data_for_multiple
 | |
|     _data_for_AAAA = _data_for_multiple
 | |
| 
 | |
|     def _data_for_NS(self, rrset):
 | |
|         return {
 | |
|             'type': rrset[0]['type'],
 | |
|             'values': [add_trailing_dot(r['data']) for r in rrset],
 | |
|             'ttl': rrset[0]['ttl']
 | |
|         }
 | |
| 
 | |
|     def _data_for_single(self, record):
 | |
|         return {
 | |
|             'type': record[0]['type'],
 | |
|             'value': add_trailing_dot(record[0]['data']),
 | |
|             'ttl': record[0]['ttl']
 | |
|         }
 | |
| 
 | |
|     _data_for_ALIAS = _data_for_single
 | |
|     _data_for_CNAME = _data_for_single
 | |
|     _data_for_PTR = _data_for_single
 | |
| 
 | |
|     def _data_for_textual(self, rrset):
 | |
|         return {
 | |
|             'type': rrset[0]['type'],
 | |
|             'values': [escape_semicolon(r['data']) for r in rrset],
 | |
|             'ttl': rrset[0]['ttl']
 | |
|         }
 | |
| 
 | |
|     _data_for_SPF = _data_for_textual
 | |
|     _data_for_TXT = _data_for_textual
 | |
| 
 | |
|     def _data_for_MX(self, rrset):
 | |
|         values = []
 | |
|         for record in rrset:
 | |
|             values.append({
 | |
|                 'priority': record['priority'],
 | |
|                 'value': add_trailing_dot(record['data']),
 | |
|             })
 | |
|         return {
 | |
|             'type': rrset[0]['type'],
 | |
|             'values': values,
 | |
|             'ttl': rrset[0]['ttl']
 | |
|         }
 | |
| 
 | |
|     def populate(self, zone, target=False, lenient=False):
 | |
|         self.log.debug('populate: name=%s', zone.name)
 | |
|         resp_data = None
 | |
|         try:
 | |
|             domain_id = self._get_zone_id_for(zone)
 | |
|             resp_data = self._request('GET',
 | |
|                                       'domains/{}/records'.format(domain_id),
 | |
|                                       pagination_key='records')
 | |
|             self.log.debug('populate:   loaded')
 | |
|         except HTTPError as e:
 | |
|             if e.response.status_code == 401:
 | |
|                 # Nicer error message for auth problems
 | |
|                 raise Exception('Rackspace request unauthorized')
 | |
|             elif e.response.status_code == 404:
 | |
|                 # Zone not found leaves the zone empty instead of failing.
 | |
|                 return False
 | |
|             raise
 | |
| 
 | |
|         before = len(zone.records)
 | |
| 
 | |
|         if resp_data:
 | |
|             records = self._group_records(resp_data)
 | |
|             for record_type, records_of_type in records.items():
 | |
|                 for raw_record_name, record_set in records_of_type.items():
 | |
|                     data_for = getattr(self,
 | |
|                                        '_data_for_{}'.format(record_type))
 | |
|                     record_name = zone.hostname_from_fqdn(raw_record_name)
 | |
|                     record = Record.new(zone, record_name,
 | |
|                                         data_for(record_set),
 | |
|                                         source=self)
 | |
|                     zone.add_record(record, lenient=lenient)
 | |
| 
 | |
|         self.log.info('populate:   found %s records, exists=True',
 | |
|                       len(zone.records) - before)
 | |
|         return True
 | |
| 
 | |
|     def _group_records(self, all_records):
 | |
|         records = defaultdict(lambda: defaultdict(list))
 | |
|         for record in all_records:
 | |
|             self._id_map[self._key_for_record(record)] = record['id']
 | |
|             records[record['type']][record['name']].append(record)
 | |
|         return records
 | |
| 
 | |
|     @staticmethod
 | |
|     def _record_for_single(record, value):
 | |
|         return {
 | |
|             'name': remove_trailing_dot(record.fqdn),
 | |
|             'type': record._type,
 | |
|             'data': value,
 | |
|             'ttl': max(record.ttl, 300),
 | |
|         }
 | |
| 
 | |
|     _record_for_A = _record_for_single
 | |
|     _record_for_AAAA = _record_for_single
 | |
| 
 | |
|     @staticmethod
 | |
|     def _record_for_named(record, value):
 | |
|         return {
 | |
|             'name': remove_trailing_dot(record.fqdn),
 | |
|             'type': record._type,
 | |
|             'data': remove_trailing_dot(value),
 | |
|             'ttl': max(record.ttl, 300),
 | |
|         }
 | |
| 
 | |
|     _record_for_NS = _record_for_named
 | |
|     _record_for_ALIAS = _record_for_named
 | |
|     _record_for_CNAME = _record_for_named
 | |
|     _record_for_PTR = _record_for_named
 | |
| 
 | |
|     @staticmethod
 | |
|     def _record_for_textual(record, value):
 | |
|         return {
 | |
|             'name': remove_trailing_dot(record.fqdn),
 | |
|             'type': record._type,
 | |
|             'data': unescape_semicolon(value),
 | |
|             'ttl': max(record.ttl, 300),
 | |
|         }
 | |
| 
 | |
|     _record_for_SPF = _record_for_textual
 | |
|     _record_for_TXT = _record_for_textual
 | |
| 
 | |
|     @staticmethod
 | |
|     def _record_for_MX(record, value):
 | |
|         return {
 | |
|             'name': remove_trailing_dot(record.fqdn),
 | |
|             'type': record._type,
 | |
|             'data': remove_trailing_dot(value.exchange),
 | |
|             'ttl': max(record.ttl, 300),
 | |
|             'priority': value.preference
 | |
|         }
 | |
| 
 | |
|     def _get_values(self, record):
 | |
|         try:
 | |
|             return record.values
 | |
|         except AttributeError:
 | |
|             return [record.value]
 | |
| 
 | |
|     def _mod_Create(self, change):
 | |
|         return self._create_given_change_values(change,
 | |
|                                                 self._get_values(change.new))
 | |
| 
 | |
|     def _create_given_change_values(self, change, values):
 | |
|         transformer = getattr(self, "_record_for_{}".format(change.new._type))
 | |
|         return [transformer(change.new, v) for v in values]
 | |
| 
 | |
|     def _mod_Update(self, change):
 | |
|         existing_values = self._get_values(change.existing)
 | |
|         new_values = self._get_values(change.new)
 | |
| 
 | |
|         # A reduction in number of values in an update record needs
 | |
|         # to get upgraded into a Delete change for the removed values.
 | |
|         deleted_values = set(existing_values) - set(new_values)
 | |
|         delete_out = self._delete_given_change_values(change, deleted_values)
 | |
| 
 | |
|         # An increase in number of values in an update record needs
 | |
|         # to get upgraded into a Create change for the added values.
 | |
|         create_values = set(new_values) - set(existing_values)
 | |
|         create_out = self._create_given_change_values(change, create_values)
 | |
| 
 | |
|         update_out = []
 | |
|         update_values = set(new_values).intersection(set(existing_values))
 | |
|         for value in update_values:
 | |
|             transformer = getattr(self,
 | |
|                                   "_record_for_{}".format(change.new._type))
 | |
|             prior_rs_record = transformer(change.existing, value)
 | |
|             prior_key = self._key_for_record(prior_rs_record)
 | |
|             next_rs_record = transformer(change.new, value)
 | |
|             next_key = self._key_for_record(next_rs_record)
 | |
|             next_rs_record["id"] = self._id_map[prior_key]
 | |
|             del next_rs_record["type"]
 | |
|             update_out.append(next_rs_record)
 | |
|             self._id_map[next_key] = self._id_map[prior_key]
 | |
|             del self._id_map[prior_key]
 | |
|         return create_out, update_out, delete_out
 | |
| 
 | |
|     def _mod_Delete(self, change):
 | |
|         return self._delete_given_change_values(change, self._get_values(
 | |
|             change.existing))
 | |
| 
 | |
|     def _delete_given_change_values(self, change, values):
 | |
|         transformer = getattr(self, "_record_for_{}".format(
 | |
|             change.existing._type))
 | |
|         out = []
 | |
|         for value in values:
 | |
|             rs_record = transformer(change.existing, value)
 | |
|             key = self._key_for_record(rs_record)
 | |
|             out.append('id=' + self._id_map[key])
 | |
|             del self._id_map[key]
 | |
|         return out
 | |
| 
 | |
|     def _apply(self, plan):
 | |
|         desired = plan.desired
 | |
|         changes = plan.changes
 | |
|         self.log.debug('_apply: zone=%s, len(changes)=%d', desired.name,
 | |
|                        len(changes))
 | |
| 
 | |
|         # Creates, updates, and deletes are processed by different endpoints
 | |
|         # and are broken out by record-set entries; pre-process everything
 | |
|         # into these buckets in order to minimize the number of API calls.
 | |
|         domain_id = self._get_zone_id_for(desired)
 | |
|         creates = []
 | |
|         updates = []
 | |
|         deletes = []
 | |
|         for change in changes:
 | |
|             if change.__class__.__name__ == 'Create':
 | |
|                 creates += self._mod_Create(change)
 | |
|             elif change.__class__.__name__ == 'Update':
 | |
|                 add_creates, add_updates, add_deletes = self._mod_Update(
 | |
|                     change)
 | |
|                 creates += add_creates
 | |
|                 updates += add_updates
 | |
|                 deletes += add_deletes
 | |
|             else:
 | |
|                 assert change.__class__.__name__ == 'Delete'
 | |
|                 deletes += self._mod_Delete(change)
 | |
| 
 | |
|         if deletes:
 | |
|             params = "&".join(sorted(deletes))
 | |
|             self._delete('domains/{}/records?{}'.format(domain_id, params))
 | |
| 
 | |
|         if updates:
 | |
|             data = {"records": sorted(updates, key=lambda v: v['name'])}
 | |
|             self._put('domains/{}/records'.format(domain_id), data=data)
 | |
| 
 | |
|         if creates:
 | |
|             data = {"records": sorted(creates, key=lambda v: v['type'] +
 | |
|                                       v['name'] +
 | |
|                                       v.get('data', ''))}
 | |
|             self._post('domains/{}/records'.format(domain_id), data=data)
 |