mirror of
https://github.com/bgp/stayrtr.git
synced 2024-05-06 15:54:54 +00:00
515 lines
12 KiB
Go
515 lines
12 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/ecdsa"
|
|
"crypto/sha256"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/json"
|
|
"encoding/pem"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
rtr "github.com/cloudflare/gortr/lib"
|
|
"github.com/cloudflare/gortr/prefixfile"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.org/x/crypto/ssh"
|
|
"io"
|
|
"io/ioutil"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
AppVersion = "GoRTR 0.10.0"
|
|
|
|
ENV_SSH_PASSWORD = "RTR_SSH_PASSWORD"
|
|
|
|
METHOD_NONE = iota
|
|
METHOD_PASSWORD
|
|
METHOD_KEY
|
|
)
|
|
|
|
var (
|
|
MetricsAddr = flag.String("metrics.addr", ":8080", "Metrics address")
|
|
MetricsPath = flag.String("metrics.path", "/metrics", "Metrics path")
|
|
RTRVersion = flag.Int("protocol", 1, "RTR protocol version")
|
|
|
|
Bind = flag.String("bind", ":8282", "Bind address")
|
|
|
|
BindTLS = flag.String("tls.bind", "", "Bind address for TLS")
|
|
TLSCert = flag.String("tls.cert", "", "Certificate path")
|
|
TLSKey = flag.String("tls.key", "", "Private key path")
|
|
|
|
BindSSH = flag.String("ssh.bind", "", "Bind address for SSH")
|
|
SSHKey = flag.String("ssh.key", "private.pem", "SSH host key")
|
|
SSHAuth = flag.String("ssh.method", "none", "Select SSH method (none or password)")
|
|
SSHAuthUser = flag.String("ssh.auth.user", "rpki", "SSH user")
|
|
SSHAuthPassword = flag.String("ssh.auth.password", "", "SSH password (if blank, will use envvar GORTR_SSH_PASSWORD)")
|
|
|
|
TimeCheck = flag.Bool("checktime", true, "Check if file is still valid")
|
|
Verify = flag.Bool("verify", true, "Check signature using provided public key")
|
|
PublicKey = flag.String("verify.key", "cf.pub", "Public key path (PEM file)")
|
|
|
|
CacheBin = flag.String("cache", "https://rpki.cloudflare.com/rpki.json", "URL of the cached JSON data")
|
|
UserAgent = flag.String("useragent", "Cloudflare-GoRTR (+https://github.com/cloudflare/gortr)", "User-Agent header")
|
|
RefreshInterval = flag.Int("refresh", 600, "Refresh interval in seconds")
|
|
MaxConn = flag.Int("maxconn", 0, "Max simultaneous connections (0 to disable limit)")
|
|
SendNotifs = flag.Bool("notifications", true, "Send notifications to clients")
|
|
|
|
LogLevel = flag.String("loglevel", "info", "Log level")
|
|
Version = flag.Bool("version", false, "Print version")
|
|
|
|
NumberOfROAs = prometheus.NewGaugeVec(
|
|
prometheus.GaugeOpts{
|
|
Name: "rpki_roas",
|
|
Help: "Number of ROAS.",
|
|
},
|
|
[]string{"ip_version", "filtered", "path"},
|
|
)
|
|
LastRefresh = prometheus.NewGaugeVec(
|
|
prometheus.GaugeOpts{
|
|
Name: "rpki_refresh",
|
|
Help: "Last refresh.",
|
|
},
|
|
[]string{"path"},
|
|
)
|
|
ClientsMetric = prometheus.NewGaugeVec(
|
|
prometheus.GaugeOpts{
|
|
Name: "rtr_clients",
|
|
Help: "Number of clients connected.",
|
|
},
|
|
[]string{"bind"},
|
|
)
|
|
PDUsRecv = prometheus.NewCounterVec(
|
|
prometheus.CounterOpts{
|
|
Name: "rtr_pdus",
|
|
Help: "PDU received.",
|
|
},
|
|
[]string{"type"},
|
|
)
|
|
|
|
protoverToLib = map[int]uint8{
|
|
0: rtr.PROTOCOL_VERSION_0,
|
|
1: rtr.PROTOCOL_VERSION_1,
|
|
}
|
|
authToId = map[string]int{
|
|
"none": METHOD_NONE,
|
|
"password": METHOD_PASSWORD,
|
|
//"key": METHOD_KEY,
|
|
}
|
|
)
|
|
|
|
func initMetrics() {
|
|
prometheus.MustRegister(NumberOfROAs)
|
|
prometheus.MustRegister(LastRefresh)
|
|
prometheus.MustRegister(ClientsMetric)
|
|
prometheus.MustRegister(PDUsRecv)
|
|
}
|
|
|
|
func metricHTTP() {
|
|
http.Handle(*MetricsPath, promhttp.Handler())
|
|
log.Fatal(http.ListenAndServe(*MetricsAddr, nil))
|
|
}
|
|
|
|
func fetchFile(file string, ua string) ([]byte, error) {
|
|
var f io.Reader
|
|
var err error
|
|
if len(file) > 8 && (file[0:7] == "http://" || file[0:8] == "https://") {
|
|
|
|
client := &http.Client{}
|
|
req, err := http.NewRequest("GET", file, nil)
|
|
req.Header.Set("User-Agent", ua)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Accept", "text/json")
|
|
|
|
fhttp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
f = fhttp.Body
|
|
} else {
|
|
f, err = os.Open(file)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
data, err2 := ioutil.ReadAll(f)
|
|
if err2 != nil {
|
|
return nil, err2
|
|
}
|
|
return data, nil
|
|
}
|
|
|
|
func checkFile(data []byte) ([]byte, error) {
|
|
hsum := sha256.Sum256(data)
|
|
return hsum[:], nil
|
|
}
|
|
|
|
func decodeJSON(data []byte) (*prefixfile.ROAList, error) {
|
|
buf := bytes.NewBuffer(data)
|
|
dec := json.NewDecoder(buf)
|
|
|
|
var roalistjson prefixfile.ROAList
|
|
err := dec.Decode(&roalistjson)
|
|
return &roalistjson, err
|
|
}
|
|
|
|
func processData(roalistjson *prefixfile.ROAList) ([]rtr.ROA, int, int, int) {
|
|
filterDuplicates := make(map[string]bool)
|
|
|
|
roalist := make([]rtr.ROA, 0)
|
|
|
|
var count int
|
|
var countv4 int
|
|
var countv6 int
|
|
for _, v := range roalistjson.Data {
|
|
_, prefix, _ := net.ParseCIDR(v.Prefix)
|
|
asnStr := v.ASN[2:len(v.ASN)]
|
|
asnInt, _ := strconv.ParseUint(asnStr, 10, 32)
|
|
asn := uint32(asnInt)
|
|
|
|
count++
|
|
if prefix.IP.To4() != nil {
|
|
countv4++
|
|
} else if prefix.IP.To16() != nil {
|
|
countv6++
|
|
}
|
|
|
|
key := fmt.Sprintf("%v,%v,%v", prefix, asn, v.Length)
|
|
_, exists := filterDuplicates[key]
|
|
if !exists {
|
|
filterDuplicates[key] = true
|
|
} else {
|
|
continue
|
|
}
|
|
|
|
roa := rtr.ROA{
|
|
Prefix: *prefix,
|
|
ASN: asn,
|
|
MaxLen: v.Length,
|
|
}
|
|
roalist = append(roalist, roa)
|
|
}
|
|
return roalist, count, countv4, countv6
|
|
}
|
|
|
|
type IdenticalFile struct {
|
|
File string
|
|
}
|
|
|
|
func (e IdenticalFile) Error() string {
|
|
return fmt.Sprintf("File %v is identical to the previous version", e.File)
|
|
}
|
|
|
|
func (s *state) updateFile(file string) error {
|
|
log.Debugf("Refreshing cache from %v", file)
|
|
data, err := fetchFile(file, s.userAgent)
|
|
if err != nil {
|
|
log.Error(err)
|
|
return err
|
|
}
|
|
hsum, _ := checkFile(data)
|
|
if s.lasthash != nil {
|
|
cres := bytes.Compare(s.lasthash, hsum)
|
|
if cres == 0 {
|
|
return IdenticalFile{File: file}
|
|
}
|
|
}
|
|
|
|
s.lastts = time.Now().UTC()
|
|
s.lastdata = data
|
|
|
|
roalistjson, err := decodeJSON(s.lastdata)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if s.checktime {
|
|
validtime := time.Unix(int64(roalistjson.Metadata.Valid), 0).UTC()
|
|
if time.Now().UTC().After(validtime) {
|
|
return errors.New(fmt.Sprintf("File is expired: %v", validtime))
|
|
}
|
|
}
|
|
|
|
if s.verify {
|
|
log.Debugf("Verifying signature in %v", file)
|
|
if roalistjson.Metadata.SignatureDate == "" || roalistjson.Metadata.Signature == "" {
|
|
return errors.New("No signatures in file")
|
|
}
|
|
|
|
validdata, validdatatime, err := roalistjson.CheckFile(s.pubkey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !(validdata && (validdatatime || !s.checktime)) {
|
|
return errors.New("Invalid signatures")
|
|
}
|
|
log.Debugf("Signature verified")
|
|
}
|
|
|
|
roas, count, countv4, countv6 := processData(roalistjson)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
log.Infof("New update (%v uniques, %v total prefixes). %v bytes. Updating sha256 hash %x -> %x",
|
|
len(roas), count, len(s.lastconverted), s.lasthash, hsum)
|
|
s.lasthash = hsum
|
|
|
|
s.server.AddROAs(roas)
|
|
|
|
sessid, _ := s.server.GetSessionId(nil)
|
|
serial, _ := s.server.GetCurrentSerial(sessid)
|
|
log.Infof("Updated added, new serial %v", serial)
|
|
if s.sendNotifs {
|
|
log.Debugf("Sending notifications to clients")
|
|
s.server.NotifyClientsLatest()
|
|
}
|
|
|
|
if s.metricsEvent != nil {
|
|
var countv4_dup int
|
|
var countv6_dup int
|
|
for _, roa := range roas {
|
|
if roa.Prefix.IP.To4() != nil {
|
|
countv4_dup++
|
|
} else if roa.Prefix.IP.To16() != nil {
|
|
countv6_dup++
|
|
}
|
|
}
|
|
s.metricsEvent.UpdateMetrics(countv4, countv6, countv4_dup, countv6_dup, s.lastts, file)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *state) routineUpdate(file string, interval int) {
|
|
log.Debugf("Starting refresh routine (file: %v, interval: %vs)", file, interval)
|
|
for {
|
|
select {
|
|
case <-time.After(time.Duration(interval) * time.Second):
|
|
err := s.updateFile(file)
|
|
if err != nil {
|
|
switch err.(type) {
|
|
case IdenticalFile:
|
|
log.Info(err)
|
|
default:
|
|
log.Errorf("Error updating: %v", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
type state struct {
|
|
lastdata []byte
|
|
lastconverted []byte
|
|
lasthash []byte
|
|
lastts time.Time
|
|
sendNotifs bool
|
|
userAgent string
|
|
|
|
server *rtr.Server
|
|
|
|
metricsEvent *metricsEvent
|
|
|
|
pubkey *ecdsa.PublicKey
|
|
verify bool
|
|
checktime bool
|
|
}
|
|
|
|
type metricsEvent struct {
|
|
}
|
|
|
|
func (m *metricsEvent) ClientConnected(c *rtr.Client) {
|
|
ClientsMetric.WithLabelValues(c.GetLocalAddress().String()).Inc()
|
|
}
|
|
|
|
func (m *metricsEvent) ClientDisconnected(c *rtr.Client) {
|
|
ClientsMetric.WithLabelValues(c.GetLocalAddress().String()).Dec()
|
|
}
|
|
|
|
func (m *metricsEvent) HandlePDU(c *rtr.Client, pdu rtr.PDU) {
|
|
PDUsRecv.WithLabelValues(
|
|
strings.ToLower(
|
|
strings.Replace(
|
|
rtr.TypeToString(
|
|
pdu.GetType()),
|
|
" ",
|
|
"_", -1))).Inc()
|
|
}
|
|
|
|
func (m *metricsEvent) UpdateMetrics(numIPv4 int, numIPv6 int, numIPv4filtered int, numIPv6filtered int, refreshed time.Time, file string) {
|
|
NumberOfROAs.WithLabelValues("ipv4", "filtered", file).Set(float64(numIPv4filtered))
|
|
NumberOfROAs.WithLabelValues("ipv4", "unfiltered", file).Set(float64(numIPv4))
|
|
NumberOfROAs.WithLabelValues("ipv6", "filtered", file).Set(float64(numIPv6filtered))
|
|
NumberOfROAs.WithLabelValues("ipv6", "unfiltered", file).Set(float64(numIPv6))
|
|
LastRefresh.WithLabelValues(file).Set(float64(refreshed.UnixNano() / 1e9))
|
|
}
|
|
|
|
func ReadPublicKey(key []byte, isPem bool) (*ecdsa.PublicKey, error) {
|
|
if isPem {
|
|
block, _ := pem.Decode(key)
|
|
key = block.Bytes
|
|
}
|
|
|
|
k, err := x509.ParsePKIXPublicKey(key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
kconv, ok := k.(*ecdsa.PublicKey)
|
|
if !ok {
|
|
return nil, errors.New("Not EDCSA public key")
|
|
}
|
|
return kconv, nil
|
|
}
|
|
|
|
func main() {
|
|
runtime.GOMAXPROCS(runtime.NumCPU())
|
|
|
|
flag.Parse()
|
|
if *Version {
|
|
fmt.Println(AppVersion)
|
|
os.Exit(0)
|
|
}
|
|
|
|
lvl, _ := log.ParseLevel(*LogLevel)
|
|
log.SetLevel(lvl)
|
|
|
|
deh := &rtr.DefaultRTREventHandler{
|
|
Log: log.StandardLogger(),
|
|
}
|
|
|
|
sc := rtr.ServerConfiguration{
|
|
ProtocolVersion: protoverToLib[*RTRVersion],
|
|
KeepDifference: 3,
|
|
Log: log.StandardLogger(),
|
|
}
|
|
|
|
var me *metricsEvent
|
|
if *MetricsAddr != "" {
|
|
initMetrics()
|
|
go metricHTTP()
|
|
me = &metricsEvent{}
|
|
}
|
|
|
|
server := rtr.NewServer(sc, me, deh)
|
|
deh.SetROAManager(server)
|
|
|
|
var pubkey *ecdsa.PublicKey
|
|
if *Verify {
|
|
pubkeyBytes, err := ioutil.ReadFile(*PublicKey)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
pubkey, err = ReadPublicKey(pubkeyBytes, true)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
s := state{
|
|
server: server,
|
|
metricsEvent: me,
|
|
sendNotifs: *SendNotifs,
|
|
pubkey: pubkey,
|
|
verify: *Verify,
|
|
checktime: *TimeCheck,
|
|
userAgent: *UserAgent,
|
|
}
|
|
|
|
if *Bind == "" && *BindTLS == "" && *BindSSH == "" {
|
|
log.Fatalf("Specify at least a bind address")
|
|
}
|
|
|
|
if *Bind != "" {
|
|
go func() {
|
|
err := server.Start(*Bind)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}()
|
|
}
|
|
if *BindTLS != "" {
|
|
cert, err := tls.LoadX509KeyPair(*TLSCert, *TLSKey)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
tlsConfig := tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
}
|
|
go func() {
|
|
err := server.StartTLS(*BindTLS, &tlsConfig)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}()
|
|
}
|
|
if *BindSSH != "" {
|
|
sshkey, err := ioutil.ReadFile(*SSHKey)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
private, err := ssh.ParsePrivateKey(sshkey)
|
|
if err != nil {
|
|
log.Fatal("Failed to parse private key: ", err)
|
|
}
|
|
|
|
sshConfig := ssh.ServerConfig{}
|
|
|
|
if authType, ok := authToId[*SSHAuth]; ok {
|
|
if authType == METHOD_PASSWORD {
|
|
password := *SSHAuthPassword
|
|
if password == "" {
|
|
password = os.Getenv(ENV_SSH_PASSWORD)
|
|
}
|
|
sshConfig.PasswordCallback = func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
log.Infof("Connected: %v/%v", conn.User(), conn.RemoteAddr())
|
|
if conn.User() != *SSHAuthUser || !bytes.Equal(password, []byte(*SSHAuthPassword)) {
|
|
log.Warnf("Wrong user or password for %v/%v. Disconnecting.", conn.User(), conn.RemoteAddr())
|
|
return nil, errors.New("Wrong user or password")
|
|
}
|
|
|
|
return &ssh.Permissions{
|
|
CriticalOptions: make(map[string]string),
|
|
Extensions: make(map[string]string),
|
|
}, nil
|
|
}
|
|
} else if authType == METHOD_NONE {
|
|
sshConfig.NoClientAuth = true
|
|
}
|
|
} else {
|
|
log.Fatalf("Auth type %v unknown", *SSHAuth)
|
|
}
|
|
|
|
sshConfig.AddHostKey(private)
|
|
go func() {
|
|
err := server.StartSSH(*BindSSH, &sshConfig)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
err := s.updateFile(*CacheBin)
|
|
if err != nil {
|
|
switch err.(type) {
|
|
case IdenticalFile:
|
|
log.Info(err)
|
|
default:
|
|
log.Errorf("Error updating: %v", err)
|
|
}
|
|
}
|
|
s.routineUpdate(*CacheBin, *RefreshInterval)
|
|
|
|
}
|