diff --git a/pkg/server/server.go b/pkg/server/server.go index a962a1e2..b20bab5e 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -2627,8 +2627,10 @@ func (s *BgpServer) getAdjRib(addr string, family bgp.RouteFamily, in bool, enab options := &table.PolicyOptions{ Validate: s.roaTable.Validate, } - if s.policy.ApplyPolicy(peer.TableID(), table.POLICY_DIRECTION_IMPORT, path, options) == nil { + if p := s.policy.ApplyPolicy(peer.TableID(), table.POLICY_DIRECTION_IMPORT, path, options); p == nil { filtered[path.GetNlri().String()] = path + } else { + adjRib.Update([]*table.Path{p}) } } } diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index b400e0a8..bb6c5dbe 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -238,9 +238,11 @@ func waitEstablished(s *BgpServer, ch chan struct{}) { func TestListPathEnableFiltered(test *testing.T) { assert := assert.New(test) - s := NewBgpServer() - go s.Serve() - err := s.StartBgp(context.Background(), &api.StartBgpRequest{ + + // Create servers and add peers + server1 := NewBgpServer() + go server1.Serve() + err := server1.StartBgp(context.Background(), &api.StartBgpRequest{ Global: &api.Global{ Asn: 1, RouterId: "1.1.1.1", @@ -248,7 +250,19 @@ func TestListPathEnableFiltered(test *testing.T) { }, }) assert.Nil(err) - defer s.StopBgp(context.Background(), &api.StopBgpRequest{}) + defer server1.StopBgp(context.Background(), &api.StopBgpRequest{}) + + server2 := NewBgpServer() + go server2.Serve() + err = server2.StartBgp(context.Background(), &api.StartBgpRequest{ + Global: &api.Global{ + Asn: 2, + RouterId: "2.2.2.2", + ListenPort: -1, + }, + }) + assert.Nil(err) + defer server2.StopBgp(context.Background(), &api.StopBgpRequest{}) peer1 := &api.Peer{ Conf: &api.PeerConf{ @@ -259,9 +273,32 @@ func TestListPathEnableFiltered(test *testing.T) { PassiveMode: true, }, } - err = s.AddPeer(context.Background(), &api.AddPeerRequest{Peer: peer1}) + err = server1.AddPeer(context.Background(), &api.AddPeerRequest{Peer: peer1}) assert.Nil(err) + peer2 := &api.Peer{ + Conf: &api.PeerConf{ + NeighborAddress: "127.0.0.1", + PeerAsn: 1, + }, + Transport: &api.Transport{ + RemotePort: 10179, + }, + Timers: &api.Timers{ + Config: &api.TimersConfig{ + ConnectRetry: 1, + IdleHoldTimeAfterReset: 1, + }, + }, + } + ch := make(chan struct{}) + go waitEstablished(server1, ch) + + err = server2.AddPeer(context.Background(), &api.AddPeerRequest{Peer: peer2}) + assert.Nil(err) + <-ch + + // Add IMPORT policy at server1 for rejecting 10.1.0.0/24 d1 := &api.DefinedSet{ DefinedType: api.DefinedType_PREFIX, Name: "d1", @@ -284,15 +321,15 @@ func TestListPathEnableFiltered(test *testing.T) { RouteAction: api.RouteAction_REJECT, }, } - err = s.AddDefinedSet(context.Background(), &api.AddDefinedSetRequest{DefinedSet: d1}) + err = server1.AddDefinedSet(context.Background(), &api.AddDefinedSetRequest{DefinedSet: d1}) assert.Nil(err) p1 := &api.Policy{ Name: "p1", Statements: []*api.Statement{s1}, } - err = s.AddPolicy(context.Background(), &api.AddPolicyRequest{Policy: p1}) + err = server1.AddPolicy(context.Background(), &api.AddPolicyRequest{Policy: p1}) assert.Nil(err) - err = s.AddPolicyAssignment(context.Background(), &api.AddPolicyAssignmentRequest{ + err = server1.AddPolicyAssignment(context.Background(), &api.AddPolicyAssignmentRequest{ Assignment: &api.PolicyAssignment{ Name: table.GLOBAL_RIB_NAME, Direction: api.PolicyDirection_IMPORT, @@ -302,18 +339,75 @@ func TestListPathEnableFiltered(test *testing.T) { }) assert.Nil(err) - t := NewBgpServer() - go t.Serve() - err = t.StartBgp(context.Background(), &api.StartBgpRequest{ - Global: &api.Global{ - Asn: 2, - RouterId: "2.2.2.2", - ListenPort: -1, - }, + // Add EXPORT policy at server2 for accepting all routes and adding communities. + commSet, _ := table.NewCommunitySet(oc.CommunitySet{ + CommunitySetName: "comset1", + CommunityList: []string{"100:100"}, }) - assert.Nil(err) - defer t.StopBgp(context.Background(), &api.StopBgpRequest{}) + server2.policy.AddDefinedSet(commSet, false) + statement := oc.Statement{ + Name: "stmt1", + Actions: oc.Actions{ + BgpActions: oc.BgpActions{ + SetCommunity: oc.SetCommunity{ + SetCommunityMethod: oc.SetCommunityMethod{ + CommunitiesList: []string{"100:100"}, + }, + Options: string(oc.BGP_SET_COMMUNITY_OPTION_TYPE_ADD), + }, + }, + RouteDisposition: oc.ROUTE_DISPOSITION_ACCEPT_ROUTE, + }, + } + policy := oc.PolicyDefinition{ + Name: "policy1", + Statements: []oc.Statement{statement}, + } + p, err := table.NewPolicy(policy) + if err != nil { + test.Fatalf("cannot create new policy: %v", err) + } + server2.policy.AddPolicy(p, false) + policies := []*oc.PolicyDefinition{ + { + Name: "policy1", + }, + } + server2.policy.AddPolicyAssignment(table.GLOBAL_RIB_NAME, table.POLICY_DIRECTION_EXPORT, policies, table.ROUTE_TYPE_REJECT) + + // Add IMPORT policy at server1 for accepting all routes and replacing communities. + statement = oc.Statement{ + Name: "stmt1", + Actions: oc.Actions{ + BgpActions: oc.BgpActions{ + SetCommunity: oc.SetCommunity{ + SetCommunityMethod: oc.SetCommunityMethod{ + CommunitiesList: []string{"200:200"}, + }, + Options: string(oc.BGP_SET_COMMUNITY_OPTION_TYPE_REPLACE), + }, + }, + RouteDisposition: oc.ROUTE_DISPOSITION_ACCEPT_ROUTE, + }, + } + policy = oc.PolicyDefinition{ + Name: "policy1", + Statements: []oc.Statement{statement}, + } + p, err = table.NewPolicy(policy) + if err != nil { + test.Fatalf("cannot create new policy: %v", err) + } + server1.policy.AddPolicy(p, false) + policies = []*oc.PolicyDefinition{ + { + Name: "policy1", + }, + } + server1.policy.AddPolicyAssignment(table.GLOBAL_RIB_NAME, table.POLICY_DIRECTION_IMPORT, policies, table.ROUTE_TYPE_REJECT) + + // Add paths family := &api.Family{ Afi: api.Family_AFI_IP, Safi: api.Family_SAFI_UNICAST, @@ -332,7 +426,7 @@ func TestListPathEnableFiltered(test *testing.T) { }) attrs := []*apb.Any{a1, a2} - t.AddPath(context.Background(), &api.AddPathRequest{ + server2.AddPath(context.Background(), &api.AddPathRequest{ TableType: api.TableType_GLOBAL, Path: &api.Path{ Family: family, @@ -345,7 +439,7 @@ func TestListPathEnableFiltered(test *testing.T) { Prefix: "10.2.0.0", PrefixLen: 24, }) - t.AddPath(context.Background(), &api.AddPathRequest{ + server2.AddPath(context.Background(), &api.AddPathRequest{ TableType: api.TableType_GLOBAL, Path: &api.Path{ Family: family, @@ -354,51 +448,169 @@ func TestListPathEnableFiltered(test *testing.T) { }, }) - peer2 := &api.Peer{ - Conf: &api.PeerConf{ - NeighborAddress: "127.0.0.1", - PeerAsn: 1, - }, - Transport: &api.Transport{ - RemotePort: 10179, - }, - Timers: &api.Timers{ - Config: &api.TimersConfig{ - ConnectRetry: 1, - IdleHoldTimeAfterReset: 1, - }, - }, - } - ch := make(chan struct{}) - go waitEstablished(s, ch) - - err = t.AddPeer(context.Background(), &api.AddPeerRequest{Peer: peer2}) - assert.Nil(err) - <-ch + var wantEmptyCommunities []uint32 + wantCommunitiesAfterExportPolicies := []uint32{100<<16 | 100} + wantCommunitiesAfterImportPolicies := []uint32{200<<16 | 200} + // Check ADJ_OUT routes before applying export policies. for { count := 0 - s.ListPath(context.Background(), &api.ListPathRequest{TableType: api.TableType_ADJ_IN, Family: family, Name: "127.0.0.1"}, func(d *api.Destination) { + server2.ListPath(context.Background(), &api.ListPathRequest{ + TableType: api.TableType_ADJ_OUT, + Family: family, Name: "127.0.0.1", + // TODO(wenovus): This is confusing and we may want to change this. + EnableFiltered: true, + }, func(d *api.Destination) { count++ + for _, path := range d.Paths { + var comms []uint32 + for _, attr := range path.GetPattrs() { + m, err := attr.UnmarshalNew() + if err != nil { + test.Fatalf("Unable to unmarshal a GoBGP path attribute: %v", err) + continue + } + switch m := m.(type) { + case *api.CommunitiesAttribute: + comms = m.GetCommunities() + } + } + if diff := cmp.Diff(wantEmptyCommunities, comms); diff != "" { + test.Errorf("AdjRibOutPre communities for %v (-want, +got):\n%s", d.GetPrefix(), diff) + } else { + test.Logf("Got expected communities for %v: %v", d.GetPrefix(), comms) + } + } }) if count == 2 { break } } + // Check ADJ_OUT routes after applying export policies. + for { + count := 0 + server2.ListPath(context.Background(), &api.ListPathRequest{ + TableType: api.TableType_ADJ_OUT, + Family: family, Name: "127.0.0.1", + // TODO(wenovus): This is confusing and we may want to change this. + EnableFiltered: false, + }, func(d *api.Destination) { + count++ + for _, path := range d.Paths { + if path.Filtered { + continue + } + var comms []uint32 + for _, attr := range path.GetPattrs() { + m, err := attr.UnmarshalNew() + if err != nil { + test.Fatalf("Unable to unmarshal a GoBGP path attribute: %v", err) + continue + } + switch m := m.(type) { + case *api.CommunitiesAttribute: + comms = m.GetCommunities() + } + } + if diff := cmp.Diff(wantCommunitiesAfterExportPolicies, comms); diff != "" { + test.Errorf("AdjRibOutPost communities for %v (-want, +got):\n%s", d.GetPrefix(), diff) + } else { + test.Logf("Got expected communities for %v: %v", d.GetPrefix(), comms) + } + } + }) + if count == 2 { + break + } + } + // Check ADJ_IN routes before applying import policies. + for { + count := 0 + server1.ListPath(context.Background(), &api.ListPathRequest{ + TableType: api.TableType_ADJ_IN, + Family: family, + Name: "127.0.0.1", + EnableFiltered: false, + }, func(d *api.Destination) { + count++ + for _, path := range d.Paths { + var comms []uint32 + for _, attr := range path.GetPattrs() { + m, err := attr.UnmarshalNew() + if err != nil { + test.Fatalf("Unable to unmarshal a GoBGP path attribute: %v", err) + continue + } + switch m := m.(type) { + case *api.CommunitiesAttribute: + comms = m.GetCommunities() + } + } + if diff := cmp.Diff(wantCommunitiesAfterExportPolicies, comms); diff != "" { + test.Errorf("AdjRibInPre communities for %v (-want, +got):\n%s", d.GetPrefix(), diff) + } else { + test.Logf("Got expected communities for %v: %v", d.GetPrefix(), comms) + } + } + }) + if count == 2 { + break + } + } + // Check ADJ_IN routes after applying import policies. + for { + count := 0 + server1.ListPath(context.Background(), &api.ListPathRequest{ + TableType: api.TableType_ADJ_IN, + Family: family, + Name: "127.0.0.1", + EnableFiltered: true, + }, func(d *api.Destination) { + count++ + for _, path := range d.Paths { + if path.Filtered { + continue + } + var comms []uint32 + for _, attr := range path.GetPattrs() { + m, err := attr.UnmarshalNew() + if err != nil { + test.Fatalf("Unable to unmarshal a GoBGP path attribute: %v", err) + continue + } + switch m := m.(type) { + case *api.CommunitiesAttribute: + comms = m.GetCommunities() + } + } + if diff := cmp.Diff(wantCommunitiesAfterImportPolicies, comms); diff != "" { + test.Errorf("AdjRibInPost communities for %v (-want, +got):\n%s", d.GetPrefix(), diff) + } else { + test.Logf("Got expected communities for %v: %v", d.GetPrefix(), comms) + } + } + }) + if count == 2 { + break + } + } + + // Check that 10.1.0.0/24 is filtered at the import side. count := 0 - s.ListPath(context.Background(), &api.ListPathRequest{TableType: api.TableType_GLOBAL, Family: family}, func(d *api.Destination) { + server1.ListPath(context.Background(), &api.ListPathRequest{TableType: api.TableType_GLOBAL, Family: family}, func(d *api.Destination) { count++ }) assert.Equal(1, count) filtered := 0 - s.ListPath(context.Background(), &api.ListPathRequest{TableType: api.TableType_ADJ_IN, Family: family, Name: "127.0.0.1", EnableFiltered: true}, func(d *api.Destination) { + server1.ListPath(context.Background(), &api.ListPathRequest{TableType: api.TableType_ADJ_IN, Family: family, Name: "127.0.0.1", EnableFiltered: true}, func(d *api.Destination) { if d.Paths[0].Filtered { filtered++ } }) assert.Equal(1, filtered) + // Validate filtering at the export side. d2 := &api.DefinedSet{ DefinedType: api.DefinedType_PREFIX, Name: "d2", @@ -421,15 +633,15 @@ func TestListPathEnableFiltered(test *testing.T) { RouteAction: api.RouteAction_REJECT, }, } - err = s.AddDefinedSet(context.Background(), &api.AddDefinedSetRequest{DefinedSet: d2}) + err = server1.AddDefinedSet(context.Background(), &api.AddDefinedSetRequest{DefinedSet: d2}) assert.Nil(err) p2 := &api.Policy{ Name: "p2", Statements: []*api.Statement{s2}, } - err = s.AddPolicy(context.Background(), &api.AddPolicyRequest{Policy: p2}) + err = server1.AddPolicy(context.Background(), &api.AddPolicyRequest{Policy: p2}) assert.Nil(err) - err = s.AddPolicyAssignment(context.Background(), &api.AddPolicyAssignmentRequest{ + err = server1.AddPolicyAssignment(context.Background(), &api.AddPolicyAssignmentRequest{ Assignment: &api.PolicyAssignment{ Name: table.GLOBAL_RIB_NAME, Direction: api.PolicyDirection_EXPORT, @@ -443,7 +655,7 @@ func TestListPathEnableFiltered(test *testing.T) { Prefix: "10.3.0.0", PrefixLen: 24, }) - s.AddPath(context.Background(), &api.AddPathRequest{ + server1.AddPath(context.Background(), &api.AddPathRequest{ TableType: api.TableType_GLOBAL, Path: &api.Path{ Family: family, @@ -456,7 +668,7 @@ func TestListPathEnableFiltered(test *testing.T) { Prefix: "10.4.0.0", PrefixLen: 24, }) - s.AddPath(context.Background(), &api.AddPathRequest{ + server1.AddPath(context.Background(), &api.AddPathRequest{ TableType: api.TableType_GLOBAL, Path: &api.Path{ Family: family, @@ -466,14 +678,14 @@ func TestListPathEnableFiltered(test *testing.T) { }) count = 0 - s.ListPath(context.Background(), &api.ListPathRequest{TableType: api.TableType_GLOBAL, Family: family}, func(d *api.Destination) { + server1.ListPath(context.Background(), &api.ListPathRequest{TableType: api.TableType_GLOBAL, Family: family}, func(d *api.Destination) { count++ }) assert.Equal(3, count) count = 0 filtered = 0 - s.ListPath(context.Background(), &api.ListPathRequest{TableType: api.TableType_ADJ_OUT, Family: family, Name: "127.0.0.1", EnableFiltered: true}, func(d *api.Destination) { + server1.ListPath(context.Background(), &api.ListPathRequest{TableType: api.TableType_ADJ_OUT, Family: family, Name: "127.0.0.1", EnableFiltered: true}, func(d *api.Destination) { count++ if d.Paths[0].Filtered { filtered++