1
0
mirror of https://github.com/bgp/stayrtr.git synced 2024-05-06 15:54:54 +00:00
bgp-stayrtr/cmd/gortr/gortr.go
Louis Poinsignon 79b7ec03ea Bump version
2019-04-05 15:40:39 -07:00

515 lines
12 KiB
Go

package main
import (
"bytes"
"crypto/ecdsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"flag"
"fmt"
rtr "github.com/cloudflare/gortr/lib"
"github.com/cloudflare/gortr/prefixfile"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"io"
"io/ioutil"
"net"
"net/http"
"os"
"runtime"
"strconv"
"strings"
"time"
)
const (
AppVersion = "GoRTR 0.10.0"
ENV_SSH_PASSWORD = "RTR_SSH_PASSWORD"
METHOD_NONE = iota
METHOD_PASSWORD
METHOD_KEY
)
var (
MetricsAddr = flag.String("metrics.addr", ":8080", "Metrics address")
MetricsPath = flag.String("metrics.path", "/metrics", "Metrics path")
RTRVersion = flag.Int("protocol", 1, "RTR protocol version")
Bind = flag.String("bind", ":8282", "Bind address")
BindTLS = flag.String("tls.bind", "", "Bind address for TLS")
TLSCert = flag.String("tls.cert", "", "Certificate path")
TLSKey = flag.String("tls.key", "", "Private key path")
BindSSH = flag.String("ssh.bind", "", "Bind address for SSH")
SSHKey = flag.String("ssh.key", "private.pem", "SSH host key")
SSHAuth = flag.String("ssh.method", "none", "Select SSH method (none or password)")
SSHAuthUser = flag.String("ssh.auth.user", "rpki", "SSH user")
SSHAuthPassword = flag.String("ssh.auth.password", "", "SSH password (if blank, will use envvar GORTR_SSH_PASSWORD)")
TimeCheck = flag.Bool("checktime", true, "Check if file is still valid")
Verify = flag.Bool("verify", true, "Check signature using provided public key")
PublicKey = flag.String("verify.key", "cf.pub", "Public key path (PEM file)")
CacheBin = flag.String("cache", "https://rpki.cloudflare.com/rpki.json", "URL of the cached JSON data")
UserAgent = flag.String("useragent", "Cloudflare-GoRTR (+https://github.com/cloudflare/gortr)", "User-Agent header")
RefreshInterval = flag.Int("refresh", 600, "Refresh interval in seconds")
MaxConn = flag.Int("maxconn", 0, "Max simultaneous connections (0 to disable limit)")
SendNotifs = flag.Bool("notifications", true, "Send notifications to clients")
LogLevel = flag.String("loglevel", "info", "Log level")
Version = flag.Bool("version", false, "Print version")
NumberOfROAs = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "rpki_roas",
Help: "Number of ROAS.",
},
[]string{"ip_version", "filtered", "path"},
)
LastRefresh = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "rpki_refresh",
Help: "Last refresh.",
},
[]string{"path"},
)
ClientsMetric = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "rtr_clients",
Help: "Number of clients connected.",
},
[]string{"bind"},
)
PDUsRecv = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "rtr_pdus",
Help: "PDU received.",
},
[]string{"type"},
)
protoverToLib = map[int]uint8{
0: rtr.PROTOCOL_VERSION_0,
1: rtr.PROTOCOL_VERSION_1,
}
authToId = map[string]int{
"none": METHOD_NONE,
"password": METHOD_PASSWORD,
//"key": METHOD_KEY,
}
)
func initMetrics() {
prometheus.MustRegister(NumberOfROAs)
prometheus.MustRegister(LastRefresh)
prometheus.MustRegister(ClientsMetric)
prometheus.MustRegister(PDUsRecv)
}
func metricHTTP() {
http.Handle(*MetricsPath, promhttp.Handler())
log.Fatal(http.ListenAndServe(*MetricsAddr, nil))
}
func fetchFile(file string, ua string) ([]byte, error) {
var f io.Reader
var err error
if len(file) > 8 && (file[0:7] == "http://" || file[0:8] == "https://") {
client := &http.Client{}
req, err := http.NewRequest("GET", file, nil)
req.Header.Set("User-Agent", ua)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "text/json")
fhttp, err := client.Do(req)
if err != nil {
return nil, err
}
f = fhttp.Body
} else {
f, err = os.Open(file)
if err != nil {
return nil, err
}
}
data, err2 := ioutil.ReadAll(f)
if err2 != nil {
return nil, err2
}
return data, nil
}
func checkFile(data []byte) ([]byte, error) {
hsum := sha256.Sum256(data)
return hsum[:], nil
}
func decodeJSON(data []byte) (*prefixfile.ROAList, error) {
buf := bytes.NewBuffer(data)
dec := json.NewDecoder(buf)
var roalistjson prefixfile.ROAList
err := dec.Decode(&roalistjson)
return &roalistjson, err
}
func processData(roalistjson *prefixfile.ROAList) ([]rtr.ROA, int, int, int) {
filterDuplicates := make(map[string]bool)
roalist := make([]rtr.ROA, 0)
var count int
var countv4 int
var countv6 int
for _, v := range roalistjson.Data {
_, prefix, _ := net.ParseCIDR(v.Prefix)
asnStr := v.ASN[2:len(v.ASN)]
asnInt, _ := strconv.ParseUint(asnStr, 10, 32)
asn := uint32(asnInt)
count++
if prefix.IP.To4() != nil {
countv4++
} else if prefix.IP.To16() != nil {
countv6++
}
key := fmt.Sprintf("%v,%v,%v", prefix, asn, v.Length)
_, exists := filterDuplicates[key]
if !exists {
filterDuplicates[key] = true
} else {
continue
}
roa := rtr.ROA{
Prefix: *prefix,
ASN: asn,
MaxLen: v.Length,
}
roalist = append(roalist, roa)
}
return roalist, count, countv4, countv6
}
type IdenticalFile struct {
File string
}
func (e IdenticalFile) Error() string {
return fmt.Sprintf("File %v is identical to the previous version", e.File)
}
func (s *state) updateFile(file string) error {
log.Debugf("Refreshing cache from %v", file)
data, err := fetchFile(file, s.userAgent)
if err != nil {
log.Error(err)
return err
}
hsum, _ := checkFile(data)
if s.lasthash != nil {
cres := bytes.Compare(s.lasthash, hsum)
if cres == 0 {
return IdenticalFile{File: file}
}
}
s.lastts = time.Now().UTC()
s.lastdata = data
roalistjson, err := decodeJSON(s.lastdata)
if err != nil {
return err
}
if s.checktime {
validtime := time.Unix(int64(roalistjson.Metadata.Valid), 0).UTC()
if time.Now().UTC().After(validtime) {
return errors.New(fmt.Sprintf("File is expired: %v", validtime))
}
}
if s.verify {
log.Debugf("Verifying signature in %v", file)
if roalistjson.Metadata.SignatureDate == "" || roalistjson.Metadata.Signature == "" {
return errors.New("No signatures in file")
}
validdata, validdatatime, err := roalistjson.CheckFile(s.pubkey)
if err != nil {
return err
}
if !(validdata && (validdatatime || !s.checktime)) {
return errors.New("Invalid signatures")
}
log.Debugf("Signature verified")
}
roas, count, countv4, countv6 := processData(roalistjson)
if err != nil {
return err
}
log.Infof("New update (%v uniques, %v total prefixes). %v bytes. Updating sha256 hash %x -> %x",
len(roas), count, len(s.lastconverted), s.lasthash, hsum)
s.lasthash = hsum
s.server.AddROAs(roas)
sessid, _ := s.server.GetSessionId(nil)
serial, _ := s.server.GetCurrentSerial(sessid)
log.Infof("Updated added, new serial %v", serial)
if s.sendNotifs {
log.Debugf("Sending notifications to clients")
s.server.NotifyClientsLatest()
}
if s.metricsEvent != nil {
var countv4_dup int
var countv6_dup int
for _, roa := range roas {
if roa.Prefix.IP.To4() != nil {
countv4_dup++
} else if roa.Prefix.IP.To16() != nil {
countv6_dup++
}
}
s.metricsEvent.UpdateMetrics(countv4, countv6, countv4_dup, countv6_dup, s.lastts, file)
}
return nil
}
func (s *state) routineUpdate(file string, interval int) {
log.Debugf("Starting refresh routine (file: %v, interval: %vs)", file, interval)
for {
select {
case <-time.After(time.Duration(interval) * time.Second):
err := s.updateFile(file)
if err != nil {
switch err.(type) {
case IdenticalFile:
log.Info(err)
default:
log.Errorf("Error updating: %v", err)
}
}
}
}
}
type state struct {
lastdata []byte
lastconverted []byte
lasthash []byte
lastts time.Time
sendNotifs bool
userAgent string
server *rtr.Server
metricsEvent *metricsEvent
pubkey *ecdsa.PublicKey
verify bool
checktime bool
}
type metricsEvent struct {
}
func (m *metricsEvent) ClientConnected(c *rtr.Client) {
ClientsMetric.WithLabelValues(c.GetLocalAddress().String()).Inc()
}
func (m *metricsEvent) ClientDisconnected(c *rtr.Client) {
ClientsMetric.WithLabelValues(c.GetLocalAddress().String()).Dec()
}
func (m *metricsEvent) HandlePDU(c *rtr.Client, pdu rtr.PDU) {
PDUsRecv.WithLabelValues(
strings.ToLower(
strings.Replace(
rtr.TypeToString(
pdu.GetType()),
" ",
"_", -1))).Inc()
}
func (m *metricsEvent) UpdateMetrics(numIPv4 int, numIPv6 int, numIPv4filtered int, numIPv6filtered int, refreshed time.Time, file string) {
NumberOfROAs.WithLabelValues("ipv4", "filtered", file).Set(float64(numIPv4filtered))
NumberOfROAs.WithLabelValues("ipv4", "unfiltered", file).Set(float64(numIPv4))
NumberOfROAs.WithLabelValues("ipv6", "filtered", file).Set(float64(numIPv6filtered))
NumberOfROAs.WithLabelValues("ipv6", "unfiltered", file).Set(float64(numIPv6))
LastRefresh.WithLabelValues(file).Set(float64(refreshed.UnixNano() / 1e9))
}
func ReadPublicKey(key []byte, isPem bool) (*ecdsa.PublicKey, error) {
if isPem {
block, _ := pem.Decode(key)
key = block.Bytes
}
k, err := x509.ParsePKIXPublicKey(key)
if err != nil {
return nil, err
}
kconv, ok := k.(*ecdsa.PublicKey)
if !ok {
return nil, errors.New("Not EDCSA public key")
}
return kconv, nil
}
func main() {
runtime.GOMAXPROCS(runtime.NumCPU())
flag.Parse()
if *Version {
fmt.Println(AppVersion)
os.Exit(0)
}
lvl, _ := log.ParseLevel(*LogLevel)
log.SetLevel(lvl)
deh := &rtr.DefaultRTREventHandler{
Log: log.StandardLogger(),
}
sc := rtr.ServerConfiguration{
ProtocolVersion: protoverToLib[*RTRVersion],
KeepDifference: 3,
Log: log.StandardLogger(),
}
var me *metricsEvent
if *MetricsAddr != "" {
initMetrics()
go metricHTTP()
me = &metricsEvent{}
}
server := rtr.NewServer(sc, me, deh)
deh.SetROAManager(server)
var pubkey *ecdsa.PublicKey
if *Verify {
pubkeyBytes, err := ioutil.ReadFile(*PublicKey)
if err != nil {
log.Fatal(err)
}
pubkey, err = ReadPublicKey(pubkeyBytes, true)
if err != nil {
log.Fatal(err)
}
}
s := state{
server: server,
metricsEvent: me,
sendNotifs: *SendNotifs,
pubkey: pubkey,
verify: *Verify,
checktime: *TimeCheck,
userAgent: *UserAgent,
}
if *Bind == "" && *BindTLS == "" && *BindSSH == "" {
log.Fatalf("Specify at least a bind address")
}
if *Bind != "" {
go func() {
err := server.Start(*Bind)
if err != nil {
log.Fatal(err)
}
}()
}
if *BindTLS != "" {
cert, err := tls.LoadX509KeyPair(*TLSCert, *TLSKey)
if err != nil {
log.Fatal(err)
}
tlsConfig := tls.Config{
Certificates: []tls.Certificate{cert},
}
go func() {
err := server.StartTLS(*BindTLS, &tlsConfig)
if err != nil {
log.Fatal(err)
}
}()
}
if *BindSSH != "" {
sshkey, err := ioutil.ReadFile(*SSHKey)
if err != nil {
log.Fatal(err)
}
private, err := ssh.ParsePrivateKey(sshkey)
if err != nil {
log.Fatal("Failed to parse private key: ", err)
}
sshConfig := ssh.ServerConfig{}
if authType, ok := authToId[*SSHAuth]; ok {
if authType == METHOD_PASSWORD {
password := *SSHAuthPassword
if password == "" {
password = os.Getenv(ENV_SSH_PASSWORD)
}
sshConfig.PasswordCallback = func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
log.Infof("Connected: %v/%v", conn.User(), conn.RemoteAddr())
if conn.User() != *SSHAuthUser || !bytes.Equal(password, []byte(*SSHAuthPassword)) {
log.Warnf("Wrong user or password for %v/%v. Disconnecting.", conn.User(), conn.RemoteAddr())
return nil, errors.New("Wrong user or password")
}
return &ssh.Permissions{
CriticalOptions: make(map[string]string),
Extensions: make(map[string]string),
}, nil
}
} else if authType == METHOD_NONE {
sshConfig.NoClientAuth = true
}
} else {
log.Fatalf("Auth type %v unknown", *SSHAuth)
}
sshConfig.AddHostKey(private)
go func() {
err := server.StartSSH(*BindSSH, &sshConfig)
if err != nil {
log.Fatal(err)
}
}()
}
err := s.updateFile(*CacheBin)
if err != nil {
switch err.(type) {
case IdenticalFile:
log.Info(err)
default:
log.Errorf("Error updating: %v", err)
}
}
s.routineUpdate(*CacheBin, *RefreshInterval)
}