1
0
mirror of https://github.com/github/octodns.git synced 2024-05-11 05:55:00 +00:00

Implement black formatting

This commit is contained in:
Ross McFarland
2022-07-04 12:27:39 -07:00
parent 392d8b516f
commit e116d26eec
101 changed files with 6403 additions and 5490 deletions

View File

@@ -2,8 +2,12 @@
#
#
from __future__ import absolute_import, division, print_function, \
unicode_literals
from __future__ import (
absolute_import,
division,
print_function,
unicode_literals,
)
import dns.name
import dns.query
@@ -25,8 +29,22 @@ class AxfrBaseSource(BaseSource):
SUPPORTS_GEO = False
SUPPORTS_DYNAMIC = False
SUPPORTS = set(('A', 'AAAA', 'CAA', 'CNAME', 'LOC', 'MX', 'NS', 'PTR',
'SPF', 'SRV', 'SSHFP', 'TXT'))
SUPPORTS = set(
(
'A',
'AAAA',
'CAA',
'CNAME',
'LOC',
'MX',
'NS',
'PTR',
'SPF',
'SRV',
'SSHFP',
'TXT',
)
)
def __init__(self, id):
super(AxfrBaseSource, self).__init__(id)
@@ -35,7 +53,7 @@ class AxfrBaseSource(BaseSource):
return {
'ttl': records[0]['ttl'],
'type': _type,
'values': [r['value'] for r in records]
'values': [r['value'] for r in records],
}
_data_for_A = _data_for_multiple
@@ -46,75 +64,62 @@ class AxfrBaseSource(BaseSource):
values = []
for record in records:
flags, tag, value = record['value'].split(' ', 2)
values.append({
'flags': flags,
'tag': tag,
'value': value.replace('"', '')
})
return {
'ttl': records[0]['ttl'],
'type': _type,
'values': values
}
values.append(
{'flags': flags, 'tag': tag, 'value': value.replace('"', '')}
)
return {'ttl': records[0]['ttl'], 'type': _type, 'values': values}
def _data_for_LOC(self, _type, records):
values = []
for record in records:
lat_degrees, lat_minutes, lat_seconds, lat_direction, \
long_degrees, long_minutes, long_seconds, long_direction, \
altitude, size, precision_horz, precision_vert = \
record['value'].replace('m', '').split(' ', 11)
values.append({
'lat_degrees': lat_degrees,
'lat_minutes': lat_minutes,
'lat_seconds': lat_seconds,
'lat_direction': lat_direction,
'long_degrees': long_degrees,
'long_minutes': long_minutes,
'long_seconds': long_seconds,
'long_direction': long_direction,
'altitude': altitude,
'size': size,
'precision_horz': precision_horz,
'precision_vert': precision_vert,
})
return {
'ttl': records[0]['ttl'],
'type': _type,
'values': values
}
(
lat_degrees,
lat_minutes,
lat_seconds,
lat_direction,
long_degrees,
long_minutes,
long_seconds,
long_direction,
altitude,
size,
precision_horz,
precision_vert,
) = (record['value'].replace('m', '').split(' ', 11))
values.append(
{
'lat_degrees': lat_degrees,
'lat_minutes': lat_minutes,
'lat_seconds': lat_seconds,
'lat_direction': lat_direction,
'long_degrees': long_degrees,
'long_minutes': long_minutes,
'long_seconds': long_seconds,
'long_direction': long_direction,
'altitude': altitude,
'size': size,
'precision_horz': precision_horz,
'precision_vert': precision_vert,
}
)
return {'ttl': records[0]['ttl'], 'type': _type, 'values': values}
def _data_for_MX(self, _type, records):
values = []
for record in records:
preference, exchange = record['value'].split(' ', 1)
values.append({
'preference': preference,
'exchange': exchange,
})
return {
'ttl': records[0]['ttl'],
'type': _type,
'values': values
}
values.append({'preference': preference, 'exchange': exchange})
return {'ttl': records[0]['ttl'], 'type': _type, 'values': values}
def _data_for_TXT(self, _type, records):
values = [value['value'].replace(';', '\\;') for value in records]
return {
'ttl': records[0]['ttl'],
'type': _type,
'values': values
}
return {'ttl': records[0]['ttl'], 'type': _type, 'values': values}
_data_for_SPF = _data_for_TXT
def _data_for_single(self, _type, records):
record = records[0]
return {
'ttl': record['ttl'],
'type': _type,
'value': record['value']
}
return {'ttl': record['ttl'], 'type': _type, 'value': record['value']}
_data_for_CNAME = _data_for_single
_data_for_PTR = _data_for_single
@@ -123,37 +128,38 @@ class AxfrBaseSource(BaseSource):
values = []
for record in records:
priority, weight, port, target = record['value'].split(' ', 3)
values.append({
'priority': priority,
'weight': weight,
'port': port,
'target': target,
})
return {
'type': _type,
'ttl': records[0]['ttl'],
'values': values
}
values.append(
{
'priority': priority,
'weight': weight,
'port': port,
'target': target,
}
)
return {'type': _type, 'ttl': records[0]['ttl'], 'values': values}
def _data_for_SSHFP(self, _type, records):
values = []
for record in records:
algorithm, fingerprint_type, fingerprint = \
record['value'].split(' ', 2)
values.append({
'algorithm': algorithm,
'fingerprint_type': fingerprint_type,
'fingerprint': fingerprint,
})
return {
'type': _type,
'ttl': records[0]['ttl'],
'values': values
}
algorithm, fingerprint_type, fingerprint = record['value'].split(
' ', 2
)
values.append(
{
'algorithm': algorithm,
'fingerprint_type': fingerprint_type,
'fingerprint': fingerprint,
}
)
return {'type': _type, 'ttl': records[0]['ttl'], 'values': values}
def populate(self, zone, target=False, lenient=False):
self.log.debug('populate: name=%s, target=%s, lenient=%s', zone.name,
target, lenient)
self.log.debug(
'populate: name=%s, target=%s, lenient=%s',
zone.name,
target,
lenient,
)
values = defaultdict(lambda: defaultdict(list))
for record in self.zone_records(zone):
@@ -167,12 +173,18 @@ class AxfrBaseSource(BaseSource):
for name, types in values.items():
for _type, records in types.items():
data_for = getattr(self, f'_data_for_{_type}')
record = Record.new(zone, name, data_for(_type, records),
source=self, lenient=lenient)
record = Record.new(
zone,
name,
data_for(_type, records),
source=self,
lenient=lenient,
)
zone.add_record(record, lenient=lenient)
self.log.info('populate: found %s records',
len(zone.records) - before)
self.log.info(
'populate: found %s records', len(zone.records) - before
)
class AxfrSourceException(Exception):
@@ -180,10 +192,10 @@ class AxfrSourceException(Exception):
class AxfrSourceZoneTransferFailed(AxfrSourceException):
def __init__(self):
super(AxfrSourceZoneTransferFailed, self).__init__(
'Unable to Perform Zone Transfer')
'Unable to Perform Zone Transfer'
)
class AxfrSource(AxfrBaseSource):
@@ -195,6 +207,7 @@ class AxfrSource(AxfrBaseSource):
# The address of nameserver to perform zone transfer against
master: ns1.example.com
'''
def __init__(self, id, master):
self.log = logging.getLogger(f'AxfrSource[{id}]')
self.log.debug('__init__: id=%s, master=%s', id, master)
@@ -203,9 +216,10 @@ class AxfrSource(AxfrBaseSource):
def zone_records(self, zone):
try:
z = dns.zone.from_xfr(dns.query.xfr(self.master, zone.name,
relativize=False),
relativize=False)
z = dns.zone.from_xfr(
dns.query.xfr(self.master, zone.name, relativize=False),
relativize=False,
)
except DNSException:
raise AxfrSourceZoneTransferFailed()
@@ -213,12 +227,14 @@ class AxfrSource(AxfrBaseSource):
for (name, ttl, rdata) in z.iterate_rdatas():
rdtype = dns.rdatatype.to_text(rdata.rdtype)
records.append({
"name": name.to_text(),
"ttl": ttl,
"type": rdtype,
"value": rdata.to_text()
})
records.append(
{
"name": name.to_text(),
"ttl": ttl,
"type": rdtype,
"value": rdata.to_text(),
}
)
return records
@@ -228,14 +244,11 @@ class ZoneFileSourceException(Exception):
class ZoneFileSourceNotFound(ZoneFileSourceException):
def __init__(self):
super(ZoneFileSourceNotFound, self).__init__(
'Zone file not found')
super(ZoneFileSourceNotFound, self).__init__('Zone file not found')
class ZoneFileSourceLoadFailure(ZoneFileSourceException):
def __init__(self, error):
super(ZoneFileSourceLoadFailure, self).__init__(str(error))
@@ -258,11 +271,17 @@ class ZoneFileSource(AxfrBaseSource):
# (optional, default true)
check_origin: false
'''
def __init__(self, id, directory, file_extension='.', check_origin=True):
self.log = logging.getLogger(f'ZoneFileSource[{id}]')
self.log.debug('__init__: id=%s, directory=%s, file_extension=%s, '
'check_origin=%s', id,
directory, file_extension, check_origin)
self.log.debug(
'__init__: id=%s, directory=%s, file_extension=%s, '
'check_origin=%s',
id,
directory,
file_extension,
check_origin,
)
super(ZoneFileSource, self).__init__(id)
self.directory = directory
self.file_extension = file_extension
@@ -275,9 +294,12 @@ class ZoneFileSource(AxfrBaseSource):
zonefiles = listdir(self.directory)
if zone_filename in zonefiles:
try:
z = dns.zone.from_file(join(self.directory, zone_filename),
zone_name, relativize=False,
check_origin=self.check_origin)
z = dns.zone.from_file(
join(self.directory, zone_filename),
zone_name,
relativize=False,
check_origin=self.check_origin,
)
except DNSException as error:
raise ZoneFileSourceLoadFailure(error)
else:
@@ -292,12 +314,14 @@ class ZoneFileSource(AxfrBaseSource):
records = []
for (name, ttl, rdata) in z.iterate_rdatas():
rdtype = dns.rdatatype.to_text(rdata.rdtype)
records.append({
"name": name.to_text(),
"ttl": ttl,
"type": rdtype,
"value": rdata.to_text()
})
records.append(
{
"name": name.to_text(),
"ttl": ttl,
"type": rdtype,
"value": rdata.to_text(),
}
)
self._zone_records[zone.name] = records
except ZoneFileSourceNotFound: