#!/usr/bin/env python3
"""
BGP RAW Update Generator

Christian Giese, March 2022

Copyright (C) 2020-2022, RtBrick, Inc.
SPDX-License-Identifier: BSD-3-Clause
"""
import argparse
import ipaddress
import json
import logging
import struct
import sys

try:
    from scapy.all import *
    log_runtime.setLevel(logging.ERROR)
    from scapy.contrib.bgp import *
    log_runtime.setLevel(logging.INFO)
except:
    print("Failed to load scapy!")
    exit(1)

# ==============================================================
# DEFINITIONS
# ==============================================================

DESCRIPTION = """
The BGP RAW update generator is a simple 
tool to generate BGP RAW update streams 
for use with the BNG Blaster. 
"""

LOG_LEVELS = {
    'warning': logging.WARNING,
    'info': logging.INFO,
    'debug': logging.DEBUG
}

MPLS_LABEL_MIN = 1
MPLS_LABEL_MAX = 1048575

BGP_UPDATE_MIN_LEN = 34
BGP_LOCAL_PREF_LEN = 7
BGP_MP_REACH_IPV4_FIXED_HDR_LEN = 14
BGP_MP_REACH_IPV6_FIXED_HDR_LEN = 26

# ==============================================================
# SCAPY EXTENSIONS
# ==============================================================

class BGPFieldLabelledIPv4(Field):
    """Labelled IPv4 Field (CIDR)."""

    def mask2iplen(self, mask):
        """Get the IP field mask length (in bytes)."""
        return (mask + 7) // 8

    def h2i(self, pkt, h):
        """Human (x.x.x.x/y/zzzz) to internal representation."""
        ip, mask, label = re.split("/", h)
        return int(label), int(mask), ip

    def i2h(self, pkt, i):
        """Internal to human (x.x.x.x/y/zzzz) representation."""
        label, mask, ip = i
        return "%s/%s/%s" %(ip, mask, label)

    def i2repr(self, pkt, i):
        return self.i2h(pkt, i)

    def i2len(self, pkt, i):
        label, mask, ip = i
        return self.mask2iplen(mask) + 1 + 3

    def i2m(self, pkt, i):
        """Internal to machine representation."""
        label, mask, ip = i
        len = mask + 24
        ip = socket.inet_aton(ip)
        return struct.pack(">B", len) + struct.pack(">I", (label << 4) | 1)[1:] + ip[:self.mask2iplen(mask)]

    def addfield(self, pkt, s, val):
        return s + self.i2m(pkt, val)


class BGPNLRI_LabelledIPv4(Packet):
    """Packet handling labelled IPv4 NLRI fields."""
    name = "Labelled IPv4 NLRI"
    fields_desc = [BGPFieldLabelledIPv4("prefix", "0.0.0.0/0/0")]


class BGPFieldLabelledIPv6(Field):
    """Labelled IPv6 Field (CIDR)."""

    def mask2iplen(self, mask):
        """Get the IP field mask length (in bytes)."""
        return (mask + 7) // 8

    def h2i(self, pkt, h):
        """Human (::/y/zzzz) to internal representation."""
        ip, mask, label = re.split("/", h)
        return int(label), int(mask), ip

    def i2h(self, pkt, i):
        """Internal to human (::/y/zzzz) representation."""
        label, mask, ip = i
        return "%s/%s/%s" %(ip, mask, label)

    def i2repr(self, pkt, i):
        return self.i2h(pkt, i)

    def i2len(self, pkt, i):
        label, mask, ip = i
        return self.mask2iplen(mask) + 1 + 3

    def i2m(self, pkt, i):
        """Internal to machine representation."""
        label, mask, ip = i
        len = mask + 24
        ip = pton_ntop.inet_pton(socket.AF_INET6, ip)
        return struct.pack(">B", len) + struct.pack(">I", (label << 4) | 1)[1:] + ip[:self.mask2iplen(mask)]

    def addfield(self, pkt, s, val):
        return s + self.i2m(pkt, val)


class BGPNLRI_LabelledIPv6(Packet):
    """Packet handling labelled IPv6 NLRI fields."""
    name = "Labelled IPv6 NLRI"
    fields_desc = [BGPFieldLabelledIPv6("prefix", "::/0/0")]


# ==============================================================
# FUNCTIONS
# ==============================================================

def init_logging(log_level: int) -> logging.Logger:
    """Init logging."""
    level = LOG_LEVELS[log_level]
    log = logging.getLogger()
    log.setLevel(level)
    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(level)
    formatter = logging.Formatter('[%(asctime)s][%(levelname)-7s] %(message)s')
    formatter.datefmt = '%Y-%m-%d %H:%M:%S'
    handler.setFormatter(formatter)
    log.addHandler(handler)
    return log


def label_type(label: int) -> int:
    """Argument parser type for MPLS labels."""
    label = int(label)
    if label < MPLS_LABEL_MIN or label > MPLS_LABEL_MAX:
        raise argparse.ArgumentTypeError("MPLS label out of range %s - %s" % (MPLS_LABEL_MIN, MPLS_LABEL_MAX))
    return label


# ==============================================================
# MAIN
# ==============================================================

def main():
    # parse arguments
    parser = argparse.ArgumentParser(description=DESCRIPTION)
    parser.add_argument('-a', '--asn', type=int, default=[], action='append', help='autonomous system number')
    parser.add_argument('-n', '--next-hop-base', metavar='ADDRESS', type=ipaddress.ip_address, required=True, help='next-hop base address (IPv4 or IPv6)')
    parser.add_argument('-N', '--next-hop-num', metavar='N', type=int, default=1, help='next-hop count')
    parser.add_argument('-p', '--prefix-base', metavar='PREFIX', type=ipaddress.ip_network, required=True, help='prefix base network (IPv4 or IPv6)')
    parser.add_argument('-P', '--prefix-num', metavar='N', type=int, default=1, help='prefix count')
    parser.add_argument('-m', '--label-base', metavar='LABEL', type=label_type, help='label base')
    parser.add_argument('-M', '--label-num', metavar='N', type=int, default=1, help='label count')
    parser.add_argument('-l', '--local-pref', type=int, help='local preference')
    parser.add_argument('-f', '--file', type=str, default="out.bgp", help='output file')
    parser.add_argument('-w', '--withdraw', action="store_true", help="withdraw prefixes")
    parser.add_argument('-s', '--streams', type=str, help="generate BNG Blaster traffic stream file")
    parser.add_argument('--stream-tx-label', metavar='LABEL', type=label_type, help="stream TX outer label")
    parser.add_argument('--stream-tx-inner-label', metavar='LABEL', type=label_type, help="stream TX inner label")
    parser.add_argument('--stream-rx-label', metavar='LABEL', type=label_type, help="stream RX label")
    parser.add_argument('--stream-rx-label-num', metavar='N', type=int, default=1, help="stream RX label count")
    parser.add_argument('--stream-threads', metavar='N', type=int, default=1, help="stream TX threads")
    parser.add_argument('--stream-pps', metavar='N', type=float, default=1.0, help="stream packets per seconds")
    parser.add_argument('--stream-interface', metavar='IFACE', type=str, help="stream interface")
    parser.add_argument('--stream-group-id', metavar='N', type=int, help="stream group identifier")
    parser.add_argument('--stream-direction', default="downstream", choices=['upstream', 'downstream', 'both'], help="stream direction")
    parser.add_argument('--stream-append', action="store_true", help="append to stream file if exist")
    parser.add_argument('--end-of-rib', action="store_true", help="add end-of-rib message")
    parser.add_argument('--append', action="store_true", help="append to file if exist")
    parser.add_argument('--pcap', metavar='FILE', type=str, help="write BGP updates to PCAP file")
    parser.add_argument('--log-level', type=str, default='info', choices=LOG_LEVELS.keys(), help='logging Level')
    args = parser.parse_args()

    # init logging
    log = init_logging(args.log_level)

    if args.label_base:
        log.info("init %s labelled IPv%s prefixes" % (args.prefix_num, args.prefix_base.version))
        labelled = True
    else:
        log.info("init %s IPv%s prefixes" % (args.prefix_num, args.prefix_base.version))
        labelled = False

    if args.prefix_base.version == 6 and args.next_hop_base.version == 4:
        log.warning("next-hop converted tp IPv6 compatible IPv4 address ::FFFF:%s" % args.next_hop_base)
        args.next_hop_base = ipaddress.ip_address("::FFFF:%s" % args.next_hop_base)

    if args.prefix_base.version != args.next_hop_base.version:
        log.error("prefix and next-hop address family must be equal")
        exit(1)

    ip_version = args.prefix_base.version

    streams = []
    stream_thread = 1
    stream_label_index = 0
    stream_label = args.stream_rx_label
    if args.streams and args.stream_append:
        try:
            with open(args.streams) as json_file:
                data = json.load(json_file)
                streams = data.get("streams", [])
        except:
            pass

    # Here we will store packets for optional PCAP output
    pcap_packets = []
    def pcap(message):
        if args.pcap:
            pcap_packets.append(Ether()/IP()/TCP(sport=len(pcap_packets)+10000, dport=179, seq=1, flags='PA')/message)

    # The prefixes are ordered by nexthop index
    #
    # prefixes = {
    #    0: ["<prefix1>", "<prefix2>", "..."],
    #    1: ["<prefix1>", "<prefix2>", "..."]
    # }
    prefixes = {i: [] for i in range(args.next_hop_num)}
    next_hops = []
    for nh_index in range(args.next_hop_num):
        next_hops.append(str(args.next_hop_base+nh_index))

    nh_index = 0
    label_index = 0
    prefix = args.prefix_base
    label = args.label_base
    for _ in range(args.prefix_num):
        log.debug("add prefix %s via %s label %s" % (prefix, next_hops[nh_index], label))
        prefixes[nh_index].append((prefix, label))

        if args.streams:
            stream = {
                "name": "%s" % prefix,
                "direction": args.stream_direction,
                "pps": args.stream_pps,
                "threaded": True,
                "thread-group": stream_thread
            }
            if ip_version == 4:
                stream["type"] = "ipv4"
                stream["destination-ipv4-address"] = str(prefix.network_address+1)
            else:
                stream["type"] = "ipv6"
                stream["destination-ipv6-address"] = str(prefix.network_address+1)

            if args.stream_interface:
                stream["network-interface"] = args.stream_interface

            if args.stream_group_id:
                stream["stream-group-id"] = args.stream_group_id
            
            if stream_label: 
                stream["rx-label1"] = stream_label
                if labelled:
                    stream["rx-label2"] = label
                stream_label_index += 1
                if stream_label_index < args.stream_rx_label_num:
                    stream_label = args.stream_rx_label + stream_label_index
                    if stream_label > MPLS_LABEL_MAX:
                        stream_label_index = 0
                        stream_label = args.stream_rx_label
                else:
                    stream_label_index = 0
                    stream_label = args.stream_rx_label
            else:
                if labelled: 
                    stream["rx-label1"] = label

            if args.stream_tx_label:
                stream["tx-label1"] = args.stream_tx_label
                if args.stream_tx_inner_label:
                    stream["tx-label2"] = args.stream_tx_inner_label

            streams.append(stream)
            stream_thread += 1
            if stream_thread > args.stream_threads:
                stream_thread = 1

        # next...
        nh_index += 1
        if nh_index >= args.next_hop_num:
            nh_index = 0

        if labelled: 
            label_index += 1
            if label_index < args.label_num:
                label = args.label_base + label_index
                if label > MPLS_LABEL_MAX:
                    label_index = 0
                    label = args.label_base
            else:
                label_index = 0
                label = args.label_base

        prefix = ipaddress.ip_network("%s/%s" % (prefix.broadcast_address+1, prefix.prefixlen))

    if args.streams:
        log.info("write %s streams to file %s", len(streams), args.streams)
        with open(args.streams, "w") as f:
            json.dump({ "streams": streams}, f, indent=4)

    prefix_bytes = (args.prefix_base.prefixlen + 7) // 8 
    if labelled:
        prefix_attr_len = prefix_bytes + 4 # N prefix bytes + 1 byte prefix len + 3 byte label
    else:
        prefix_attr_len = prefix_bytes + 1 # N prefix bytes + 1 byte prefix len

    if args.append:
        log.info("open file %s (append)" % args.file)
        file_flags = "ab"
    else:
        log.info("open file %s (replace)" % args.file)
        file_flags = "wb"

    with open(args.file, file_flags) as f:
        origin_attr = BGPPathAttr(type_flags=64, type_code="ORIGIN", attribute=BGPPAOrigin())
        as_path_attr = BGPPathAttr(type_flags=64, type_code="AS_PATH", attribute=BGPPAAS4BytesPath(segments = [BGPPAAS4BytesPath.ASPathSegment(segment_type="AS_SEQUENCE", segment_value=args.asn)]))

        if args.local_pref is not None:
            local_pref_attr = BGPPathAttr(type_flags=64, type_code="LOCAL_PREF", attribute=BGPPALocalPref(local_pref=args.local_pref))

        while len(prefixes):
            for nh_index in list(prefixes.keys()):
                prefix_list = prefixes[nh_index]
                prefix_count = 0

                path_attr = [origin_attr, as_path_attr]

                nlri = []

                remaining = BGP_MAXIMUM_MESSAGE_SIZE - (BGP_UPDATE_MIN_LEN + (len(args.asn) * 4))

                if args.local_pref is not None:
                    remaining -= BGP_LOCAL_PREF_LEN
                    path_attr.append(local_pref_attr)

                if ip_version == 4:
                    if labelled:
                        remaining -= BGP_MP_REACH_IPV4_FIXED_HDR_LEN
                    else:
                        remaining -= 5 # BGP IPv4 next-hop path attribute 
                        next_hop_attr = BGPPANextHop(next_hop=next_hops[nh_index])
                        path_attr.append(BGPPathAttr(type_flags=64, type_code="NEXT_HOP", attribute=next_hop_attr))
                else:
                    remaining -= BGP_MP_REACH_IPV6_FIXED_HDR_LEN

                if args.withdraw:
                    path_attr = []

                while len(prefix_list) > 0:
                    if remaining < prefix_attr_len:
                        break
                    remaining -= prefix_attr_len

                    # get next prefix and label
                    prefix, label = prefix_list.pop(0)
                    prefix_count += 1
                    if labelled:
                        labelled_prefix = "%s/%s" % (prefix, label)
                        if prefix.version == 4:
                            nlri.append(BGPNLRI_LabelledIPv4(prefix=labelled_prefix))
                        else:
                            nlri.append(BGPNLRI_LabelledIPv6(prefix=labelled_prefix))
                        # There is a limitation which allows to withdraw only one prefix
                        # per update message for SAFI labelled-unicast. 
                        if args.withdraw:
                            break
                    else:
                        if prefix.version == 4:
                            nlri.append(BGPNLRI_IPv4(prefix=str(prefix)))
                        else:
                            nlri.append(BGPNLRI_IPv6(prefix=str(prefix)))

                if len(prefix_list) == 0:
                    del prefixes[nh_index]

                if prefix_count == 0:
                    # skip empty updates
                    continue

                if labelled or ip_version == 6:
                    if ip_version == 4:
                        # labelled IPv4 unicast
                        if args.withdraw:
                            mp_reach_attr = BGPPAMPUnreachNLRI(afi=1, safi=4, afi_safi_specific=nlri)
                        else:
                            mp_reach_attr = BGPPAMPReachNLRI(afi=1, safi=4, nh_v4_addr=next_hops[nh_index], nh_addr_len=4, nlri=nlri)
                    elif labelled and ip_version == 6:
                        # labelled IPv6 unicast
                        if args.withdraw:
                            mp_reach_attr = BGPPAMPUnreachNLRI(afi=2, safi=4, afi_safi_specific=nlri)
                        else:
                            mp_reach_attr = BGPPAMPReachNLRI(afi=2, safi=4, nh_v6_addr=next_hops[nh_index], nh_addr_len=16, nlri=nlri)
                    else:
                        # IPv6 unicast
                        if args.withdraw:
                            mp_reach_attr = BGPPAMPUnreachNLRI(afi=2, safi=1, afi_safi_specific=BGPPAMPUnreachNLRI_IPv6(withdrawn_routes=nlri))
                        else:
                            mp_reach_attr = BGPPAMPReachNLRI(afi=2, safi=1, nh_v6_addr=next_hops[nh_index], nh_addr_len=16, nlri=nlri)

                    if args.withdraw:
                        path_attr.append(BGPPathAttr(type_flags=144, type_code="MP_UNREACH_NLRI", attribute=mp_reach_attr))
                    else:
                        path_attr.append(BGPPathAttr(type_flags=144, type_code="MP_REACH_NLRI", attribute=mp_reach_attr))
                    nlri = []    

                # build update message
                if args.withdraw:
                    message = BGPHeader(type="UPDATE")/BGPUpdate(path_attr=path_attr, withdrawn_routes=nlri)
                else:
                    message = BGPHeader(type="UPDATE")/BGPUpdate(path_attr=path_attr, nlri=nlri)
                message_bin = bytearray(message.build())
                log.debug("add update with %s prefixes and length of %s bytes" % (prefix_count, len(message_bin)))
                if len(message_bin) > BGP_MAXIMUM_MESSAGE_SIZE:
                    # not expected ...
                    log.error("invalid BGP update message with length of %s bytes generated, please open a ticket", len(message_bin))
                pcap(message)
                f.write(message_bin)

        # add end-of-rib update message
        if args.end_of_rib:
            message = BGPHeader(type="UPDATE")/BGPUpdate()
            log.debug("add end-of-rib")
            pcap(message)
            f.write(bytearray(message.build()))

    if args.pcap:
        log.info("create PCAP file %s" % args.pcap)
        try:
            wrpcap(args.pcap, pcap_packets)
        except Exception as e:
            log.error("failed to create PCAP file")
            log.debug(e)
    
    log.info("finished")


if __name__ == "__main__":
    main()