支持多家云存储的云盘系统
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

340 lines
8.4 KiB

  1. package inventory
  2. import (
  3. "context"
  4. "fmt"
  5. "time"
  6. "entgo.io/ent/dialect/sql"
  7. "github.com/cloudreve/Cloudreve/v4/ent"
  8. "github.com/cloudreve/Cloudreve/v4/ent/task"
  9. "github.com/cloudreve/Cloudreve/v4/inventory/types"
  10. "github.com/cloudreve/Cloudreve/v4/pkg/conf"
  11. "github.com/cloudreve/Cloudreve/v4/pkg/hashid"
  12. "github.com/gofrs/uuid"
  13. "github.com/samber/lo"
  14. )
  15. type (
  16. // Ctx keys for eager loading options.
  17. LoadTaskUser struct{}
  18. TaskArgs struct {
  19. Status task.Status
  20. Type string
  21. PublicState *types.TaskPublicState
  22. PrivateState string
  23. OwnerID int
  24. CorrelationID uuid.UUID
  25. }
  26. )
  27. type TaskClient interface {
  28. TxOperator
  29. // New creates a new task with the given args.
  30. New(ctx context.Context, task *TaskArgs) (*ent.Task, error)
  31. // Update updates the task with the given args.
  32. Update(ctx context.Context, task *ent.Task, args *TaskArgs) (*ent.Task, error)
  33. // GetPendingTasks returns all pending tasks of given type.
  34. GetPendingTasks(ctx context.Context, taskType ...string) ([]*ent.Task, error)
  35. // GetTaskByID returns the task with the given ID.
  36. GetTaskByID(ctx context.Context, taskID int) (*ent.Task, error)
  37. // SetCompleteByID sets the task with the given ID to complete.
  38. SetCompleteByID(ctx context.Context, taskID int) error
  39. // List returns a list of tasks with the given args.
  40. List(ctx context.Context, args *ListTaskArgs) (*ListTaskResult, error)
  41. // DeleteByIDs deletes the tasks with the given IDs.
  42. DeleteByIDs(ctx context.Context, ids ...int) error
  43. // DeleteBy deletes the tasks with the given args.
  44. DeleteBy(ctx context.Context, args *DeleteTaskArgs) error
  45. }
  46. type (
  47. ListTaskArgs struct {
  48. *PaginationArgs
  49. Types []string
  50. Status []task.Status
  51. UserID int
  52. CorrelationID *uuid.UUID
  53. }
  54. ListTaskResult struct {
  55. *PaginationResults
  56. Tasks []*ent.Task
  57. }
  58. DeleteTaskArgs struct {
  59. NotAfter time.Time
  60. Types []string
  61. Status []task.Status
  62. }
  63. )
  64. func NewTaskClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) TaskClient {
  65. return &taskClient{client: client, maxSQlParam: sqlParamLimit(dbType), hasher: hasher}
  66. }
  67. type taskClient struct {
  68. maxSQlParam int
  69. hasher hashid.Encoder
  70. client *ent.Client
  71. }
  72. func (c *taskClient) SetClient(newClient *ent.Client) TxOperator {
  73. return &taskClient{client: newClient, maxSQlParam: c.maxSQlParam, hasher: c.hasher}
  74. }
  75. func (c *taskClient) GetClient() *ent.Client {
  76. return c.client
  77. }
  78. func (c *taskClient) New(ctx context.Context, task *TaskArgs) (*ent.Task, error) {
  79. stm := c.client.Task.
  80. Create().
  81. SetType(task.Type).
  82. SetPublicState(task.PublicState)
  83. if task.PrivateState != "" {
  84. stm.SetPrivateState(task.PrivateState)
  85. }
  86. if task.OwnerID != 0 {
  87. stm.SetUserID(task.OwnerID)
  88. }
  89. if task.Status != "" {
  90. stm.SetStatus(task.Status)
  91. }
  92. if task.CorrelationID.String() != uuid.Nil.String() {
  93. stm.SetCorrelationID(task.CorrelationID)
  94. }
  95. newTask, err := stm.Save(ctx)
  96. if err != nil {
  97. return nil, fmt.Errorf("failed to create task: %w", err)
  98. }
  99. return newTask, nil
  100. }
  101. func (c *taskClient) DeleteByIDs(ctx context.Context, ids ...int) error {
  102. _, err := c.client.Task.Delete().Where(task.IDIn(ids...)).Exec(ctx)
  103. return err
  104. }
  105. func (c *taskClient) DeleteBy(ctx context.Context, args *DeleteTaskArgs) error {
  106. query := c.client.Task.
  107. Delete().
  108. Where(task.CreatedAtLTE(args.NotAfter))
  109. if len(args.Status) > 0 {
  110. query.Where(task.StatusIn(args.Status...))
  111. }
  112. if len(args.Types) > 0 {
  113. query.Where(task.TypeIn(args.Types...))
  114. }
  115. _, err := query.Exec(ctx)
  116. return err
  117. }
  118. func (c *taskClient) Update(ctx context.Context, task *ent.Task, args *TaskArgs) (*ent.Task, error) {
  119. stm := c.client.Task.UpdateOne(task).
  120. SetPublicState(args.PublicState)
  121. task.PublicState = args.PublicState
  122. if task.PrivateState != "" {
  123. stm.SetPrivateState(task.PrivateState)
  124. task.PrivateState = args.PrivateState
  125. }
  126. if task.Status != "" {
  127. stm.SetStatus(args.Status)
  128. task.Status = args.Status
  129. }
  130. if err := stm.Exec(ctx); err != nil {
  131. return nil, fmt.Errorf("failed to create task: %w", err)
  132. }
  133. return task, nil
  134. }
  135. func (c *taskClient) GetPendingTasks(ctx context.Context, taskType ...string) ([]*ent.Task, error) {
  136. tasks, err := withTaskEagerLoading(ctx, c.client.Task.Query()).
  137. Where(task.StatusIn(task.StatusProcessing, task.StatusQueued, task.StatusSuspending)).
  138. Where(task.TypeIn(taskType...)).
  139. All(ctx)
  140. if err != nil {
  141. return nil, err
  142. }
  143. // Anonymous user is not loaded by default, so we need to load it manually.
  144. userClient := NewUserClient(c.client)
  145. anonymous, err := userClient.AnonymousUser(ctx)
  146. for _, t := range tasks {
  147. if t.UserTasks == 0 {
  148. if err != nil {
  149. return nil, err
  150. }
  151. t.SetUser(anonymous)
  152. }
  153. }
  154. return tasks, nil
  155. }
  156. func (c *taskClient) GetTaskByID(ctx context.Context, taskID int) (*ent.Task, error) {
  157. return withTaskEagerLoading(ctx, c.client.Task.Query()).
  158. Where(task.ID(taskID)).
  159. First(ctx)
  160. }
  161. func (c *taskClient) SetCompleteByID(ctx context.Context, taskID int) error {
  162. _, err := c.client.Task.UpdateOneID(taskID).
  163. SetStatus(task.StatusCompleted).
  164. Save(ctx)
  165. return err
  166. }
  167. func (c *taskClient) List(ctx context.Context, args *ListTaskArgs) (*ListTaskResult, error) {
  168. q := c.client.Task.Query()
  169. if args.UserID != 0 {
  170. q.Where(task.UserTasks(args.UserID))
  171. }
  172. if args.Types != nil {
  173. q.Where(task.TypeIn(args.Types...))
  174. }
  175. if args.Status != nil {
  176. q.Where(task.StatusIn(args.Status...))
  177. }
  178. if args.CorrelationID != nil {
  179. q.Where(task.CorrelationID(*args.CorrelationID))
  180. }
  181. q = withTaskEagerLoading(ctx, q)
  182. var (
  183. tasks []*ent.Task
  184. err error
  185. paginationRes *PaginationResults
  186. )
  187. if args.UseCursorPagination {
  188. tasks, paginationRes, err = c.cursorPagination(ctx, q, args, 1)
  189. } else {
  190. tasks, paginationRes, err = c.offsetPagination(ctx, q, args, 1)
  191. }
  192. if err != nil {
  193. return nil, fmt.Errorf("query failed with paginiation: %w", err)
  194. }
  195. return &ListTaskResult{
  196. Tasks: tasks,
  197. PaginationResults: paginationRes,
  198. }, nil
  199. }
  200. func (c *taskClient) cursorPagination(ctx context.Context, query *ent.TaskQuery, args *ListTaskArgs, paramMargin int) ([]*ent.Task, *PaginationResults, error) {
  201. pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
  202. query.Order(task.ByID(sql.OrderDesc()))
  203. var (
  204. pageToken *PageToken
  205. err error
  206. queryPaged = query
  207. )
  208. if args.PageToken != "" {
  209. pageToken, err = pageTokenFromString(args.PageToken, c.hasher, hashid.TaskID)
  210. if err != nil {
  211. return nil, nil, fmt.Errorf("invalid page token %q: %w", args.PageToken, err)
  212. }
  213. queryPaged = query.Where(task.IDLT(pageToken.ID))
  214. }
  215. // Use page size + 1 to determine if there are more items to come
  216. queryPaged.Limit(pageSize + 1)
  217. tasks, err := queryPaged.
  218. All(ctx)
  219. if err != nil {
  220. return nil, nil, err
  221. }
  222. // More items to come
  223. nextTokenStr := ""
  224. if len(tasks) > pageSize {
  225. lastItem := tasks[len(tasks)-2]
  226. nextToken, err := getTaskNextPageToken(c.hasher, lastItem)
  227. if err != nil {
  228. return nil, nil, fmt.Errorf("failed to generate next page token: %w", err)
  229. }
  230. nextTokenStr = nextToken
  231. }
  232. return lo.Subset(tasks, 0, uint(pageSize)), &PaginationResults{
  233. PageSize: pageSize,
  234. NextPageToken: nextTokenStr,
  235. IsCursor: true,
  236. }, nil
  237. }
  238. func (c *taskClient) offsetPagination(ctx context.Context, query *ent.TaskQuery, args *ListTaskArgs, paramMargin int) ([]*ent.Task, *PaginationResults, error) {
  239. pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
  240. query.Order(getTaskOrderOption(args)...)
  241. // Count total items
  242. total, err := query.Clone().Count(ctx)
  243. if err != nil {
  244. return nil, nil, err
  245. }
  246. logs, err := query.
  247. Limit(pageSize).
  248. Offset(args.Page * args.PageSize).
  249. All(ctx)
  250. if err != nil {
  251. return nil, nil, err
  252. }
  253. return logs, &PaginationResults{
  254. PageSize: pageSize,
  255. TotalItems: total,
  256. Page: args.Page,
  257. }, nil
  258. }
  259. func getTaskOrderOption(args *ListTaskArgs) []task.OrderOption {
  260. orderTerm := getOrderTerm(args.Order)
  261. switch args.OrderBy {
  262. default:
  263. return []task.OrderOption{task.ByID(orderTerm)}
  264. }
  265. }
  266. // getTaskNextPageToken returns the next page token for the given last task.
  267. func getTaskNextPageToken(hasher hashid.Encoder, last *ent.Task) (string, error) {
  268. token := &PageToken{
  269. ID: last.ID,
  270. }
  271. return token.Encode(hasher, hashid.EncodeTaskID)
  272. }
  273. func withTaskEagerLoading(ctx context.Context, q *ent.TaskQuery) *ent.TaskQuery {
  274. if v, ok := ctx.Value(LoadTaskUser{}).(bool); ok && v {
  275. q.WithUser(func(q *ent.UserQuery) {
  276. withUserEagerLoading(ctx, q)
  277. })
  278. }
  279. return q
  280. }