From d1e1ace3186d8908414f048989cdb5f988b64606 Mon Sep 17 00:00:00 2001 From: Mario Macias Date: Mon, 1 Nov 2021 00:42:07 +0100 Subject: [PATCH] Allow Flow Routines to be cancellable (#40) * Allow Flow Routines to be cancellable --- utils/netflow.go | 7 ++- utils/nflegacy.go | 7 ++- utils/sflow.go | 7 ++- utils/stopper.go | 28 ++++++++++++ utils/stopper_test.go | 51 +++++++++++++++++++++ utils/utils.go | 104 ++++++++++++++++++++++++++++-------------- utils/utils_test.go | 92 +++++++++++++++++++++++++++++++++++++ 7 files changed, 259 insertions(+), 37 deletions(-) create mode 100644 utils/stopper.go create mode 100644 utils/stopper_test.go create mode 100644 utils/utils_test.go diff --git a/utils/netflow.go b/utils/netflow.go index bab1021..1c92e5e 100644 --- a/utils/netflow.go +++ b/utils/netflow.go @@ -51,6 +51,8 @@ func (s *TemplateSystem) GetTemplate(version uint16, obsDomainId uint32, templat } type StateNetFlow struct { + stopper + Format format.FormatInterface Transport transport.TransportInterface Logger Logger @@ -373,7 +375,10 @@ func (s *StateNetFlow) initConfig() { } func (s *StateNetFlow) FlowRoutine(workers int, addr string, port int, reuseport bool) error { + if err := s.start(); err != nil { + return err + } s.InitTemplates() s.initConfig() - return UDPRoutine("NetFlow", s.DecodeFlow, workers, addr, port, reuseport, s.Logger) + return UDPStoppableRoutine(s.stopCh, "NetFlow", s.DecodeFlow, workers, addr, port, reuseport, s.Logger) } diff --git a/utils/nflegacy.go b/utils/nflegacy.go index f7971b4..1b3114b 100644 --- a/utils/nflegacy.go +++ b/utils/nflegacy.go @@ -13,6 +13,8 @@ import ( ) type StateNFLegacy struct { + stopper + Format format.FormatInterface Transport transport.TransportInterface Logger Logger @@ -95,5 +97,8 @@ func (s *StateNFLegacy) DecodeFlow(msg interface{}) error { } func (s *StateNFLegacy) FlowRoutine(workers int, addr string, port int, reuseport bool) error { - return UDPRoutine("NetFlowV5", s.DecodeFlow, workers, addr, port, reuseport, s.Logger) + if err := s.start(); err != nil { + return err + } + return UDPStoppableRoutine(s.stopCh, "NetFlowV5", s.DecodeFlow, workers, addr, port, reuseport, s.Logger) } diff --git a/utils/sflow.go b/utils/sflow.go index 62ac82d..57915ab 100644 --- a/utils/sflow.go +++ b/utils/sflow.go @@ -14,6 +14,8 @@ import ( ) type StateSFlow struct { + stopper + Format format.FormatInterface Transport transport.TransportInterface Logger Logger @@ -153,6 +155,9 @@ func (s *StateSFlow) initConfig() { } func (s *StateSFlow) FlowRoutine(workers int, addr string, port int, reuseport bool) error { + if err := s.start(); err != nil { + return err + } s.initConfig() - return UDPRoutine("sFlow", s.DecodeFlow, workers, addr, port, reuseport, s.Logger) + return UDPStoppableRoutine(s.stopCh, "sFlow", s.DecodeFlow, workers, addr, port, reuseport, s.Logger) } diff --git a/utils/stopper.go b/utils/stopper.go new file mode 100644 index 0000000..fa6f10d --- /dev/null +++ b/utils/stopper.go @@ -0,0 +1,28 @@ +package utils + +import ( + "errors" +) + +// ErrAlreadyStarted error happens when you try to start twice a flow routine +var ErrAlreadyStarted = errors.New("the routine is already started") + +// stopper mechanism, common for all the flow routines +type stopper struct { + stopCh chan struct{} +} + +func (s *stopper) start() error { + if s.stopCh != nil { + return ErrAlreadyStarted + } + s.stopCh = make(chan struct{}) + return nil +} + +func (s *stopper) Shutdown() { + if s.stopCh != nil { + close(s.stopCh) + s.stopCh = nil + } +} diff --git a/utils/stopper_test.go b/utils/stopper_test.go new file mode 100644 index 0000000..f76e7bf --- /dev/null +++ b/utils/stopper_test.go @@ -0,0 +1,51 @@ +package utils + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStopper(t *testing.T) { + r := routine{} + require.False(t, r.Running) + require.NoError(t, r.StartRoutine()) + assert.True(t, r.Running) + r.Shutdown() + assert.Eventually(t, func() bool { + return r.Running == false + }, time.Second, time.Millisecond) + + // after shutdown, we can start it again + require.NoError(t, r.StartRoutine()) + assert.True(t, r.Running) +} + +func TestStopper_CannotStartTwice(t *testing.T) { + r := routine{} + require.False(t, r.Running) + require.NoError(t, r.StartRoutine()) + assert.ErrorIs(t, r.StartRoutine(), ErrAlreadyStarted) +} + +type routine struct { + stopper + Running bool +} + +func (p *routine) StartRoutine() error { + if err := p.start(); err != nil { + return err + } + p.Running = true + waitForGoRoutine := make(chan struct{}) + go func() { + close(waitForGoRoutine) + <-p.stopCh + p.Running = false + }() + <-waitForGoRoutine + return nil +} diff --git a/utils/utils.go b/utils/utils.go index 130c81f..3ac4f1d 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -6,6 +6,7 @@ import ( "io" "net" "strconv" + "sync/atomic" "time" reuseport "github.com/libp2p/go-reuseport" @@ -99,6 +100,11 @@ func (cb *DefaultErrorCallback) Callback(name string, id int, start, end time.Ti } func UDPRoutine(name string, decodeFunc decoder.DecoderFunc, workers int, addr string, port int, sockReuse bool, logger Logger) error { + return UDPStoppableRoutine(make(chan struct{}), name, decodeFunc, workers, addr, port, sockReuse, logger) +} + +// UDPStoppableRoutine runs a UDPRoutine that can be stopped by closing the stopCh passed as argument +func UDPStoppableRoutine(stopCh <-chan struct{}, name string, decodeFunc decoder.DecoderFunc, workers int, addr string, port int, sockReuse bool, logger Logger) error { ecb := DefaultErrorCallback{ Logger: logger, } @@ -146,41 +152,71 @@ func UDPRoutine(name string, decodeFunc decoder.DecoderFunc, workers int, addr s localIP = "" } - for { - size, pktAddr, _ := udpconn.ReadFromUDP(payload) - payloadCut := make([]byte, size) - copy(payloadCut, payload[0:size]) + type udpData struct { + size int + pktAddr *net.UDPAddr + } - baseMessage := BaseMessage{ - Src: pktAddr.IP, - Port: pktAddr.Port, - Payload: payloadCut, + stopped := atomic.Value{} + stopped.Store(false) + udpDataCh := make(chan udpData) + go func() { + for { + u := udpData{} + u.size, u.pktAddr, _ = udpconn.ReadFromUDP(payload) + if stopped.Load() == false { + udpDataCh <- u + } else { + return + } + } + }() + for { + select { + case u := <-udpDataCh: + process(u.size, payload, u.pktAddr, processor, localIP, addrUDP, name) + case <-stopCh: + stopped.Store(true) + udpconn.Close() + close(udpDataCh) + return nil } - processor.ProcessMessage(baseMessage) - - MetricTrafficBytes.With( - prometheus.Labels{ - "remote_ip": pktAddr.IP.String(), - "local_ip": localIP, - "local_port": strconv.Itoa(addrUDP.Port), - "type": name, - }). - Add(float64(size)) - MetricTrafficPackets.With( - prometheus.Labels{ - "remote_ip": pktAddr.IP.String(), - "local_ip": localIP, - "local_port": strconv.Itoa(addrUDP.Port), - "type": name, - }). - Inc() - MetricPacketSizeSum.With( - prometheus.Labels{ - "remote_ip": pktAddr.IP.String(), - "local_ip": localIP, - "local_port": strconv.Itoa(addrUDP.Port), - "type": name, - }). - Observe(float64(size)) } } + +func process(size int, payload []byte, pktAddr *net.UDPAddr, processor decoder.Processor, localIP string, addrUDP net.UDPAddr, name string) { + payloadCut := make([]byte, size) + copy(payloadCut, payload[0:size]) + + baseMessage := BaseMessage{ + Src: pktAddr.IP, + Port: pktAddr.Port, + Payload: payloadCut, + } + processor.ProcessMessage(baseMessage) + + MetricTrafficBytes.With( + prometheus.Labels{ + "remote_ip": pktAddr.IP.String(), + "local_ip": localIP, + "local_port": strconv.Itoa(addrUDP.Port), + "type": name, + }). + Add(float64(size)) + MetricTrafficPackets.With( + prometheus.Labels{ + "remote_ip": pktAddr.IP.String(), + "local_ip": localIP, + "local_port": strconv.Itoa(addrUDP.Port), + "type": name, + }). + Inc() + MetricPacketSizeSum.With( + prometheus.Labels{ + "remote_ip": pktAddr.IP.String(), + "local_ip": localIP, + "local_port": strconv.Itoa(addrUDP.Port), + "type": name, + }). + Observe(float64(size)) +} diff --git a/utils/utils_test.go b/utils/utils_test.go new file mode 100644 index 0000000..5ecbf3e --- /dev/null +++ b/utils/utils_test.go @@ -0,0 +1,92 @@ +package utils + +import ( + "fmt" + "net" + "testing" + "time" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCancelUDPRoutine(t *testing.T) { + testTimeout := time.After(10 * time.Second) + port, err := getFreeUDPPort() + require.NoError(t, err) + dp := dummyFlowProcessor{} + go func() { + require.NoError(t, dp.FlowRoutine("127.0.0.1", port)) + }() + + // wait slightly so we give time to the server to accept requests + time.Sleep(100 * time.Millisecond) + + sendMessage := func(msg string) error { + conn, err := net.Dial("udp", fmt.Sprintf("127.0.0.1:%d", port)) + if err != nil { + return err + } + defer conn.Close() + _, err = conn.Write([]byte(msg)) + return err + } + require.NoError(t, sendMessage("message 1")) + require.NoError(t, sendMessage("message 2")) + require.NoError(t, sendMessage("message 3")) + + readMessage := func() string { + select { + case msg := <-dp.receivedMessages: + return string(msg.(BaseMessage).Payload) + case <-testTimeout: + require.Fail(t, "test timed out while waiting for message") + return "" + } + } + + // in UDP, messages might arrive out of order or duplicate, so whe just verify they arrive + // to avoid flaky tests + require.Contains(t, []string{"message 1", "message 2", "message 3"}, readMessage()) + require.Contains(t, []string{"message 1", "message 2", "message 3"}, readMessage()) + require.Contains(t, []string{"message 1", "message 2", "message 3"}, readMessage()) + + dp.Shutdown() + + _ = sendMessage("no more messages should be processed") + + select { + case msg := <-dp.receivedMessages: + assert.Fail(t, fmt.Sprint(msg)) + default: + // everything is correct + } +} + +type dummyFlowProcessor struct { + stopper + receivedMessages chan interface{} +} + +func (d *dummyFlowProcessor) FlowRoutine(host string, port int) error { + _ = d.start() + d.receivedMessages = make(chan interface{}) + return UDPStoppableRoutine(d.stopCh, "test_udp", func(msg interface{}) error { + d.receivedMessages <- msg + return nil + }, 3, host, port, false, logrus.StandardLogger()) +} + +func getFreeUDPPort() (int, error) { + a, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + if err != nil { + return 0, err + } + l, err := net.ListenUDP("udp", a) + if err != nil { + return 0, err + } + defer l.Close() + return l.LocalAddr().(*net.UDPAddr).Port, nil +}