summaryrefslogtreecommitdiff
path: root/src/packages.cc
blob: 0309df116871d9457debe2581636796ad2fad5f6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
// -*- mode: c++; c-basic-offset: 2; -*-

#include "common.hh"

#include <istream>
#include <memory>
#include <ostream>
#include <string.h>

#include "data.hh"
#include "packages.hh"

namespace {

class PackagesWriterImpl : public PackagesWriter {
public:
  PackagesWriterImpl(size_t count, std::ostream* out)
    : count_(count), out_(out) {
    uint8_t header[8];
    memcpy(header, "TPP", 3);
    header[3] = 0x1;  // Version
    write_u32(header + 4, count_);  // Count
    out->write(reinterpret_cast<char*>(header), 8);

    if (count == 0) {
      write_u64(header, 0);  // EOF
      out_->write(reinterpret_cast<char*>(header), 8);
    }
  }

  ~PackagesWriterImpl() {
    assert(count_ == 0);
  }

  void write(Package const& package, std::string const& data) override {
    if (count_ == 0) {
      assert(false);
      return;
    }
    uint8_t buf[8192];
    std::unique_ptr<uint8_t[]> backup;
    uint8_t* ptr = buf;
    size_t need = write_package(package, buf, sizeof(buf));
    if (need > sizeof(buf)) {
      backup.reset(new uint8_t[need]);
      ptr = backup.get();
      write_package(package, ptr, need);
    }
    uint8_t size[8];
    write_u64(size, need + data.size());
    out_->write(reinterpret_cast<char*>(size), 8);
    out_->write(reinterpret_cast<char*>(ptr), need);
    backup.reset();
    out_->write(data.data(), data.size());

    if (--count_ == 0) {
      write_u64(size, 0);  // EOF
      out_->write(reinterpret_cast<char*>(size), 8);
    }
  }

private:
  size_t count_;
  std::ostream* const out_;
};

}  // namespace

// static
PackagesWriter* PackagesWriter::create(size_t count, std::ostream* out) {
  return new PackagesWriterImpl(count, out);
}

// static
PackagesReader::Status PackagesReader::read(std::istream& in,
                                            Delegate* delegate) {
  uint8_t header[8];
  in.read(reinterpret_cast<char*>(header), 8);
  if (!in.good() || memcmp(header, "TPP", 3) || header[3] != 1) {
    return INVALID;
  }
  auto count = read_u32(header + 4);
  while (count--) {
    in.read(reinterpret_cast<char*>(header), 8);
    if (!in.good()) return IO_ERROR;
    size_t size = read_u64(header);
    if (size == 0) return INVALID;
    uint8_t buf[8192];
    size_t avail = std::min(size, sizeof(buf));
    in.read(reinterpret_cast<char*>(buf), avail);
    if (!in.good()) return IO_ERROR;
    Package pkg;
    auto pkg_size = read_package(&pkg, buf, avail);
    if (pkg_size == 0) {
      if (avail == size) return INVALID;
      size_t need = std::min(static_cast<size_t>(1024) * 1024, size);
      std::unique_ptr<uint8_t[]> mem(new uint8_t[need]);
      memcpy(mem.get(), buf, avail);
      in.read(reinterpret_cast<char*>(mem.get()) + avail, need - avail);
      if (!in.good()) return IO_ERROR;
      pkg_size = read_package(&pkg, mem.get(), need);
      if (pkg_size == 0) return INVALID;
      delegate->package(pkg);
      delegate->data(pkg.id, reinterpret_cast<char*>(mem.get())
                     + pkg_size, need - pkg_size, need == size);
      size -= need;
    } else {
        delegate->package(pkg);
        delegate->data(pkg.id, reinterpret_cast<char*>(buf) + pkg_size,
                       avail - pkg_size, avail == size);
        size -= avail;
    }
    while (size) {
      avail = std::min(sizeof(buf), size);
      in.read(reinterpret_cast<char*>(buf), avail);
      if (!in.good()) return IO_ERROR;
      delegate->data(pkg.id, reinterpret_cast<char*>(buf), avail,
                     avail == size);
      size -= avail;
    }
  }
  in.read(reinterpret_cast<char*>(header), 8);
  if (!in.good() || read_u64(header) != 0) return INVALID;
  return GOOD;
}