package writegroup
import (
"context"
"sync"
"time"
)
type WriteGroup[T, R any] interface {
Do(ctx context.Context, key string, input T, fn func(context.Context, []T) []WriteResult[R]) (R, error)
}
// 组写入方法
type writeGroup[T, R any] struct {
batchTimeout time.Duration
bufferSize int
batchSize int
workerMu keyedMutex
chMu sync.Mutex
workerCh map[string]chan writeTask[T, R]
}
type writeTask[T, R any] struct {
input T
result chan<- WriteResult[R]
}
type WriteResult[R any] struct {
Result R
Err error
}
func NewWriteGroup[T, R any](bufferSize, batchSize int, batchTimeout time.Duration) WriteGroup[T, R] {
return &writeGroup[T, R]{
bufferSize: bufferSize,
batchSize: batchSize,
batchTimeout: batchTimeout,
workerCh: make(map[string]chan writeTask[T, R]),
}
}
func (w *writeGroup[T, R]) Do(ctx context.Context, key string, input T, fn func(context.Context, []T) []WriteResult[R]) (R, error) {
w.chMu.Lock()
ch, ok := w.workerCh[key]
if !ok {
ch = make(chan writeTask[T, R], w.bufferSize)
w.workerCh[key] = ch
}
w.chMu.Unlock()
resultCh := make(chan WriteResult[R], 1)
ch <- writeTask[T, R]{
input: input,
result: resultCh,
}
workerMu := w.workerMu.GetMutex(key)
if workerMu.TryLock() {
go func(lock *sync.Mutex) {
defer lock.Unlock()
worker := func(jobs []writeTask[T, R]) {
jobInputs := make([]T, len(jobs))
for i, job := range jobs {
jobInputs[i] = job.input
}
results := fn(ctx, jobInputs)
for i, job := range jobs {
job.result <- results[i]
}
}
timeoutTimer := time.NewTimer(w.batchTimeout)
defer timeoutTimer.Stop()
for {
jobs := make([]writeTask[T, R], 0, w.bufferSize)
timeoutTimer.Reset(w.batchTimeout)
PROC:
for {
select {
case <-timeoutTimer.C:
if len(jobs) > 0 {
worker(jobs)
}
break PROC
case val, ok := <-ch:
if !ok {
if len(jobs) > 0 {
worker(jobs)
}
break PROC
}
jobs = append(jobs, val)
if len(jobs) >= w.batchSize {
worker(jobs)
break PROC
}
}
}
if ctr := len(ch); ctr == 0 {
return
}
}
}(workerMu)
}
select {
case <-ctx.Done():
return *new(R), ctx.Err()
case ret := <-resultCh:
return ret.Result, ret.Err
}
}
type keyedMutex struct {
m sync.Map
}
func (km *keyedMutex) GetMutex(key string) *sync.Mutex {
if mu, ok := km.m.Load(key); ok {
return mu.(*sync.Mutex)
}
newMu := &sync.Mutex{}
if value, loaded := km.m.LoadOrStore(key, newMu); !loaded {
return newMu
} else {
return value.(*sync.Mutex)
}
}