From 9d5fbf2f79603d37315adb6a75e4acfaa8d6e0b7 Mon Sep 17 00:00:00 2001 From: Maksym Pavlenko Date: Sun, 2 Dec 2018 13:27:31 -0800 Subject: [PATCH] Implement DynamoDB storage, refactor unit tests --- go.mod | 2 +- go.sum | 6 +- pkg/api/api.go | 13 ++ pkg/model/model.go | 31 +-- pkg/storage/dynamo.go | 404 ++++++++++++++++++++++++++++++++++++ pkg/storage/dynamo_test.go | 117 +++++++++++ pkg/storage/pg.go | 7 +- pkg/storage/pg_test.go | 187 +---------------- pkg/storage/storage_test.go | 272 ++++++++++++++++++++++++ 9 files changed, 839 insertions(+), 200 deletions(-) create mode 100644 pkg/storage/dynamo.go create mode 100644 pkg/storage/dynamo_test.go create mode 100644 pkg/storage/storage_test.go diff --git a/go.mod b/go.mod index 154da83..fa05d56 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ require ( github.com/BrianHicks/finch v0.0.0-20140409222414-419bd73c29ec github.com/BurntSushi/toml v0.3.1 // indirect github.com/GoogleCloudPlatform/cloudsql-proxy v0.0.0-20170929212804-61590edac4c7 + github.com/aws/aws-sdk-go v1.15.81 github.com/boj/redistore v0.0.0-20160128113310-fc113767cd6b // indirect github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737 // indirect github.com/bradleypeabody/gorilla-sessions-memcache v0.0.0-20180621172731-4e5d6d543851 // indirect @@ -39,7 +40,6 @@ require ( github.com/spf13/jwalterweatherman v0.0.0-20180109140146-7c0cea34c8ec // indirect github.com/spf13/pflag v1.0.1 // indirect github.com/spf13/viper v1.0.2 - github.com/stretchr/objx v0.1.1 // indirect github.com/stretchr/testify v1.2.2 github.com/teris-io/shortid v0.0.0-20171029131806-771a37caa5cf // indirect github.com/ugorji/go v1.1.1 // indirect diff --git a/go.sum b/go.sum index 12cb380..64fe483 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/GoogleCloudPlatform/cloudsql-proxy v0.0.0-20170929212804-61590edac4c7 h1:Clo7QBZv+fHzjCgVp4ELlbIsY5rScCmj+4VCfoMfqtQ= github.com/GoogleCloudPlatform/cloudsql-proxy v0.0.0-20170929212804-61590edac4c7/go.mod h1:aJ4qN3TfrelA6NZ6AXsXRfmEVaYin3EDbSPJrKS8OXo= +github.com/aws/aws-sdk-go v1.15.81 h1:va7uoFaV9uKAtZ6BTmp1u7paoMsizYRRLvRuoC07nQ8= +github.com/aws/aws-sdk-go v1.15.81/go.mod h1:E3/ieXAlvM0XWO57iftYVDLLvQ824smPP3ATZkfNZeM= github.com/boj/redistore v0.0.0-20160128113310-fc113767cd6b h1:PfxLkkgJYE095CKZji++BNwZjxWfoAF21WFPzkzOZEs= github.com/boj/redistore v0.0.0-20160128113310-fc113767cd6b/go.mod h1:5r9chGCb4uUhBCGMDDCYfyHU/awSRoBeG53Zaj1crhU= github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737 h1:rRISKWyXfVxvoa702s91Zl5oREZTrR3yv+tXrrX7G/g= @@ -48,6 +50,8 @@ github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a h1:eeaG9XMUvRBYXJi4pg1ZKM7nxc5AfXfojeLLW7O5J3k= github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8 h1:12VvqtR6Aowv3l/EQUlocDHW2Cp4G9WJVH7uyH8QFJE= +github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/kidstuff/mongostore v0.0.0-20180412085134-db2a8b4fac1f h1:84d0qxD9AiuBNpeK5TkYwTKKNezsYxIVn8nWh0pq51E= github.com/kidstuff/mongostore v0.0.0-20180412085134-db2a8b4fac1f/go.mod h1:g2nVr8KZVXJSS97Jo8pJ0jgq29P6H7dG0oplUA86MQw= github.com/magiconair/properties v1.8.0 h1:LLgXmsheXeRoUOBOjtwPQCWIYqM/LU1ayDtDePerRcY= @@ -85,8 +89,6 @@ github.com/spf13/pflag v1.0.1 h1:aCvUg6QPl3ibpQUxyLkrEkCHtPqYJL4x9AuhqVqFis4= github.com/spf13/pflag v1.0.1/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/viper v1.0.2 h1:Ncr3ZIuJn322w2k1qmzXDnkLAdQMlJqBa9kfAH+irso= github.com/spf13/viper v1.0.2/go.mod h1:A8kyI5cUJhb8N+3pkfONlcEcZbueH6nhAm0Fq7SrnBM= -github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A= -github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/teris-io/shortid v0.0.0-20171029131806-771a37caa5cf h1:Z2X3Os7oRzpdJ75iPqWZc0HeJWFYNCvKsfpQwFpRNTA= diff --git a/pkg/api/api.go b/pkg/api/api.go index ff20c63..1584a4f 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -54,9 +54,22 @@ type Metadata struct { } const ( + // Page size: 50 + // Format: video + // Quality: high DefaultFeatures = iota + + // Max page size: 150 + // Format: any + // Quality: any ExtendedFeatures + + // Max page size: 600 + // Format: any + // Quality: any ExtendedPagination + + // Unlimited PodcasterFeature ) diff --git a/pkg/model/model.go b/pkg/model/model.go index 3252f32..9e5be12 100644 --- a/pkg/model/model.go +++ b/pkg/model/model.go @@ -6,28 +6,31 @@ import ( "github.com/mxpv/podsync/pkg/api" ) +//noinspection SpellCheckingInspection type Pledge struct { PledgeID int64 `sql:",pk"` PatronID int64 - CreatedAt time.Time - DeclinedSince time.Time + CreatedAt time.Time `dynamodbav:",unixtime"` + DeclinedSince time.Time `dynamodbav:",unixtime"` AmountCents int TotalHistoricalAmountCents int OutstandingPaymentAmountCents int IsPaused bool } +//noinspection SpellCheckingInspection type Feed struct { - FeedID int64 `sql:",pk"` - HashID string // Short human readable feed id for users - UserID string // Patreon user id - ItemID string - LinkType api.LinkType // Either group, channel or user - Provider api.Provider // Youtube or Vimeo - PageSize int // The number of episodes to return - Format api.Format - Quality api.Quality - FeatureLevel int - CreatedAt time.Time - LastAccess time.Time // Available features + FeedID int64 `sql:",pk" dynamodbav:"-"` + HashID string // Short human readable feed id for users + UserID string // Patreon user id + ItemID string + LinkType api.LinkType // Either group, channel or user + Provider api.Provider // Youtube or Vimeo + PageSize int // The number of episodes to return + Format api.Format + Quality api.Quality + FeatureLevel int + CreatedAt time.Time `dynamodbav:",unixtime"` + LastAccess time.Time `dynamodbav:",unixtime"` + ExpirationTime time.Time `sql:"-" dynamodbav:",unixtime"` } diff --git a/pkg/storage/dynamo.go b/pkg/storage/dynamo.go new file mode 100644 index 0000000..06f8325 --- /dev/null +++ b/pkg/storage/dynamo.go @@ -0,0 +1,404 @@ +package storage + +import ( + "context" + "strconv" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/dynamodb" + attr "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" + expr "github.com/aws/aws-sdk-go/service/dynamodb/expression" + "github.com/pkg/errors" + + "github.com/mxpv/podsync/pkg/api" + "github.com/mxpv/podsync/pkg/model" +) + +const ( + defaultRegion = "us-east-1" + + pingTimeout = 5 * time.Second + pledgesPrimaryKey = "PatronID" + feedsPrimaryKey = "HashID" + + // Update LastAccess field every hour + feedLastAccessUpdatePeriod = time.Hour + feedTimeToLive = time.Hour * 24 * 90 +) + +var ( + pledgesTableName = aws.String("Pledges") + feedsTableName = aws.String("Feeds") + feedTimeToLiveField = aws.String("ExpirationTime") + feedDowngradeIndexName = aws.String("UserID-HashID-Index") +) + +/* +Pledges: + Table name: Pledges + Primary key: PatronID (Number) + RCU: 1 (used while creating a new feed) + WCU: 1 (used when pledge changes) + No secondary indexed needed +Feeds: + Table name: Feeds + Primary key: HashID (String) + Secondary index: + Primary key: UserID (String) + Sort key: HashID (String) + RCU: 10 + WCU: 5 + Index name: UserID-HashID-Index + Projected attr: Keys only + RCU/WCU: 1/1 + TTL attr: ExpirationTime +*/ +type Dynamo struct { + dynamo *dynamodb.DynamoDB +} + +func NewDynamo(region, endpoint string) (Dynamo, error) { + if region == "" { + region = defaultRegion + } + + sess, err := session.NewSession(&aws.Config{ + Region: aws.String(region), + Endpoint: aws.String(endpoint), + }) + + if err != nil { + return Dynamo{}, err + } + + db := dynamodb.New(sess) + + // Verify connectivity + ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) + defer cancel() + + _, err = db.ListTablesWithContext(ctx, &dynamodb.ListTablesInput{}) + if err != nil { + return Dynamo{}, err + } + + return Dynamo{dynamo: db}, nil +} + +func (d Dynamo) SaveFeed(feed *model.Feed) error { + now := time.Now().UTC() + + feed.LastAccess = now + feed.ExpirationTime = now.Add(feedTimeToLive) + + item, err := attr.MarshalMap(feed) + if err != nil { + return err + } + + input := &dynamodb.PutItemInput{ + TableName: feedsTableName, + Item: item, + ConditionExpression: aws.String("attribute_not_exists(HashID)"), + } + + _, err = d.dynamo.PutItem(input) + return err +} + +func (d Dynamo) GetFeed(hashID string) (*model.Feed, error) { + getInput := &dynamodb.GetItemInput{ + TableName: feedsTableName, + Key: map[string]*dynamodb.AttributeValue{ + "HashID": {S: aws.String(hashID)}, + }, + } + + getOutput, err := d.dynamo.GetItem(getInput) + if err != nil { + return nil, err + } + + if getOutput.Item == nil { + return nil, errors.New("not found") + } + + var feed model.Feed + if err := attr.UnmarshalMap(getOutput.Item, &feed); err != nil { + return nil, err + } + + // Check if we need to update LastAccess field (no more than once per hour) + now := time.Now().UTC() + if feed.LastAccess.Add(feedLastAccessUpdatePeriod).Before(now) { + // Set LastAccess field to now + // Set ExpirationTime field to now + feedTimeToLive + builder := expr. + Set(expr.Name("LastAccess"), expr.Value(now)). + Set(expr.Name("ExpirationTime"), expr.Value(now.Add(feedTimeToLive))) + + updateExpression, err := expr.NewBuilder().WithUpdate(builder).Build() + if err != nil { + return nil, err + } + + updateInput := &dynamodb.UpdateItemInput{ + TableName: feedsTableName, + Key: getInput.Key, + UpdateExpression: updateExpression.Update(), + } + + _, err = d.dynamo.UpdateItem(updateInput) + if err != nil { + return nil, err + } + + feed.LastAccess = now + } + + return &feed, nil +} + +func (d Dynamo) GetMetadata(hashID string) (*model.Feed, error) { + projectionExpression, err := expr. + NewBuilder(). + WithProjection( + expr.NamesList( + expr.Name("FeedID"), + expr.Name("HashID"), + expr.Name("UserID"), + expr.Name("Provider"), + expr.Name("Format"), + expr.Name("Quality"))). + Build() + + + input := &dynamodb.GetItemInput{ + TableName: feedsTableName, + Key: map[string]*dynamodb.AttributeValue{ + "HashID": {S: aws.String(hashID)}, + }, + ProjectionExpression: projectionExpression.Projection(), + ExpressionAttributeNames: projectionExpression.Names(), + } + + output, err := d.dynamo.GetItem(input) + if err != nil { + return nil, err + } + + if output.Item == nil { + return nil, errors.New("not found") + } + + var feed model.Feed + if err := attr.UnmarshalMap(output.Item, &feed); err != nil { + return nil, err + } + + return &feed, nil +} + +func (d Dynamo) Downgrade(userID string, featureLevel int) error { + if featureLevel > api.ExtendedFeatures { + // Max page size: 600 + // Format: any + // Quality: any + return nil + } + + keyConditionExpression, err := expr. + NewBuilder(). + WithKeyCondition(expr.KeyEqual(expr.Key("UserID"), expr.Value(userID))). + Build() + + if err != nil { + return err + } + + // Query all feed's hash ids for specified + + queryInput := &dynamodb.QueryInput{ + TableName: feedsTableName, + IndexName: feedDowngradeIndexName, + KeyConditionExpression: keyConditionExpression.KeyCondition(), + ExpressionAttributeNames: keyConditionExpression.Names(), + ExpressionAttributeValues: keyConditionExpression.Values(), + Select: aws.String(dynamodb.SelectAllProjectedAttributes), + } + + var keys []map[string]*dynamodb.AttributeValue + err = d.dynamo.QueryPages(queryInput, func(output *dynamodb.QueryOutput, lastPage bool) bool { + for _, item := range output.Items { + keys = append(keys, map[string]*dynamodb.AttributeValue{ + feedsPrimaryKey: item[feedsPrimaryKey], + }) + } + + return true + }) + + if err != nil { + return err + } + + if featureLevel == api.ExtendedFeatures { + // Max page size: 150 + // Format: any + // Quality: any + updateExpression, err := expr. + NewBuilder(). + WithUpdate(expr. + Set(expr.Name("PageSize"), expr.Value(150)). + Set(expr.Name("FeatureLevel"), expr.Value(api.ExtendedFeatures))). + WithCondition(expr. + Name("PageSize").GreaterThan(expr.Value(150))). + Build() + + if err != nil { + return err + } + + for _, key := range keys { + input := &dynamodb.UpdateItemInput{ + TableName: feedsTableName, + Key: key, + ConditionExpression: updateExpression.Condition(), + UpdateExpression: updateExpression.Update(), + ExpressionAttributeNames: updateExpression.Names(), + ExpressionAttributeValues: updateExpression.Values(), + } + + _, err := d.dynamo.UpdateItem(input) + if err != nil { + return err + } + } + + } else if featureLevel == api.DefaultFeatures { + // Page size: 50 + // Format: video + // Quality: high + updateExpression, err := expr. + NewBuilder(). + WithUpdate(expr. + Set(expr.Name("PageSize"), expr.Value(50)). + Set(expr.Name("FeatureLevel"), expr.Value(api.DefaultFeatures)). + Set(expr.Name("Format"), expr.Value(api.FormatVideo)). + Set(expr.Name("Quality"), expr.Value(api.QualityHigh))). + Build() + + if err != nil { + return err + } + + for _, key := range keys { + input := &dynamodb.UpdateItemInput{ + TableName: feedsTableName, + Key: key, + UpdateExpression: updateExpression.Update(), + ExpressionAttributeNames: updateExpression.Names(), + ExpressionAttributeValues: updateExpression.Values(), + } + + _, err := d.dynamo.UpdateItem(input) + if err != nil { + return err + } + } + } + + return nil +} + +func (d Dynamo) AddPledge(pledge *model.Pledge) error { + item, err := attr.MarshalMap(pledge) + if err != nil { + return err + } + + input := &dynamodb.PutItemInput{ + TableName: pledgesTableName, + Item: item, + ConditionExpression: aws.String("attribute_not_exists(PatronID)"), + } + + _, err = d.dynamo.PutItem(input) + return err +} + +func (d Dynamo) UpdatePledge(patronID string, pledge *model.Pledge) error { + builder := expr. + Set(expr.Name("DeclinedSince"), expr.Value(pledge.DeclinedSince)). + Set(expr.Name("AmountCents"), expr.Value(pledge.AmountCents)). + Set(expr.Name("TotalHistoricalAmountCents"), expr.Value(pledge.TotalHistoricalAmountCents)). + Set(expr.Name("OutstandingPaymentAmountCents"), expr.Value(pledge.OutstandingPaymentAmountCents)). + Set(expr.Name("IsPaused"), expr.Value(pledge.IsPaused)) + + updateExpression, err := expr.NewBuilder().WithUpdate(builder).Build() + if err != nil { + return err + } + + input := &dynamodb.UpdateItemInput{ + TableName: pledgesTableName, + Key: map[string]*dynamodb.AttributeValue{ + pledgesPrimaryKey: {N: aws.String(patronID)}, + }, + UpdateExpression: updateExpression.Update(), + ExpressionAttributeNames: updateExpression.Names(), + ExpressionAttributeValues: updateExpression.Values(), + } + + _, err = d.dynamo.UpdateItem(input) + if err != nil { + return err + } + + return nil +} + +func (d Dynamo) DeletePledge(pledge *model.Pledge) error { + pk := strconv.FormatInt(pledge.PatronID, 10) + + input := &dynamodb.DeleteItemInput{ + TableName: pledgesTableName, + Key: map[string]*dynamodb.AttributeValue{ + pledgesPrimaryKey: {N: aws.String(pk)}, + }, + } + + _, err := d.dynamo.DeleteItem(input) + return err +} + +func (d Dynamo) GetPledge(patronID string) (*model.Pledge, error) { + input := &dynamodb.GetItemInput{ + TableName: pledgesTableName, + Key: map[string]*dynamodb.AttributeValue{ + pledgesPrimaryKey: {N: aws.String(patronID)}, + }, + } + + output, err := d.dynamo.GetItem(input) + if err != nil { + return nil, err + } + + if output.Item == nil { + return nil, errors.New("not found") + } + + var pledge model.Pledge + if err := attr.UnmarshalMap(output.Item, &pledge); err != nil { + return nil, err + } + + return &pledge, nil +} + +func (d Dynamo) Close() error { + return nil +} diff --git a/pkg/storage/dynamo_test.go b/pkg/storage/dynamo_test.go new file mode 100644 index 0000000..91d6a1c --- /dev/null +++ b/pkg/storage/dynamo_test.go @@ -0,0 +1,117 @@ +package storage + +import ( + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/stretchr/testify/require" +) + +func TestDynamo(t *testing.T) { + runStorageTests(t, createDynamo) +} + +// docker run -it --rm -p 8000:8000 amazon/dynamodb-local +// noinspection ALL +func createDynamo(t *testing.T) storage { + d, err := NewDynamo("", "http://localhost:8000/") + require.NoError(t, err) + + d.dynamo.DeleteTable(&dynamodb.DeleteTableInput{TableName: pledgesTableName}) + d.dynamo.DeleteTable(&dynamodb.DeleteTableInput{TableName: feedsTableName}) + + // Create Pledges table + _, err = d.dynamo.CreateTable(&dynamodb.CreateTableInput{ + TableName: pledgesTableName, + AttributeDefinitions: []*dynamodb.AttributeDefinition{ + { + AttributeName: aws.String(pledgesPrimaryKey), + AttributeType: aws.String("N"), + }, + }, + KeySchema: []*dynamodb.KeySchemaElement{ + { + AttributeName: aws.String(pledgesPrimaryKey), + KeyType: aws.String("HASH"), + }, + }, + ProvisionedThroughput: &dynamodb.ProvisionedThroughput{ + ReadCapacityUnits: aws.Int64(1), + WriteCapacityUnits: aws.Int64(1), + }, + }) + + require.NoError(t, err) + + // Create Feeds table + _, err = d.dynamo.CreateTable(&dynamodb.CreateTableInput{ + TableName: feedsTableName, + AttributeDefinitions: []*dynamodb.AttributeDefinition{ + { + AttributeName: aws.String(feedsPrimaryKey), + AttributeType: aws.String("S"), + }, + { + AttributeName: aws.String("UserID"), + AttributeType: aws.String("S"), + }, + { + AttributeName: aws.String("CreatedAt"), + AttributeType: aws.String("N"), + }, + }, + KeySchema: []*dynamodb.KeySchemaElement{ + { + AttributeName: aws.String(feedsPrimaryKey), + KeyType: aws.String("HASH"), + }, + }, + GlobalSecondaryIndexes: []*dynamodb.GlobalSecondaryIndex{ + { + IndexName: feedDowngradeIndexName, + KeySchema: []*dynamodb.KeySchemaElement{ + { + AttributeName: aws.String("UserID"), + KeyType: aws.String("HASH"), + }, + { + AttributeName: aws.String("CreatedAt"), + KeyType: aws.String("RANGE"), + }, + }, + Projection: &dynamodb.Projection{ + ProjectionType: aws.String("KEYS_ONLY"), + }, + ProvisionedThroughput: &dynamodb.ProvisionedThroughput{ + ReadCapacityUnits: aws.Int64(1), + WriteCapacityUnits: aws.Int64(1), + }, + }, + }, + ProvisionedThroughput: &dynamodb.ProvisionedThroughput{ + ReadCapacityUnits: aws.Int64(1), + WriteCapacityUnits: aws.Int64(1), + }, + }) + + require.NoError(t, err) + + err = d.dynamo.WaitUntilTableExists(&dynamodb.DescribeTableInput{TableName: pledgesTableName}) + require.NoError(t, err) + + err = d.dynamo.WaitUntilTableExists(&dynamodb.DescribeTableInput{TableName: feedsTableName}) + require.NoError(t, err) + + _, err = d.dynamo.UpdateTimeToLive(&dynamodb.UpdateTimeToLiveInput{ + TableName: feedsTableName, + TimeToLiveSpecification: &dynamodb.TimeToLiveSpecification{ + AttributeName: feedTimeToLiveField, + Enabled: aws.Bool(true), + }, + }) + + require.NoError(t, err) + + return d +} diff --git a/pkg/storage/pg.go b/pkg/storage/pg.go index 306886c..8abc121 100644 --- a/pkg/storage/pg.go +++ b/pkg/storage/pg.go @@ -176,7 +176,12 @@ func (p Postgres) DeletePledge(pledge *model.Pledge) error { func (p Postgres) GetPledge(patronID string) (*model.Pledge, error) { pledge := &model.Pledge{} - return pledge, p.db.Model(pledge).Where("patron_id = ?", patronID).Limit(1).Select() + err := p.db.Model(pledge).Where("patron_id = ?", patronID).Limit(1).Select() + if err != nil { + return nil, err + } + + return pledge, nil } func (p Postgres) Close() error { diff --git a/pkg/storage/pg_test.go b/pkg/storage/pg_test.go index 62e2798..342c65e 100644 --- a/pkg/storage/pg_test.go +++ b/pkg/storage/pg_test.go @@ -2,58 +2,13 @@ package storage import ( "testing" - "time" - "github.com/go-pg/pg" "github.com/stretchr/testify/require" - "github.com/mxpv/podsync/pkg/api" "github.com/mxpv/podsync/pkg/model" ) -var ( - testPledge = &model.Pledge{PledgeID: 12345, AmountCents: 400, PatronID: 1, CreatedAt: time.Now()} - testFeed = &model.Feed{FeedID: 1, HashID: "3", UserID: "3", ItemID: "4", LinkType: api.LinkTypeChannel, Provider: api.ProviderVimeo, Format: api.FormatAudio ,Quality: api.QualityLow} -) - -func TestPostgres_SaveFeed(t *testing.T) { - stor := createPG(t) - defer func() { _ = stor.Close() }() - - err := stor.SaveFeed(testFeed) - require.NoError(t, err) - - find := &model.Feed{FeedID: 1} - err = stor.db.Model(find).Select() - require.NoError(t, err) - - require.Equal(t, testFeed.FeedID, find.FeedID) - require.Equal(t, testFeed.HashID, find.HashID) - require.Equal(t, testFeed.UserID, find.UserID) - require.Equal(t, testFeed.ItemID, find.ItemID) - require.Equal(t, testFeed.LinkType, find.LinkType) - require.Equal(t, testFeed.Provider, find.Provider) -} - -func TestPostgres_GetFeed(t *testing.T) { - stor := createPG(t) - defer func() { _ = stor.Close() }() - - err := stor.SaveFeed(testFeed) - require.NoError(t, err) - - find, err := stor.GetFeed(testFeed.HashID) - require.NoError(t, err) - - require.Equal(t, testFeed.FeedID, find.FeedID) - require.Equal(t, testFeed.HashID, find.HashID) - require.Equal(t, testFeed.UserID, find.UserID) - require.Equal(t, testFeed.ItemID, find.ItemID) - require.Equal(t, testFeed.LinkType, find.LinkType) - require.Equal(t, testFeed.Provider, find.Provider) -} - -func TestService_UpdateLastAccess(t *testing.T) { +func TestPostgres_UpdateLastAccess(t *testing.T) { stor := createPG(t) defer func() { _ = stor.Close() }() @@ -69,142 +24,10 @@ func TestService_UpdateLastAccess(t *testing.T) { require.True(t, feed2.LastAccess.After(feed1.LastAccess)) } -func TestPostgres_GetMetadata(t *testing.T) { - stor := createPG(t) - defer func() { _ = stor.Close() }() - - err := stor.SaveFeed(testFeed) - require.NoError(t, err) - - find, err := stor.GetMetadata(testFeed.HashID) - require.NoError(t, err) - - require.Equal(t, testFeed.UserID, find.UserID) - require.Equal(t, testFeed.Provider, find.Provider) - require.Equal(t, testFeed.Quality, find.Quality) - require.Equal(t, testFeed.Format, find.Format) -} - -func TestService_DowngradeToAnonymous(t *testing.T) { - stor := createPG(t) - defer func() { _ = stor.Close() }() - - feed := &model.Feed{ - HashID: "123456", - UserID: "123456", - ItemID: "123456", - Provider: api.ProviderVimeo, - LinkType: api.LinkTypeGroup, - PageSize: 150, - Quality: api.QualityLow, - Format: api.FormatAudio, - FeatureLevel: api.ExtendedFeatures, - } - - err := stor.db.Insert(feed) - require.NoError(t, err) - - err = stor.Downgrade(feed.UserID, api.DefaultFeatures) - require.NoError(t, err) - - downgraded := &model.Feed{FeedID: feed.FeedID} - err = stor.db.Select(downgraded) - require.NoError(t, err) - - require.Equal(t, 50, downgraded.PageSize) - require.Equal(t, api.QualityHigh, downgraded.Quality) - require.Equal(t, api.FormatVideo, downgraded.Format) - require.Equal(t, api.DefaultFeatures, downgraded.FeatureLevel) -} - -func TestService_DowngradeToExtendedFeatures(t *testing.T) { - stor := createPG(t) - defer func() { _ = stor.Close() }() - - feed := &model.Feed{ - HashID: "123456", - UserID: "123456", - ItemID: "123456", - Provider: api.ProviderVimeo, - LinkType: api.LinkTypeGroup, - PageSize: 500, - Quality: api.QualityLow, - Format: api.FormatAudio, - FeatureLevel: api.ExtendedFeatures, - } - - err := stor.db.Insert(feed) - require.NoError(t, err) - - err = stor.Downgrade(feed.UserID, api.ExtendedFeatures) - require.NoError(t, err) - - downgraded := &model.Feed{FeedID: feed.FeedID} - err = stor.db.Select(downgraded) - require.NoError(t, err) - - require.Equal(t, 150, downgraded.PageSize) - require.Equal(t, feed.Quality, downgraded.Quality) - require.Equal(t, feed.Format, downgraded.Format) - require.Equal(t, api.ExtendedFeatures, downgraded.FeatureLevel) -} - -func TestPostgres_AddPledge(t *testing.T) { - stor := createPG(t) - defer func() { _ = stor.Close() }() - - err := stor.AddPledge(testPledge) - require.NoError(t, err) - - pledge := &model.Pledge{PledgeID: 12345} - err = stor.db.Select(pledge) - require.NoError(t, err) - - require.Equal(t, int64(12345), pledge.PledgeID) - require.Equal(t, 400, pledge.AmountCents) -} - -func TestPostgres_UpdatePledge(t *testing.T) { - stor := createPG(t) - defer func() { _ = stor.Close() }() - - err := stor.AddPledge(testPledge) - require.NoError(t, err) - - err = stor.UpdatePledge("1", &model.Pledge{AmountCents: 999}) - require.NoError(t, err) - - pledge := &model.Pledge{PledgeID: 12345} - err = stor.db.Select(pledge) - require.NoError(t, err) - require.Equal(t, 999, pledge.AmountCents) -} - -func TestPostgres_DeletePledge(t *testing.T) { - stor := createPG(t) - defer func() { _ = stor.Close() }() - - err := stor.AddPledge(testPledge) - require.NoError(t, err) - - err = stor.DeletePledge(testPledge) - require.NoError(t, err) - - err = stor.db.Select(&model.Pledge{PledgeID: 12345}) - require.Equal(t, pg.ErrNoRows, err) -} - -func TestPostgres_GetPledge(t *testing.T) { - stor := createPG(t) - defer func() { _ = stor.Close() }() - - err := stor.AddPledge(testPledge) - require.NoError(t, err) - - pledge, err := stor.GetPledge("1") - require.NoError(t, err) - require.Equal(t, 400, pledge.AmountCents) - require.Equal(t, int64(12345), pledge.PledgeID) +func TestPostgres(t *testing.T) { + runStorageTests(t, func(t *testing.T) storage { + return createPG(t) + }) } // docker run -it --rm -p 5432:5432 -e POSTGRES_DB=podsync postgres diff --git a/pkg/storage/storage_test.go b/pkg/storage/storage_test.go new file mode 100644 index 0000000..bf2e524 --- /dev/null +++ b/pkg/storage/storage_test.go @@ -0,0 +1,272 @@ +package storage + +import ( + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/mxpv/podsync/pkg/api" + "github.com/mxpv/podsync/pkg/model" +) + +type storage interface { + SaveFeed(feed *model.Feed) error + GetFeed(hashID string) (*model.Feed, error) + GetMetadata(hashID string) (*model.Feed, error) + Downgrade(userID string, featureLevel int) error + + // Patreon pledges + AddPledge(pledge *model.Pledge) error + UpdatePledge(patronID string, pledge *model.Pledge) error + DeletePledge(pledge *model.Pledge) error + GetPledge(patronID string) (*model.Pledge, error) + + Close() error +} + +var ( + testPledge = &model.Pledge{ + PledgeID: 12345, + AmountCents: 400, + PatronID: 1, + CreatedAt: time.Now().UTC(), + TotalHistoricalAmountCents: 100, + OutstandingPaymentAmountCents: 100, + IsPaused: true, + } + + testFeed = &model.Feed{ + FeedID: 1, + HashID: "3", + UserID: "4", + ItemID: "5", + LinkType: api.LinkTypeChannel, + Provider: api.ProviderVimeo, + Format: api.FormatAudio, + Quality: api.QualityLow, + PageSize: 150, + FeatureLevel: api.ExtendedFeatures, + CreatedAt: time.Now().UTC(), + LastAccess: time.Now().UTC(), + } +) + +func runStorageTests(t *testing.T, createFn func(t *testing.T) storage) { + // Feeds + t.Run("SaveFeed", makeTest(createFn, testSaveFeed)) + t.Run("LastAccess", makeTest(createFn, testLastAccess)) + t.Run("GetMetadata", makeTest(createFn, testGetMetadata)) + t.Run("Downgrade", func(t *testing.T) { + t.Run("DefaultFeatures", makeTest(createFn, testDowngradeToDefaultFeatures)) + t.Run("ExtendedFeatures", makeTest(createFn, testDowngradeToExtendedFeatures)) + }) + + // Pledge tests + t.Run("AddPledge", makeTest(createFn, testAddPledge)) + t.Run("GetPledge", makeTest(createFn, testGetPledge)) + t.Run("DeletePledge", makeTest(createFn, testDeletePledge)) + t.Run("UpdatePledge", makeTest(createFn, testUpdatePledge)) +} + +func makeTest(createFn func(t *testing.T) storage, testFn func(t *testing.T, storage storage)) func(t *testing.T) { + return func(t *testing.T) { + storage := createFn(t) + + testFn(t, storage) + + err := storage.Close() + require.Nil(t, err) + } +} + +func testSaveFeed(t *testing.T, storage storage) { + err := storage.SaveFeed(testFeed) + require.NoError(t, err) + + find, err := storage.GetFeed(testFeed.HashID) + require.NoError(t, err) + + require.Equal(t, testFeed.HashID, find.HashID) + require.Equal(t, testFeed.UserID, find.UserID) + require.Equal(t, testFeed.ItemID, find.ItemID) + require.Equal(t, testFeed.LinkType, find.LinkType) + require.Equal(t, testFeed.Provider, find.Provider) +} + +func testGetMetadata(t *testing.T, storage storage) { + err := storage.SaveFeed(testFeed) + require.NoError(t, err) + + find, err := storage.GetMetadata(testFeed.HashID) + require.NoError(t, err) + + require.Equal(t, testFeed.UserID, find.UserID) + require.Equal(t, testFeed.Provider, find.Provider) + require.Equal(t, testFeed.Quality, find.Quality) + require.Equal(t, testFeed.Format, find.Format) + + require.Equal(t, 0, find.PageSize) + require.Equal(t, time.Time{}.Unix(), find.CreatedAt.Unix()) + require.Equal(t, time.Time{}.Unix(), find.LastAccess.Unix()) + require.Equal(t, 0, find.FeatureLevel) +} + +func testDowngradeToDefaultFeatures(t *testing.T, storage storage) { + feed := &model.Feed{ + HashID: "123456", + UserID: "123456", + ItemID: "123456", + Provider: api.ProviderVimeo, + LinkType: api.LinkTypeGroup, + PageSize: 200, + Quality: api.QualityLow, + Format: api.FormatAudio, + FeatureLevel: api.ExtendedFeatures, + } + + err := storage.SaveFeed(feed) + require.NoError(t, err) + + err = storage.Downgrade(feed.UserID, api.DefaultFeatures) + require.NoError(t, err) + + downgraded, err := storage.GetFeed(feed.HashID) + require.NoError(t, err) + + require.Equal(t, 50, downgraded.PageSize) + require.Equal(t, api.QualityHigh, downgraded.Quality) + require.Equal(t, api.FormatVideo, downgraded.Format) + require.Equal(t, api.DefaultFeatures, downgraded.FeatureLevel) +} + +func testDowngradeToExtendedFeatures(t *testing.T, storage storage) { + feed := &model.Feed{ + HashID: "123456", + UserID: "123456", + ItemID: "123456", + Provider: api.ProviderVimeo, + LinkType: api.LinkTypeGroup, + PageSize: 500, + Quality: api.QualityLow, + Format: api.FormatAudio, + FeatureLevel: api.ExtendedFeatures, + } + + err := storage.SaveFeed(feed) + require.NoError(t, err) + + err = storage.Downgrade(feed.UserID, api.ExtendedFeatures) + require.NoError(t, err) + + downgraded, err := storage.GetFeed(feed.HashID) + require.NoError(t, err) + + require.Equal(t, 150, downgraded.PageSize) + require.Equal(t, feed.Quality, downgraded.Quality) + require.Equal(t, feed.Format, downgraded.Format) + require.Equal(t, api.ExtendedFeatures, downgraded.FeatureLevel) +} + +func testLastAccess(t *testing.T, storage storage) { + date := time.Now().AddDate(-1, 0, 0).UTC() + + feed := &model.Feed{ + FeedID: 1, + HashID: "3", + UserID: "4", + ItemID: "5", + LinkType: api.LinkTypeChannel, + Provider: api.ProviderVimeo, + Format: api.FormatAudio, + Quality: api.QualityLow, + PageSize: 150, + FeatureLevel: api.ExtendedFeatures, + CreatedAt: date, + LastAccess: date, + } + + err := storage.SaveFeed(feed) + require.NoError(t, err) + + result, err := storage.GetFeed(feed.HashID) + require.NoError(t, err) + + require.True(t, result.LastAccess.Sub(time.Now().UTC()) < 2*time.Second) +} + +func testAddPledge(t *testing.T, storage storage) { + err := storage.AddPledge(testPledge) + require.NoError(t, err) + + pledge, err := storage.GetPledge(strconv.FormatInt(testPledge.PatronID, 10)) + require.NoError(t, err) + + require.Equal(t, testPledge.PledgeID, pledge.PledgeID) + require.Equal(t, testPledge.PatronID, pledge.PatronID) + require.Equal(t, testPledge.CreatedAt.Unix(), pledge.CreatedAt.Unix()) + require.Equal(t, testPledge.DeclinedSince.Unix(), pledge.DeclinedSince.Unix()) + require.Equal(t, testPledge.AmountCents, pledge.AmountCents) + require.Equal(t, testPledge.TotalHistoricalAmountCents, pledge.TotalHistoricalAmountCents) + require.Equal(t, testPledge.OutstandingPaymentAmountCents, pledge.OutstandingPaymentAmountCents) + require.Equal(t, testPledge.IsPaused, pledge.IsPaused) +} + +func testGetPledge(t *testing.T, storage storage) { + err := storage.AddPledge(testPledge) + require.NoError(t, err) + + pledge, err := storage.GetPledge(strconv.FormatInt(testPledge.PatronID, 10)) + require.NoError(t, err) + + require.Equal(t, testPledge.PledgeID, pledge.PledgeID) + require.Equal(t, testPledge.PatronID, pledge.PatronID) + require.Equal(t, testPledge.CreatedAt.Unix(), pledge.CreatedAt.Unix()) + require.Equal(t, testPledge.DeclinedSince.Unix(), pledge.DeclinedSince.Unix()) + require.Equal(t, testPledge.AmountCents, pledge.AmountCents) + require.Equal(t, testPledge.TotalHistoricalAmountCents, pledge.TotalHistoricalAmountCents) + require.Equal(t, testPledge.OutstandingPaymentAmountCents, pledge.OutstandingPaymentAmountCents) + require.Equal(t, testPledge.IsPaused, pledge.IsPaused) +} + +func testDeletePledge(t *testing.T, storage storage) { + err := storage.AddPledge(testPledge) + require.NoError(t, err) + + err = storage.DeletePledge(testPledge) + require.NoError(t, err) + + pledge, err := storage.GetPledge(strconv.FormatInt(testPledge.PatronID, 10)) + require.Error(t, err) + require.Nil(t, pledge) +} + +func testUpdatePledge(t *testing.T, storage storage) { + err := storage.AddPledge(testPledge) + require.NoError(t, err) + + now := time.Now().UTC() + + err = storage.UpdatePledge(strconv.FormatInt(testPledge.PatronID, 10), &model.Pledge{ + DeclinedSince: now, + AmountCents: 400, + TotalHistoricalAmountCents: 800, + OutstandingPaymentAmountCents: 900, + IsPaused: true, + }) + + require.NoError(t, err) + + pledge, err := storage.GetPledge("1") + require.NoError(t, err) + + require.Equal(t, testPledge.PledgeID, pledge.PledgeID) + require.Equal(t, testPledge.PatronID, pledge.PatronID) + require.Equal(t, testPledge.CreatedAt.Unix(), pledge.CreatedAt.Unix()) + require.Equal(t, now.Unix(), pledge.DeclinedSince.Unix()) + require.Equal(t, 400, pledge.AmountCents) + require.Equal(t, 800, pledge.TotalHistoricalAmountCents) + require.Equal(t, 900, pledge.OutstandingPaymentAmountCents) + require.Equal(t, true, pledge.IsPaused) +}