mirror of
https://github.com/mxpv/podsync.git
synced 2024-05-11 05:55:04 +00:00
Refactor database storage
This commit is contained in:
184
pkg/storage/pg.go
Normal file
184
pkg/storage/pg.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/proxy"
|
||||
"github.com/go-pg/pg"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/mxpv/podsync/pkg/api"
|
||||
"github.com/mxpv/podsync/pkg/model"
|
||||
)
|
||||
|
||||
type Postgres struct {
|
||||
db *pg.DB
|
||||
}
|
||||
|
||||
func NewPG(connectionURL string, ping bool) (Postgres, error) {
|
||||
opts, err := pg.ParseURL(connectionURL)
|
||||
if err != nil {
|
||||
return Postgres{}, err
|
||||
}
|
||||
|
||||
// If host format is "projection:region:host", than use Google SQL Proxy
|
||||
// See https://github.com/go-pg/pg/issues/576
|
||||
if strings.Count(opts.Addr, ":") == 2 {
|
||||
log.Print("using GCP SQL proxy")
|
||||
opts.Dialer = func(network, addr string) (net.Conn, error) {
|
||||
return proxy.Dial(addr)
|
||||
}
|
||||
}
|
||||
|
||||
db := pg.Connect(opts)
|
||||
|
||||
// Check database connectivity
|
||||
if ping {
|
||||
if _, err := db.ExecOne("SELECT 1"); err != nil {
|
||||
_ = db.Close()
|
||||
return Postgres{}, errors.Wrap(err, "failed to check database connectivity")
|
||||
}
|
||||
}
|
||||
|
||||
return Postgres{db: db}, nil
|
||||
}
|
||||
|
||||
func (p Postgres) SaveFeed(feed *model.Feed) error {
|
||||
_, err := p.db.Model(feed).Insert()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to save feed to database")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (p Postgres) GetFeed(hashID string) (*model.Feed, error) {
|
||||
lastAccess := time.Now().UTC()
|
||||
|
||||
feed := &model.Feed{}
|
||||
res, err := p.db.Model(feed).
|
||||
Set("last_access = ?", lastAccess).
|
||||
Where("hash_id = ?", hashID).
|
||||
Returning("*").
|
||||
Update()
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to query feed: %s", hashID)
|
||||
}
|
||||
|
||||
if res.RowsAffected() != 1 {
|
||||
return nil, api.ErrNotFound
|
||||
}
|
||||
|
||||
return feed, nil
|
||||
}
|
||||
|
||||
func (p Postgres) GetMetadata(hashID string) (*model.Feed, error) {
|
||||
feed := &model.Feed{}
|
||||
err := p.db.
|
||||
Model(feed).
|
||||
Where("hash_id = ?", hashID).
|
||||
Column("provider", "format", "quality", "user_id").
|
||||
Select()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return feed, nil
|
||||
}
|
||||
|
||||
func (p Postgres) Downgrade(patronID string, featureLevel int) error {
|
||||
if featureLevel > api.ExtendedFeatures {
|
||||
return nil
|
||||
}
|
||||
|
||||
if featureLevel == api.ExtendedFeatures {
|
||||
const maxPages = 150
|
||||
_, err := p.db.
|
||||
Model(&model.Feed{}).
|
||||
Set("page_size = ?", maxPages).
|
||||
Where("user_id = ? AND page_size > ?", patronID, maxPages).
|
||||
Update()
|
||||
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to reduce page sizes for patron '%s'", patronID)
|
||||
}
|
||||
|
||||
_, err = p.db.
|
||||
Model(&model.Feed{}).
|
||||
Set("feature_level = ?", api.ExtendedFeatures).
|
||||
Where("user_id = ?", patronID, maxPages).
|
||||
Update()
|
||||
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to downgrade patron '%s' to feature level %d", patronID, featureLevel)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if featureLevel == api.DefaultFeatures {
|
||||
_, err := p.db.
|
||||
Model(&model.Feed{}).
|
||||
Set("page_size = ?", 50).
|
||||
Set("feature_level = ?", api.DefaultFeatures).
|
||||
Set("format = ?", api.FormatVideo).
|
||||
Set("quality = ?", api.QualityHigh).
|
||||
Where("user_id = ?", patronID).
|
||||
Update()
|
||||
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to downgrade patron '%s' to feature level %d", patronID, featureLevel)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.New("unsupported downgrade type")
|
||||
}
|
||||
|
||||
func (p Postgres) AddPledge(pledge *model.Pledge) error {
|
||||
return p.db.Insert(pledge)
|
||||
}
|
||||
|
||||
func (p Postgres) UpdatePledge(patronID string, pledge *model.Pledge) error {
|
||||
updateColumns := []string{
|
||||
"declined_since",
|
||||
"amount_cents",
|
||||
"total_historical_amount_cents",
|
||||
"outstanding_payment_amount_cents",
|
||||
"is_paused",
|
||||
}
|
||||
|
||||
res, err := p.db.Model(pledge).Column(updateColumns...).Where("patron_id = ?", patronID).Update()
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to update pledge %d for user %s: %v", pledge.PledgeID, patronID, err)
|
||||
}
|
||||
|
||||
if res.RowsAffected() != 1 {
|
||||
return errors.Wrapf(err, "unexpected number of updated rows: %d for user %s", res.RowsAffected(), patronID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p Postgres) DeletePledge(pledge *model.Pledge) error {
|
||||
err := p.db.Delete(pledge)
|
||||
if err == pg.ErrNoRows {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (p Postgres) GetPledge(patronID string) (*model.Pledge, error) {
|
||||
pledge := &model.Pledge{}
|
||||
return pledge, p.db.Model(pledge).Where("patron_id = ?", patronID).Limit(1).Select()
|
||||
}
|
||||
|
||||
func (p Postgres) Close() error {
|
||||
return p.db.Close()
|
||||
}
|
64
pkg/storage/pg_sql.go
Normal file
64
pkg/storage/pg_sql.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package storage
|
||||
|
||||
//noinspection SpellCheckingInspection
|
||||
const pgsql = `
|
||||
BEGIN;
|
||||
|
||||
-- Pledges
|
||||
|
||||
CREATE TABLE IF NOT EXISTS pledges (
|
||||
pledge_id BIGSERIAL PRIMARY KEY,
|
||||
patron_id BIGINT NOT NULL UNIQUE,
|
||||
created_at TIMESTAMPTZ NOT NULL,
|
||||
declined_since TIMESTAMPTZ NULL,
|
||||
amount_cents INT NOT NULL,
|
||||
total_historical_amount_cents INT,
|
||||
outstanding_payment_amount_cents INT,
|
||||
is_paused BOOLEAN
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS patron_id_idx ON pledges(patron_id);
|
||||
|
||||
-- Feeds
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'provider') THEN
|
||||
CREATE TYPE provider AS ENUM ('youtube', 'vimeo');
|
||||
END IF;
|
||||
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'link_type') THEN
|
||||
CREATE TYPE link_type AS ENUM ('channel', 'playlist', 'user', 'group');
|
||||
END IF;
|
||||
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'quality') THEN
|
||||
CREATE TYPE quality AS ENUM ('low', 'high');
|
||||
END IF;
|
||||
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'format') THEN
|
||||
CREATE TYPE format AS ENUM ('video', 'audio');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS feeds (
|
||||
feed_id BIGSERIAL PRIMARY KEY,
|
||||
hash_id VARCHAR(12) NOT NULL UNIQUE,
|
||||
user_id VARCHAR(32) NULL,
|
||||
item_id VARCHAR(64) NOT NULL CHECK (item_id <> ''),
|
||||
provider provider NOT NULL,
|
||||
link_type link_type NOT NULL,
|
||||
page_size INT NOT NULL DEFAULT 50,
|
||||
format format NOT NULL DEFAULT 'video',
|
||||
quality quality NOT NULL DEFAULT 'high',
|
||||
feature_level INT NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
last_access TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS feeds_hash_id_idx ON feeds(hash_id);
|
||||
CREATE INDEX IF NOT EXISTS feeds_user_id_idx ON feeds(user_id);
|
||||
|
||||
COMMIT;
|
||||
END;
|
||||
`
|
226
pkg/storage/pg_test.go
Normal file
226
pkg/storage/pg_test.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-pg/pg"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/mxpv/podsync/pkg/api"
|
||||
"github.com/mxpv/podsync/pkg/model"
|
||||
)
|
||||
|
||||
var (
|
||||
testPledge = &model.Pledge{PledgeID: 12345, AmountCents: 400, PatronID: 1, CreatedAt: time.Now()}
|
||||
testFeed = &model.Feed{FeedID: 1, HashID: "3", UserID: "3", ItemID: "4", LinkType: api.LinkTypeChannel, Provider: api.ProviderVimeo, Format: api.FormatAudio ,Quality: api.QualityLow}
|
||||
)
|
||||
|
||||
func TestPostgres_SaveFeed(t *testing.T) {
|
||||
stor := createPG(t)
|
||||
defer func() { _ = stor.Close() }()
|
||||
|
||||
err := stor.SaveFeed(testFeed)
|
||||
require.NoError(t, err)
|
||||
|
||||
find := &model.Feed{FeedID: 1}
|
||||
err = stor.db.Model(find).Select()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, testFeed.FeedID, find.FeedID)
|
||||
require.Equal(t, testFeed.HashID, find.HashID)
|
||||
require.Equal(t, testFeed.UserID, find.UserID)
|
||||
require.Equal(t, testFeed.ItemID, find.ItemID)
|
||||
require.Equal(t, testFeed.LinkType, find.LinkType)
|
||||
require.Equal(t, testFeed.Provider, find.Provider)
|
||||
}
|
||||
|
||||
func TestPostgres_GetFeed(t *testing.T) {
|
||||
stor := createPG(t)
|
||||
defer func() { _ = stor.Close() }()
|
||||
|
||||
err := stor.SaveFeed(testFeed)
|
||||
require.NoError(t, err)
|
||||
|
||||
find, err := stor.GetFeed(testFeed.HashID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, testFeed.FeedID, find.FeedID)
|
||||
require.Equal(t, testFeed.HashID, find.HashID)
|
||||
require.Equal(t, testFeed.UserID, find.UserID)
|
||||
require.Equal(t, testFeed.ItemID, find.ItemID)
|
||||
require.Equal(t, testFeed.LinkType, find.LinkType)
|
||||
require.Equal(t, testFeed.Provider, find.Provider)
|
||||
}
|
||||
|
||||
func TestService_UpdateLastAccess(t *testing.T) {
|
||||
stor := createPG(t)
|
||||
defer func() { _ = stor.Close() }()
|
||||
|
||||
err := stor.db.Insert(testFeed)
|
||||
require.NoError(t, err)
|
||||
|
||||
feed1, err := stor.GetFeed(testFeed.HashID)
|
||||
require.NoError(t, err)
|
||||
|
||||
feed2, err := stor.GetFeed(testFeed.HashID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, feed2.LastAccess.After(feed1.LastAccess))
|
||||
}
|
||||
|
||||
func TestPostgres_GetMetadata(t *testing.T) {
|
||||
stor := createPG(t)
|
||||
defer func() { _ = stor.Close() }()
|
||||
|
||||
err := stor.SaveFeed(testFeed)
|
||||
require.NoError(t, err)
|
||||
|
||||
find, err := stor.GetMetadata(testFeed.HashID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, testFeed.UserID, find.UserID)
|
||||
require.Equal(t, testFeed.Provider, find.Provider)
|
||||
require.Equal(t, testFeed.Quality, find.Quality)
|
||||
require.Equal(t, testFeed.Format, find.Format)
|
||||
}
|
||||
|
||||
func TestService_DowngradeToAnonymous(t *testing.T) {
|
||||
stor := createPG(t)
|
||||
defer func() { _ = stor.Close() }()
|
||||
|
||||
feed := &model.Feed{
|
||||
HashID: "123456",
|
||||
UserID: "123456",
|
||||
ItemID: "123456",
|
||||
Provider: api.ProviderVimeo,
|
||||
LinkType: api.LinkTypeGroup,
|
||||
PageSize: 150,
|
||||
Quality: api.QualityLow,
|
||||
Format: api.FormatAudio,
|
||||
FeatureLevel: api.ExtendedFeatures,
|
||||
}
|
||||
|
||||
err := stor.db.Insert(feed)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = stor.Downgrade(feed.UserID, api.DefaultFeatures)
|
||||
require.NoError(t, err)
|
||||
|
||||
downgraded := &model.Feed{FeedID: feed.FeedID}
|
||||
err = stor.db.Select(downgraded)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, 50, downgraded.PageSize)
|
||||
require.Equal(t, api.QualityHigh, downgraded.Quality)
|
||||
require.Equal(t, api.FormatVideo, downgraded.Format)
|
||||
require.Equal(t, api.DefaultFeatures, downgraded.FeatureLevel)
|
||||
}
|
||||
|
||||
func TestService_DowngradeToExtendedFeatures(t *testing.T) {
|
||||
stor := createPG(t)
|
||||
defer func() { _ = stor.Close() }()
|
||||
|
||||
feed := &model.Feed{
|
||||
HashID: "123456",
|
||||
UserID: "123456",
|
||||
ItemID: "123456",
|
||||
Provider: api.ProviderVimeo,
|
||||
LinkType: api.LinkTypeGroup,
|
||||
PageSize: 500,
|
||||
Quality: api.QualityLow,
|
||||
Format: api.FormatAudio,
|
||||
FeatureLevel: api.ExtendedFeatures,
|
||||
}
|
||||
|
||||
err := stor.db.Insert(feed)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = stor.Downgrade(feed.UserID, api.ExtendedFeatures)
|
||||
require.NoError(t, err)
|
||||
|
||||
downgraded := &model.Feed{FeedID: feed.FeedID}
|
||||
err = stor.db.Select(downgraded)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, 150, downgraded.PageSize)
|
||||
require.Equal(t, feed.Quality, downgraded.Quality)
|
||||
require.Equal(t, feed.Format, downgraded.Format)
|
||||
require.Equal(t, api.ExtendedFeatures, downgraded.FeatureLevel)
|
||||
}
|
||||
|
||||
func TestPostgres_AddPledge(t *testing.T) {
|
||||
stor := createPG(t)
|
||||
defer func() { _ = stor.Close() }()
|
||||
|
||||
err := stor.AddPledge(testPledge)
|
||||
require.NoError(t, err)
|
||||
|
||||
pledge := &model.Pledge{PledgeID: 12345}
|
||||
err = stor.db.Select(pledge)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int64(12345), pledge.PledgeID)
|
||||
require.Equal(t, 400, pledge.AmountCents)
|
||||
}
|
||||
|
||||
func TestPostgres_UpdatePledge(t *testing.T) {
|
||||
stor := createPG(t)
|
||||
defer func() { _ = stor.Close() }()
|
||||
|
||||
err := stor.AddPledge(testPledge)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = stor.UpdatePledge("1", &model.Pledge{AmountCents: 999})
|
||||
require.NoError(t, err)
|
||||
|
||||
pledge := &model.Pledge{PledgeID: 12345}
|
||||
err = stor.db.Select(pledge)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 999, pledge.AmountCents)
|
||||
}
|
||||
|
||||
func TestPostgres_DeletePledge(t *testing.T) {
|
||||
stor := createPG(t)
|
||||
defer func() { _ = stor.Close() }()
|
||||
|
||||
err := stor.AddPledge(testPledge)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = stor.DeletePledge(testPledge)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = stor.db.Select(&model.Pledge{PledgeID: 12345})
|
||||
require.Equal(t, pg.ErrNoRows, err)
|
||||
}
|
||||
|
||||
func TestPostgres_GetPledge(t *testing.T) {
|
||||
stor := createPG(t)
|
||||
defer func() { _ = stor.Close() }()
|
||||
|
||||
err := stor.AddPledge(testPledge)
|
||||
require.NoError(t, err)
|
||||
|
||||
pledge, err := stor.GetPledge("1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 400, pledge.AmountCents)
|
||||
require.Equal(t, int64(12345), pledge.PledgeID)
|
||||
}
|
||||
|
||||
// docker run -it --rm -p 5432:5432 -e POSTGRES_DB=podsync postgres
|
||||
func createPG(t *testing.T) Postgres {
|
||||
const localConnectionString = "postgres://postgres:@localhost/podsync?sslmode=disable"
|
||||
|
||||
postgres, err := NewPG(localConnectionString, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = postgres.db.Exec(pgsql)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, obj := range []interface{}{&model.Pledge{}, &model.Feed{}} {
|
||||
_, err = postgres.db.Model(obj).Where("1=1").Delete()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
return postgres
|
||||
}
|
Reference in New Issue
Block a user