1
0
mirror of https://github.com/oskar456/dzonegit.git synced 2024-05-11 05:55:41 +00:00
Files
oskar456-dzonegit/dzonegit.py
Ondřej Caletka d26d5a8ebb Refactor commit checks add pre-receive hook
The pre receive hook has exactly same function as update hook,
one can deploy one or another, preferrably not both.
2018-07-16 12:27:41 +02:00

334 lines
10 KiB
Python

#!/usr/bin/env python3
import os
import sys
import subprocess
import re
import time
import datetime
from collections import namedtuple
from hashlib import sha256
from pathlib import Path
class HookException(ValueError):
"""Exception raised when there is an error in input data.
Attribures:
message -- the cause of problem
fname -- affected file
stderr -- output of the specific checker
"""
def __init__(self, message, fname=None, stderr=None):
self.message = message
self.fname = fname
self.stderr = stderr
def __str__(self):
r = list()
if self.fname:
r.append("{fname}: ".format(fname=self.fname))
r.append(self.message)
r.append("\n")
if self.stderr:
r.append("\n")
r.append(self.stderr)
r.append("\n\n")
return "".join(r)
def get_head():
r = subprocess.run(
["git", "rev-parse", "--verify", "HEAD"],
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
)
if r.returncode == 0:
return r.stdout.decode("utf-8").strip()
else:
# Initial commit: diff against an empty tree object
return "4b825dc642cb6eb9a060e54bf8d69288fbee4904"
def check_whitespace_errors(against, revision=None):
if revision:
cmd = ["git", "diff-tree", "--check", against, revision]
else:
cmd = ["git", "diff-index", "--check", "--cached", against]
r = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
if r.returncode != 0:
raise HookException(
"Whitespace errors",
stderr=r.stdout.decode("utf-8"),
)
def get_file_contents(path, revision=None):
""" Return contents of a file in staged env or in some revision. """
revision = "" if revision is None else revision
r = subprocess.run(
["git", "show", "{r}:{p}".format(r=revision, p=path)],
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
check=True,
)
return r.stdout
def compile_zone(zonename, zonedata):
""" Compile the zone. Return tuple with results."""
CompileResults = namedtuple(
"CompileResults", "success, serial, zonehash, stderr",
)
r = subprocess.run(
["/usr/sbin/named-compilezone", "-o", "-", zonename, "/dev/stdin"],
input=zonedata,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stderr = r.stderr.decode("utf-8")
m = re.search(r"^zone.*loaded serial ([0-9]*)$", stderr, re.MULTILINE)
if r.returncode == 0 and m:
serial = m.group(1)
zonehash = sha256(r.stdout).hexdigest()
return CompileResults(True, serial, zonehash, stderr)
else:
return CompileResults(False, None, None, stderr)
def is_serial_increased(old, new):
""" Return true if serial number was increased using RFC 1982 logic. """
old, new = (int(n) for n in [old, new])
diff = (new - old) % 2**32
return 0 < diff < (2**31 - 1)
def get_increased_serial(old):
""" Return increased serial number, automatically recognizing the type. """
old = int(old)
now = int(time.time())
todayserial = int(datetime.date.today().strftime("%Y%m%d00"))
# Note to my future self: This is expected to break on 2034-06-16
# as unix timestamp will become in the same range as YYMMDDnn serial
if 1e9 < old < now:
# Serial is unix timestamp
return str(now)
elif 2e9 < old < todayserial:
# Serial is YYYYMMDDnn, updated before today
return str(todayserial)
else:
# No pattern recognized, just increase the number
return str(old + 1)
def get_altered_files(against, diff_filter=None, revision=None):
""" Return list of changed files.
If revision is None, list changes between staging area and
revision. Otherwise differences between two revisions are computed.
"""
cmd = ["git", "diff", "--name-only", "-z"]
if diff_filter:
cmd.append("--diff-filter={}".format(diff_filter))
if revision:
cmd.append(against)
cmd.append(revision)
else:
cmd.append("--cached")
cmd.append(against)
r = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
check=True,
)
if r.stdout:
return (Path(p)
for p in r.stdout.decode("utf-8").rstrip("\0").split("\0"))
else:
return list()
def get_zone_origin(zonedata):
"""
Parse $ORIGIN directive before the SOA record.
Return zone name without the trailing dot.
"""
for line in zonedata.splitlines():
if re.match(br"^[^\s;]+\s+([0-9]+\s+)?(IN\s+)?SOA\s+", line, re.I):
break
m = re.match(br"^\$ORIGIN\s+([^ ]+)\.\s*(;.*)?$", line, re.I)
if m:
return m.group(1).decode("utf-8").lower()
def get_zone_name(path, zonedata):
"""
Try to guess zone name from either filename or the first $ORIGIN.
Throw a HookException if filename and zone ORIGIN differ more than
in slashes.
"""
stemname = Path(path).stem.lower()
originname = get_zone_origin(zonedata)
if originname:
tt = str.maketrans("", "", "/_,:-+*%^&#$")
sn, on = [s.translate(tt) for s in [stemname, originname]]
if sn != on:
raise HookException(
"Zone origin {o} differs from zone file.".format(o=originname),
fname=path,
)
return originname
else:
return stemname
def check_updated_zones(against, revision=None, autoupdate_serial=False):
""" Check whether all updated zone files compile. """
for f in get_altered_files(against, "AM", revision):
if not f.suffix == ".zone":
continue
print("Checking file {f}".format(f=f))
zonedata = get_file_contents(f, revision)
zname = get_zone_name(f, zonedata)
rnew = compile_zone(zname, zonedata)
if not rnew.success:
raise HookException(
"New zone version does not compile",
f, rnew.stderr,
)
try:
zonedata = get_file_contents(f, against)
zname = get_zone_name(f, zonedata)
rold = compile_zone(zname, zonedata)
if (rold.success and rold.zonehash != rnew.zonehash and not
is_serial_increased(rold.serial, rnew.serial)):
errmsg = "Zone contents changed without increasing serial."
diagmsg = "Old revision {}, serial {}, new serial {}".format(
against, rold.serial, rnew.serial,
)
if autoupdate_serial:
newserial = get_increased_serial(rnew.serial)
replace_serial(f, rnew.serial, newserial)
errmsg += " Serial has been automatically increased."
errmsg += " Check and recommit."
raise HookException(
errmsg,
fname=f,
stderr=diagmsg,
)
except subprocess.CalledProcessError:
pass # Old version of zone did not exist
def get_config(name, type_=None):
cmd = ["git", "config", ]
if type_ == bool:
cmd.append("--bool")
elif type_ == int:
cmd.append("--int")
elif type_:
raise ValueError("Invalid type supplied")
cmd.append(name)
r = subprocess.run(
cmd,
stdout=subprocess.PIPE,
)
if r.returncode != 0:
return None
if type_ == bool:
return r.stdout == b"true\n"
elif type_ == int:
return int(r.stdout)
else:
return r.stdout.decode("utf-8").rstrip("\n")
def replace_serial(path, oldserial, newserial):
contents = path.read_text()
updated, count = re.subn(
r'(^.*\sSOA\s.+?\s){}([^0-9])'.format(oldserial),
r'\g<1>{}\g<2>'.format(newserial),
contents,
count=1,
flags=re.DOTALL | re.IGNORECASE | re.MULTILINE,
)
if count != 1:
raise HookException("Cannot update zone serial number")
path.write_text(updated)
def do_commit_checks(against, revision=None, autoupdate_serial=False):
try:
if not get_config("dzonegit.ignorewhitespaceerrors", bool):
check_whitespace_errors(against, revision=revision)
check_updated_zones(
against, revision=revision,
autoupdate_serial=autoupdate_serial,
)
except HookException as e:
print(e)
raise SystemExit(1)
def pre_commit():
against = get_head()
autoupdate_serial = not get_config("dzonegit.noserialupdate", bool)
do_commit_checks(against, autoupdate_serial=autoupdate_serial)
def update(argv=sys.argv):
if "GIT_DIR" not in os.environ:
raise SystemExit("Don't run this hook from the command line")
if len(argv) < 4:
raise SystemExit(
"Usage: {} <ref> <oldrev> <newrev>".format(argv[0]),
)
refname, against, revision = argv[1:4]
if against == "0000000000000000000000000000000000000000":
against = get_head() # Empty commit
if refname != "refs/heads/master":
raise SystemExit("Nothing else than master branch is accepted here")
do_commit_checks(against, revision)
def pre_receive(stdin=sys.stdin):
if stdin.isatty():
raise SystemExit("Don't run this hook from the command line")
for line in stdin:
against, revision, refname = line.rstrip().split(" ")
if refname != "refs/heads/master":
raise SystemExit(
"Nothing else than master branch "
"is accepted here",
)
if against == "0000000000000000000000000000000000000000":
against = get_head() # Empty commit
do_commit_checks(against, revision)
def main():
name = Path(sys.argv[0]).name
print(name)
if name == "pre-commit":
pre_commit()
elif name == "update":
update()
elif name == "pre-receive":
pre_receive()
else:
sys.exit("No valid command found")
if __name__ == "__main__":
main()