tesseract  4.1.1
lstmrecognizer.h
Go to the documentation of this file.
1 // File: lstmrecognizer.h
3 // Description: Top-level line recognizer class for LSTM-based networks.
4 // Author: Ray Smith
5 // Created: Thu May 02 08:57:06 PST 2013
6 //
7 // (C) Copyright 2013, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #ifndef TESSERACT_LSTM_LSTMRECOGNIZER_H_
20 #define TESSERACT_LSTM_LSTMRECOGNIZER_H_
21 
22 #include "ccutil.h"
23 #include "helpers.h"
24 #include "imagedata.h"
25 #include "matrix.h"
26 #include "network.h"
27 #include "networkscratch.h"
28 #include "params.h"
29 #include "recodebeam.h"
30 #include "series.h"
31 #include "strngs.h"
32 #include "unicharcompress.h"
33 
34 class BLOB_CHOICE_IT;
35 struct Pix;
36 class ROW_RES;
37 class ScrollView;
38 class TBOX;
39 class WERD_RES;
40 
41 namespace tesseract {
42 
43 class Dict;
44 class ImageData;
45 
46 // Enum indicating training mode control flags.
50 };
51 
52 // Top-level line recognizer class for LSTM-based networks.
53 // Note that a sub-class, LSTMTrainer is used for training.
55  public:
57  LSTMRecognizer(const STRING language_data_path_prefix);
59 
60  int NumOutputs() const { return network_->NumOutputs(); }
61  int training_iteration() const { return training_iteration_; }
62  int sample_iteration() const { return sample_iteration_; }
63  double learning_rate() const { return learning_rate_; }
65  if (network_ == nullptr) return LT_NONE;
66  StaticShape shape;
67  shape = network_->OutputShape(shape);
68  return shape.loss_type();
69  }
70  bool SimpleTextOutput() const { return OutputLossType() == LT_SOFTMAX; }
71  bool IsIntMode() const { return (training_flags_ & TF_INT_MODE) != 0; }
72  // True if recoder_ is active to re-encode text to a smaller space.
73  bool IsRecoding() const {
75  }
76  // Returns true if the network is a TensorFlow network.
77  bool IsTensorFlow() const { return network_->type() == NT_TENSORFLOW; }
78  // Returns a vector of layer ids that can be passed to other layer functions
79  // to access a specific layer.
81  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
82  auto* series = static_cast<Series*>(network_);
83  GenericVector<STRING> layers;
84  series->EnumerateLayers(nullptr, &layers);
85  return layers;
86  }
87  // Returns a specific layer from its id (from EnumerateLayers).
88  Network* GetLayer(const STRING& id) const {
89  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
90  ASSERT_HOST(id.length() > 1 && id[0] == ':');
91  auto* series = static_cast<Series*>(network_);
92  return series->GetLayer(&id[1]);
93  }
94  // Returns the learning rate of the layer from its id.
95  float GetLayerLearningRate(const STRING& id) const {
96  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
98  ASSERT_HOST(id.length() > 1 && id[0] == ':');
99  auto* series = static_cast<Series*>(network_);
100  return series->LayerLearningRate(&id[1]);
101  } else {
102  return learning_rate_;
103  }
104  }
105  // Multiplies the all the learning rate(s) by the given factor.
106  void ScaleLearningRate(double factor) {
107  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
108  learning_rate_ *= factor;
111  for (int i = 0; i < layers.size(); ++i) {
112  ScaleLayerLearningRate(layers[i], factor);
113  }
114  }
115  }
116  // Multiplies the learning rate of the layer with id, by the given factor.
117  void ScaleLayerLearningRate(const STRING& id, double factor) {
118  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
119  ASSERT_HOST(id.length() > 1 && id[0] == ':');
120  auto* series = static_cast<Series*>(network_);
121  series->ScaleLayerLearningRate(&id[1], factor);
122  }
123 
124  // Converts the network to int if not already.
125  void ConvertToInt() {
126  if ((training_flags_ & TF_INT_MODE) == 0) {
129  }
130  }
131 
132  // Provides access to the UNICHARSET that this classifier works with.
133  const UNICHARSET& GetUnicharset() const { return ccutil_.unicharset; }
134  // Provides access to the UnicharCompress that this classifier works with.
135  const UnicharCompress& GetRecoder() const { return recoder_; }
136  // Provides access to the Dict that this classifier works with.
137  const Dict* GetDict() const { return dict_; }
138  // Sets the sample iteration to the given value. The sample_iteration_
139  // determines the seed for the random number generator. The training
140  // iteration is incremented only by a successful training iteration.
141  void SetIteration(int iteration) { sample_iteration_ = iteration; }
142  // Accessors for textline image normalization.
143  int NumInputs() const { return network_->NumInputs(); }
144  int null_char() const { return null_char_; }
145 
146  // Loads a model from mgr, including the dictionary only if lang is not null.
147  bool Load(const ParamsVectors* params, const char* lang,
148  TessdataManager* mgr);
149 
150  // Writes to the given file. Returns false in case of error.
151  // If mgr contains a unicharset and recoder, then they are not encoded to fp.
152  bool Serialize(const TessdataManager* mgr, TFile* fp) const;
153  // Reads from the given file. Returns false in case of error.
154  // If mgr contains a unicharset and recoder, then they are taken from there,
155  // otherwise, they are part of the serialization in fp.
156  bool DeSerialize(const TessdataManager* mgr, TFile* fp);
157  // Loads the charsets from mgr.
158  bool LoadCharsets(const TessdataManager* mgr);
159  // Loads the Recoder.
160  bool LoadRecoder(TFile* fp);
161  // Loads the dictionary if possible from the traineddata file.
162  // Prints a warning message, and returns false but otherwise fails silently
163  // and continues to work without it if loading fails.
164  // Note that dictionary load is independent from DeSerialize, but dependent
165  // on the unicharset matching. This enables training to deserialize a model
166  // from checkpoint or restore without having to go back and reload the
167  // dictionary.
168  bool LoadDictionary(const ParamsVectors* params, const char* lang,
169  TessdataManager* mgr);
170 
171  // Recognizes the line image, contained within image_data, returning the
172  // recognized tesseract WERD_RES for the words.
173  // If invert, tries inverted as well if the normal interpretation doesn't
174  // produce a good enough result. The line_box is used for computing the
175  // box_word in the output words. worst_dict_cert is the worst certainty that
176  // will be used in a dictionary word.
177  void RecognizeLine(const ImageData& image_data, bool invert, bool debug,
178  double worst_dict_cert, const TBOX& line_box,
179  PointerVector<WERD_RES>* words, int lstm_choice_mode = 0);
180 
181  // Helper computes min and mean best results in the output.
182  void OutputStats(const NetworkIO& outputs, float* min_output,
183  float* mean_output, float* sd);
184  // Recognizes the image_data, returning the labels,
185  // scores, and corresponding pairs of start, end x-coords in coords.
186  // Returned in scale_factor is the reduction factor
187  // between the image and the output coords, for computing bounding boxes.
188  // If re_invert is true, the input is inverted back to its original
189  // photometric interpretation if inversion is attempted but fails to
190  // improve the results. This ensures that outputs contains the correct
191  // forward outputs for the best photometric interpretation.
192  // inputs is filled with the used inputs to the network.
193  bool RecognizeLine(const ImageData& image_data, bool invert, bool debug,
194  bool re_invert, bool upside_down, float* scale_factor,
195  NetworkIO* inputs, NetworkIO* outputs);
196 
197  // Converts an array of labels to utf-8, whether or not the labels are
198  // augmented with character boundaries.
199  STRING DecodeLabels(const GenericVector<int>& labels);
200 
201  // Displays the forward results in a window with the characters and
202  // boundaries as determined by the labels and label_coords.
203  void DisplayForward(const NetworkIO& inputs, const GenericVector<int>& labels,
204  const GenericVector<int>& label_coords,
205  const char* window_name, ScrollView** window);
206  // Converts the network output to a sequence of labels. Outputs labels, scores
207  // and start xcoords of each char, and each null_char_, with an additional
208  // final xcoord for the end of the output.
209  // The conversion method is determined by internal state.
210  void LabelsFromOutputs(const NetworkIO& outputs, GenericVector<int>* labels,
211  GenericVector<int>* xcoords);
212 
213  protected:
214  // Sets the random seed from the sample_iteration_;
215  void SetRandomSeed() {
216  int64_t seed = static_cast<int64_t>(sample_iteration_) * 0x10000001;
217  randomizer_.set_seed(seed);
219  }
220 
221  // Displays the labels and cuts at the corresponding xcoords.
222  // Size of labels should match xcoords.
223  void DisplayLSTMOutput(const GenericVector<int>& labels,
224  const GenericVector<int>& xcoords, int height,
225  ScrollView* window);
226 
227  // Prints debug output detailing the activation path that is implied by the
228  // xcoords.
229  void DebugActivationPath(const NetworkIO& outputs,
230  const GenericVector<int>& labels,
231  const GenericVector<int>& xcoords);
232 
233  // Prints debug output detailing activations and 2nd choice over a range
234  // of positions.
235  void DebugActivationRange(const NetworkIO& outputs, const char* label,
236  int best_choice, int x_start, int x_end);
237 
238  // As LabelsViaCTC except that this function constructs the best path that
239  // contains only legal sequences of subcodes for recoder_.
240  void LabelsViaReEncode(const NetworkIO& output, GenericVector<int>* labels,
241  GenericVector<int>* xcoords);
242  // Converts the network output to a sequence of labels, with scores, using
243  // the simple character model (each position is a char, and the null_char_ is
244  // mainly intended for tail padding.)
245  void LabelsViaSimpleText(const NetworkIO& output, GenericVector<int>* labels,
246  GenericVector<int>* xcoords);
247 
248  // Returns a string corresponding to the label starting at start. Sets *end
249  // to the next start and if non-null, *decoded to the unichar id.
250  const char* DecodeLabel(const GenericVector<int>& labels, int start, int* end,
251  int* decoded);
252 
253  // Returns a string corresponding to a given single label id, falling back to
254  // a default of ".." for part of a multi-label unichar-id.
255  const char* DecodeSingleLabel(int label);
256 
257  protected:
258  // The network hierarchy.
260  // The unicharset. Only the unicharset element is serialized.
261  // Has to be a CCUtil, so Dict can point to it.
263  // For backward compatibility, recoder_ is serialized iff
264  // training_flags_ & TF_COMPRESS_UNICHARSET.
265  // Further encode/decode ccutil_.unicharset's ids to simplify the unicharset.
267 
268  // ==Training parameters that are serialized to provide a record of them.==
270  // Flags used to determine the training method of the network.
271  // See enum TrainingFlags above.
273  // Number of actual backward training steps used.
275  // Index into training sample set. sample_iteration >= training_iteration_.
277  // Index in softmax of null character. May take the value UNICHAR_BROKEN or
278  // ccutil_.unicharset.size().
279  int32_t null_char_;
280  // Learning rate and momentum multipliers of deltas in backprop.
282  float momentum_;
283  // Smoothing factor for 2nd moment of gradients.
284  float adam_beta_;
285 
286  // === NOT SERIALIZED.
289  // Language model (optional) to use with the beam search.
291  // Beam search held between uses to optimize memory allocation/use.
293 
294  // == Debugging parameters.==
295  // Recognition debug display window.
297 };
298 
299 } // namespace tesseract.
300 
301 #endif // TESSERACT_LSTM_LSTMRECOGNIZER_H_
#define ASSERT_HOST(x)
Definition: errcode.h:88
@ TF_COMPRESS_UNICHARSET
@ NT_TENSORFLOW
Definition: network.h:78
@ NT_SERIES
Definition: network.h:54
@ NF_LAYER_SPECIFIC_LR
Definition: network.h:87
int size() const
Definition: genericvector.h:72
Definition: rect.h:34
UNICHARSET unicharset
Definition: ccutil.h:73
int32_t IntRand()
Definition: helpers.h:50
void set_seed(uint64_t seed)
Definition: helpers.h:40
Definition: strngs.h:45
LossType OutputLossType() const
Network * GetLayer(const STRING &id) const
bool Load(const ParamsVectors *params, const char *lang, TessdataManager *mgr)
NetworkScratch scratch_space_
double learning_rate() const
const char * DecodeSingleLabel(int label)
void DisplayLSTMOutput(const GenericVector< int > &labels, const GenericVector< int > &xcoords, int height, ScrollView *window)
bool LoadCharsets(const TessdataManager *mgr)
const char * DecodeLabel(const GenericVector< int > &labels, int start, int *end, int *decoded)
const UNICHARSET & GetUnicharset() const
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0)
STRING DecodeLabels(const GenericVector< int > &labels)
void LabelsViaReEncode(const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
void OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
void SetIteration(int iteration)
void DebugActivationPath(const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
RecodeBeamSearch * search_
float GetLayerLearningRate(const STRING &id) const
GenericVector< STRING > EnumerateLayers() const
void ScaleLearningRate(double factor)
const Dict * GetDict() const
void LabelsFromOutputs(const NetworkIO &outputs, GenericVector< int > *labels, GenericVector< int > *xcoords)
bool LoadDictionary(const ParamsVectors *params, const char *lang, TessdataManager *mgr)
const UnicharCompress & GetRecoder() const
bool Serialize(const TessdataManager *mgr, TFile *fp) const
void DebugActivationRange(const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)
void ScaleLayerLearningRate(const STRING &id, double factor)
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
void DisplayForward(const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
void LabelsViaSimpleText(const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
int NumOutputs() const
Definition: network.h:123
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: network.h:133
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
int NumInputs() const
Definition: network.h:120
NetworkType type() const
Definition: network.h:112
virtual void ConvertToInt()
Definition: network.h:191
float LayerLearningRate(const char *id) const
Definition: plumbing.h:105
void ScaleLayerLearningRate(const char *id, double factor)
Definition: plumbing.h:111
Network * GetLayer(const char *id) const
Definition: plumbing.cpp:155
LossType loss_type() const
Definition: static_shape.h:50