1
0
mirror of https://github.com/bgp/stayrtr.git synced 2024-05-06 15:54:54 +00:00
Files
Ties de Kock b93d7a8477 Merge pull request #117 from cjeker/avoid_empty_deltas
Avoid adding empty deltas to the cache
2024-03-01 09:36:03 +01:00

1166 lines
26 KiB
Go

package rtrlib
import (
"bytes"
"crypto/tls"
"flag"
"fmt"
"io"
"math/rand"
"net"
"net/netip"
"sync"
"time"
"golang.org/x/crypto/ssh"
)
func GenerateSessionId() uint16 {
var sessid uint16
r := rand.New(rand.NewSource(time.Now().UTC().Unix()))
sessid = uint16(r.Uint32())
return sessid
}
type RTRServerEventHandler interface {
ClientConnected(*Client)
ClientDisconnected(*Client)
HandlePDU(*Client, PDU)
}
type RTREventHandler interface {
RequestCache(*Client)
RequestNewVersion(*Client, uint16, uint32)
}
// This is a general interface for things like a VRP, BGPsec Router key or ASPA object
// Be sure to have all of these as pointers, or SetFlag() cannot work!
type SendableData interface {
Copy() SendableData
Equals(SendableData) bool
HashKey() string
String() string
Type() string
SetFlag(uint8)
GetFlag() uint8
}
// This handles things like ROAs, BGPsec Router keys, ASPA info etc
type SendableDataManager interface {
GetCurrentSerial(uint16) (uint32, bool)
GetSessionId() uint16
GetCurrentSDs() ([]SendableData, bool)
GetSDsSerialDiff(uint32) ([]SendableData, bool)
}
type DefaultRTREventHandler struct {
sdManager SendableDataManager
Log Logger
}
func (e *DefaultRTREventHandler) SetSDManager(m SendableDataManager) {
e.sdManager = m
}
func (e *DefaultRTREventHandler) RequestCache(c *Client) {
if e.Log != nil {
e.Log.Debugf("%v > Request Cache", c)
}
sessionId := e.sdManager.GetSessionId()
serial, valid := e.sdManager.GetCurrentSerial(sessionId)
if !valid {
c.SendNoDataError()
if e.Log != nil {
e.Log.Debugf("%v < No data", c)
}
} else {
data, exists := e.sdManager.GetCurrentSDs()
if !exists {
c.SendInternalError()
if e.Log != nil {
e.Log.Debugf("%v < Internal error requesting cache (does not exists)", c)
}
} else {
c.SendSDs(sessionId, serial, data)
if e.Log != nil {
e.Log.Debugf("%v < Sent cache (current serial %d, session: %d)", c, serial, sessionId)
}
}
}
}
func (e *DefaultRTREventHandler) RequestNewVersion(c *Client, sessionId uint16, serialNumber uint32) {
if e.Log != nil {
e.Log.Debugf("%v > Request New Version", c)
}
serverSessionId := e.sdManager.GetSessionId()
if sessionId != serverSessionId {
c.SendCorruptData()
if e.Log != nil {
e.Log.Debugf("%v < Invalid request (client asked for session %d but server is at %d)", c, sessionId, serverSessionId)
}
c.Disconnect()
return
}
serial, valid := e.sdManager.GetCurrentSerial(sessionId)
if !valid {
c.SendNoDataError()
if e.Log != nil {
e.Log.Debugf("%v < No data", c)
}
} else {
data, exists := e.sdManager.GetSDsSerialDiff(serialNumber)
if !exists {
c.SendCacheReset()
if e.Log != nil {
e.Log.Debugf("%v < Sent cache reset", c)
}
} else {
c.SendSDs(sessionId, serial, data)
if e.Log != nil {
e.Log.Debugf("%v < Sent cache (current serial %d, session from client: %d)", c, serial, sessionId)
}
}
}
}
type Server struct {
baseVersion uint8
clientlock *sync.RWMutex
clients []*Client
sessId uint16
connected int
maxconn int
sshconfig *ssh.ServerConfig
handler RTRServerEventHandler
simpleHandler RTREventHandler
enforceVersion bool
sdlock *sync.RWMutex
sdListDiff [][]SendableData
sdMapSerial map[uint32]int
sdListSerial []uint32
sdCurrent []SendableData
sdCurrentSerial uint32
keepDiff int
pduRefreshInterval uint32
pduRetryInterval uint32
pduExpireInterval uint32
log Logger
logverbose bool
}
type ServerConfiguration struct {
MaxConn int
ProtocolVersion uint8
EnforceVersion bool
KeepDifference int
SessId int
RefreshInterval uint32
RetryInterval uint32
ExpireInterval uint32
Log Logger
LogVerbose bool
}
func NewServer(configuration ServerConfiguration, handler RTRServerEventHandler, simpleHandler RTREventHandler) *Server {
sessid := GenerateSessionId()
refreshInterval := uint32(3600)
if configuration.RefreshInterval != 0 {
refreshInterval = configuration.RefreshInterval
}
retryInterval := uint32(600)
if configuration.RetryInterval != 0 {
retryInterval = configuration.RetryInterval
}
expireInterval := uint32(7200)
if configuration.ExpireInterval != 0 {
expireInterval = configuration.ExpireInterval
}
return &Server{
sdlock: &sync.RWMutex{},
sdListDiff: make([][]SendableData, 0),
sdMapSerial: make(map[uint32]int),
sdListSerial: make([]uint32, 0),
sdCurrent: make([]SendableData, 0),
keepDiff: configuration.KeepDifference,
clientlock: &sync.RWMutex{},
clients: make([]*Client, 0),
sessId: sessid,
maxconn: configuration.MaxConn,
baseVersion: configuration.ProtocolVersion,
enforceVersion: configuration.EnforceVersion,
handler: handler,
simpleHandler: simpleHandler,
pduRefreshInterval: refreshInterval,
pduRetryInterval: retryInterval,
pduExpireInterval: expireInterval,
log: configuration.Log,
logverbose: configuration.LogVerbose,
}
}
func ConvertSDListToMap(SDs []SendableData) map[string]uint8 {
sdMap := make(map[string]uint8, len(SDs))
for _, v := range SDs {
sdMap[v.HashKey()] = v.GetFlag()
}
return sdMap
}
func ComputeDiff(newSDs, prevSDs []SendableData, populateUnchanged bool) (added, removed, unchanged []SendableData) {
added = make([]SendableData, 0)
removed = make([]SendableData, 0)
unchanged = make([]SendableData, 0)
newSDsMap := ConvertSDListToMap(newSDs)
prevSDsMap := ConvertSDListToMap(prevSDs)
for _, item := range newSDs {
_, exists := prevSDsMap[item.HashKey()]
if !exists {
rcopy := item.Copy()
rcopy.SetFlag(FLAG_ADDED)
added = append(added, rcopy)
}
}
for _, item := range prevSDs {
_, exists := newSDsMap[item.HashKey()]
if !exists {
rcopy := item.Copy()
rcopy.SetFlag(FLAG_REMOVED)
removed = append(removed, rcopy)
} else if populateUnchanged {
rcopy := item.Copy()
unchanged = append(unchanged, rcopy)
}
}
return added, removed, unchanged
}
func ApplyDiff(diff, prevSDs []SendableData) []SendableData {
newSDs := make([]SendableData, 0)
diffMap := ConvertSDListToMap(diff)
prevSDsMap := ConvertSDListToMap(prevSDs)
for _, item := range prevSDs {
_, exists := diffMap[item.HashKey()]
if !exists {
rcopy := item.Copy()
newSDs = append(newSDs, rcopy)
}
}
for _, item := range diff {
if item.GetFlag() == FLAG_ADDED {
rcopy := item.Copy()
newSDs = append(newSDs, rcopy)
} else if item.GetFlag() == FLAG_REMOVED {
citem, exists := prevSDsMap[item.HashKey()]
if !exists {
rcopy := item.Copy()
newSDs = append(newSDs, rcopy)
} else {
if citem == FLAG_REMOVED {
rcopy := item.Copy()
newSDs = append(newSDs, rcopy)
}
}
}
}
return newSDs
}
func (s *Server) GetSessionId() uint16 {
return s.sessId
}
func (s *Server) GetCurrentSDs() ([]SendableData, bool) {
s.sdlock.RLock()
sd := s.sdCurrent
s.sdlock.RUnlock()
return sd, true
}
func (s *Server) GetSDsSerialDiff(serial uint32) ([]SendableData, bool) {
s.sdlock.RLock()
sd, ok := s.getSDsSerialDiff(serial)
s.sdlock.RUnlock()
return sd, ok
}
func (s *Server) getSDsSerialDiff(serial uint32) ([]SendableData, bool) {
if serial == s.sdCurrentSerial {
return []SendableData{}, true
}
sd := make([]SendableData, 0)
index, ok := s.sdMapSerial[serial]
if ok {
sd = s.sdListDiff[index]
}
return sd, ok
}
func (s *Server) GetCurrentSerial(sessId uint16) (uint32, bool) {
s.sdlock.RLock()
serial, valid := s.getCurrentSerial()
s.sdlock.RUnlock()
return serial, valid
}
func (s *Server) getCurrentSerial() (uint32, bool) {
return s.sdCurrentSerial, len(s.sdListSerial) > 0
}
func (s *Server) GenerateSerial() uint32 {
s.sdlock.RLock()
newserial := s.generateSerial()
s.sdlock.RUnlock()
return newserial
}
func (s *Server) generateSerial() uint32 {
newserial := s.sdCurrentSerial
if len(s.sdListSerial) > 0 {
newserial = s.sdListSerial[len(s.sdListSerial)-1] + 1
}
return newserial
}
func (s *Server) setSerial(serial uint32) {
s.sdCurrentSerial = serial
}
// This function sets the serial. Function must
// be called before the cache data is added.
func (s *Server) SetSerial(serial uint32) {
s.sdlock.RLock()
defer s.sdlock.RUnlock()
//s.sdListSerial = make([]uint32, 0)
s.setSerial(serial)
}
func (s *Server) CountSDs() int {
s.sdlock.RLock()
defer s.sdlock.RUnlock()
return len(s.sdCurrent)
}
func (s *Server) AddData(new []SendableData) bool {
s.sdlock.RLock()
added, removed, _ := ComputeDiff(new, s.sdCurrent, false)
if s.log != nil && s.logverbose {
s.log.Debugf("Computed diff: added (%v), removed (%v)", added, removed)
} else if s.log != nil {
s.log.Debugf("Computed diff: added (%d), removed (%d)", len(added), len(removed))
}
curDiff := append(added, removed...)
s.sdlock.RUnlock()
if len(curDiff) == 0 {
return false
} else {
s.AddSDsDiff(curDiff)
return true
}
}
func (s *Server) addSerial(serial uint32) []uint32 {
removed := make([]uint32, 0)
if len(s.sdListSerial) >= s.keepDiff && s.keepDiff > 0 {
removeDiff := len(s.sdListSerial) - s.keepDiff
removed = s.sdListSerial[0:removeDiff]
s.sdListSerial = s.sdListSerial[removeDiff:]
}
s.sdListSerial = append(s.sdListSerial, serial)
return removed
}
func (s *Server) AddSDsDiff(diff []SendableData) {
s.sdlock.RLock()
nextDiff := make([][]SendableData, len(s.sdListDiff))
for i, prevSDs := range s.sdListDiff {
nextDiff[i] = ApplyDiff(diff, prevSDs)
}
newSDCurrent := ApplyDiff(diff, s.sdCurrent)
curserial, _ := s.getCurrentSerial()
s.sdlock.RUnlock()
s.sdlock.Lock()
defer s.sdlock.Unlock()
newserial := s.generateSerial()
removed := s.addSerial(newserial)
nextDiff = append(nextDiff, diff)
if len(nextDiff) >= s.keepDiff && s.keepDiff > 0 {
nextDiff = nextDiff[len(removed):]
}
s.sdMapSerial[curserial] = len(nextDiff) - 1
if len(removed) > 0 {
for k, v := range s.sdMapSerial {
if k != curserial {
s.sdMapSerial[k] = v - len(removed)
}
}
}
for _, removeSerial := range removed {
delete(s.sdMapSerial, removeSerial)
}
s.sdListDiff = nextDiff
s.sdCurrent = newSDCurrent
s.setSerial(newserial)
}
func (s *Server) SetBaseVersion(version uint8) {
s.baseVersion = version
}
func (s *Server) SetVersionEnforced(adapt bool) {
s.enforceVersion = adapt
}
func (s *Server) SetMaxConnections(maxconn int) {
if s.connected > maxconn {
todisconnect := s.connected - maxconn
clients := s.GetClientList()
if s.log != nil {
s.log.Debugf("Too many clients connected, disconnecting first %v", todisconnect)
}
for i := 0; i < todisconnect; i++ {
if len(clients) > i {
clients[i].Disconnect()
}
}
}
s.maxconn = maxconn
}
func (s *Server) GetMaxConnections() int {
return s.maxconn
}
func (s *Server) SetSessionId(sessId uint16) {
s.sessId = sessId
}
func (s *Server) ClientConnected(c *Client) {
s.clientlock.Lock()
s.clients = append(s.clients, c)
s.connected++
s.clientlock.Unlock()
if s.handler != nil {
s.handler.ClientConnected(c)
}
}
func (s *Server) ClientDisconnected(c *Client) {
s.clientlock.Lock()
tmpclients := make([]*Client, 0)
for _, cc := range s.clients {
if cc != c {
tmpclients = append(tmpclients, cc)
}
}
s.clients = tmpclients
s.connected--
s.clientlock.Unlock()
if s.handler != nil {
s.handler.ClientDisconnected(c)
}
}
func (s *Server) HandlePDU(c *Client, pdu PDU) {
if s.enforceVersion && c.GetVersion() != s.baseVersion {
// Enforce a single version
if s.log != nil {
s.log.Debugf("Client %v uses version %v and server is using %v", c.String(), c.GetVersion(), s.baseVersion)
}
c.SendWrongVersionError()
c.Disconnect()
}
if c.GetVersion() > s.baseVersion {
// Downgrade
c.SetVersion(s.baseVersion)
}
if s.handler != nil {
s.handler.HandlePDU(c, pdu)
}
}
func (s *Server) RequestCache(c *Client) {
if s.simpleHandler != nil {
s.simpleHandler.RequestCache(c)
}
}
func (s *Server) RequestNewVersion(c *Client, sessionId uint16, serial uint32) {
if s.simpleHandler != nil {
s.simpleHandler.RequestNewVersion(c, sessionId, serial)
}
}
func (s *Server) Start(bind string) error {
tcplist, err := net.Listen("tcp", bind)
if err != nil {
return err
}
return s.loopTCP(tcplist, "tcp", s.acceptClientTCP)
}
var DisableBGPSec = flag.Bool("disable.bgpsec", false, "Disable sending out BGPSEC Router Keys")
var DisableASPA = flag.Bool("disable.aspa", false, "Disable sending out ASPA objects")
var EnableNODELAY = flag.Bool("enable.nodelay", false, "Force enable TCP NODELAY (Likely increases CPU)")
func (s *Server) acceptClientTCP(tcpconn net.Conn) error {
if !*EnableNODELAY {
tc, ok := tcpconn.(*net.TCPConn)
if ok {
tc.SetNoDelay(false)
}
}
client := ClientFromConn(tcpconn, s, s)
client.log = s.log
if s.enforceVersion {
client.SetVersion(s.baseVersion)
}
client.SetIntervals(s.pduRefreshInterval, s.pduRetryInterval, s.pduExpireInterval)
if *DisableBGPSec {
client.DisableBGPsec()
}
if *DisableASPA {
client.DisableASPA()
}
go client.Start()
return nil
}
func (s *Server) acceptClientSSH(tcpconn net.Conn) error {
_, chans, reqs, err := ssh.NewServerConn(tcpconn, s.sshconfig)
if err != nil {
return err
}
go func() {
s.connected++
cont := true
for cont {
select {
case req := <-reqs:
if req != nil && req.WantReply {
req.Reply(false, nil)
} else if req == nil {
cont = false
break
}
case newChannel := <-chans:
if newChannel != nil && newChannel.ChannelType() != "session" {
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
} else if newChannel == nil {
cont = false
break
}
channel, requests, err := newChannel.Accept()
if err != nil {
if s.log != nil {
s.log.Errorf("Could not accept channel: %v", err)
}
cont = false
break
}
for req := range requests {
if req != nil && req.Type == "subsystem" && bytes.Equal(req.Payload, []byte{0, 0, 0, 8, 114, 112, 107, 105, 45, 114, 116, 114}) {
err := req.Reply(true, nil)
if err != nil {
if s.log != nil {
s.log.Errorf("Could not accept channel: %v", err)
}
cont = false
break
}
client := ClientFromConnSSH(tcpconn, channel, s, s)
client.log = s.log
if s.enforceVersion {
client.SetVersion(s.baseVersion)
}
client.SetIntervals(s.pduRefreshInterval, s.pduRetryInterval, s.pduExpireInterval)
client.Start()
} else {
cont = false
break
}
}
}
}
s.connected--
tcpconn.Close()
}()
return nil
}
type ClientCallback func(net.Conn) error
func (s *Server) loopTCP(tcplist net.Listener, logEnv string, clientCallback ClientCallback) error {
for {
tcpconn, err := tcplist.Accept()
if err != nil {
if s.log != nil {
s.log.Errorf("Failed to accept %s connection: %s", logEnv, err)
}
continue
}
if s.maxconn > 0 && s.connected >= s.maxconn {
if s.log != nil {
s.log.Warnf("Could not accept %s connection from %v (not enough slots available: %d)", logEnv, tcpconn.RemoteAddr(), s.maxconn)
}
tcpconn.Close()
} else {
if s.log != nil {
s.log.Infof("Accepted %s connection from %v (%d/%d)", logEnv, tcpconn.RemoteAddr(), s.connected+1, s.maxconn)
}
if clientCallback != nil {
err := clientCallback(tcpconn)
if err != nil && s.log != nil {
s.log.Errorf("Error with %s client %v: %v", logEnv, tcpconn.RemoteAddr(), err)
}
}
}
}
}
func (s *Server) StartSSH(bind string, config *ssh.ServerConfig) error {
tcplist, err := net.Listen("tcp", bind)
if err != nil {
return err
}
s.sshconfig = config
return s.loopTCP(tcplist, "ssh", s.acceptClientSSH)
}
func (s *Server) StartTLS(bind string, config *tls.Config) error {
tcplist, err := tls.Listen("tcp", bind, config)
if err != nil {
return err
}
return s.loopTCP(tcplist, "tls", s.acceptClientTCP)
}
func (s *Server) GetClientList() []*Client {
s.clientlock.RLock()
list := make([]*Client, len(s.clients))
copy(list, s.clients)
s.clientlock.RUnlock()
return list
}
func (s *Server) NotifyClientsLatest() {
serial, _ := s.GetCurrentSerial(s.sessId)
s.NotifyClients(serial)
}
func (s *Server) NotifyClients(serialNumber uint32) {
clients := s.GetClientList()
for _, c := range clients {
c.Notify(s.sessId, serialNumber)
}
}
func (s *Server) SendPDU(pdu PDU) {
for _, client := range s.clients {
client.SendPDU(pdu)
}
}
func ClientFromConn(tcpconn net.Conn, handler RTRServerEventHandler, simpleHandler RTREventHandler) *Client {
return &Client{
tcpconn: tcpconn,
rd: tcpconn,
wr: tcpconn,
handler: handler,
simpleHandler: simpleHandler,
transmits: make(chan PDU, 256),
quit: make(chan bool),
}
}
func ClientFromConnSSH(tcpconn net.Conn, channel ssh.Channel, handler RTRServerEventHandler, simpleHandler RTREventHandler) *Client {
client := ClientFromConn(tcpconn, handler, simpleHandler)
client.rd = channel
client.wr = channel
return client
}
type Client struct {
connected bool
version uint8
versionset bool
tcpconn net.Conn
rd io.Reader
wr io.Writer
handler RTRServerEventHandler
simpleHandler RTREventHandler
curserial uint32
transmits chan PDU
quit chan bool
enforceVersion bool
disableVersionCheck bool
refreshInterval uint32
retryInterval uint32
expireInterval uint32
dontSendBGPsecKeys bool
dontSendASPA bool
log Logger
}
func (c *Client) String() string {
return fmt.Sprintf("%v (v%v) / Serial: %v", c.tcpconn.RemoteAddr(), c.version, c.curserial)
}
func (c *Client) GetRemoteAddress() net.Addr {
return c.tcpconn.RemoteAddr()
}
func (c *Client) GetLocalAddress() net.Addr {
return c.tcpconn.LocalAddr()
}
func (c *Client) GetVersion() uint8 {
return c.version
}
func (c *Client) DisableBGPsec() {
c.dontSendBGPsecKeys = true
}
func (c *Client) DisableASPA() {
c.dontSendASPA = true
}
func (c *Client) SetIntervals(refreshInterval uint32, retryInterval uint32, expireInterval uint32) {
c.refreshInterval = refreshInterval
c.retryInterval = retryInterval
c.expireInterval = expireInterval
}
func (c *Client) SetVersion(newversion uint8) {
c.versionset = true
c.version = newversion
}
func (c *Client) SetDisableVersionCheck(disableCheck bool) {
c.disableVersionCheck = disableCheck
}
func (c *Client) checkVersion(newversion uint8) {
if (!c.versionset || newversion == c.version) && (newversion == PROTOCOL_VERSION_2 || newversion == PROTOCOL_VERSION_1 || newversion == PROTOCOL_VERSION_0) {
c.SetVersion(newversion)
} else {
if c.log != nil {
c.log.Debugf("%v: has bad version (received: v%v, current: v%v) error", c.String(), newversion, c.version)
}
c.SendWrongVersionError()
c.Disconnect()
}
}
func (c *Client) passSimpleHandler(pdu PDU) {
if c.simpleHandler != nil {
switch pduConv := pdu.(type) {
case *PDUSerialQuery:
c.simpleHandler.RequestNewVersion(c, pduConv.SessionId, pduConv.SerialNumber)
case *PDUResetQuery:
c.simpleHandler.RequestCache(c)
default:
// not a proper client packet
}
}
}
func (c *Client) sendLoop() {
defer c.tcpconn.Close()
for c.connected {
select {
case pdu := <-c.transmits:
c.wr.Write(pdu.Bytes())
case <-c.quit:
return
}
}
}
func (c *Client) Start() {
c.connected = true
if c.handler != nil {
c.handler.ClientConnected(c)
}
go c.sendLoop()
buf := make([]byte, 8000)
for c.connected {
// Remove this?
length, err := c.rd.Read(buf)
if err != nil || length == 0 {
if c.log != nil {
c.log.Debugf("Error %v", err)
}
c.Disconnect()
return
}
pkt := buf[0:length]
dec, err := DecodeBytes(pkt)
if err != nil || dec == nil {
if c.log != nil {
c.log.Errorf("Error %v", err)
}
c.Disconnect()
continue
}
if !c.disableVersionCheck {
c.checkVersion(dec.GetVersion())
}
if c.log != nil {
c.log.Debugf("%v: Received %v", c.String(), dec)
}
if c.enforceVersion {
if !IsCorrectPDUVersion(dec, c.version) {
if c.log != nil {
c.log.Debugf("Bad version error")
}
c.SendWrongVersionError()
c.Disconnect()
}
}
switch pduconv := dec.(type) {
case *PDUSerialQuery:
c.curserial = pduconv.SerialNumber
}
if c.handler != nil {
c.handler.HandlePDU(c, dec)
}
c.passSimpleHandler(dec)
}
}
func (c *Client) Notify(sessionId uint16, serialNumber uint32) {
pdu := &PDUSerialNotify{
SessionId: sessionId,
SerialNumber: serialNumber,
}
c.SendPDU(pdu)
}
type VRP struct {
Prefix netip.Prefix
ASN uint32
MaxLen uint8
Flags uint8
}
func (r *VRP) Type() string {
return "VRP"
}
func (r *VRP) String() string {
return fmt.Sprintf("VRP %v -> /%v, AS%v, Flags: %v", r.Prefix.String(), r.MaxLen, r.ASN, r.Flags)
}
func (vrp *VRP) HashKey() string {
return fmt.Sprintf("%v-%v-%v", vrp.Prefix.String(), vrp.MaxLen, vrp.ASN)
}
func (r1 *VRP) Equals(r2 SendableData) bool {
if r1.Type() != r2.Type() {
return false
}
r2True := r2.(*VRP)
return r1.MaxLen == r2True.MaxLen && r1.ASN == r2True.ASN && r1.Prefix == r2True.Prefix
}
func (r1 *VRP) Copy() SendableData {
return &VRP{
Prefix: r1.Prefix,
ASN: r1.ASN,
MaxLen: r1.MaxLen,
Flags: r1.Flags}
}
func (r1 *VRP) SetFlag(f uint8) {
r1.Flags = f
}
func (r1 *VRP) GetFlag() uint8 {
return r1.Flags
}
type BgpsecKey struct {
ASN uint32
Pubkey []byte
Ski []byte
Flags uint8
}
func (brk *BgpsecKey) Type() string {
return "BGPsecKey"
}
func (brk *BgpsecKey) String() string {
return fmt.Sprintf("BGPsec AS%v -> %x, Flags: %v", brk.ASN, brk.Ski, brk.Flags)
}
func (brk *BgpsecKey) HashKey() string {
return fmt.Sprintf("%v-%x-%x", brk.ASN, brk.Ski, brk.Pubkey)
}
func (r1 *BgpsecKey) Equals(r2 SendableData) bool {
if r1.Type() != r2.Type() {
return false
}
r2True := r2.(*BgpsecKey)
return r1.ASN == r2True.ASN && bytes.Equal(r1.Pubkey, r2True.Pubkey) && bytes.Equal(r1.Ski, r2True.Ski)
}
func (brk *BgpsecKey) Copy() SendableData {
cop := BgpsecKey{
ASN: brk.ASN,
Pubkey: make([]byte, len(brk.Pubkey)),
Ski: make([]byte, len(brk.Ski)),
Flags: brk.Flags,
}
copy(cop.Pubkey, brk.Pubkey)
copy(cop.Ski, brk.Ski)
return &cop
}
func (brk *BgpsecKey) SetFlag(f uint8) {
brk.Flags = f
}
func (brk *BgpsecKey) GetFlag() uint8 {
return brk.Flags
}
type VAP struct {
Flags uint8
CustomerASN uint32
Providers []uint32
}
func (vap *VAP) Type() string {
return "ASPA"
}
func (vap *VAP) String() string {
return fmt.Sprintf("ASPA AS%v -> Providers: %v", vap.CustomerASN, vap.Providers)
}
func (vap *VAP) HashKey() string {
return fmt.Sprintf("%v-%v", vap.CustomerASN, vap.Providers)
}
func (r1 *VAP) Equals(r2 SendableData) bool {
if r1.Type() != r2.Type() {
return false
}
r2True := r2.(*VAP)
return r1.CustomerASN == r2True.CustomerASN && fmt.Sprint(r1.Providers) == fmt.Sprint(r2True.Providers) /*This could be made faster*/
}
func (vap *VAP) Copy() SendableData {
cop := VAP{
CustomerASN: vap.CustomerASN,
Flags: vap.Flags,
Providers: make([]uint32, 0),
}
cop.Providers = append(cop.Providers, vap.Providers...)
return &cop
}
func (vap *VAP) SetFlag(f uint8) {
vap.Flags = f
}
func (vap *VAP) GetFlag() uint8 {
return vap.Flags
}
func (c *Client) SendSDs(sessionId uint16, serialNumber uint32, data []SendableData) {
pduBegin := &PDUCacheResponse{
SessionId: sessionId,
}
c.SendPDU(pduBegin)
for _, item := range data {
c.SendData(item.Copy())
}
pduEnd := &PDUEndOfData{
SessionId: sessionId,
SerialNumber: serialNumber,
RefreshInterval: c.refreshInterval,
RetryInterval: c.retryInterval,
ExpireInterval: c.expireInterval,
}
c.SendPDU(pduEnd)
}
func (c *Client) SendCacheReset() {
pdu := &PDUCacheReset{}
c.SendPDU(pdu)
}
func (c *Client) SendInternalError() {
pdu := &PDUErrorReport{
ErrorCode: PDU_ERROR_INTERNALERR,
ErrorMsg: "Unknown internal error",
}
c.SendPDU(pdu)
}
func (c *Client) SendNoDataError() {
pdu := &PDUErrorReport{
ErrorCode: PDU_ERROR_NODATA,
ErrorMsg: "No data available",
}
c.SendPDU(pdu)
}
func (c *Client) SendCorruptData() {
pdu := &PDUErrorReport{
ErrorCode: PDU_ERROR_CORRUPTDATA,
ErrorMsg: "Session ID mismatch: client is desynchronized",
}
c.SendPDU(pdu)
}
func (c *Client) SendWrongVersionError() {
pdu := &PDUErrorReport{
ErrorCode: PDU_ERROR_BADPROTOVERSION,
ErrorMsg: "Bad protocol version",
}
c.SendPDU(pdu)
}
// Converts a SendableData to a PDU and sends it to the client
func (c *Client) SendData(sd SendableData) {
switch t := sd.(type) {
case *VRP:
if t.Prefix.Addr().Is6() {
pdu := &PDUIPv6Prefix{
Flags: t.Flags,
MaxLen: t.MaxLen,
ASN: t.ASN,
Prefix: t.Prefix,
}
c.SendPDU(pdu)
} else if t.Prefix.Addr().Is4() {
pdu := &PDUIPv4Prefix{
Flags: t.Flags,
MaxLen: t.MaxLen,
ASN: t.ASN,
Prefix: t.Prefix,
}
c.SendPDU(pdu)
}
case *BgpsecKey:
if c.version == 0 || c.dontSendBGPsecKeys {
return
}
pdu := &PDURouterKey{
Version: c.version, // The RouterKey PDU is unchanged from rfc8210 to draft-ietf-sidrops-8210bis-10
Flags: t.Flags,
SubjectKeyIdentifier: t.Ski,
ASN: t.ASN,
SubjectPublicKeyInfo: t.Pubkey,
}
c.SendPDU(pdu)
case *VAP:
if c.version < 2 || c.dontSendASPA {
return
}
pdu4 := &PDUASPA{
Version: c.version,
Flags: t.Flags,
AFIFlags: AFI_IPv4,
ProviderASCount: uint16(len(t.Providers)),
CustomerASNumber: t.CustomerASN,
ProviderASNumbers: t.Providers,
}
pdu6 := &PDUASPA{
Version: c.version,
Flags: t.Flags,
AFIFlags: AFI_IPv6,
ProviderASCount: uint16(len(t.Providers)),
CustomerASNumber: t.CustomerASN,
ProviderASNumbers: t.Providers,
}
c.SendPDU(pdu4)
c.SendPDU(pdu6)
}
}
func (c *Client) SendRawPDU(pdu PDU) {
c.transmits <- pdu
}
func (c *Client) SendPDU(pdu PDU) {
pdu.SetVersion(c.version)
c.SendRawPDU(pdu)
}
func (c *Client) Disconnect() {
c.connected = false
if c.log != nil {
c.log.Infof("Disconnecting client %v", c.String())
}
if c.handler != nil {
c.handler.ClientDisconnected(c)
}
select {
case c.quit <- true:
default:
}
}