1
0
mirror of https://github.com/github/octodns.git synced 2024-05-11 05:55:00 +00:00

ZoneFileSource: allow users to specify file extension

This commit is contained in:
Adam Smith
2021-01-08 19:36:58 -05:00
parent 97dbb6a782
commit b2eab63d54
3 changed files with 34 additions and 4 deletions

View File

@@ -206,17 +206,24 @@ class ZoneFileSource(AxfrBaseSource):
class: octodns.source.axfr.ZoneFileSource
# The directory holding the zone files
# Filenames should match zone name (eg. example.com.)
# with optional extension specified with file_extension
directory: ./zonefiles
# File extension on zone files
# Appended to zone name to locate file
# (optional, default None)
file_extension: zone
# Should sanity checks of the origin node be done
# (optional, default true)
check_origin: false
'''
def __init__(self, id, directory, check_origin=True):
def __init__(self, id, directory, file_extension=None, check_origin=True):
self.log = logging.getLogger('ZoneFileSource[{}]'.format(id))
self.log.debug('__init__: id=%s, directory=%s, check_origin=%s', id,
directory, check_origin)
self.log.debug('__init__: id=%s, directory=%s, file_extension=%s, '
'check_origin=%s', id,
directory, file_extension, check_origin)
super(ZoneFileSource, self).__init__(id)
self.directory = directory
self.file_extension = file_extension
self.check_origin = check_origin
self._zone_records = {}
@@ -225,7 +232,11 @@ class ZoneFileSource(AxfrBaseSource):
zonefiles = listdir(self.directory)
if zone_name in zonefiles:
try:
z = dns.zone.from_file(join(self.directory, zone_name),
filename = zone_name
if self.file_extension:
filename = '{}{}'.format(zone_name,
self.file_extension.lstrip('.'))
z = dns.zone.from_file(join(self.directory, filename),
zone_name, relativize=False,
check_origin=self.check_origin)
except DNSException as error:

View File

@@ -45,6 +45,13 @@ class TestAxfrSource(TestCase):
class TestZoneFileSource(TestCase):
source = ZoneFileSource('test', './tests/zones')
source_extension = ZoneFileSource('test', './tests/zones', 'extension')
def test_zonefiles_with_extension(self):
# Load zonefiles with a specified file extension
valid = Zone('unit.tests.', [])
self.source_extension.populate(valid)
self.assertEquals(1, len(valid.records))
def test_populate(self):
# Valid zone file in directory

View File

@@ -0,0 +1,12 @@
$ORIGIN unit.tests.
@ 3600 IN SOA ns1.unit.tests. root.unit.tests. (
2018071501 ; Serial
3600 ; Refresh (1 hour)
600 ; Retry (10 minutes)
604800 ; Expire (1 week)
3600 ; NXDOMAIN ttl (1 hour)
)
; NS Records
@ 3600 IN NS ns1.unit.tests.
@ 3600 IN NS ns2.unit.tests.