mirror of
https://github.com/netsampler/goflow2.git
synced 2024-05-06 15:54:52 +00:00
93 lines
2.3 KiB
Go
93 lines
2.3 KiB
Go
package utils
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestCancelUDPRoutine(t *testing.T) {
|
|
testTimeout := time.After(10 * time.Second)
|
|
port, err := getFreeUDPPort()
|
|
require.NoError(t, err)
|
|
dp := dummyFlowProcessor{}
|
|
go func() {
|
|
require.NoError(t, dp.FlowRoutine("127.0.0.1", port))
|
|
}()
|
|
|
|
// wait slightly so we give time to the server to accept requests
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
sendMessage := func(msg string) error {
|
|
conn, err := net.Dial("udp", fmt.Sprintf("127.0.0.1:%d", port))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
_, err = conn.Write([]byte(msg))
|
|
return err
|
|
}
|
|
require.NoError(t, sendMessage("message 1"))
|
|
require.NoError(t, sendMessage("message 2"))
|
|
require.NoError(t, sendMessage("message 3"))
|
|
|
|
readMessage := func() string {
|
|
select {
|
|
case msg := <-dp.receivedMessages:
|
|
return string(msg.(BaseMessage).Payload)
|
|
case <-testTimeout:
|
|
require.Fail(t, "test timed out while waiting for message")
|
|
return ""
|
|
}
|
|
}
|
|
|
|
// in UDP, messages might arrive out of order or duplicate, so whe just verify they arrive
|
|
// to avoid flaky tests
|
|
require.Contains(t, []string{"message 1", "message 2", "message 3"}, readMessage())
|
|
require.Contains(t, []string{"message 1", "message 2", "message 3"}, readMessage())
|
|
require.Contains(t, []string{"message 1", "message 2", "message 3"}, readMessage())
|
|
|
|
dp.Shutdown()
|
|
|
|
_ = sendMessage("no more messages should be processed")
|
|
|
|
select {
|
|
case msg := <-dp.receivedMessages:
|
|
assert.Fail(t, fmt.Sprint(msg))
|
|
default:
|
|
// everything is correct
|
|
}
|
|
}
|
|
|
|
type dummyFlowProcessor struct {
|
|
stopper
|
|
receivedMessages chan interface{}
|
|
}
|
|
|
|
func (d *dummyFlowProcessor) FlowRoutine(host string, port int) error {
|
|
_ = d.start()
|
|
d.receivedMessages = make(chan interface{})
|
|
return UDPStoppableRoutine(d.stopCh, "test_udp", func(msg interface{}) error {
|
|
d.receivedMessages <- msg
|
|
return nil
|
|
}, 3, host, port, false, logrus.StandardLogger())
|
|
}
|
|
|
|
func getFreeUDPPort() (int, error) {
|
|
a, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
l, err := net.ListenUDP("udp", a)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer l.Close()
|
|
return l.LocalAddr().(*net.UDPAddr).Port, nil
|
|
}
|