mirror of
https://github.com/mxpv/podsync.git
synced 2024-05-11 05:55:04 +00:00
Implement database layer
This commit is contained in:
43
web/pkg/database/models.go
Normal file
43
web/pkg/database/models.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package database
|
||||
|
||||
type Quality string
|
||||
type Format string
|
||||
|
||||
const (
|
||||
HighQuality = Quality("high")
|
||||
LowQuality = Quality("low")
|
||||
AudioFormat = Format("audio")
|
||||
VideoFormat = Format("video")
|
||||
)
|
||||
|
||||
type Feed struct {
|
||||
Id int64
|
||||
HashId string
|
||||
UserId string
|
||||
URL string
|
||||
PageSize int
|
||||
Quality Quality
|
||||
Format Format
|
||||
}
|
||||
|
||||
// Query helpers
|
||||
|
||||
type WhereFunc func() (string, interface{})
|
||||
|
||||
func WithId(id int) WhereFunc {
|
||||
return func() (string, interface{}) {
|
||||
return "id", id
|
||||
}
|
||||
}
|
||||
|
||||
func WithHashId(hashId string) WhereFunc {
|
||||
return func() (string, interface{}) {
|
||||
return "hash_id", hashId
|
||||
}
|
||||
}
|
||||
|
||||
func WithUserId(userId string) WhereFunc {
|
||||
return func() (string, interface{}) {
|
||||
return "user_id", userId
|
||||
}
|
||||
}
|
77
web/pkg/database/pg.go
Normal file
77
web/pkg/database/pg.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/proxy"
|
||||
"github.com/go-pg/pg"
|
||||
"github.com/go-pg/pg/orm"
|
||||
"github.com/pkg/errors"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type PgConfig struct {
|
||||
ConnectionUrl string `yaml:"connectionUrl"`
|
||||
}
|
||||
|
||||
type PgStorage struct {
|
||||
db *pg.DB
|
||||
}
|
||||
|
||||
func (p *PgStorage) CreateFeed(feed *Feed) error {
|
||||
_, err := p.db.Model(feed).Insert()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create feed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PgStorage) GetFeed(q ...WhereFunc) (out []Feed, err error) {
|
||||
out = []Feed{}
|
||||
err = p.db.Model(&out).Apply(whereFunc(q...)).Select()
|
||||
return
|
||||
}
|
||||
|
||||
func whereFunc(where ...WhereFunc) func(*orm.Query) (*orm.Query, error) {
|
||||
return func(q *orm.Query) (*orm.Query, error) {
|
||||
for _, fn := range where {
|
||||
field, value := fn()
|
||||
q = q.Where(field+" = ?", value)
|
||||
}
|
||||
|
||||
return q, nil
|
||||
}
|
||||
}
|
||||
|
||||
func NewPgStorage(config *PgConfig) (*PgStorage, error) {
|
||||
opts, err := pg.ParseURL(config.ConnectionUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If host format is "projection:region:host", than use Google SQL Proxy
|
||||
// See https://github.com/go-pg/pg/issues/576
|
||||
if strings.Count(opts.Addr, ":") == 2 {
|
||||
log.Print("using GCP SQL proxy")
|
||||
opts.Dialer = func(network, addr string) (net.Conn, error) {
|
||||
return proxy.Dial(addr)
|
||||
}
|
||||
}
|
||||
|
||||
db := pg.Connect(opts)
|
||||
|
||||
// Check database connectivity
|
||||
if _, err := db.ExecOne("SELECT 1"); err != nil {
|
||||
db.Close()
|
||||
return nil, errors.Wrap(err, "failed to check database connectivity")
|
||||
}
|
||||
|
||||
log.Print("running update script")
|
||||
if _, err := db.Exec(installScript); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to upgrade database structure")
|
||||
}
|
||||
|
||||
storage := &PgStorage{db: db}
|
||||
return storage, nil
|
||||
}
|
28
web/pkg/database/pg_sql.go
Normal file
28
web/pkg/database/pg_sql.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package database
|
||||
|
||||
const installScript = `
|
||||
BEGIN;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'quality') THEN
|
||||
CREATE TYPE quality AS ENUM ('high', 'low');
|
||||
END IF;
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'format') THEN
|
||||
CREATE TYPE format AS ENUM ('audio', 'video');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS feeds (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
hash_id VARCHAR(12) NOT NULL CHECK (hash_id <> ''),
|
||||
user_id VARCHAR(32) NULL,
|
||||
url VARCHAR(64) NOT NULL CHECK (url <> ''),
|
||||
page_size INT NOT NULL DEFAULT 50,
|
||||
quality quality NOT NULL DEFAULT 'high',
|
||||
format format NOT NULL DEFAULT 'video'
|
||||
);
|
||||
|
||||
COMMIT;
|
||||
`
|
51
web/pkg/database/pg_test.go
Normal file
51
web/pkg/database/pg_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
feed := &Feed{
|
||||
HashId: "xyz",
|
||||
URL: "http://youtube.com",
|
||||
}
|
||||
|
||||
client := createClient(t)
|
||||
err := client.CreateFeed(feed)
|
||||
require.NoError(t, err)
|
||||
require.True(t, feed.Id > 0)
|
||||
}
|
||||
|
||||
func TestGetFeed(t *testing.T) {
|
||||
feed := &Feed{
|
||||
HashId: "xyz",
|
||||
UserId: "123",
|
||||
URL: "http://youtube.com",
|
||||
}
|
||||
|
||||
client := createClient(t)
|
||||
client.CreateFeed(feed)
|
||||
|
||||
out, err := client.GetFeed(WithUserId("123"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(out))
|
||||
require.Equal(t, feed.Id, out[0].Id)
|
||||
|
||||
out, err = client.GetFeed(WithHashId("xyz"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(out))
|
||||
require.Equal(t, feed.Id, out[0].Id)
|
||||
}
|
||||
|
||||
const TestDatabaseConnectionUrl = "postgres://postgres:@localhost/podsync?sslmode=disable"
|
||||
|
||||
func createClient(t *testing.T) *PgStorage {
|
||||
pg, err := NewPgStorage(&PgConfig{ConnectionUrl: TestDatabaseConnectionUrl})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = pg.db.Model(&Feed{}).Where("1=1").Delete()
|
||||
require.NoError(t, err)
|
||||
|
||||
return pg
|
||||
}
|
Reference in New Issue
Block a user