#!/usr/bin/env python3 import os import sys import subprocess import re import time import datetime import json from collections import namedtuple from hashlib import sha256 from pathlib import Path from string import Template class HookException(ValueError): """Exception raised when there is an error in input data. Attribures: message -- the cause of problem fname -- affected file stderr -- output of the specific checker """ def __init__(self, message, fname=None, stderr=None): self.message = message self.fname = fname self.stderr = stderr def __str__(self): r = list() if self.fname: r.append("{fname}: ".format(fname=self.fname)) r.append(self.message) r.append("\n") if self.stderr: r.append("\n") r.append(self.stderr) r.append("\n\n") return "".join(r) def get_head(): r = subprocess.run( ["git", "rev-parse", "--verify", "HEAD"], stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, ) if r.returncode == 0: return r.stdout.decode("utf-8").strip() else: # Initial commit: diff against an empty tree object return "4b825dc642cb6eb9a060e54bf8d69288fbee4904" def check_whitespace_errors(against, revision=None): if revision: cmd = ["git", "diff-tree", "--check", against, revision] else: cmd = ["git", "diff-index", "--check", "--cached", against] r = subprocess.run( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, ) if r.returncode != 0: raise HookException( "Whitespace errors", stderr=r.stdout.decode("utf-8"), ) def get_file_contents(path, revision=None): """ Return contents of a file in staged env or in some revision. """ revision = "" if revision is None else revision r = subprocess.run( ["git", "show", "{r}:{p}".format(r=revision, p=path)], stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, check=True, ) return r.stdout def compile_zone(zonename, zonedata): """ Compile the zone. Return tuple with results.""" CompileResults = namedtuple( "CompileResults", "success, serial, zonehash, stderr", ) r = subprocess.run( ["/usr/sbin/named-compilezone", "-o", "-", zonename, "/dev/stdin"], input=zonedata, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) stderr = r.stderr.decode("utf-8") m = re.search(r"^zone.*loaded serial ([0-9]*)$", stderr, re.MULTILINE) if r.returncode == 0 and m: serial = m.group(1) zonehash = sha256(r.stdout).hexdigest() return CompileResults(True, serial, zonehash, stderr) else: return CompileResults(False, None, None, stderr) def is_serial_increased(old, new): """ Return true if serial number was increased using RFC 1982 logic. """ old, new = (int(n) for n in [old, new]) diff = (new - old) % 2**32 return 0 < diff < (2**31 - 1) def get_increased_serial(old): """ Return increased serial number, automatically recognizing the type. """ old = int(old) now = int(time.time()) todayserial = int(datetime.date.today().strftime("%Y%m%d00")) # Note to my future self: This is expected to break on 2034-06-16 # as unix timestamp will become in the same range as YYMMDDnn serial if 1e9 < old < now: # Serial is unix timestamp return str(now) elif 2e9 < old < todayserial: # Serial is YYYYMMDDnn, updated before today return str(todayserial) else: # No pattern recognized, just increase the number return str(old + 1) def get_altered_files(against, diff_filter=None, revision=None): """ Return list of changed files. If revision is None, list changes between staging area and revision. Otherwise differences between two revisions are computed. """ cmd = ["git", "diff", "--name-only", "-z"] if diff_filter: cmd.append("--diff-filter={}".format(diff_filter)) if revision: cmd.append(against) cmd.append(revision) else: cmd.append("--cached") cmd.append(against) r = subprocess.run( cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, check=True, ) if r.stdout: return (Path(p) for p in r.stdout.decode("utf-8").rstrip("\0").split("\0")) else: return list() def get_zone_origin(zonedata): """ Parse $ORIGIN directive before the SOA record. Return zone name without the trailing dot. """ for line in zonedata.splitlines(): if re.match(br"^[^\s;]+\s+([0-9]+\s+)?(IN\s+)?SOA\s+", line, re.I): break m = re.match(br"^\$ORIGIN\s+([^ ]+)\.\s*(;.*)?$", line, re.I) if m: return m.group(1).decode("utf-8").lower() def get_zone_name(path, zonedata): """ Try to guess zone name from either filename or the first $ORIGIN. Throw a HookException if filename and zone ORIGIN differ more than in slashes. """ stemname = Path(path).stem.lower() originname = get_zone_origin(zonedata) if originname: tt = str.maketrans("", "", "/_,:-+*%^&#$") sn, on = [s.translate(tt) for s in [stemname, originname]] if sn != on: raise HookException( "Zone origin {o} differs from zone file.".format(o=originname), fname=path, ) return originname else: return stemname def check_updated_zones(against, revision=None, autoupdate_serial=False): """ Check whether all updated zone files compile. """ for f in get_altered_files(against, "AM", revision): if not f.suffix == ".zone": continue print("Checking file {f}".format(f=f)) zonedata = get_file_contents(f, revision) zname = get_zone_name(f, zonedata) rnew = compile_zone(zname, zonedata) if not rnew.success: raise HookException( "New zone version does not compile", f, rnew.stderr, ) try: zonedata = get_file_contents(f, against) zname = get_zone_name(f, zonedata) rold = compile_zone(zname, zonedata) if (rold.success and rold.zonehash != rnew.zonehash and not is_serial_increased(rold.serial, rnew.serial)): errmsg = "Zone contents changed without increasing serial." diagmsg = "Old revision {}, serial {}, new serial {}".format( against, rold.serial, rnew.serial, ) if autoupdate_serial: newserial = get_increased_serial(rnew.serial) replace_serial(f, rnew.serial, newserial) errmsg += " Serial has been automatically increased." errmsg += " Check and recommit." raise HookException( errmsg, fname=f, stderr=diagmsg, ) except subprocess.CalledProcessError: pass # Old version of zone did not exist def get_config(name, type_=None): cmd = ["git", "config", ] if type_ == bool: cmd.append("--bool") elif type_ == int: cmd.append("--int") elif type_: raise ValueError("Invalid type supplied") cmd.append(name) r = subprocess.run( cmd, stdout=subprocess.PIPE, ) if r.returncode != 0: return None if type_ == bool: return r.stdout == b"true\n" elif type_ == int: return int(r.stdout) else: return r.stdout.decode("utf-8").rstrip("\n") def replace_serial(path, oldserial, newserial): contents = path.read_text() updated, count = re.subn( r'(^.*\sSOA\s.+?\s){}([^0-9])'.format(oldserial), r'\g<1>{}\g<2>'.format(newserial), contents, count=1, flags=re.DOTALL | re.IGNORECASE | re.MULTILINE, ) if count != 1: raise HookException("Cannot update zone serial number") path.write_text(updated) def template_config(checkoutpath, template, blacklist=set(), whitelist=set()): """ Recursively find all *.zone files and template config file using a simple JSON based template like this: { "header": "# Managed by dzonegit, do not edit.\n", "footer": "", "item": " - zone: \"$zonename\"\n file: \"$zonefile\"\n $zonevar\n", "defaultvar": "template: default", "zonevars": { "example.com": "template: signed" } } Available placeholders are: - $datetime - timestamp of file creation - $zonename - zone name, without trailing dot - $zonefile - full path to zone file - $zonevar - per-zone specific variables, content of `defaultvar` if not defined for current zone """ tpl = json.loads(template) headertpl = Template(tpl.get("header", "")) footertpl = Template(tpl.get("footer", "")) itemtpl = Template(tpl.get("item", "")) defaultvar = tpl.get("defaultvar", "") zonevars = tpl.get("zonevars", dict()) out = list() zones = dict() mapping = {"datetime": datetime.datetime.now().strftime("%c")} out.append(headertpl.substitute(mapping)) for f in sorted(Path(checkoutpath).glob("**/*.zone")): zonename = get_zone_name(f, f.read_bytes()) if whitelist and zonename not in whitelist: print( "WARNING: Ignoring zone {} - not whitelisted for " "this repository.".format(zonename), ) continue if zonename in blacklist: print( "WARNING: Ignoring zone {} - blacklisted for " "this repository.".format(zonename), ) continue if zonename in zones: print( "WARNING: Duplicate zone file found for zone {}. " "Using file {}, ignoring {}.".format( zonename, zones[zonename], f.relative_to(checkoutpath), ), ) continue zones[zonename] = f.relative_to(checkoutpath) zonevar = zonevars[zonename] if zonename in zonevars else defaultvar out.append(itemtpl.substitute( mapping, zonename=zonename, zonefile=str(f), zonevar=zonevar, )) out.append(footertpl.substitute(mapping)) return "\n".join(out) def load_set_file(path): if path is None: return set() with open(path) as inf: return { l.strip() for l in inf if not l.strip().startswith("#") and len(l) > 1 } def do_commit_checks(against, revision=None, autoupdate_serial=False): try: if not get_config("dzonegit.ignorewhitespaceerrors", bool): check_whitespace_errors(against, revision=revision) check_updated_zones( against, revision=revision, autoupdate_serial=autoupdate_serial, ) except HookException as e: print(e) raise SystemExit(1) def pre_commit(): against = get_head() autoupdate_serial = not get_config("dzonegit.noserialupdate", bool) do_commit_checks(against, autoupdate_serial=autoupdate_serial) def update(argv=sys.argv): if "GIT_DIR" not in os.environ: raise SystemExit("Don't run this hook from the command line") if len(argv) < 4: raise SystemExit( "Usage: {} ".format(argv[0]), ) refname, against, revision = argv[1:4] if against == "0000000000000000000000000000000000000000": against = get_head() # Empty commit if refname != "refs/heads/master": raise SystemExit("Nothing else than master branch is accepted here") do_commit_checks(against, revision) def pre_receive(stdin=sys.stdin): if stdin.isatty(): raise SystemExit("Don't run this hook from the command line") for line in stdin: against, revision, refname = line.rstrip().split(" ") if refname != "refs/heads/master": raise SystemExit( "Nothing else than master branch " "is accepted here", ) if against == "0000000000000000000000000000000000000000": against = get_head() # Empty commit do_commit_checks(against, revision) def post_receive(stdin=sys.stdin): """Checkout the repository to a path specified in the config. Re-generate config files using defined templates. Issue reload commands for modified zone files, issue reconfig command if zones were added or delefed. """ suffixes = list(str(n) if n else "" for n in range(10)) checkoutpath = get_config("dzonegit.checkoutpath") blacklist = load_set_file(get_config("dzonegit.zoneblacklist")) whitelist = load_set_file(get_config("dzonegit.zonewhitelist")) if checkoutpath: print("Checking out repository into {}…".format(checkoutpath)) subprocess.run( ["git", "checkout", "-f", "master"], check=True, env=dict(os.environ, GIT_WORK_TREE=checkoutpath), ) for s in suffixes: cfpath = get_config("dzonegit.conffilepath{}".format(s)) tplpath = get_config("dzonegit.conffiletemplate{}".format(s)) if cfpath is None or tplpath is None: continue print("Templating config file {}…".format(cfpath)) Path(cfpath).write_text( template_config( checkoutpath, Path(tplpath).read_text(), blacklist=blacklist, whitelist=whitelist, ), ) if stdin.isatty(): raise SystemExit( "Standard input should be redirected. Not issuing any reload " "commands.", ) for line in stdin: against, revision, refname = line.rstrip().split(" ") if refname != "refs/heads/master": continue if against == "0000000000000000000000000000000000000000": against = get_head() # Empty commit # TODO reloads def main(): name = Path(sys.argv[0]).name print(name) if name == "pre-commit": pre_commit() elif name == "update": update() elif name == "pre-receive": pre_receive() elif name == "post-receive": post_receive() else: sys.exit("No valid command found") if __name__ == "__main__": main()