diff --git a/cmd/alice-lg/main.go b/cmd/alice-lg/main.go index 41554bc..bb9a512 100644 --- a/cmd/alice-lg/main.go +++ b/cmd/alice-lg/main.go @@ -109,7 +109,11 @@ func main() { go m.Start(ctx) neighborsBackend = postgres.NewNeighborsBackend(pool) - routesBackend = postgres.NewRoutesBackend(pool) + routesBackend = postgres.NewRoutesBackend( + pool, cfg.Sources) + if err := routesBackend.(*postgres.RoutesBackend).Init(ctx); err != nil { + log.Println("error while initializing routes backend:", err) + } } neighborsStore := store.NewNeighborsStore(cfg, neighborsBackend) diff --git a/pkg/store/backends/postgres/routes_backend.go b/pkg/store/backends/postgres/routes_backend.go index 71f84d3..2229854 100644 --- a/pkg/store/backends/postgres/routes_backend.go +++ b/pkg/store/backends/postgres/routes_backend.go @@ -3,27 +3,59 @@ package postgres import ( "context" "fmt" + "regexp" "strings" "time" "github.com/alice-lg/alice-lg/pkg/api" + "github.com/alice-lg/alice-lg/pkg/config" + "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pgxpool" ) +var ( + // ReMatchNonChar will match on all characters not + // within a to Z or 0 to 9 + ReMatchNonChar = regexp.MustCompile(`[^a-zA-Z0-9]`) +) + // RoutesBackend implements a postgres store for routes. type RoutesBackend struct { - pool *pgxpool.Pool + pool *pgxpool.Pool + sources []*config.SourceConfig } // NewRoutesBackend creates a new instance with a postgres // connection pool. -func NewRoutesBackend(pool *pgxpool.Pool) *RoutesBackend { +func NewRoutesBackend( + pool *pgxpool.Pool, + sources []*config.SourceConfig, +) *RoutesBackend { return &RoutesBackend{ - pool: pool, + pool: pool, + sources: sources, } } +// Init will initialize all the route tables +func (b *RoutesBackend) Init(ctx context.Context) error { + tx, err := b.pool.BeginTx(ctx, pgx.TxOptions{ + IsoLevel: pgx.ReadCommitted, + }) + if err != nil { + return err + } + defer tx.Rollback(ctx) + + for _, src := range b.sources { + if err := b.initTable(ctx, tx, src.ID); err != nil { + return err + } + } + return tx.Commit(ctx) +} + // SetRoutes implements the RoutesStoreBackend interface // function for setting all routes of a source identified // by ID. @@ -42,9 +74,11 @@ func (b *RoutesBackend) SetRoutes( } defer tx.Rollback(ctx) - if err := b.clear(ctx, tx, sourceID); err != nil { + // Create table from template + if err := b.initTable(ctx, tx, sourceID); err != nil { return err } + // persist all routes for _, r := range routes { if err := b.persist(ctx, tx, sourceID, r, now); err != nil { @@ -58,6 +92,29 @@ func (b *RoutesBackend) SetRoutes( return nil } +// Private routesTable returns the name of the routes table +// for a sourceID +func (b *RoutesBackend) routesTable(sourceID string) string { + sourceID = ReMatchNonChar.ReplaceAllString(sourceID, "_") + return "routes_" + sourceID +} + +// Private initTable recreates the routes table +// for a single sourceID +func (b *RoutesBackend) initTable( + ctx context.Context, + tx pgx.Tx, + sourceID string, +) error { + tbl := b.routesTable(sourceID) + qry := ` + DROP TABLE IF EXISTS ` + tbl + `; + CREATE TABLE ` + tbl + ` ( LIKE routes INCLUDING ALL ) + ` + _, err := tx.Exec(ctx, qry) + return err +} + // Private persist route in database func (b *RoutesBackend) persist( ctx context.Context, @@ -66,8 +123,9 @@ func (b *RoutesBackend) persist( route *api.LookupRoute, now time.Time, ) error { + tbl := b.routesTable(sourceID) qry := ` - INSERT INTO routes ( + INSERT INTO ` + tbl + ` ( id, rs_id, neighbor_id, @@ -91,17 +149,19 @@ func (b *RoutesBackend) persist( } // Private clear removes all routes. +/* func (b *RoutesBackend) clear( ctx context.Context, tx pgx.Tx, sourceID string, ) error { qry := ` - DELETE FROM routes WHERE rs_id = $1 + DELETE FROM routes WHERE rs_id = $1 ` _, err := tx.Exec(ctx, qry, sourceID) return err } +*/ // Private queryCountByState will query routes and filter // by state @@ -111,10 +171,11 @@ func (b *RoutesBackend) queryCountByState( sourceID string, state string, ) pgx.Row { - qry := `SELECT COUNT(1) FROM routes - WHERE rs_id = $1 AND route -> 'state' = $2` + tbl := b.routesTable(sourceID) + qry := `SELECT COUNT(1) FROM ` + tbl + ` + WHERE route -> 'state' = $1` - return tx.QueryRow(ctx, qry, sourceID, "\""+state+"\"") + return tx.QueryRow(ctx, qry, "\""+state+"\"") } // CountRoutesAt returns the number of filtered and imported @@ -171,9 +232,17 @@ func (b *RoutesBackend) FindByNeighbors( vars[i] = fmt.Sprintf("$%d", i+1) } listQry := strings.Join(vars, ",") - qry := ` - SELECT route FROM routes - WHERE neighbor_id IN (` + listQry + `)` + + qrys := []string{} + for _, src := range b.sources { + tbl := b.routesTable(src.ID) + qry := ` + SELECT route FROM ` + tbl + ` + WHERE neighbor_id IN (` + listQry + `)` + qrys = append(qrys, qry) + } + + qry := strings.Join(qrys, " UNION ") rows, err := tx.Query(ctx, qry, vals...) if err != nil { @@ -196,10 +265,16 @@ func (b *RoutesBackend) FindByPrefix( } defer tx.Rollback(ctx) // We are searching route.Network - qry := ` - SELECT route FROM routes - WHERE network ILIKE $1 - ` + qrys := []string{} + for _, src := range b.sources { + tbl := b.routesTable(src.ID) + qry := ` + SELECT route FROM ` + tbl + ` + WHERE network ILIKE $1 + ` + qrys = append(qrys, qry) + } + qry := strings.Join(qrys, " UNION ") rows, err := tx.Query(ctx, qry, prefix+"%") if err != nil { return nil, err diff --git a/pkg/store/backends/postgres/routes_backend_test.go b/pkg/store/backends/postgres/routes_backend_test.go index fd8a0c5..04f9b84 100644 --- a/pkg/store/backends/postgres/routes_backend_test.go +++ b/pkg/store/backends/postgres/routes_backend_test.go @@ -6,8 +6,17 @@ import ( "time" "github.com/alice-lg/alice-lg/pkg/api" + "github.com/alice-lg/alice-lg/pkg/config" ) +func TestRoutesTable(t *testing.T) { + b := &RoutesBackend{} + tbl := b.routesTable("rs0-example!/;") + if tbl != "routes_rs0_example___" { + t.Error("unexpected table:", tbl) + } +} + func TestCountRoutesAt(t *testing.T) { ctx := context.Background() now := time.Now().UTC() @@ -29,6 +38,7 @@ func TestCountRoutesAt(t *testing.T) { Network: "1.2.3.0/24", }, } + b.initTable(ctx, tx, "rs1") b.persist(ctx, tx, "rs1", r, now) r.Route.ID = "r4242" @@ -63,7 +73,13 @@ func TestFindByNeighbors(t *testing.T) { t.Fatal(err) } defer tx.Rollback(ctx) - b := &RoutesBackend{pool: pool} + b := &RoutesBackend{ + pool: pool, + sources: []*config.SourceConfig{ + {ID: "rs1"}, + {ID: "rs2"}, + }, + } r := &api.LookupRoute{ State: "filtered", Neighbor: &api.Neighbor{ @@ -74,6 +90,8 @@ func TestFindByNeighbors(t *testing.T) { Network: "1.2.3.0/24", }, } + b.initTable(ctx, tx, "rs1") + b.initTable(ctx, tx, "rs2") b.persist(ctx, tx, "rs1", r, now) r.Route.ID = "r4242" @@ -85,7 +103,7 @@ func TestFindByNeighbors(t *testing.T) { r.Route.ID = "r4244" r.Neighbor.ID = "n25" - b.persist(ctx, tx, "rs1", r, now) + b.persist(ctx, tx, "rs2", r, now) if err := tx.Commit(ctx); err != nil { t.Fatal(err) @@ -113,7 +131,13 @@ func TestFindByPrefix(t *testing.T) { t.Fatal(err) } defer tx.Rollback(ctx) - b := &RoutesBackend{pool: pool} + b := &RoutesBackend{ + pool: pool, + sources: []*config.SourceConfig{ + {ID: "rs1"}, + {ID: "rs2"}, + }, + } r := &api.LookupRoute{ State: "filtered", Neighbor: &api.Neighbor{ @@ -124,6 +148,9 @@ func TestFindByPrefix(t *testing.T) { Network: "1.2.3.0/24", }, } + + b.initTable(ctx, tx, "rs1") + b.initTable(ctx, tx, "rs2") b.persist(ctx, tx, "rs1", r, now) r.Route.ID = "r4242" @@ -133,7 +160,7 @@ func TestFindByPrefix(t *testing.T) { r.Route.ID = "r4243" r.Route.Network = "1.2.5.0/24" r.Neighbor.ID = "n24" - b.persist(ctx, tx, "rs1", r, now) + b.persist(ctx, tx, "rs2", r, now) r.Route.ID = "r4244" r.Route.Network = "5.5.5.0/24"