diff --git a/web/pkg/database/models.go b/web/pkg/database/models.go new file mode 100644 index 0000000..f0892fb --- /dev/null +++ b/web/pkg/database/models.go @@ -0,0 +1,43 @@ +package database + +type Quality string +type Format string + +const ( + HighQuality = Quality("high") + LowQuality = Quality("low") + AudioFormat = Format("audio") + VideoFormat = Format("video") +) + +type Feed struct { + Id int64 + HashId string + UserId string + URL string + PageSize int + Quality Quality + Format Format +} + +// Query helpers + +type WhereFunc func() (string, interface{}) + +func WithId(id int) WhereFunc { + return func() (string, interface{}) { + return "id", id + } +} + +func WithHashId(hashId string) WhereFunc { + return func() (string, interface{}) { + return "hash_id", hashId + } +} + +func WithUserId(userId string) WhereFunc { + return func() (string, interface{}) { + return "user_id", userId + } +} diff --git a/web/pkg/database/pg.go b/web/pkg/database/pg.go new file mode 100644 index 0000000..7e90e71 --- /dev/null +++ b/web/pkg/database/pg.go @@ -0,0 +1,77 @@ +package database + +import ( + "github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/proxy" + "github.com/go-pg/pg" + "github.com/go-pg/pg/orm" + "github.com/pkg/errors" + "log" + "net" + "strings" +) + +type PgConfig struct { + ConnectionUrl string `yaml:"connectionUrl"` +} + +type PgStorage struct { + db *pg.DB +} + +func (p *PgStorage) CreateFeed(feed *Feed) error { + _, err := p.db.Model(feed).Insert() + if err != nil { + return errors.Wrap(err, "failed to create feed") + } + + return nil +} + +func (p *PgStorage) GetFeed(q ...WhereFunc) (out []Feed, err error) { + out = []Feed{} + err = p.db.Model(&out).Apply(whereFunc(q...)).Select() + return +} + +func whereFunc(where ...WhereFunc) func(*orm.Query) (*orm.Query, error) { + return func(q *orm.Query) (*orm.Query, error) { + for _, fn := range where { + field, value := fn() + q = q.Where(field+" = ?", value) + } + + return q, nil + } +} + +func NewPgStorage(config *PgConfig) (*PgStorage, error) { + opts, err := pg.ParseURL(config.ConnectionUrl) + if err != nil { + return nil, err + } + + // If host format is "projection:region:host", than use Google SQL Proxy + // See https://github.com/go-pg/pg/issues/576 + if strings.Count(opts.Addr, ":") == 2 { + log.Print("using GCP SQL proxy") + opts.Dialer = func(network, addr string) (net.Conn, error) { + return proxy.Dial(addr) + } + } + + db := pg.Connect(opts) + + // Check database connectivity + if _, err := db.ExecOne("SELECT 1"); err != nil { + db.Close() + return nil, errors.Wrap(err, "failed to check database connectivity") + } + + log.Print("running update script") + if _, err := db.Exec(installScript); err != nil { + return nil, errors.Wrap(err, "failed to upgrade database structure") + } + + storage := &PgStorage{db: db} + return storage, nil +} diff --git a/web/pkg/database/pg_sql.go b/web/pkg/database/pg_sql.go new file mode 100644 index 0000000..543781e --- /dev/null +++ b/web/pkg/database/pg_sql.go @@ -0,0 +1,28 @@ +package database + +const installScript = ` +BEGIN; + +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'quality') THEN + CREATE TYPE quality AS ENUM ('high', 'low'); + END IF; + IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'format') THEN + CREATE TYPE format AS ENUM ('audio', 'video'); + END IF; +END +$$; + +CREATE TABLE IF NOT EXISTS feeds ( + id BIGSERIAL PRIMARY KEY, + hash_id VARCHAR(12) NOT NULL CHECK (hash_id <> ''), + user_id VARCHAR(32) NULL, + url VARCHAR(64) NOT NULL CHECK (url <> ''), + page_size INT NOT NULL DEFAULT 50, + quality quality NOT NULL DEFAULT 'high', + format format NOT NULL DEFAULT 'video' +); + +COMMIT; +` diff --git a/web/pkg/database/pg_test.go b/web/pkg/database/pg_test.go new file mode 100644 index 0000000..7c6ef23 --- /dev/null +++ b/web/pkg/database/pg_test.go @@ -0,0 +1,51 @@ +package database + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestCreate(t *testing.T) { + feed := &Feed{ + HashId: "xyz", + URL: "http://youtube.com", + } + + client := createClient(t) + err := client.CreateFeed(feed) + require.NoError(t, err) + require.True(t, feed.Id > 0) +} + +func TestGetFeed(t *testing.T) { + feed := &Feed{ + HashId: "xyz", + UserId: "123", + URL: "http://youtube.com", + } + + client := createClient(t) + client.CreateFeed(feed) + + out, err := client.GetFeed(WithUserId("123")) + require.NoError(t, err) + require.Equal(t, 1, len(out)) + require.Equal(t, feed.Id, out[0].Id) + + out, err = client.GetFeed(WithHashId("xyz")) + require.NoError(t, err) + require.Equal(t, 1, len(out)) + require.Equal(t, feed.Id, out[0].Id) +} + +const TestDatabaseConnectionUrl = "postgres://postgres:@localhost/podsync?sslmode=disable" + +func createClient(t *testing.T) *PgStorage { + pg, err := NewPgStorage(&PgConfig{ConnectionUrl: TestDatabaseConnectionUrl}) + require.NoError(t, err) + + _, err = pg.db.Model(&Feed{}).Where("1=1").Delete() + require.NoError(t, err) + + return pg +}