#!/usr/bin/env python3 import os import sys import subprocess import re import time import datetime from collections import namedtuple from hashlib import sha256 from pathlib import Path class HookException(ValueError): """Exception raised when there is an error in input data. Attribures: message -- the cause of problem fname -- affected file stderr -- output of the specific checker """ def __init__(self, message, fname=None, stderr=None): self.message = message self.fname = fname self.stderr = stderr def __str__(self): r = list() if self.fname: r.append("{fname}: ".format(fname=self.fname)) r.append(self.message) r.append("\n") if self.stderr: r.append("\n") r.append(self.stderr) r.append("\n\n") return "".join(r) def get_head(): r = subprocess.run( ["git", "rev-parse", "--verify", "HEAD"], stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, ) if r.returncode == 0: return r.stdout.decode("utf-8").strip() else: # Initial commit: diff against an empty tree object return "4b825dc642cb6eb9a060e54bf8d69288fbee4904" def check_whitespace_errors(against, revision=None): if revision: cmd = ["git", "diff-tree", "--check", against, revision] else: cmd = ["git", "diff-index", "--check", "--cached", against] r = subprocess.run( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, ) if r.returncode != 0: raise HookException( "Whitespace errors", stderr=r.stdout.decode("utf-8"), ) def get_file_contents(path, revision=None): """ Return contents of a file in staged env or in some revision. """ revision = "" if revision is None else revision r = subprocess.run( ["git", "show", "{r}:{p}".format(r=revision, p=path)], stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, check=True, ) return r.stdout def compile_zone(zonename, zonedata): """ Compile the zone. Return tuple with results.""" CompileResults = namedtuple( "CompileResults", "success, serial, zonehash, stderr", ) r = subprocess.run( ["/usr/sbin/named-compilezone", "-o", "-", zonename, "/dev/stdin"], input=zonedata, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) stderr = r.stderr.decode("utf-8") m = re.search(r"^zone.*loaded serial ([0-9]*)$", stderr, re.MULTILINE) if r.returncode == 0 and m: serial = m.group(1) zonehash = sha256(r.stdout).hexdigest() return CompileResults(True, serial, zonehash, stderr) else: return CompileResults(False, None, None, stderr) def is_serial_increased(old, new): """ Return true if serial number was increased using RFC 1982 logic. """ old, new = (int(n) for n in [old, new]) diff = (new - old) % 2**32 return 0 < diff < (2**31 - 1) def get_increased_serial(old): """ Return increased serial number, automatically recognizing the type. """ old = int(old) now = int(time.time()) todayserial = int(datetime.date.today().strftime("%Y%m%d00")) # Note to my future self: This is expected to break on 2034-06-16 # as unix timestamp will become in the same range as YYMMDDnn serial if 1e9 < old < now: # Serial is unix timestamp return str(now) elif 2e9 < old < todayserial: # Serial is YYYYMMDDnn, updated before today return str(todayserial) else: # No pattern recognized, just increase the number return str(old + 1) def get_altered_files(against, diff_filter=None, revision=None): """ Return list of changed files. If revision is None, list changes between staging area and revision. Otherwise differences between two revisions are computed. """ cmd = ["git", "diff", "--name-only", "-z"] if diff_filter: cmd.append("--diff-filter={}".format(diff_filter)) if revision: cmd.append(against) cmd.append(revision) else: cmd.append("--cached") cmd.append(against) r = subprocess.run( cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, check=True, ) if r.stdout: return (Path(p) for p in r.stdout.decode("utf-8").rstrip("\0").split("\0")) else: return list() def get_zone_origin(zonedata): """ Parse $ORIGIN directive before the SOA record. Return zone name without the trailing dot. """ for line in zonedata.splitlines(): if re.match(br"^[^\s;]+\s+([0-9]+\s+)?(IN\s+)?SOA\s+", line, re.I): break m = re.match(br"^\$ORIGIN\s+([^ ]+)\.\s*(;.*)?$", line, re.I) if m: return m.group(1).decode("utf-8").lower() def get_zone_name(path, zonedata): """ Try to guess zone name from either filename or the first $ORIGIN. Throw a HookException if filename and zone ORIGIN differ more than in slashes. """ stemname = Path(path).stem.lower() originname = get_zone_origin(zonedata) if originname: tt = str.maketrans("", "", "/_,:-+*%^&#$") sn, on = [s.translate(tt) for s in [stemname, originname]] if sn != on: raise HookException( "Zone origin {o} differs from zone file.".format(o=originname), fname=path, ) return originname else: return stemname def check_updated_zones(against, revision=None, autoupdate_serial=False): """ Check whether all updated zone files compile. """ for f in get_altered_files(against, "AM", revision): if not f.suffix == ".zone": continue print("Checking file {f}".format(f=f)) zonedata = get_file_contents(f, revision) zname = get_zone_name(f, zonedata) rnew = compile_zone(zname, zonedata) if not rnew.success: raise HookException( "New zone version does not compile", f, rnew.stderr, ) try: zonedata = get_file_contents(f, against) zname = get_zone_name(f, zonedata) rold = compile_zone(zname, zonedata) if (rold.success and rold.zonehash != rnew.zonehash and not is_serial_increased(rold.serial, rnew.serial)): errmsg = "Zone contents changed without increasing serial." diagmsg = "Old revision {}, serial {}, new serial {}".format( against, rold.serial, rnew.serial, ) if autoupdate_serial: newserial = get_increased_serial(rnew.serial) replace_serial(f, rnew.serial, newserial) errmsg += " Serial has been automatically increased." errmsg += " Check and recommit." raise HookException( errmsg, fname=f, stderr=diagmsg, ) except subprocess.CalledProcessError: pass # Old version of zone did not exist def get_config(name, type_=None): cmd = ["git", "config", ] if type_ == bool: cmd.append("--bool") elif type_ == int: cmd.append("--int") elif type_: raise ValueError("Invalid type supplied") cmd.append(name) r = subprocess.run( cmd, stdout=subprocess.PIPE, ) if r.returncode != 0: return None if type_ == bool: return r.stdout == b"true\n" elif type_ == int: return int(r.stdout) else: return r.stdout.decode("utf-8").rstrip("\n") def replace_serial(path, oldserial, newserial): contents = path.read_text() updated, count = re.subn( r'(^.*\sSOA\s.+?\s){}([^0-9])'.format(oldserial), r'\g<1>{}\g<2>'.format(newserial), contents, count=1, flags=re.DOTALL | re.IGNORECASE | re.MULTILINE, ) if count != 1: raise HookException("Cannot update zone serial number") path.write_text(updated) def do_commit_checks(against, revision=None, autoupdate_serial=False): try: if not get_config("dzonegit.ignorewhitespaceerrors", bool): check_whitespace_errors(against, revision=revision) check_updated_zones( against, revision=revision, autoupdate_serial=autoupdate_serial, ) except HookException as e: print(e) raise SystemExit(1) def pre_commit(): against = get_head() autoupdate_serial = not get_config("dzonegit.noserialupdate", bool) do_commit_checks(against, autoupdate_serial=autoupdate_serial) def update(argv=sys.argv): if "GIT_DIR" not in os.environ: raise SystemExit("Don't run this hook from the command line") if len(argv) < 4: raise SystemExit( "Usage: {} ".format(argv[0]), ) refname, against, revision = argv[1:4] if against == "0000000000000000000000000000000000000000": against = get_head() # Empty commit if refname != "refs/heads/master": raise SystemExit("Nothing else than master branch is accepted here") do_commit_checks(against, revision) def pre_receive(stdin=sys.stdin): if stdin.isatty(): raise SystemExit("Don't run this hook from the command line") for line in stdin: against, revision, refname = line.rstrip().split(" ") if refname != "refs/heads/master": raise SystemExit( "Nothing else than master branch " "is accepted here", ) if against == "0000000000000000000000000000000000000000": against = get_head() # Empty commit do_commit_checks(against, revision) def main(): name = Path(sys.argv[0]).name print(name) if name == "pre-commit": pre_commit() elif name == "update": update() elif name == "pre-receive": pre_receive() else: sys.exit("No valid command found") if __name__ == "__main__": main()