支持多家云存储的云盘系统
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

340 lines
8.4 KiB

package inventory
import (
"context"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/ent/task"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
"github.com/gofrs/uuid"
"github.com/samber/lo"
)
type (
// Ctx keys for eager loading options.
LoadTaskUser struct{}
TaskArgs struct {
Status task.Status
Type string
PublicState *types.TaskPublicState
PrivateState string
OwnerID int
CorrelationID uuid.UUID
}
)
type TaskClient interface {
TxOperator
// New creates a new task with the given args.
New(ctx context.Context, task *TaskArgs) (*ent.Task, error)
// Update updates the task with the given args.
Update(ctx context.Context, task *ent.Task, args *TaskArgs) (*ent.Task, error)
// GetPendingTasks returns all pending tasks of given type.
GetPendingTasks(ctx context.Context, taskType ...string) ([]*ent.Task, error)
// GetTaskByID returns the task with the given ID.
GetTaskByID(ctx context.Context, taskID int) (*ent.Task, error)
// SetCompleteByID sets the task with the given ID to complete.
SetCompleteByID(ctx context.Context, taskID int) error
// List returns a list of tasks with the given args.
List(ctx context.Context, args *ListTaskArgs) (*ListTaskResult, error)
// DeleteByIDs deletes the tasks with the given IDs.
DeleteByIDs(ctx context.Context, ids ...int) error
// DeleteBy deletes the tasks with the given args.
DeleteBy(ctx context.Context, args *DeleteTaskArgs) error
}
type (
ListTaskArgs struct {
*PaginationArgs
Types []string
Status []task.Status
UserID int
CorrelationID *uuid.UUID
}
ListTaskResult struct {
*PaginationResults
Tasks []*ent.Task
}
DeleteTaskArgs struct {
NotAfter time.Time
Types []string
Status []task.Status
}
)
func NewTaskClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) TaskClient {
return &taskClient{client: client, maxSQlParam: sqlParamLimit(dbType), hasher: hasher}
}
type taskClient struct {
maxSQlParam int
hasher hashid.Encoder
client *ent.Client
}
func (c *taskClient) SetClient(newClient *ent.Client) TxOperator {
return &taskClient{client: newClient, maxSQlParam: c.maxSQlParam, hasher: c.hasher}
}
func (c *taskClient) GetClient() *ent.Client {
return c.client
}
func (c *taskClient) New(ctx context.Context, task *TaskArgs) (*ent.Task, error) {
stm := c.client.Task.
Create().
SetType(task.Type).
SetPublicState(task.PublicState)
if task.PrivateState != "" {
stm.SetPrivateState(task.PrivateState)
}
if task.OwnerID != 0 {
stm.SetUserID(task.OwnerID)
}
if task.Status != "" {
stm.SetStatus(task.Status)
}
if task.CorrelationID.String() != uuid.Nil.String() {
stm.SetCorrelationID(task.CorrelationID)
}
newTask, err := stm.Save(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create task: %w", err)
}
return newTask, nil
}
func (c *taskClient) DeleteByIDs(ctx context.Context, ids ...int) error {
_, err := c.client.Task.Delete().Where(task.IDIn(ids...)).Exec(ctx)
return err
}
func (c *taskClient) DeleteBy(ctx context.Context, args *DeleteTaskArgs) error {
query := c.client.Task.
Delete().
Where(task.CreatedAtLTE(args.NotAfter))
if len(args.Status) > 0 {
query.Where(task.StatusIn(args.Status...))
}
if len(args.Types) > 0 {
query.Where(task.TypeIn(args.Types...))
}
_, err := query.Exec(ctx)
return err
}
func (c *taskClient) Update(ctx context.Context, task *ent.Task, args *TaskArgs) (*ent.Task, error) {
stm := c.client.Task.UpdateOne(task).
SetPublicState(args.PublicState)
task.PublicState = args.PublicState
if task.PrivateState != "" {
stm.SetPrivateState(task.PrivateState)
task.PrivateState = args.PrivateState
}
if task.Status != "" {
stm.SetStatus(args.Status)
task.Status = args.Status
}
if err := stm.Exec(ctx); err != nil {
return nil, fmt.Errorf("failed to create task: %w", err)
}
return task, nil
}
func (c *taskClient) GetPendingTasks(ctx context.Context, taskType ...string) ([]*ent.Task, error) {
tasks, err := withTaskEagerLoading(ctx, c.client.Task.Query()).
Where(task.StatusIn(task.StatusProcessing, task.StatusQueued, task.StatusSuspending)).
Where(task.TypeIn(taskType...)).
All(ctx)
if err != nil {
return nil, err
}
// Anonymous user is not loaded by default, so we need to load it manually.
userClient := NewUserClient(c.client)
anonymous, err := userClient.AnonymousUser(ctx)
for _, t := range tasks {
if t.UserTasks == 0 {
if err != nil {
return nil, err
}
t.SetUser(anonymous)
}
}
return tasks, nil
}
func (c *taskClient) GetTaskByID(ctx context.Context, taskID int) (*ent.Task, error) {
return withTaskEagerLoading(ctx, c.client.Task.Query()).
Where(task.ID(taskID)).
First(ctx)
}
func (c *taskClient) SetCompleteByID(ctx context.Context, taskID int) error {
_, err := c.client.Task.UpdateOneID(taskID).
SetStatus(task.StatusCompleted).
Save(ctx)
return err
}
func (c *taskClient) List(ctx context.Context, args *ListTaskArgs) (*ListTaskResult, error) {
q := c.client.Task.Query()
if args.UserID != 0 {
q.Where(task.UserTasks(args.UserID))
}
if args.Types != nil {
q.Where(task.TypeIn(args.Types...))
}
if args.Status != nil {
q.Where(task.StatusIn(args.Status...))
}
if args.CorrelationID != nil {
q.Where(task.CorrelationID(*args.CorrelationID))
}
q = withTaskEagerLoading(ctx, q)
var (
tasks []*ent.Task
err error
paginationRes *PaginationResults
)
if args.UseCursorPagination {
tasks, paginationRes, err = c.cursorPagination(ctx, q, args, 1)
} else {
tasks, paginationRes, err = c.offsetPagination(ctx, q, args, 1)
}
if err != nil {
return nil, fmt.Errorf("query failed with paginiation: %w", err)
}
return &ListTaskResult{
Tasks: tasks,
PaginationResults: paginationRes,
}, nil
}
func (c *taskClient) cursorPagination(ctx context.Context, query *ent.TaskQuery, args *ListTaskArgs, paramMargin int) ([]*ent.Task, *PaginationResults, error) {
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
query.Order(task.ByID(sql.OrderDesc()))
var (
pageToken *PageToken
err error
queryPaged = query
)
if args.PageToken != "" {
pageToken, err = pageTokenFromString(args.PageToken, c.hasher, hashid.TaskID)
if err != nil {
return nil, nil, fmt.Errorf("invalid page token %q: %w", args.PageToken, err)
}
queryPaged = query.Where(task.IDLT(pageToken.ID))
}
// Use page size + 1 to determine if there are more items to come
queryPaged.Limit(pageSize + 1)
tasks, err := queryPaged.
All(ctx)
if err != nil {
return nil, nil, err
}
// More items to come
nextTokenStr := ""
if len(tasks) > pageSize {
lastItem := tasks[len(tasks)-2]
nextToken, err := getTaskNextPageToken(c.hasher, lastItem)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate next page token: %w", err)
}
nextTokenStr = nextToken
}
return lo.Subset(tasks, 0, uint(pageSize)), &PaginationResults{
PageSize: pageSize,
NextPageToken: nextTokenStr,
IsCursor: true,
}, nil
}
func (c *taskClient) offsetPagination(ctx context.Context, query *ent.TaskQuery, args *ListTaskArgs, paramMargin int) ([]*ent.Task, *PaginationResults, error) {
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
query.Order(getTaskOrderOption(args)...)
// Count total items
total, err := query.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
logs, err := query.
Limit(pageSize).
Offset(args.Page * args.PageSize).
All(ctx)
if err != nil {
return nil, nil, err
}
return logs, &PaginationResults{
PageSize: pageSize,
TotalItems: total,
Page: args.Page,
}, nil
}
func getTaskOrderOption(args *ListTaskArgs) []task.OrderOption {
orderTerm := getOrderTerm(args.Order)
switch args.OrderBy {
default:
return []task.OrderOption{task.ByID(orderTerm)}
}
}
// getTaskNextPageToken returns the next page token for the given last task.
func getTaskNextPageToken(hasher hashid.Encoder, last *ent.Task) (string, error) {
token := &PageToken{
ID: last.ID,
}
return token.Encode(hasher, hashid.EncodeTaskID)
}
func withTaskEagerLoading(ctx context.Context, q *ent.TaskQuery) *ent.TaskQuery {
if v, ok := ctx.Value(LoadTaskUser{}).(bool); ok && v {
q.WithUser(func(q *ent.UserQuery) {
withUserEagerLoading(ctx, q)
})
}
return q
}