支持多家云存储的云盘系统
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

101 lines
2.1 KiB

package inventory
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
)
type TxOperator interface {
SetClient(newClient *ent.Client) TxOperator
GetClient() *ent.Client
}
type (
Tx struct {
tx *ent.Tx
parent *Tx
inherited bool
finished bool
storageDiff StorageDiff
}
// TxCtx is the context key for inherited transaction
TxCtx struct{}
)
// AppendStorageDiff appends the given storage diff to the transaction.
func (t *Tx) AppendStorageDiff(diff StorageDiff) {
root := t
for root.inherited {
root = root.parent
}
if root.storageDiff == nil {
root.storageDiff = diff
} else {
root.storageDiff.Merge(diff)
}
}
// WithTx wraps the given inventory client with a transaction.
func WithTx[T TxOperator](ctx context.Context, c T) (T, *Tx, context.Context, error) {
var txClient *ent.Client
var txWrapper *Tx
if txInherited, ok := ctx.Value(TxCtx{}).(*Tx); ok && !txInherited.finished {
txWrapper = &Tx{inherited: true, tx: txInherited.tx, parent: txInherited}
} else {
tx, err := c.GetClient().Tx(ctx)
if err != nil {
return c, nil, ctx, fmt.Errorf("failed to create transaction: %w", err)
}
txWrapper = &Tx{inherited: false, tx: tx}
ctx = context.WithValue(ctx, TxCtx{}, txWrapper)
}
txClient = txWrapper.tx.Client()
return c.SetClient(txClient).(T), txWrapper, ctx, nil
}
func Rollback(tx *Tx) error {
if !tx.inherited {
tx.finished = true
return tx.tx.Rollback()
}
return nil
}
func commit(tx *Tx) (bool, error) {
if !tx.inherited {
tx.finished = true
return true, tx.tx.Commit()
}
return false, nil
}
func Commit(tx *Tx) error {
_, err := commit(tx)
return err
}
// CommitWithStorageDiff commits the transaction and applies the storage diff, only if the transaction is not inherited.
func CommitWithStorageDiff(ctx context.Context, tx *Tx, l logging.Logger, uc UserClient) error {
commited, err := commit(tx)
if err != nil {
return err
}
if !commited {
return nil
}
if err := uc.ApplyStorageDiff(ctx, tx.storageDiff); err != nil {
l.Error("Failed to apply storage diff", "error", err)
}
return nil
}