This file is indexed.

/usr/include/tesseract/lstmtrainer.h is in libtesseract-dev 4.00~git2288-10f4998a-2.

This file is owned by root:root, with mode 0o644.

The actual contents of the file can be viewed below.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
///////////////////////////////////////////////////////////////////////
// File:        lstmtrainer.h
// Description: Top-level line trainer class for LSTM-based networks.
// Author:      Ray Smith
// Created:     Fri May 03 09:07:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////

#ifndef TESSERACT_LSTM_LSTMTRAINER_H_
#define TESSERACT_LSTM_LSTMTRAINER_H_

#include "imagedata.h"
#include "lstmrecognizer.h"
#include "rect.h"
#include "tesscallback.h"

namespace tesseract {

class LSTM;
class LSTMTrainer;
class Parallel;
class Reversed;
class Softmax;
class Series;

// Enum for the types of errors that are counted.
enum ErrorTypes {
  ET_RMS,          // RMS activation error.
  ET_DELTA,        // Number of big errors in deltas.
  ET_WORD_RECERR,  // Output text string word recall error.
  ET_CHAR_ERROR,   // Output text string total char error.
  ET_SKIP_RATIO,   // Fraction of samples skipped.
  ET_COUNT         // For array sizing.
};

// Enum for the trainability_ flags.
enum Trainability {
  TRAINABLE,         // Non-zero delta error.
  PERFECT,           // Zero delta error.
  UNENCODABLE,       // Not trainable due to coding/alignment trouble.
  HI_PRECISION_ERR,  // Hi confidence disagreement.
  NOT_BOXED,         // Early in training and has no character boxes.
};

// Enum to define the amount of data to get serialized.
enum SerializeAmount {
  LIGHT,            // Minimal data for remote training.
  NO_BEST_TRAINER,  // Save an empty vector in place of best_trainer_.
  FULL,             // All data including best_trainer_.
};

// Enum to indicate how the sub_trainer_ training went.
enum SubTrainerResult {
  STR_NONE,     // Did nothing as not good enough.
  STR_UPDATED,  // Subtrainer was updated, but didn't replace *this.
  STR_REPLACED  // Subtrainer replaced *this.
};

class LSTMTrainer;
// Function to restore the trainer state from a given checkpoint.
// Returns false on failure.
typedef TessResultCallback2<bool, const GenericVector<char>&, LSTMTrainer*>*
    CheckPointReader;
// Function to save a checkpoint of the current trainer state.
// Returns false on failure. SerializeAmount determines the amount of the
// trainer to serialize, typically used for saving the best state.
typedef TessResultCallback3<bool, SerializeAmount, const LSTMTrainer*,
                            GenericVector<char>*>* CheckPointWriter;
// Function to compute and record error rates on some external test set(s).
// Args are: iteration, mean errors, model, training stage.
// Returns a STRING containing logging information about the tests.
typedef TessResultCallback4<STRING, int, const double*, const TessdataManager&,
                            int>* TestCallback;

// Trainer class for LSTM networks. Most of the effort is in creating the
// ideal target outputs from the transcription. A box file is used if it is
// available, otherwise estimates of the char widths from the unicharset are
// used to guide a DP search for the best fit to the transcription.
class LSTMTrainer : public LSTMRecognizer {
 public:
  LSTMTrainer();
  // Callbacks may be null, in which case defaults are used.
  LSTMTrainer(FileReader file_reader, FileWriter file_writer,
              CheckPointReader checkpoint_reader,
              CheckPointWriter checkpoint_writer,
              const char* model_base, const char* checkpoint_name,
              int debug_interval, int64_t max_memory);
  virtual ~LSTMTrainer();

  // Tries to deserialize a trainer from the given file and silently returns
  // false in case of failure. If old_traineddata is not null, then it is
  // assumed that the character set is to be re-mapped from old_traineddata to
  // the new, with consequent change in weight matrices etc.
  bool TryLoadingCheckpoint(const char* filename, const char* old_traineddata);

  // Initializes the character set encode/decode mechanism directly from a
  // previously setup traineddata containing dawgs, UNICHARSET and
  // UnicharCompress. Note: Call before InitNetwork!
  void InitCharSet(const std::string& traineddata_path) {
    ASSERT_HOST(mgr_.Init(traineddata_path.c_str()));
    InitCharSet();
  }
  void InitCharSet(const TessdataManager& mgr) {
    mgr_ = mgr;
    InitCharSet();
  }

  // Initializes the trainer with a network_spec in the network description
  // net_flags control network behavior according to the NetworkFlags enum.
  // There isn't really much difference between them - only where the effects
  // are implemented.
  // For other args see NetworkBuilder::InitNetwork.
  // Note: Be sure to call InitCharSet before InitNetwork!
  bool InitNetwork(const STRING& network_spec, int append_index, int net_flags,
                   float weight_range, float learning_rate, float momentum,
                   float adam_beta);
  // Initializes a trainer from a serialized TFNetworkModel proto.
  // Returns the global step of TensorFlow graph or 0 if failed.
  // Building a compatible TF graph: See tfnetwork.proto.
  int InitTensorFlowNetwork(const std::string& tf_proto);
  // Resets all the iteration counters for fine tuning or training a head,
  // where we want the error reporting to reset.
  void InitIterations();

  // Accessors.
  double ActivationError() const {
    return error_rates_[ET_DELTA];
  }
  double CharError() const { return error_rates_[ET_CHAR_ERROR]; }
  const double* error_rates() const {
    return error_rates_;
  }
  double best_error_rate() const {
    return best_error_rate_;
  }
  int best_iteration() const {
    return best_iteration_;
  }
  int learning_iteration() const { return learning_iteration_; }
  int improvement_steps() const { return improvement_steps_; }
  void set_perfect_delay(int delay) { perfect_delay_ = delay; }
  const GenericVector<char>& best_trainer() const { return best_trainer_; }
  // Returns the error that was just calculated by PrepareForBackward.
  double NewSingleError(ErrorTypes type) const {
    return error_buffers_[type][training_iteration() % kRollingBufferSize_];
  }
  // Returns the error that was just calculated by TrainOnLine. Since
  // TrainOnLine rolls the error buffers, this is one further back than
  // NewSingleError.
  double LastSingleError(ErrorTypes type) const {
    return error_buffers_[type]
                         [(training_iteration() + kRollingBufferSize_ - 1) %
                          kRollingBufferSize_];
  }
  const DocumentCache& training_data() const {
    return training_data_;
  }
  DocumentCache* mutable_training_data() { return &training_data_; }

  // If the training sample is usable, grid searches for the optimal
  // dict_ratio/cert_offset, and returns the results in a string of space-
  // separated triplets of ratio,offset=worderr.
  Trainability GridSearchDictParams(
      const ImageData* trainingdata, int iteration, double min_dict_ratio,
      double dict_ratio_step, double max_dict_ratio, double min_cert_offset,
      double cert_offset_step, double max_cert_offset, STRING* results);

  // Provides output on the distribution of weight values.
  void DebugNetwork();

  // Loads a set of lstmf files that were created using the lstm.train config to
  // tesseract into memory ready for training. Returns false if nothing was
  // loaded.
  bool LoadAllTrainingData(const GenericVector<STRING>& filenames,
                           CachingStrategy cache_strategy,
                           bool randomly_rotate);

  // Keeps track of best and locally worst error rate, using internally computed
  // values. See MaintainCheckpointsSpecific for more detail.
  bool MaintainCheckpoints(TestCallback tester, STRING* log_msg);
  // Keeps track of best and locally worst error_rate (whatever it is) and
  // launches tests using rec_model, when a new min or max is reached.
  // Writes checkpoints using train_model at appropriate times and builds and
  // returns a log message to indicate progress. Returns false if nothing
  // interesting happened.
  bool MaintainCheckpointsSpecific(int iteration,
                                   const GenericVector<char>* train_model,
                                   const GenericVector<char>* rec_model,
                                   TestCallback tester, STRING* log_msg);
  // Builds a string containing a progress message with current error rates.
  void PrepareLogMsg(STRING* log_msg) const;
  // Appends <intro_str> iteration learning_iteration()/training_iteration()/
  // sample_iteration() to the log_msg.
  void LogIterations(const char* intro_str, STRING* log_msg) const;

  // TODO(rays) Add curriculum learning.
  // Returns true and increments the training_stage_ if the error rate has just
  // passed through the given threshold for the first time.
  bool TransitionTrainingStage(float error_threshold);
  // Returns the current training stage.
  int CurrentTrainingStage() const { return training_stage_; }

  // Writes to the given file. Returns false in case of error.
  virtual bool Serialize(SerializeAmount serialize_amount,
                         const TessdataManager* mgr, TFile* fp) const;
  // Reads from the given file. Returns false in case of error.
  virtual bool DeSerialize(const TessdataManager* mgr, TFile* fp);

  // De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
  // learning rates (by scaling reduction, or layer specific, according to
  // NF_LAYER_SPECIFIC_LR).
  void StartSubtrainer(STRING* log_msg);
  // While the sub_trainer_ is behind the current training iteration and its
  // training error is at least kSubTrainerMarginFraction better than the
  // current training error, trains the sub_trainer_, and returns STR_UPDATED if
  // it did anything. If it catches up, and has a better error rate than the
  // current best, as well as a margin over the current error rate, then the
  // trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
  // returned. STR_NONE is returned if the subtrainer wasn't good enough to
  // receive any training iterations.
  SubTrainerResult UpdateSubtrainer(STRING* log_msg);
  // Reduces network learning rates, either for everything, or for layers
  // independently, according to NF_LAYER_SPECIFIC_LR.
  void ReduceLearningRates(LSTMTrainer* samples_trainer, STRING* log_msg);
  // Considers reducing the learning rate independently for each layer down by
  // factor(<1), or leaving it the same, by double-training the given number of
  // samples and minimizing the amount of changing of sign of weight updates.
  // Even if it looks like all weights should remain the same, an adjustment
  // will be made to guarantee a different result when reverting to an old best.
  // Returns the number of layer learning rates that were reduced.
  int ReduceLayerLearningRates(double factor, int num_samples,
                               LSTMTrainer* samples_trainer);

  // Converts the string to integer class labels, with appropriate null_char_s
  // in between if not in SimpleTextOutput mode. Returns false on failure.
  bool EncodeString(const STRING& str, GenericVector<int>* labels) const {
    return EncodeString(str, GetUnicharset(), IsRecoding() ? &recoder_ : nullptr,
                        SimpleTextOutput(), null_char_, labels);
  }
  // Static version operates on supplied unicharset, encoder, simple_text.
  static bool EncodeString(const STRING& str, const UNICHARSET& unicharset,
                           const UnicharCompress* recoder, bool simple_text,
                           int null_char, GenericVector<int>* labels);

  // Performs forward-backward on the given trainingdata.
  // Returns the sample that was used or nullptr if the next sample was deemed
  // unusable. samples_trainer could be this or an alternative trainer that
  // holds the training samples.
  const ImageData* TrainOnLine(LSTMTrainer* samples_trainer, bool batch) {
    int sample_index = sample_iteration();
    const ImageData* image =
        samples_trainer->training_data_.GetPageBySerial(sample_index);
    if (image != nullptr) {
      Trainability trainable = TrainOnLine(image, batch);
      if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
        return nullptr;  // Sample was unusable.
      }
    } else {
      ++sample_iteration_;
    }
    return image;
  }
  Trainability TrainOnLine(const ImageData* trainingdata, bool batch);

  // Prepares the ground truth, runs forward, and prepares the targets.
  // Returns a Trainability enum to indicate the suitability of the sample.
  Trainability PrepareForBackward(const ImageData* trainingdata,
                                  NetworkIO* fwd_outputs, NetworkIO* targets);

  // Writes the trainer to memory, so that the current training state can be
  // restored.  *this must always be the master trainer that retains the only
  // copy of the training data and language model. trainer is the model that is
  // actually serialized.
  bool SaveTrainingDump(SerializeAmount serialize_amount,
                        const LSTMTrainer* trainer,
                        GenericVector<char>* data) const;

  // Reads previously saved trainer from memory. *this must always be the
  // master trainer that retains the only copy of the training data and
  // language model. trainer is the model that is restored.
  bool ReadTrainingDump(const GenericVector<char>& data,
                        LSTMTrainer* trainer) const {
    if (data.empty()) return false;
    return ReadSizedTrainingDump(&data[0], data.size(), trainer);
  }
  bool ReadSizedTrainingDump(const char* data, int size,
                             LSTMTrainer* trainer) const {
    return trainer->ReadLocalTrainingDump(&mgr_, data, size);
  }
  // Restores the model to *this.
  bool ReadLocalTrainingDump(const TessdataManager* mgr, const char* data,
                             int size);

  // Sets up the data for MaintainCheckpoints from a light ReadTrainingDump.
  void SetupCheckpointInfo();

  // Writes the full recognition traineddata to the given filename.
  bool SaveTraineddata(const STRING& filename);

  // Writes the recognizer to memory, so that it can be used for testing later.
  void SaveRecognitionDump(GenericVector<char>* data) const;

  // Returns a suitable filename for a training dump, based on the model_base_,
  // the iteration and the error rates.
  STRING DumpFilename() const;

  // Fills the whole error buffer of the given type with the given value.
  void FillErrorBuffer(double new_error, ErrorTypes type);
  // Helper generates a map from each current recoder_ code (ie softmax index)
  // to the corresponding old_recoder code, or -1 if there isn't one.
  std::vector<int> MapRecoder(const UNICHARSET& old_chset,
                              const UnicharCompress& old_recoder) const;

 protected:
  // Private version of InitCharSet above finishes the job after initializing
  // the mgr_ data member.
  void InitCharSet();
  // Helper computes and sets the null_char_.
  void SetNullChar();

  // Factored sub-constructor sets up reasonable default values.
  void EmptyConstructor();

  // Outputs the string and periodically displays the given network inputs
  // as an image in the given window, and the corresponding labels at the
  // corresponding x_starts.
  // Returns false if the truth string is empty.
  bool DebugLSTMTraining(const NetworkIO& inputs,
                         const ImageData& trainingdata,
                         const NetworkIO& fwd_outputs,
                         const GenericVector<int>& truth_labels,
                         const NetworkIO& outputs);
  // Displays the network targets as line a line graph.
  void DisplayTargets(const NetworkIO& targets, const char* window_name,
                      ScrollView** window);

  // Builds a no-compromises target where the first positions should be the
  // truth labels and the rest is padded with the null_char_.
  bool ComputeTextTargets(const NetworkIO& outputs,
                          const GenericVector<int>& truth_labels,
                          NetworkIO* targets);

  // Builds a target using standard CTC. truth_labels should be pre-padded with
  // nulls wherever desired. They don't have to be between all labels.
  // outputs is input-output, as it gets clipped to minimum probability.
  bool ComputeCTCTargets(const GenericVector<int>& truth_labels,
                         NetworkIO* outputs, NetworkIO* targets);

  // Computes network errors, and stores the results in the rolling buffers,
  // along with the supplied text_error.
  // Returns the delta error of the current sample (not running average.)
  double ComputeErrorRates(const NetworkIO& deltas, double char_error,
                           double word_error);

  // Computes the network activation RMS error rate.
  double ComputeRMSError(const NetworkIO& deltas);

  // Computes network activation winner error rate. (Number of values that are
  // in error by >= 0.5 divided by number of time-steps.) More closely related
  // to final character error than RMS, but still directly calculable from
  // just the deltas. Because of the binary nature of the targets, zero winner
  // error is a sufficient but not necessary condition for zero char error.
  double ComputeWinnerError(const NetworkIO& deltas);

  // Computes a very simple bag of chars char error rate.
  double ComputeCharError(const GenericVector<int>& truth_str,
                          const GenericVector<int>& ocr_str);
  // Computes a very simple bag of words word recall error rate.
  // NOTE that this is destructive on both input strings.
  double ComputeWordError(STRING* truth_str, STRING* ocr_str);

  // Updates the error buffer and corresponding mean of the given type with
  // the new_error.
  void UpdateErrorBuffer(double new_error, ErrorTypes type);

  // Rolls error buffers and reports the current means.
  void RollErrorBuffers();

  // Given that error_rate is either a new min or max, updates the best/worst
  // error rates, and record of progress.
  STRING UpdateErrorGraph(int iteration, double error_rate,
                          const GenericVector<char>& model_data,
                          TestCallback tester);

 protected:
  // Alignment display window.
  ScrollView* align_win_;
  // CTC target display window.
  ScrollView* target_win_;
  // CTC output display window.
  ScrollView* ctc_win_;
  // Reconstructed image window.
  ScrollView* recon_win_;
  // How often to display a debug image.
  int debug_interval_;
  // Iteration at which the last checkpoint was dumped.
  int checkpoint_iteration_;
  // Basename of files to save best models to.
  STRING model_base_;
  // Checkpoint filename.
  STRING checkpoint_name_;
  // Training data.
  bool randomly_rotate_;
  DocumentCache training_data_;
  // Name to use when saving best_trainer_.
  STRING best_model_name_;
  // Number of available training stages.
  int num_training_stages_;
  // Checkpointing callbacks.
  FileReader file_reader_;
  FileWriter file_writer_;
  // TODO(rays) These are pointers, and must be deleted. Switch to unique_ptr
  // when we can commit to c++11.
  CheckPointReader checkpoint_reader_;
  CheckPointWriter checkpoint_writer_;

  // ===Serialized data to ensure that a restart produces the same results.===
  // These members are only serialized when serialize_amount != LIGHT.
  // Best error rate so far.
  double best_error_rate_;
  // Snapshot of all error rates at best_iteration_.
  double best_error_rates_[ET_COUNT];
  // Iteration of best_error_rate_.
  int best_iteration_;
  // Worst error rate since best_error_rate_.
  double worst_error_rate_;
  // Snapshot of all error rates at worst_iteration_.
  double worst_error_rates_[ET_COUNT];
  // Iteration of worst_error_rate_.
  int worst_iteration_;
  // Iteration at which the process will be thought stalled.
  int stall_iteration_;
  // Saved recognition models for computing test error for graph points.
  GenericVector<char> best_model_data_;
  GenericVector<char> worst_model_data_;
  // Saved trainer for reverting back to last known best.
  GenericVector<char> best_trainer_;
  // A subsidiary trainer running with a different learning rate until either
  // *this or sub_trainer_ hits a new best.
  LSTMTrainer* sub_trainer_;
  // Error rate at which last best model was dumped.
  float error_rate_of_last_saved_best_;
  // Current stage of training.
  int training_stage_;
  // History of best error rate against iteration. Used for computing the
  // number of steps to each 2% improvement.
  GenericVector<double> best_error_history_;
  GenericVector<int> best_error_iterations_;
  // Number of iterations since the best_error_rate_ was 2% more than it is now.
  int improvement_steps_;
  // Number of iterations that yielded a non-zero delta error and thus provided
  // significant learning. learning_iteration_ <= training_iteration_.
  // learning_iteration_ is used to measure rate of learning progress.
  int learning_iteration_;
  // Saved value of sample_iteration_ before looking for the the next sample.
  int prev_sample_iteration_;
  // How often to include a PERFECT training sample in backprop.
  // A PERFECT training sample is used if the current
  // training_iteration_ > last_perfect_training_iteration_ + perfect_delay_,
  // so with perfect_delay_ == 0, all samples are used, and with
  // perfect_delay_ == 4, at most 1 in 5 samples will be perfect.
  int perfect_delay_;
  // Value of training_iteration_ at which the last PERFECT training sample
  // was used in back prop.
  int last_perfect_training_iteration_;
  // Rolling buffers storing recent training errors are indexed by
  // training_iteration % kRollingBufferSize_.
  static const int kRollingBufferSize_ = 1000;
  GenericVector<double> error_buffers_[ET_COUNT];
  // Rounded mean percent trailing training errors in the buffers.
  double error_rates_[ET_COUNT];    // RMS training error.
  // Traineddata file with optional dawgs + UNICHARSET and recoder.
  TessdataManager mgr_;
};

}  // namespace tesseract.

#endif  // TESSERACT_LSTM_LSTMTRAINER_H_