xgboost
base64.h
Go to the documentation of this file.
1 
8 #ifndef XGBOOST_COMMON_BASE64_H_
9 #define XGBOOST_COMMON_BASE64_H_
10 
11 #include <xgboost/logging.h>
12 #include <cctype>
13 #include <cstdio>
14 #include <string>
15 #include "./io.h"
16 
17 namespace xgboost {
18 namespace common {
21  public:
22  explicit StreamBufferReader(size_t buffer_size)
23  :stream_(NULL),
24  read_len_(1), read_ptr_(1) {
25  buffer_.resize(buffer_size);
26  }
30  inline void set_stream(dmlc::Stream *stream) {
31  stream_ = stream;
32  read_len_ = read_ptr_ = 1;
33  }
37  inline char GetChar(void) {
38  while (true) {
39  if (read_ptr_ < read_len_) {
40  return buffer_[read_ptr_++];
41  } else {
42  read_len_ = stream_->Read(&buffer_[0], buffer_.length());
43  if (read_len_ == 0) return EOF;
44  read_ptr_ = 0;
45  }
46  }
47  }
49  inline bool AtEnd(void) const {
50  return read_len_ == 0;
51  }
52 
53  private:
55  dmlc::Stream *stream_;
57  std::string buffer_;
59  size_t read_len_;
61  size_t read_ptr_;
62 };
63 
65 namespace base64 {
66 const char DecodeTable[] = {
67  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
68  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
69  62, // '+'
70  0, 0, 0,
71  63, // '/'
72  52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
73  0, 0, 0, 0, 0, 0, 0,
74  0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
75  13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
76  0, 0, 0, 0, 0, 0,
77  26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
78  39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
79 };
80 static const char EncodeTable[] =
81  "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
82 } // namespace base64
84 class Base64InStream: public dmlc::Stream {
85  public:
86  explicit Base64InStream(dmlc::Stream *fs) : reader_(256) {
87  reader_.set_stream(fs);
88  num_prev = 0; tmp_ch = 0;
89  }
94  inline void InitPosition(void) {
95  // get a character
96  do {
97  tmp_ch = reader_.GetChar();
98  } while (isspace(tmp_ch));
99  }
101  inline bool IsEOF(void) const {
102  return num_prev == 0 && (tmp_ch == EOF || isspace(tmp_ch));
103  }
104  virtual size_t Read(void *ptr, size_t size) {
105  using base64::DecodeTable;
106  if (size == 0) return 0;
107  // use tlen to record left size
108  size_t tlen = size;
109  unsigned char *cptr = static_cast<unsigned char*>(ptr);
110  // if anything left, load from previous buffered result
111  if (num_prev != 0) {
112  if (num_prev == 2) {
113  if (tlen >= 2) {
114  *cptr++ = buf_prev[0];
115  *cptr++ = buf_prev[1];
116  tlen -= 2;
117  num_prev = 0;
118  } else {
119  // assert tlen == 1
120  *cptr++ = buf_prev[0]; --tlen;
121  buf_prev[0] = buf_prev[1];
122  num_prev = 1;
123  }
124  } else {
125  // assert num_prev == 1
126  *cptr++ = buf_prev[0]; --tlen; num_prev = 0;
127  }
128  }
129  if (tlen == 0) return size;
130  int nvalue;
131  // note: everything goes with 4 bytes in Base64
132  // so we process 4 bytes a unit
133  while (tlen && tmp_ch != EOF && !isspace(tmp_ch)) {
134  // first byte
135  nvalue = DecodeTable[tmp_ch] << 18;
136  {
137  // second byte
138  tmp_ch = reader_.GetChar();
139  CHECK(tmp_ch != EOF && !isspace(tmp_ch)) << "invalid base64 format";
140  nvalue |= DecodeTable[tmp_ch] << 12;
141  *cptr++ = (nvalue >> 16) & 0xFF; --tlen;
142  }
143  {
144  // third byte
145  tmp_ch = reader_.GetChar();
146  CHECK(tmp_ch != EOF && !isspace(tmp_ch)) << "invalid base64 format";
147  // handle termination
148  if (tmp_ch == '=') {
149  tmp_ch = reader_.GetChar();
150  CHECK(tmp_ch == '=') << "invalid base64 format";
151  tmp_ch = reader_.GetChar();
152  CHECK(tmp_ch == EOF || isspace(tmp_ch))
153  << "invalid base64 format";
154  break;
155  }
156  nvalue |= DecodeTable[tmp_ch] << 6;
157  if (tlen) {
158  *cptr++ = (nvalue >> 8) & 0xFF; --tlen;
159  } else {
160  buf_prev[num_prev++] = (nvalue >> 8) & 0xFF;
161  }
162  }
163  {
164  // fourth byte
165  tmp_ch = reader_.GetChar();
166  CHECK(tmp_ch != EOF && !isspace(tmp_ch))
167  << "invalid base64 format";
168  if (tmp_ch == '=') {
169  tmp_ch = reader_.GetChar();
170  CHECK(tmp_ch == EOF || isspace(tmp_ch))
171  << "invalid base64 format";
172  break;
173  }
174  nvalue |= DecodeTable[tmp_ch];
175  if (tlen) {
176  *cptr++ = nvalue & 0xFF; --tlen;
177  } else {
178  buf_prev[num_prev ++] = nvalue & 0xFF;
179  }
180  }
181  // get next char
182  tmp_ch = reader_.GetChar();
183  }
184  if (kStrictCheck) {
185  CHECK_EQ(tlen, 0) << "Base64InStream: read incomplete";
186  }
187  return size - tlen;
188  }
189  virtual void Write(const void *ptr, size_t size) {
190  LOG(FATAL) << "Base64InStream do not support write";
191  }
192 
193  private:
194  StreamBufferReader reader_;
195  int tmp_ch;
196  int num_prev;
197  unsigned char buf_prev[2];
198  // whether we need to do strict check
199  static const bool kStrictCheck = false;
200 };
202 class Base64OutStream: public dmlc::Stream {
203  public:
204  explicit Base64OutStream(dmlc::Stream *fp) : fp(fp) {
205  buf_top = 0;
206  }
207  virtual void Write(const void *ptr, size_t size) {
208  using base64::EncodeTable;
209  size_t tlen = size;
210  const unsigned char *cptr = static_cast<const unsigned char*>(ptr);
211  while (tlen) {
212  while (buf_top < 3 && tlen != 0) {
213  buf[++buf_top] = *cptr++; --tlen;
214  }
215  if (buf_top == 3) {
216  // flush 4 bytes out
217  PutChar(EncodeTable[buf[1] >> 2]);
218  PutChar(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F]);
219  PutChar(EncodeTable[((buf[2] << 2) | (buf[3] >> 6)) & 0x3F]);
220  PutChar(EncodeTable[buf[3] & 0x3F]);
221  buf_top = 0;
222  }
223  }
224  }
225  virtual size_t Read(void *ptr, size_t size) {
226  LOG(FATAL) << "Base64OutStream do not support read";
227  return 0;
228  }
233  inline void Finish(char endch = EOF) {
234  using base64::EncodeTable;
235  if (buf_top == 1) {
236  PutChar(EncodeTable[buf[1] >> 2]);
237  PutChar(EncodeTable[(buf[1] << 4) & 0x3F]);
238  PutChar('=');
239  PutChar('=');
240  }
241  if (buf_top == 2) {
242  PutChar(EncodeTable[buf[1] >> 2]);
243  PutChar(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F]);
244  PutChar(EncodeTable[(buf[2] << 2) & 0x3F]);
245  PutChar('=');
246  }
247  buf_top = 0;
248  if (endch != EOF) PutChar(endch);
249  this->Flush();
250  }
251 
252  private:
253  dmlc::Stream *fp;
254  int buf_top;
255  unsigned char buf[4];
256  std::string out_buf;
257  static const size_t kBufferSize = 256;
258 
259  inline void PutChar(char ch) {
260  out_buf += ch;
261  if (out_buf.length() >= kBufferSize) Flush();
262  }
263  inline void Flush(void) {
264  if (out_buf.length() != 0) {
265  fp->Write(&out_buf[0], out_buf.length());
266  out_buf.clear();
267  }
268  }
269 };
270 } // namespace common
271 } // namespace xgboost
272 #endif // XGBOOST_COMMON_BASE64_H_
buffer reader of the stream that allows you to get
Definition: base64.h:20
const char DecodeTable[]
Definition: base64.h:66
bool AtEnd(void) const
whether we are reaching the end of file
Definition: base64.h:49
char GetChar(void)
allows quick read using get char
Definition: base64.h:37
virtual size_t Read(void *ptr, size_t size)
Definition: base64.h:104
general stream interface for serialization, I/O
void Finish(char endch=EOF)
finish writing of all current base64 stream, do some post processing
Definition: base64.h:233
void set_stream(dmlc::Stream *stream)
set input stream
Definition: base64.h:30
the stream that write to base64, note we take from file pointers
Definition: base64.h:202
the stream that reads from base64, note we take from file pointers
Definition: base64.h:84
Base64OutStream(dmlc::Stream *fp)
Definition: base64.h:204
void InitPosition(void)
initialize the stream position to beginning of next base64 stream call this function before actually ...
Definition: base64.h:94
namespace of xgboost
Definition: base.h:102
virtual void Write(const void *ptr, size_t size)
Definition: base64.h:207
Base64InStream(dmlc::Stream *fs)
Definition: base64.h:86
bool IsEOF(void) const
whether current position is end of a base64 stream
Definition: base64.h:101
virtual size_t Read(void *ptr, size_t size)
Definition: base64.h:225
virtual void Write(const void *ptr, size_t size)
Definition: base64.h:189
StreamBufferReader(size_t buffer_size)
Definition: base64.h:22