1
0
mirror of https://github.com/bgp/stayrtr.git synced 2024-05-06 15:54:54 +00:00
bgp-stayrtr/lib/server.go
Louis Poinsignon 8672363e36 Go fmt
2019-04-05 15:42:54 -07:00

951 lines
20 KiB
Go

package rtrlib
import (
"bytes"
"crypto/tls"
"fmt"
"golang.org/x/crypto/ssh"
"io"
"math/rand"
"net"
"sync"
"time"
)
func GenerateSessionId() uint16 {
var sessid uint16
r := rand.New(rand.NewSource(time.Now().UTC().Unix()))
sessid = uint16(r.Uint32())
return sessid
}
type RTRServerEventHandler interface {
ClientConnected(*Client)
ClientDisconnected(*Client)
HandlePDU(*Client, PDU)
}
type RTREventHandler interface {
RequestCache(*Client)
RequestNewVersion(*Client, uint16, uint32)
}
type ROAManager interface {
GetCurrentSerial(uint16) (uint32, bool)
GetSessionId(*Client) (uint16, error)
GetCurrentROAs() ([]ROA, bool)
GetROAsSerialDiff(uint32) ([]ROA, bool)
}
type DefaultRTREventHandler struct {
roaManager ROAManager
Log Logger
}
func (e *DefaultRTREventHandler) SetROAManager(m ROAManager) {
e.roaManager = m
}
func (e *DefaultRTREventHandler) RequestCache(c *Client) {
if e.Log != nil {
e.Log.Debugf("%v > Request Cache", c)
}
sessionId, _ := e.roaManager.GetSessionId(c)
serial, valid := e.roaManager.GetCurrentSerial(sessionId)
if !valid {
c.SendNoDataError()
if e.Log != nil {
e.Log.Debugf("%v < No data", c)
}
} else {
roas, exists := e.roaManager.GetCurrentROAs()
if !exists {
c.SendInternalError()
if e.Log != nil {
e.Log.Debugf("%v < Internal error requesting cache (does not exists)", c)
}
} else {
c.SendROAs(sessionId, serial, roas)
if e.Log != nil {
e.Log.Debugf("%v < Sent ROAs (current serial %v)", c, serial)
}
}
}
}
func (e *DefaultRTREventHandler) RequestNewVersion(c *Client, sessionId uint16, serialNumber uint32) {
if e.Log != nil {
e.Log.Debugf("%v > Request New Version", c)
}
serial, valid := e.roaManager.GetCurrentSerial(sessionId)
if !valid {
c.SendNoDataError()
if e.Log != nil {
e.Log.Debugf("%v < No data", c)
}
} else {
roas, exists := e.roaManager.GetROAsSerialDiff(serialNumber)
if !exists {
c.SendCacheReset()
if e.Log != nil {
e.Log.Debugf("%v < Sent cache reset", c)
}
} else {
c.SendROAs(sessionId, serial, roas)
if e.Log != nil {
e.Log.Debugf("%v < Sent ROAs (current serial %v)", c, serial)
}
}
}
}
type Server struct {
baseVersion uint8
clientlock *sync.RWMutex
clients []*Client
sessId uint16
connected int
maxconn int
sshconfig *ssh.ServerConfig
handler RTRServerEventHandler
simpleHandler RTREventHandler
enforceVersion bool
roalock *sync.RWMutex
roaListDiff [][]ROA
roaMapSerial map[uint32]int
roaListSerial []uint32
roaCurrent []ROA
roaCurrentSerial uint32
keepDiff int
pduRefreshInterval uint32
pduRetryInterval uint32
pduExpireInterval uint32
log Logger
}
type ServerConfiguration struct {
MaxConn int
ProtocolVersion uint8
EnforceVersion bool
KeepDifference int
SessId int
RefreshInterval uint32
RetryInterval uint32
ExpireInterval uint32
Log Logger
}
func NewServer(configuration ServerConfiguration, handler RTRServerEventHandler, simpleHandler RTREventHandler) *Server {
var sessid uint16
if configuration.SessId < 0 {
GenerateSessionId()
} else {
sessid = uint16(configuration.SessId)
}
refreshInterval := uint32(3600)
if configuration.RefreshInterval != 0 {
refreshInterval = configuration.RefreshInterval
}
retryInterval := uint32(600)
if configuration.RetryInterval != 0 {
retryInterval = configuration.RetryInterval
}
expireInterval := uint32(7200)
if configuration.ExpireInterval != 0 {
expireInterval = configuration.ExpireInterval
}
return &Server{
roalock: &sync.RWMutex{},
roaListDiff: make([][]ROA, 0),
roaMapSerial: make(map[uint32]int),
roaListSerial: make([]uint32, 0),
roaCurrent: make([]ROA, 0),
keepDiff: configuration.KeepDifference,
clientlock: &sync.RWMutex{},
clients: make([]*Client, 0),
sessId: sessid,
maxconn: configuration.MaxConn,
baseVersion: configuration.ProtocolVersion,
enforceVersion: configuration.EnforceVersion,
handler: handler,
simpleHandler: simpleHandler,
pduRefreshInterval: refreshInterval,
pduRetryInterval: retryInterval,
pduExpireInterval: expireInterval,
log: configuration.Log,
}
}
func ComputeDiff(newRoas []ROA, prevRoas []ROA) ([]ROA, []ROA, []ROA) {
added := make([]ROA, 0)
removed := make([]ROA, 0)
unchanged := make([]ROA, 0)
for _, roa := range newRoas {
var exists bool
for _, croa := range prevRoas {
if roa.Equals(croa) {
exists = true
break
}
}
if !exists {
rcopy := roa.Copy()
rcopy.Flags = 1
added = append(added, rcopy)
}
}
for _, roa := range prevRoas {
var exists bool
for _, croa := range newRoas {
if roa.Equals(croa) {
rcopy := roa.Copy()
unchanged = append(unchanged, rcopy)
exists = true
break
}
}
if !exists {
rcopy := roa.Copy()
rcopy.Flags = 0
removed = append(removed, rcopy)
}
}
return added, removed, unchanged
}
func ApplyDiff(diff []ROA, prevRoas []ROA) []ROA {
newroas := make([]ROA, 0)
for _, roa := range prevRoas {
var exists bool
for _, croa := range diff {
if roa.Equals(croa) {
exists = true
break
}
}
if !exists {
rcopy := roa.Copy()
newroas = append(newroas, rcopy)
}
}
for _, roa := range diff {
if roa.Flags == FLAG_ADDED {
rcopy := roa.Copy()
newroas = append(newroas, rcopy)
} else if roa.Flags == FLAG_REMOVED {
var exists bool
for _, croa := range prevRoas {
if roa.Equals(croa) {
if croa.Flags == FLAG_REMOVED {
rcopy := roa.Copy()
newroas = append(newroas, rcopy)
}
exists = true
break
}
}
if !exists {
rcopy := roa.Copy()
newroas = append(newroas, rcopy)
}
}
}
return newroas
}
func (s *Server) GetSessionId(c *Client) (uint16, error) {
return s.sessId, nil
}
func (s *Server) GetCurrentROAs() ([]ROA, bool) {
s.roalock.RLock()
roa := s.roaCurrent
s.roalock.RUnlock()
return roa, true
}
func (s *Server) GetROAsSerialDiff(serial uint32) ([]ROA, bool) {
s.roalock.RLock()
roa, ok := s.getROAsSerialDiff(serial)
s.roalock.RUnlock()
return roa, ok
}
func (s *Server) getROAsSerialDiff(serial uint32) ([]ROA, bool) {
if serial == s.roaCurrentSerial {
return []ROA{}, true
}
roa := make([]ROA, 0)
index, ok := s.roaMapSerial[serial]
if ok {
roa = s.roaListDiff[index]
}
return roa, ok
}
func (s *Server) GetCurrentSerial(sessId uint16) (uint32, bool) {
s.roalock.RLock()
serial, valid := s.getCurrentSerial()
s.roalock.RUnlock()
return serial, valid
}
func (s *Server) getCurrentSerial() (uint32, bool) {
if len(s.roaListSerial) > 0 {
return s.roaCurrentSerial, true
} else {
return 0, false
}
}
func (s *Server) GenerateSerial() uint32 {
s.roalock.RLock()
newserial := s.generateSerial()
s.roalock.RUnlock()
return newserial
}
func (s *Server) generateSerial() uint32 {
newserial := uint32(1)
if len(s.roaListSerial) > 0 {
newserial = s.roaListSerial[len(s.roaListSerial)-1] + 1
}
return newserial
}
func (s *Server) AddROAs(roas []ROA) {
s.roalock.RLock()
curDiff := make([]ROA, 0)
roaCurrent := s.roaCurrent
added, removed, unchanged := ComputeDiff(roas, roaCurrent)
if s.log != nil {
s.log.Debugf("Computed diff: added (%v), removed (%v), unchanged (%v)", added, removed, unchanged)
}
curDiff = append(added, removed...)
s.roalock.RUnlock()
s.AddROAsDiff(curDiff)
}
func (s *Server) addSerial(serial uint32) []uint32 {
removed := make([]uint32, 0)
if len(s.roaListSerial) >= s.keepDiff && s.keepDiff > 0 {
removeDiff := len(s.roaListSerial) - s.keepDiff
removed = s.roaListSerial[0:removeDiff]
s.roaListSerial = s.roaListSerial[removeDiff:]
}
s.roaListSerial = append(s.roaListSerial, serial)
return removed
}
func (s *Server) AddROAsDiff(diff []ROA) {
s.roalock.RLock()
nextDiff := make([][]ROA, len(s.roaListDiff))
for i, prevRoas := range s.roaListDiff {
nextDiff[i] = ApplyDiff(diff, prevRoas)
}
newRoaCurrent := ApplyDiff(diff, s.roaCurrent)
curserial, valid := s.getCurrentSerial()
s.roalock.RUnlock()
s.roalock.Lock()
newserial := s.generateSerial()
removed := s.addSerial(newserial)
if valid {
nextDiff = append(nextDiff, diff)
if len(nextDiff) >= s.keepDiff && s.keepDiff > 0 {
nextDiff = nextDiff[len(removed):]
}
s.roaMapSerial[curserial] = len(nextDiff) - 1
if len(removed) > 0 {
for k, v := range s.roaMapSerial {
if k != curserial {
s.roaMapSerial[k] = v - len(removed)
}
}
}
}
for _, removeSerial := range removed {
delete(s.roaMapSerial, removeSerial)
}
s.roaListDiff = nextDiff
s.roaCurrent = newRoaCurrent
s.roaCurrentSerial = newserial
s.roalock.Unlock()
}
func (s *Server) SetBaseVersion(version uint8) {
s.baseVersion = version
}
func (s *Server) SetVersionEnforced(adapt bool) {
s.enforceVersion = adapt
}
func (s *Server) SetMaxConnections(maxconn int) {
if s.connected > maxconn {
todisconnect := s.connected - maxconn
clients := s.GetClientList()
if s.log != nil {
s.log.Debugf("Too many clients connected, disconnecting first %v", todisconnect)
}
for i := 0; i < todisconnect; i++ {
if len(clients) > i {
clients[i].Disconnect()
}
}
}
s.maxconn = maxconn
}
func (s *Server) GetMaxConnections() int {
return s.maxconn
}
func (s *Server) SetSessionId(sessId uint16) {
s.sessId = sessId
}
func (s *Server) ClientConnected(c *Client) {
s.clientlock.Lock()
s.clients = append(s.clients, c)
s.connected++
s.clientlock.Unlock()
if s.handler != nil {
s.handler.ClientConnected(c)
}
}
func (s *Server) ClientDisconnected(c *Client) {
s.clientlock.Lock()
tmpclients := make([]*Client, 0)
for _, cc := range s.clients {
if cc != c {
tmpclients = append(tmpclients, cc)
}
}
s.clients = tmpclients
s.connected--
s.clientlock.Unlock()
if s.handler != nil {
s.handler.ClientDisconnected(c)
}
}
func (s *Server) HandlePDU(c *Client, pdu PDU) {
if s.enforceVersion && c.GetVersion() != s.baseVersion {
// Enforce a single version
if s.log != nil {
s.log.Debugf("Client %v uses version %v and server is using %v", c.String(), c.GetVersion(), s.baseVersion)
}
c.SendWrongVersionError()
c.Disconnect()
}
if c.GetVersion() > s.baseVersion {
// Downgrade
c.SetVersion(s.baseVersion)
}
if s.handler != nil {
s.handler.HandlePDU(c, pdu)
}
}
func (s *Server) RequestCache(c *Client) {
if s.simpleHandler != nil {
s.simpleHandler.RequestCache(c)
}
}
func (s *Server) RequestNewVersion(c *Client, sessionId uint16, serial uint32) {
if s.simpleHandler != nil {
s.simpleHandler.RequestNewVersion(c, sessionId, serial)
}
}
func (s *Server) Start(bind string) error {
tcplist, err := net.Listen("tcp", bind)
if err != nil {
return err
}
s.loopTCP(tcplist, s.acceptClientTCP)
return nil
}
func (s *Server) acceptClientTCP(tcpconn net.Conn) error {
client := ClientFromConn(tcpconn, s, s)
client.log = s.log
if s.enforceVersion {
client.SetVersion(s.baseVersion)
}
client.SetIntervals(s.pduRefreshInterval, s.pduRetryInterval, s.pduExpireInterval)
go client.Start()
return nil
}
func (s *Server) acceptClientSSH(tcpconn net.Conn) error {
_, chans, reqs, err := ssh.NewServerConn(tcpconn, s.sshconfig)
if err != nil {
return err
}
go func() {
s.connected++
cont := true
for cont {
select {
case req := <-reqs:
if req != nil && req.WantReply {
req.Reply(false, nil)
} else if req == nil {
cont = false
break
}
case newChannel := <-chans:
if newChannel != nil && newChannel.ChannelType() != "session" {
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
} else if newChannel == nil {
cont = false
break
}
channel, requests, err := newChannel.Accept()
if err != nil {
if s.log != nil {
s.log.Errorf("Could not accept channel: %v", err)
}
cont = false
break
}
for req := range requests {
if req != nil && req.Type == "subsystem" && bytes.Equal(req.Payload, []byte{0, 0, 0, 8, 114, 112, 107, 105, 45, 114, 116, 114}) {
err := req.Reply(true, nil)
if err != nil {
if s.log != nil {
s.log.Errorf("Could not accept channel: %v", err)
}
cont = false
break
}
client := ClientFromConnSSH(tcpconn, channel, s, s)
client.log = s.log
if s.enforceVersion {
client.SetVersion(s.baseVersion)
}
client.SetIntervals(s.pduRefreshInterval, s.pduRetryInterval, s.pduExpireInterval)
client.Start()
} else {
cont = false
break
}
}
}
}
s.connected--
tcpconn.Close()
}()
return nil
}
type ClientCallback func(net.Conn) error
func (s *Server) loopTCP(tcplist net.Listener, clientCallback ClientCallback) {
for {
tcpconn, _ := tcplist.Accept()
if s.maxconn > 0 && s.connected >= s.maxconn {
if s.log != nil {
s.log.Warnf("Could not accept connection from %v (not enough slots avaible: %v)", tcpconn.RemoteAddr(), s.maxconn)
}
tcpconn.Close()
} else {
if s.log != nil {
s.log.Infof("Accepted connection from %v (%v/%v)", tcpconn.RemoteAddr(), s.connected+1, s.maxconn)
}
if clientCallback != nil {
err := clientCallback(tcpconn)
if err != nil && s.log != nil {
s.log.Errorf("Error with client %v: %v", tcpconn.RemoteAddr(), err)
}
}
}
}
}
func (s *Server) StartSSH(bind string, config *ssh.ServerConfig) error {
tcplist, err := net.Listen("tcp", bind)
if err != nil {
return err
}
s.sshconfig = config
s.loopTCP(tcplist, s.acceptClientSSH)
return nil
}
func (s *Server) StartTLS(bind string, config *tls.Config) error {
tcplist, err := tls.Listen("tcp", bind, config)
if err != nil {
return err
}
s.loopTCP(tcplist, s.acceptClientTCP)
return nil
}
func (s *Server) GetClientList() []*Client {
s.clientlock.RLock()
list := make([]*Client, len(s.clients))
for i, c := range s.clients {
list[i] = c
}
s.clientlock.RUnlock()
return list
}
func (s *Server) NotifyClientsLatest() {
serial, _ := s.GetCurrentSerial(s.sessId)
s.NotifyClients(serial)
}
func (s *Server) NotifyClients(serialNumber uint32) {
clients := s.GetClientList()
for _, c := range clients {
c.Notify(s.sessId, serialNumber)
}
}
func (s *Server) SendPDU(pdu PDU) {
for _, client := range s.clients {
client.SendPDU(pdu)
}
}
func ClientFromConn(tcpconn net.Conn, handler RTRServerEventHandler, simpleHandler RTREventHandler) *Client {
return &Client{
tcpconn: tcpconn,
rd: tcpconn,
wr: tcpconn,
handler: handler,
simpleHandler: simpleHandler,
transmits: make(chan PDU, 256),
quit: make(chan bool),
}
}
func ClientFromConnSSH(tcpconn net.Conn, channel ssh.Channel, handler RTRServerEventHandler, simpleHandler RTREventHandler) *Client {
client := ClientFromConn(tcpconn, handler, simpleHandler)
client.rd = channel
client.wr = channel
return client
}
type Client struct {
connected bool
version uint8
versionset bool
tcpconn net.Conn
rd io.Reader
wr io.Writer
handler RTRServerEventHandler
simpleHandler RTREventHandler
curserial uint32
transmits chan PDU
quit chan bool
enforceVersion bool
disableVersionCheck bool
refreshInterval uint32
retryInterval uint32
expireInterval uint32
log Logger
}
func (c *Client) String() string {
return fmt.Sprintf("%v (v%v) / Serial: %v", c.tcpconn.RemoteAddr(), c.version, c.curserial)
}
func (c *Client) GetRemoteAddress() net.Addr {
return c.tcpconn.RemoteAddr()
}
func (c *Client) GetLocalAddress() net.Addr {
return c.tcpconn.LocalAddr()
}
func (c *Client) GetVersion() uint8 {
return c.version
}
func (c *Client) SetIntervals(refreshInterval uint32, retryInterval uint32, expireInterval uint32) {
c.refreshInterval = refreshInterval
c.retryInterval = retryInterval
c.expireInterval = expireInterval
}
func (c *Client) SetVersion(newversion uint8) {
c.versionset = true
c.version = newversion
}
func (c *Client) SetDisableVersionCheck(disableCheck bool) {
c.disableVersionCheck = disableCheck
}
func (c *Client) checkVersion(newversion uint8) {
if (!c.versionset || newversion == c.version) && (newversion == PROTOCOL_VERSION_1 || newversion == PROTOCOL_VERSION_0) {
c.SetVersion(newversion)
} else {
if c.log != nil {
c.log.Debugf("%v: has bad version (received: v%v, current: v%v) error", c.String(), newversion, c.version)
}
c.SendWrongVersionError()
c.Disconnect()
}
}
func (c *Client) passSimpleHandler(pdu PDU) {
if c.simpleHandler != nil {
switch pduConv := pdu.(type) {
case *PDUSerialQuery:
c.simpleHandler.RequestNewVersion(c, pduConv.SessionId, pduConv.SerialNumber)
case *PDUResetQuery:
c.simpleHandler.RequestCache(c)
default:
// not a proper client packet
}
}
}
func (c *Client) sendLoop() {
for c.connected {
select {
case pdu := <-c.transmits:
c.wr.Write(pdu.Bytes())
case <-c.quit:
break
}
}
}
func (c *Client) Start() {
c.connected = true
if c.handler != nil {
c.handler.ClientConnected(c)
}
go c.sendLoop()
buf := make([]byte, 8000)
for c.connected {
// Remove this?
length, err := c.rd.Read(buf)
if err != nil || length == 0 {
if c.log != nil {
c.log.Debugf("Error %v", err)
}
c.Disconnect()
return
}
pkt := buf[0:length]
dec, err := DecodeBytes(pkt)
if err != nil || dec == nil {
if c.log != nil {
c.log.Errorf("Error %v", err)
}
c.Disconnect()
continue
}
if !c.disableVersionCheck {
c.checkVersion(dec.GetVersion())
}
if c.log != nil {
c.log.Debugf("%v: Received %v", c.String(), dec)
}
if c.enforceVersion {
if !IsCorrectPDUVersion(dec, c.version) {
if c.log != nil {
c.log.Debugf("Bad version error")
}
c.SendWrongVersionError()
c.Disconnect()
}
}
switch pduconv := dec.(type) {
case *PDUSerialQuery:
c.curserial = pduconv.SerialNumber
}
if c.handler != nil {
c.handler.HandlePDU(c, dec)
}
c.passSimpleHandler(dec)
}
}
func (c *Client) Notify(sessionId uint16, serialNumber uint32) {
pdu := &PDUSerialNotify{
SessionId: sessionId,
SerialNumber: serialNumber,
}
c.SendPDU(pdu)
}
type ROA struct {
Prefix net.IPNet
MaxLen uint8
ASN uint32
Flags uint8
}
func (r ROA) String() string {
return fmt.Sprintf("ROA %v -> /%v, AS%v, Flags: %v", r.Prefix.String(), r.MaxLen, r.ASN, r.Flags)
}
func (r1 ROA) Equals(r2 ROA) bool {
return r1.MaxLen == r2.MaxLen && r1.ASN == r2.ASN && bytes.Equal(r1.Prefix.IP, r2.Prefix.IP) && bytes.Equal(r1.Prefix.Mask, r2.Prefix.Mask)
}
func (r1 ROA) Copy() ROA {
newprefix := net.IPNet{
IP: make([]byte, len(r1.Prefix.IP)),
Mask: make([]byte, len(r1.Prefix.Mask)),
}
copy(newprefix.IP, r1.Prefix.IP)
copy(newprefix.Mask, r1.Prefix.Mask)
return ROA{
Prefix: newprefix,
ASN: r1.ASN,
MaxLen: r1.MaxLen,
Flags: r1.Flags}
}
func (c *Client) SendROAs(sessionId uint16, serialNumber uint32, roas []ROA) {
pduBegin := &PDUCacheResponse{
SessionId: sessionId,
}
c.SendPDU(pduBegin)
for _, roa := range roas {
c.SendROA(roa)
}
pduEnd := &PDUEndOfData{
SessionId: sessionId,
SerialNumber: serialNumber,
RefreshInterval: c.refreshInterval,
RetryInterval: c.retryInterval,
ExpireInterval: c.expireInterval,
}
c.SendPDU(pduEnd)
}
func (c *Client) SendCacheReset() {
pdu := &PDUCacheReset{}
c.SendPDU(pdu)
}
func (c *Client) SendInternalError() {
pdu := &PDUErrorReport{
ErrorCode: PDU_ERROR_INTERNALERR,
ErrorMsg: "Unknown internal error",
}
c.SendPDU(pdu)
}
func (c *Client) SendNoDataError() {
pdu := &PDUErrorReport{
ErrorCode: PDU_ERROR_NODATA,
ErrorMsg: "No data available",
}
c.SendPDU(pdu)
}
func (c *Client) SendWrongVersionError() {
pdu := &PDUErrorReport{
ErrorCode: PDU_ERROR_BADPROTOVERSION,
ErrorMsg: "Bad protocol version",
}
c.SendPDU(pdu)
}
func (c *Client) SendROA(roa ROA) {
if roa.Prefix.IP.To4() == nil && roa.Prefix.IP.To16() != nil {
pdu := &PDUIPv6Prefix{
Flags: roa.Flags,
MaxLen: roa.MaxLen,
ASN: roa.ASN,
Prefix: roa.Prefix,
}
c.SendPDU(pdu)
} else if roa.Prefix.IP.To4() != nil {
pdu := &PDUIPv4Prefix{
Flags: roa.Flags,
MaxLen: roa.MaxLen,
ASN: roa.ASN,
Prefix: roa.Prefix,
}
c.SendPDU(pdu)
}
}
func (c *Client) SendRawPDU(pdu PDU) {
//c.tcpconn.Write(pdu.Bytes())
c.transmits <- pdu
}
func (c *Client) SendPDU(pdu PDU) {
pdu.SetVersion(c.version)
c.SendRawPDU(pdu)
}
func (c *Client) Disconnect() {
c.connected = false
if c.log != nil {
c.log.Infof("Disconnecting client %v", c.String())
}
if c.handler != nil {
c.handler.ClientDisconnected(c)
}
select {
case c.quit <- true:
default:
}
c.tcpconn.Close()
}