diff options
Diffstat (limited to 'src/decompress_lzma.cc')
| -rw-r--r-- | src/decompress_lzma.cc | 110 |
1 files changed, 110 insertions, 0 deletions
diff --git a/src/decompress_lzma.cc b/src/decompress_lzma.cc new file mode 100644 index 0000000..6baea18 --- /dev/null +++ b/src/decompress_lzma.cc @@ -0,0 +1,110 @@ +#include "decompress.hh" + +#include "buffer.hh" + +#include <lzma.h> + +#include <algorithm> +#include <cstddef> +#include <cstdint> +#include <cstring> +#include <expected> +#include <memory> +#include <optional> +#include <utility> + +namespace decompress { + +namespace { + +const size_t kBufferSizeXz = static_cast<size_t>(1024) * 1024; + +class XzReader : public io::Reader { + public: + explicit XzReader(std::unique_ptr<io::Reader> reader) + : reader_(std::move(reader)) {} + + ~XzReader() override { + if (initialized_) + lzma_end(&stream_); + } + + std::expected<size_t, io::ReadError> read(void* dst, size_t max) override { + auto err = fill(); + if (err.has_value()) + return std::unexpected(err.value()); + + stream_.next_out = reinterpret_cast<unsigned char*>(dst); + stream_.avail_out = max; + + if (!initialized_) { + if (in_eof_ && buffer_->empty()) + return 0; + + lzma_mt options; + memset(&options, 0, sizeof(options)); + options.threads = std::max(static_cast<uint32_t>(1), lzma_cputhreads()); + options.memlimit_threading = lzma_physmem() / 4; + options.memlimit_stop = lzma_physmem() / 4; + auto ret = lzma_stream_decoder_mt(&stream_, &options); + if (ret != LZMA_OK) + return std::unexpected(io::ReadError::Error); + initialized_ = true; + } + + auto* const rptr = stream_.next_in; + auto ret = lzma_code(&stream_, in_eof_ ? LZMA_FINISH : LZMA_RUN); + auto got = max - stream_.avail_out; + if (ret == LZMA_STREAM_END) { + lzma_end(&stream_); + initialized_ = false; + buffer_->consume(stream_.next_in - rptr); + } else if (ret == LZMA_OK) { + if (!in_eof_) + buffer_->consume(stream_.next_in - rptr); + } else { + return std::unexpected( + ret == LZMA_DATA_ERROR + ? io::ReadError::InvalidData : io::ReadError::Error); + } + return got; + } + + std::expected<size_t, io::ReadError> skip(size_t max) override { + auto tmp = std::make_unique_for_overwrite<char[]>(max); + return read(tmp.get(), max); + } + + private: + std::optional<io::ReadError> fill() { + auto* rptr = buffer_->rptr(stream_.avail_in); + if (!in_eof_ && stream_.avail_in < kBufferSizeXz / 2) { + auto* wptr = buffer_->wptr(stream_.avail_in); + auto got = reader_->read(wptr, stream_.avail_in); + if (got.has_value()) { + buffer_->commit(got.value()); + if (got.value() == 0) + in_eof_ = true; + } else { + return got.error(); + } + rptr = buffer_->rptr(stream_.avail_in); + } + stream_.next_in = reinterpret_cast<const unsigned char*>(rptr); + return std::nullopt; + } + + std::unique_ptr<io::Reader> reader_; + bool in_eof_{false}; + std::unique_ptr<Buffer> buffer_{Buffer::fixed(kBufferSizeXz)}; + bool initialized_{false}; + lzma_stream stream_ = LZMA_STREAM_INIT; +}; + +} // namespace + +std::unique_ptr<io::Reader> xz(std::unique_ptr<io::Reader> reader) { + return std::make_unique<XzReader>(std::move(reader)); +} + +} // namespace decompress |
