支持多家云存储的云盘系统
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

426 lines
12 KiB

package inventory
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"time"
"entgo.io/ent/dialect/sql"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/ent/file"
"github.com/cloudreve/Cloudreve/v4/ent/predicate"
"github.com/cloudreve/Cloudreve/v4/ent/share"
"github.com/cloudreve/Cloudreve/v4/ent/user"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
"github.com/samber/lo"
)
type (
// Ctx keys for eager loading options.
LoadShareFile struct{}
LoadShareUser struct{}
)
var (
ErrShareLinkExpired = fmt.Errorf("share link expired")
ErrOwnerInactive = fmt.Errorf("owner is inactive")
ErrSourceFileInvalid = fmt.Errorf("source file is deleted")
)
type (
ShareClient interface {
TxOperator
// GetByIDs returns the shares with given ids.
GetByIDs(ctx context.Context, ids []int) ([]*ent.Share, error)
// GetByID returns the share with given id.
GetByID(ctx context.Context, id int) (*ent.Share, error)
// GetByIDUser returns the share with given id and user id.
GetByIDUser(ctx context.Context, id, uid int) (*ent.Share, error)
// GetByHashID returns the share with given hash id.
GetByHashID(ctx context.Context, idRaw string) (*ent.Share, error)
// Upsert creates or update a new share record.
Upsert(ctx context.Context, params *CreateShareParams) (*ent.Share, error)
// Viewed increase the view count of the share.
Viewed(ctx context.Context, share *ent.Share) error
// Downloaded increase the download count of the share.
Downloaded(ctx context.Context, share *ent.Share) error
// Delete deletes the share.
Delete(ctx context.Context, shareId int) error
// List returns a list of shares with the given args.
List(ctx context.Context, args *ListShareArgs) (*ListShareResult, error)
// CountByTimeRange counts the number of shares created in the given time range.
CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error)
// DeleteBatch deletes the shares with the given ids.
DeleteBatch(ctx context.Context, shareIds []int) error
}
CreateShareParams struct {
Existed *ent.Share
Password string
RemainDownloads int
Expires *time.Time
OwnerID int
FileID int
Props *types.ShareProps
}
ListShareArgs struct {
*PaginationArgs
UserID int
FileID int
PublicOnly bool
ShareIDs []int
}
ListShareResult struct {
*PaginationResults
Shares []*ent.Share
}
)
func NewShareClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) ShareClient {
return &shareClient{
client: client,
hasher: hasher,
maxSQlParam: sqlParamLimit(dbType),
}
}
type shareClient struct {
maxSQlParam int
client *ent.Client
hasher hashid.Encoder
}
func (c *shareClient) SetClient(newClient *ent.Client) TxOperator {
return &shareClient{client: newClient, hasher: c.hasher, maxSQlParam: c.maxSQlParam}
}
func (c *shareClient) GetClient() *ent.Client {
return c.client
}
func (c *shareClient) CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error) {
if start == nil || end == nil {
return c.client.Share.Query().Count(ctx)
}
return c.client.Share.Query().Where(share.CreatedAtGTE(*start), share.CreatedAtLT(*end)).Count(ctx)
}
func (c *shareClient) Upsert(ctx context.Context, params *CreateShareParams) (*ent.Share, error) {
if params.Existed != nil {
createQuery := c.client.Share.
UpdateOne(params.Existed)
if params.RemainDownloads > 0 {
createQuery.SetRemainDownloads(params.RemainDownloads)
} else {
createQuery.ClearRemainDownloads()
}
if params.Expires != nil {
createQuery.SetNillableExpires(params.Expires)
} else {
createQuery.ClearExpires()
}
if params.Props != nil {
createQuery.SetProps(params.Props)
}
return createQuery.Save(ctx)
}
query := c.client.Share.
Create().
SetUserID(params.OwnerID).
SetFileID(params.FileID)
if params.Password != "" {
query.SetPassword(params.Password)
}
if params.RemainDownloads > 0 {
query.SetRemainDownloads(params.RemainDownloads)
}
if params.Expires != nil {
query.SetNillableExpires(params.Expires)
}
if params.Props != nil {
query.SetProps(params.Props)
}
return query.Save(ctx)
}
func (c *shareClient) GetByHashID(ctx context.Context, idRaw string) (*ent.Share, error) {
id, err := c.hasher.Decode(idRaw, hashid.ShareID)
if err != nil {
return nil, fmt.Errorf("failed to decode hash id %q: %w", idRaw, err)
}
return c.GetByID(ctx, id)
}
func (c *shareClient) GetByID(ctx context.Context, id int) (*ent.Share, error) {
s, err := withShareEagerLoading(ctx, c.client.Share.Query().Where(share.ID(id))).First(ctx)
if err != nil {
return nil, fmt.Errorf("failed to query share %d: %w", id, err)
}
return s, nil
}
func (c *shareClient) GetByIDUser(ctx context.Context, id, uid int) (*ent.Share, error) {
s, err := withShareEagerLoading(ctx, c.client.Share.Query().
Where(share.ID(id))).
Where(share.HasUserWith(user.ID(uid))).First(ctx)
if err != nil {
return nil, fmt.Errorf("failed to query share %d: %w", id, err)
}
return s, nil
}
func (c *shareClient) GetByIDs(ctx context.Context, ids []int) ([]*ent.Share, error) {
s, err := withShareEagerLoading(ctx, c.client.Share.Query().Where(share.IDIn(ids...))).All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to query shares %v: %w", ids, err)
}
return s, nil
}
func (c *shareClient) DeleteBatch(ctx context.Context, shareIds []int) error {
_, err := c.client.Share.Delete().Where(share.IDIn(shareIds...)).Exec(ctx)
return err
}
func (c *shareClient) Delete(ctx context.Context, shareId int) error {
return c.client.Share.DeleteOneID(shareId).Exec(ctx)
}
// Viewed increments the view count of the share.
func (c *shareClient) Viewed(ctx context.Context, share *ent.Share) error {
_, err := c.client.Share.UpdateOneID(share.ID).AddViews(1).Save(ctx)
return err
}
// Downloaded increments the download count of the share.
func (c *shareClient) Downloaded(ctx context.Context, share *ent.Share) error {
stm := c.client.Share.
UpdateOneID(share.ID).
AddDownloads(1)
if share.RemainDownloads != nil && *share.RemainDownloads >= 0 {
stm.AddRemainDownloads(-1)
}
_, err := stm.Save(ctx)
return err
}
func IsValidShare(share *ent.Share) error {
// Check if share is expired
if err := IsShareExpired(share); err != nil {
return err
}
// Check owner status
owner, err := share.Edges.UserOrErr()
if err != nil || owner.Status != user.StatusActive {
// Owner already deleted, or not active.
return ErrOwnerInactive
}
// Check source file status
file, err := share.Edges.FileOrErr()
if err != nil || file.FileChildren == 0 || file.OwnerID != owner.ID {
// Source file already deleted
return ErrSourceFileInvalid
}
return nil
}
func IsShareExpired(share *ent.Share) error {
// Check if share is expired
if (share.Expires != nil && share.Expires.Before(time.Now())) ||
(share.RemainDownloads != nil && *share.RemainDownloads <= 0) {
return ErrShareLinkExpired
}
return nil
}
func (c *shareClient) List(ctx context.Context, args *ListShareArgs) (*ListShareResult, error) {
rawQuery := c.listQuery(args)
query := withShareEagerLoading(ctx, rawQuery)
var (
shares []*ent.Share
err error
paginationRes *PaginationResults
)
if args.UseCursorPagination {
shares, paginationRes, err = c.cursorPagination(ctx, query, args, 10)
} else {
shares, paginationRes, err = c.offsetPagination(ctx, query, args, 10)
}
if err != nil {
return nil, fmt.Errorf("query failed with paginiation: %w", err)
}
return &ListShareResult{
Shares: shares,
PaginationResults: paginationRes,
}, nil
}
func (c *shareClient) cursorPagination(ctx context.Context, query *ent.ShareQuery, args *ListShareArgs, paramMargin int) ([]*ent.Share, *PaginationResults, error) {
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
query.Order(getShareOrderOption(args)...)
var (
pageToken *PageToken
err error
)
if args.PageToken != "" {
pageToken, err = pageTokenFromString(args.PageToken, c.hasher, hashid.ShareID)
if err != nil {
return nil, nil, fmt.Errorf("invalid page token %q: %w", args.PageToken, err)
}
}
queryPaged := getShareCursorQuery(args, pageToken, query)
// Use page size + 1 to determine if there are more items to come
queryPaged.Limit(pageSize + 1)
logs, err := queryPaged.
All(ctx)
if err != nil {
return nil, nil, err
}
// More items to come
nextTokenStr := ""
if len(logs) > pageSize {
lastItem := logs[len(logs)-2]
nextToken, err := getShareNextPageToken(c.hasher, lastItem, args)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate next page token: %w", err)
}
nextTokenStr = nextToken
}
return lo.Subset(logs, 0, uint(pageSize)), &PaginationResults{
PageSize: pageSize,
NextPageToken: nextTokenStr,
IsCursor: true,
}, nil
}
func (c *shareClient) offsetPagination(ctx context.Context, query *ent.ShareQuery, args *ListShareArgs, paramMargin int) ([]*ent.Share, *PaginationResults, error) {
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
query.Order(getShareOrderOption(args)...)
total, err := query.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
logs, err := query.Limit(pageSize).Offset(args.Page * args.PageSize).All(ctx)
if err != nil {
return nil, nil, err
}
return logs, &PaginationResults{
PageSize: pageSize,
TotalItems: total,
Page: args.Page,
}, nil
}
func (c *shareClient) listQuery(args *ListShareArgs) *ent.ShareQuery {
query := c.client.Share.Query()
if args.UserID > 0 {
query.Where(share.HasUserWith(user.ID(args.UserID)))
}
if args.PublicOnly {
query.Where(share.PasswordIsNil())
}
if args.FileID > 0 {
query.Where(share.HasFileWith(file.ID(args.FileID)))
}
if len(args.ShareIDs) > 0 {
query.Where(share.IDIn(args.ShareIDs...))
}
return query
}
// getShareNextPageToken returns the next page token for the given last share.
func getShareNextPageToken(hasher hashid.Encoder, last *ent.Share, args *ListShareArgs) (string, error) {
token := &PageToken{
ID: last.ID,
}
return token.Encode(hasher, hashid.EncodeShareID)
}
func getShareCursorQuery(args *ListShareArgs, token *PageToken, query *ent.ShareQuery) *ent.ShareQuery {
o := &sql.OrderTermOptions{}
getOrderTerm(args.Order)(o)
predicates, ok := shareCursorQuery[args.OrderBy]
if !ok {
predicates = shareCursorQuery[share.FieldID]
}
if token != nil {
query.Where(predicates[o.Desc](token))
}
return query
}
var shareCursorQuery = map[string]map[bool]func(token *PageToken) predicate.Share{
share.FieldID: {
true: func(token *PageToken) predicate.Share {
return share.IDLT(token.ID)
},
false: func(token *PageToken) predicate.Share {
return share.IDGT(token.ID)
},
},
}
func getShareOrderOption(args *ListShareArgs) []share.OrderOption {
orderTerm := getOrderTerm(args.Order)
switch args.OrderBy {
case share.FieldViews:
return []share.OrderOption{share.ByViews(orderTerm), share.ByID(orderTerm)}
case share.FieldDownloads:
return []share.OrderOption{share.ByDownloads(orderTerm), share.ByID(orderTerm)}
case share.FieldRemainDownloads:
return []share.OrderOption{share.ByRemainDownloads(orderTerm), share.ByID(orderTerm)}
default:
return []share.OrderOption{share.ByID(orderTerm)}
}
}
func withShareEagerLoading(ctx context.Context, q *ent.ShareQuery) *ent.ShareQuery {
if v, ok := ctx.Value(LoadShareFile{}).(bool); ok && v {
q.WithFile(func(q *ent.FileQuery) {
withFileEagerLoading(ctx, q)
})
}
if v, ok := ctx.Value(LoadShareUser{}).(bool); ok && v {
q.WithUser(func(q *ent.UserQuery) {
withUserEagerLoading(ctx, q)
})
}
return q
}