You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
101 lines
2.1 KiB
101 lines
2.1 KiB
package inventory
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"github.com/cloudreve/Cloudreve/v4/ent"
|
|
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
|
)
|
|
|
|
type TxOperator interface {
|
|
SetClient(newClient *ent.Client) TxOperator
|
|
GetClient() *ent.Client
|
|
}
|
|
|
|
type (
|
|
Tx struct {
|
|
tx *ent.Tx
|
|
parent *Tx
|
|
inherited bool
|
|
finished bool
|
|
storageDiff StorageDiff
|
|
}
|
|
|
|
// TxCtx is the context key for inherited transaction
|
|
TxCtx struct{}
|
|
)
|
|
|
|
// AppendStorageDiff appends the given storage diff to the transaction.
|
|
func (t *Tx) AppendStorageDiff(diff StorageDiff) {
|
|
root := t
|
|
for root.inherited {
|
|
root = root.parent
|
|
}
|
|
|
|
if root.storageDiff == nil {
|
|
root.storageDiff = diff
|
|
} else {
|
|
root.storageDiff.Merge(diff)
|
|
}
|
|
}
|
|
|
|
// WithTx wraps the given inventory client with a transaction.
|
|
func WithTx[T TxOperator](ctx context.Context, c T) (T, *Tx, context.Context, error) {
|
|
var txClient *ent.Client
|
|
var txWrapper *Tx
|
|
|
|
if txInherited, ok := ctx.Value(TxCtx{}).(*Tx); ok && !txInherited.finished {
|
|
txWrapper = &Tx{inherited: true, tx: txInherited.tx, parent: txInherited}
|
|
} else {
|
|
tx, err := c.GetClient().Tx(ctx)
|
|
if err != nil {
|
|
return c, nil, ctx, fmt.Errorf("failed to create transaction: %w", err)
|
|
}
|
|
|
|
txWrapper = &Tx{inherited: false, tx: tx}
|
|
ctx = context.WithValue(ctx, TxCtx{}, txWrapper)
|
|
}
|
|
|
|
txClient = txWrapper.tx.Client()
|
|
return c.SetClient(txClient).(T), txWrapper, ctx, nil
|
|
}
|
|
|
|
func Rollback(tx *Tx) error {
|
|
if !tx.inherited {
|
|
tx.finished = true
|
|
return tx.tx.Rollback()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func commit(tx *Tx) (bool, error) {
|
|
if !tx.inherited {
|
|
tx.finished = true
|
|
return true, tx.tx.Commit()
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func Commit(tx *Tx) error {
|
|
_, err := commit(tx)
|
|
return err
|
|
}
|
|
|
|
// CommitWithStorageDiff commits the transaction and applies the storage diff, only if the transaction is not inherited.
|
|
func CommitWithStorageDiff(ctx context.Context, tx *Tx, l logging.Logger, uc UserClient) error {
|
|
commited, err := commit(tx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !commited {
|
|
return nil
|
|
}
|
|
|
|
if err := uc.ApplyStorageDiff(ctx, tx.storageDiff); err != nil {
|
|
l.Error("Failed to apply storage diff", "error", err)
|
|
}
|
|
|
|
return nil
|
|
}
|