You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
426 lines
12 KiB
426 lines
12 KiB
package inventory
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
|
"time"
|
|
|
|
"entgo.io/ent/dialect/sql"
|
|
"github.com/cloudreve/Cloudreve/v4/ent"
|
|
"github.com/cloudreve/Cloudreve/v4/ent/file"
|
|
"github.com/cloudreve/Cloudreve/v4/ent/predicate"
|
|
"github.com/cloudreve/Cloudreve/v4/ent/share"
|
|
"github.com/cloudreve/Cloudreve/v4/ent/user"
|
|
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
|
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
|
|
"github.com/samber/lo"
|
|
)
|
|
|
|
type (
|
|
// Ctx keys for eager loading options.
|
|
LoadShareFile struct{}
|
|
LoadShareUser struct{}
|
|
)
|
|
|
|
var (
|
|
ErrShareLinkExpired = fmt.Errorf("share link expired")
|
|
ErrOwnerInactive = fmt.Errorf("owner is inactive")
|
|
ErrSourceFileInvalid = fmt.Errorf("source file is deleted")
|
|
)
|
|
|
|
type (
|
|
ShareClient interface {
|
|
TxOperator
|
|
// GetByIDs returns the shares with given ids.
|
|
GetByIDs(ctx context.Context, ids []int) ([]*ent.Share, error)
|
|
// GetByID returns the share with given id.
|
|
GetByID(ctx context.Context, id int) (*ent.Share, error)
|
|
// GetByIDUser returns the share with given id and user id.
|
|
GetByIDUser(ctx context.Context, id, uid int) (*ent.Share, error)
|
|
// GetByHashID returns the share with given hash id.
|
|
GetByHashID(ctx context.Context, idRaw string) (*ent.Share, error)
|
|
// Upsert creates or update a new share record.
|
|
Upsert(ctx context.Context, params *CreateShareParams) (*ent.Share, error)
|
|
// Viewed increase the view count of the share.
|
|
Viewed(ctx context.Context, share *ent.Share) error
|
|
// Downloaded increase the download count of the share.
|
|
Downloaded(ctx context.Context, share *ent.Share) error
|
|
// Delete deletes the share.
|
|
Delete(ctx context.Context, shareId int) error
|
|
// List returns a list of shares with the given args.
|
|
List(ctx context.Context, args *ListShareArgs) (*ListShareResult, error)
|
|
// CountByTimeRange counts the number of shares created in the given time range.
|
|
CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error)
|
|
// DeleteBatch deletes the shares with the given ids.
|
|
DeleteBatch(ctx context.Context, shareIds []int) error
|
|
}
|
|
|
|
CreateShareParams struct {
|
|
Existed *ent.Share
|
|
Password string
|
|
RemainDownloads int
|
|
Expires *time.Time
|
|
OwnerID int
|
|
FileID int
|
|
Props *types.ShareProps
|
|
}
|
|
|
|
ListShareArgs struct {
|
|
*PaginationArgs
|
|
UserID int
|
|
FileID int
|
|
PublicOnly bool
|
|
ShareIDs []int
|
|
}
|
|
ListShareResult struct {
|
|
*PaginationResults
|
|
Shares []*ent.Share
|
|
}
|
|
)
|
|
|
|
func NewShareClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) ShareClient {
|
|
return &shareClient{
|
|
client: client,
|
|
hasher: hasher,
|
|
maxSQlParam: sqlParamLimit(dbType),
|
|
}
|
|
}
|
|
|
|
type shareClient struct {
|
|
maxSQlParam int
|
|
client *ent.Client
|
|
hasher hashid.Encoder
|
|
}
|
|
|
|
func (c *shareClient) SetClient(newClient *ent.Client) TxOperator {
|
|
return &shareClient{client: newClient, hasher: c.hasher, maxSQlParam: c.maxSQlParam}
|
|
}
|
|
|
|
func (c *shareClient) GetClient() *ent.Client {
|
|
return c.client
|
|
}
|
|
|
|
func (c *shareClient) CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error) {
|
|
if start == nil || end == nil {
|
|
return c.client.Share.Query().Count(ctx)
|
|
}
|
|
|
|
return c.client.Share.Query().Where(share.CreatedAtGTE(*start), share.CreatedAtLT(*end)).Count(ctx)
|
|
}
|
|
|
|
func (c *shareClient) Upsert(ctx context.Context, params *CreateShareParams) (*ent.Share, error) {
|
|
if params.Existed != nil {
|
|
createQuery := c.client.Share.
|
|
UpdateOne(params.Existed)
|
|
if params.RemainDownloads > 0 {
|
|
createQuery.SetRemainDownloads(params.RemainDownloads)
|
|
} else {
|
|
createQuery.ClearRemainDownloads()
|
|
}
|
|
if params.Expires != nil {
|
|
createQuery.SetNillableExpires(params.Expires)
|
|
} else {
|
|
createQuery.ClearExpires()
|
|
}
|
|
|
|
if params.Props != nil {
|
|
createQuery.SetProps(params.Props)
|
|
}
|
|
|
|
return createQuery.Save(ctx)
|
|
}
|
|
|
|
query := c.client.Share.
|
|
Create().
|
|
SetUserID(params.OwnerID).
|
|
SetFileID(params.FileID)
|
|
if params.Password != "" {
|
|
query.SetPassword(params.Password)
|
|
}
|
|
if params.RemainDownloads > 0 {
|
|
query.SetRemainDownloads(params.RemainDownloads)
|
|
}
|
|
if params.Expires != nil {
|
|
query.SetNillableExpires(params.Expires)
|
|
}
|
|
if params.Props != nil {
|
|
query.SetProps(params.Props)
|
|
}
|
|
|
|
return query.Save(ctx)
|
|
}
|
|
|
|
func (c *shareClient) GetByHashID(ctx context.Context, idRaw string) (*ent.Share, error) {
|
|
id, err := c.hasher.Decode(idRaw, hashid.ShareID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode hash id %q: %w", idRaw, err)
|
|
}
|
|
|
|
return c.GetByID(ctx, id)
|
|
}
|
|
|
|
func (c *shareClient) GetByID(ctx context.Context, id int) (*ent.Share, error) {
|
|
s, err := withShareEagerLoading(ctx, c.client.Share.Query().Where(share.ID(id))).First(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query share %d: %w", id, err)
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (c *shareClient) GetByIDUser(ctx context.Context, id, uid int) (*ent.Share, error) {
|
|
s, err := withShareEagerLoading(ctx, c.client.Share.Query().
|
|
Where(share.ID(id))).
|
|
Where(share.HasUserWith(user.ID(uid))).First(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query share %d: %w", id, err)
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (c *shareClient) GetByIDs(ctx context.Context, ids []int) ([]*ent.Share, error) {
|
|
s, err := withShareEagerLoading(ctx, c.client.Share.Query().Where(share.IDIn(ids...))).All(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query shares %v: %w", ids, err)
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (c *shareClient) DeleteBatch(ctx context.Context, shareIds []int) error {
|
|
_, err := c.client.Share.Delete().Where(share.IDIn(shareIds...)).Exec(ctx)
|
|
return err
|
|
}
|
|
|
|
func (c *shareClient) Delete(ctx context.Context, shareId int) error {
|
|
return c.client.Share.DeleteOneID(shareId).Exec(ctx)
|
|
}
|
|
|
|
// Viewed increments the view count of the share.
|
|
func (c *shareClient) Viewed(ctx context.Context, share *ent.Share) error {
|
|
_, err := c.client.Share.UpdateOneID(share.ID).AddViews(1).Save(ctx)
|
|
return err
|
|
}
|
|
|
|
// Downloaded increments the download count of the share.
|
|
func (c *shareClient) Downloaded(ctx context.Context, share *ent.Share) error {
|
|
stm := c.client.Share.
|
|
UpdateOneID(share.ID).
|
|
AddDownloads(1)
|
|
if share.RemainDownloads != nil && *share.RemainDownloads >= 0 {
|
|
stm.AddRemainDownloads(-1)
|
|
}
|
|
_, err := stm.Save(ctx)
|
|
return err
|
|
}
|
|
|
|
func IsValidShare(share *ent.Share) error {
|
|
// Check if share is expired
|
|
if err := IsShareExpired(share); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Check owner status
|
|
owner, err := share.Edges.UserOrErr()
|
|
if err != nil || owner.Status != user.StatusActive {
|
|
// Owner already deleted, or not active.
|
|
return ErrOwnerInactive
|
|
}
|
|
|
|
// Check source file status
|
|
file, err := share.Edges.FileOrErr()
|
|
if err != nil || file.FileChildren == 0 || file.OwnerID != owner.ID {
|
|
// Source file already deleted
|
|
return ErrSourceFileInvalid
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func IsShareExpired(share *ent.Share) error {
|
|
// Check if share is expired
|
|
if (share.Expires != nil && share.Expires.Before(time.Now())) ||
|
|
(share.RemainDownloads != nil && *share.RemainDownloads <= 0) {
|
|
return ErrShareLinkExpired
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *shareClient) List(ctx context.Context, args *ListShareArgs) (*ListShareResult, error) {
|
|
rawQuery := c.listQuery(args)
|
|
query := withShareEagerLoading(ctx, rawQuery)
|
|
|
|
var (
|
|
shares []*ent.Share
|
|
err error
|
|
paginationRes *PaginationResults
|
|
)
|
|
if args.UseCursorPagination {
|
|
shares, paginationRes, err = c.cursorPagination(ctx, query, args, 10)
|
|
} else {
|
|
shares, paginationRes, err = c.offsetPagination(ctx, query, args, 10)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query failed with paginiation: %w", err)
|
|
}
|
|
|
|
return &ListShareResult{
|
|
Shares: shares,
|
|
PaginationResults: paginationRes,
|
|
}, nil
|
|
}
|
|
|
|
func (c *shareClient) cursorPagination(ctx context.Context, query *ent.ShareQuery, args *ListShareArgs, paramMargin int) ([]*ent.Share, *PaginationResults, error) {
|
|
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
|
|
query.Order(getShareOrderOption(args)...)
|
|
|
|
var (
|
|
pageToken *PageToken
|
|
err error
|
|
)
|
|
if args.PageToken != "" {
|
|
pageToken, err = pageTokenFromString(args.PageToken, c.hasher, hashid.ShareID)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("invalid page token %q: %w", args.PageToken, err)
|
|
}
|
|
}
|
|
queryPaged := getShareCursorQuery(args, pageToken, query)
|
|
|
|
// Use page size + 1 to determine if there are more items to come
|
|
queryPaged.Limit(pageSize + 1)
|
|
|
|
logs, err := queryPaged.
|
|
All(ctx)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// More items to come
|
|
nextTokenStr := ""
|
|
if len(logs) > pageSize {
|
|
lastItem := logs[len(logs)-2]
|
|
nextToken, err := getShareNextPageToken(c.hasher, lastItem, args)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to generate next page token: %w", err)
|
|
}
|
|
|
|
nextTokenStr = nextToken
|
|
}
|
|
|
|
return lo.Subset(logs, 0, uint(pageSize)), &PaginationResults{
|
|
PageSize: pageSize,
|
|
NextPageToken: nextTokenStr,
|
|
IsCursor: true,
|
|
}, nil
|
|
}
|
|
|
|
func (c *shareClient) offsetPagination(ctx context.Context, query *ent.ShareQuery, args *ListShareArgs, paramMargin int) ([]*ent.Share, *PaginationResults, error) {
|
|
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
|
|
query.Order(getShareOrderOption(args)...)
|
|
|
|
total, err := query.Clone().Count(ctx)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
logs, err := query.Limit(pageSize).Offset(args.Page * args.PageSize).All(ctx)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
return logs, &PaginationResults{
|
|
PageSize: pageSize,
|
|
TotalItems: total,
|
|
Page: args.Page,
|
|
}, nil
|
|
}
|
|
|
|
func (c *shareClient) listQuery(args *ListShareArgs) *ent.ShareQuery {
|
|
query := c.client.Share.Query()
|
|
if args.UserID > 0 {
|
|
query.Where(share.HasUserWith(user.ID(args.UserID)))
|
|
}
|
|
|
|
if args.PublicOnly {
|
|
query.Where(share.PasswordIsNil())
|
|
}
|
|
|
|
if args.FileID > 0 {
|
|
query.Where(share.HasFileWith(file.ID(args.FileID)))
|
|
}
|
|
|
|
if len(args.ShareIDs) > 0 {
|
|
query.Where(share.IDIn(args.ShareIDs...))
|
|
}
|
|
|
|
return query
|
|
}
|
|
|
|
// getShareNextPageToken returns the next page token for the given last share.
|
|
func getShareNextPageToken(hasher hashid.Encoder, last *ent.Share, args *ListShareArgs) (string, error) {
|
|
token := &PageToken{
|
|
ID: last.ID,
|
|
}
|
|
|
|
return token.Encode(hasher, hashid.EncodeShareID)
|
|
}
|
|
|
|
func getShareCursorQuery(args *ListShareArgs, token *PageToken, query *ent.ShareQuery) *ent.ShareQuery {
|
|
o := &sql.OrderTermOptions{}
|
|
getOrderTerm(args.Order)(o)
|
|
|
|
predicates, ok := shareCursorQuery[args.OrderBy]
|
|
if !ok {
|
|
predicates = shareCursorQuery[share.FieldID]
|
|
}
|
|
|
|
if token != nil {
|
|
query.Where(predicates[o.Desc](token))
|
|
}
|
|
|
|
return query
|
|
}
|
|
|
|
var shareCursorQuery = map[string]map[bool]func(token *PageToken) predicate.Share{
|
|
share.FieldID: {
|
|
true: func(token *PageToken) predicate.Share {
|
|
return share.IDLT(token.ID)
|
|
},
|
|
false: func(token *PageToken) predicate.Share {
|
|
return share.IDGT(token.ID)
|
|
},
|
|
},
|
|
}
|
|
|
|
func getShareOrderOption(args *ListShareArgs) []share.OrderOption {
|
|
orderTerm := getOrderTerm(args.Order)
|
|
switch args.OrderBy {
|
|
case share.FieldViews:
|
|
return []share.OrderOption{share.ByViews(orderTerm), share.ByID(orderTerm)}
|
|
case share.FieldDownloads:
|
|
return []share.OrderOption{share.ByDownloads(orderTerm), share.ByID(orderTerm)}
|
|
case share.FieldRemainDownloads:
|
|
return []share.OrderOption{share.ByRemainDownloads(orderTerm), share.ByID(orderTerm)}
|
|
default:
|
|
return []share.OrderOption{share.ByID(orderTerm)}
|
|
}
|
|
}
|
|
|
|
func withShareEagerLoading(ctx context.Context, q *ent.ShareQuery) *ent.ShareQuery {
|
|
if v, ok := ctx.Value(LoadShareFile{}).(bool); ok && v {
|
|
q.WithFile(func(q *ent.FileQuery) {
|
|
withFileEagerLoading(ctx, q)
|
|
})
|
|
}
|
|
if v, ok := ctx.Value(LoadShareUser{}).(bool); ok && v {
|
|
q.WithUser(func(q *ent.UserQuery) {
|
|
withUserEagerLoading(ctx, q)
|
|
})
|
|
}
|
|
|
|
return q
|
|
}
|