// Copyright (C) 2022-2023 Luke Shumaker // // SPDX-License-Identifier: GPL-2.0-or-later package diskio import ( "context" "sync" "github.com/datawire/dlib/dlog" "git.lukeshu.com/btrfs-progs-ng/lib/containers" ) type bufferedBlock[A ~int64] struct { Mu sync.RWMutex Addr A Dirty bool Dat []byte Err error } type bufferedFile[A ~int64] struct { ctx context.Context //nolint:containedctx // don't have an option while keeping the io.ReaderAt/io.WriterAt API inner File[A] blockSize A blockCache containers.Cache[A, bufferedBlock[A]] } var _ File[assertAddr] = (*bufferedFile[assertAddr])(nil) func NewBufferedFile[A ~int64](ctx context.Context, file File[A], blockSize A, cacheSize int) *bufferedFile[A] { ret := &bufferedFile[A]{ ctx: ctx, inner: file, blockSize: blockSize, } ret.blockCache = containers.NewARCache[A, bufferedBlock[A]](cacheSize, bufferedBlockSource[A]{ret}) return ret } type bufferedBlockSource[A ~int64] struct { bf *bufferedFile[A] } func (src bufferedBlockSource[A]) Flush(ctx context.Context, block *bufferedBlock[A]) { if !block.Dirty { return } if _, err := src.bf.inner.WriteAt(block.Dat, block.Addr); err != nil { dlog.Errorf(ctx, "i/o error: write: %v", err) } block.Dirty = false } func (src bufferedBlockSource[A]) Load(ctx context.Context, blockAddr A, block *bufferedBlock[A]) { src.Flush(ctx, block) if block.Dat == nil { block.Dat = make([]byte, src.bf.blockSize) } n, err := src.bf.inner.ReadAt(block.Dat[:src.bf.blockSize], blockAddr) block.Addr = blockAddr block.Dat = block.Dat[:n] block.Err = err } func (bf *bufferedFile[A]) Name() string { return bf.inner.Name() } func (bf *bufferedFile[A]) Size() A { return bf.inner.Size() } func (bf *bufferedFile[A]) Close() error { return bf.inner.Close() } func (bf *bufferedFile[A]) Flush() { bf.blockCache.Flush(bf.ctx) } func (bf *bufferedFile[A]) ReadAt(dat []byte, off A) (n int, err error) { done := 0 for done < len(dat) { n, err := bf.maybeShortReadAt(dat[done:], off+A(done)) done += n if err != nil { return done, err } } return done, nil } func (bf *bufferedFile[A]) maybeShortReadAt(dat []byte, off A) (n int, err error) { offsetWithinBlock := off % bf.blockSize blockOffset := off - offsetWithinBlock cachedBlock := bf.blockCache.Acquire(bf.ctx, blockOffset) defer bf.blockCache.Release(blockOffset) cachedBlock.Mu.RLock() defer cachedBlock.Mu.RUnlock() n = copy(dat, cachedBlock.Dat[offsetWithinBlock:]) if n < len(dat) { return n, cachedBlock.Err } return n, nil } func (bf *bufferedFile[A]) WriteAt(dat []byte, off A) (n int, err error) { done := 0 for done < len(dat) { n, err := bf.maybeShortWriteAt(dat[done:], off+A(done)) done += n if err != nil { return done, err } } return done, nil } func (bf *bufferedFile[A]) maybeShortWriteAt(dat []byte, off A) (n int, err error) { offsetWithinBlock := off % bf.blockSize blockOffset := off - offsetWithinBlock cachedBlock := bf.blockCache.Acquire(bf.ctx, blockOffset) defer bf.blockCache.Release(blockOffset) cachedBlock.Mu.Lock() defer cachedBlock.Mu.Unlock() cachedBlock.Dirty = true n = copy(cachedBlock.Dat[offsetWithinBlock:], dat) if n < len(dat) { return n, cachedBlock.Err } return n, nil }