tesseract  4.1.1
tesseract::LSTM Class Reference

#include <lstm.h>

Inheritance diagram for tesseract::LSTM:
tesseract::Network

Public Types

enum  WeightType {
  CI, GI, GF1, GO,
  GFS, WT_COUNT
}
 

Public Member Functions

 LSTM (const STRING &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
 
 ~LSTM () override
 
StaticShape OutputShape (const StaticShape &input_shape) const override
 
STRING spec () const override
 
void SetEnableTraining (TrainingState state) override
 
int InitWeights (float range, TRand *randomizer) override
 
int RemapOutputs (int old_no, const std::vector< int > &code_map) override
 
void ConvertToInt () override
 
void DebugWeights () override
 
bool Serialize (TFile *fp) const override
 
bool DeSerialize (TFile *fp) override
 
void Forward (bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
 
bool Backward (bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
 
void Update (float learning_rate, float momentum, float adam_beta, int num_samples) override
 
void CountAlternators (const Network &other, double *same, double *changed) const override
 
void PrintW ()
 
void PrintDW ()
 
bool Is2D () const
 
- Public Member Functions inherited from tesseract::Network
 Network ()
 
 Network (NetworkType type, const STRING &name, int ni, int no)
 
virtual ~Network ()=default
 
NetworkType type () const
 
bool IsTraining () const
 
bool needs_to_backprop () const
 
int num_weights () const
 
int NumInputs () const
 
int NumOutputs () const
 
virtual StaticShape InputShape () const
 
const STRINGname () const
 
bool TestFlag (NetworkFlags flag) const
 
virtual bool IsPlumbingType () const
 
virtual void SetNetworkFlags (uint32_t flags)
 
virtual void SetRandomizer (TRand *randomizer)
 
virtual bool SetupNeedsBackprop (bool needs_backprop)
 
virtual int XScaleFactor () const
 
virtual void CacheXScaleFactor (int factor)
 
void DisplayForward (const NetworkIO &matrix)
 
void DisplayBackward (const NetworkIO &matrix)
 

Additional Inherited Members

- Static Public Member Functions inherited from tesseract::Network
static NetworkCreateFromFile (TFile *fp)
 
static void ClearWindow (bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
 
static int DisplayImage (Pix *pix, ScrollView *window)
 
- Protected Member Functions inherited from tesseract::Network
double Random (double range)
 
- Protected Attributes inherited from tesseract::Network
NetworkType type_
 
TrainingState training_
 
bool needs_to_backprop_
 
int32_t network_flags_
 
int32_t ni_
 
int32_t no_
 
int32_t num_weights_
 
STRING name_
 
ScrollViewforward_win_
 
ScrollViewbackward_win_
 
TRandrandomizer_
 

Detailed Description

Definition at line 28 of file lstm.h.

Member Enumeration Documentation

◆ WeightType

Enumerator
CI 
GI 
GF1 
GO 
GFS 
WT_COUNT 

Definition at line 33 of file lstm.h.

33  {
34  CI, // Cell Inputs.
35  GI, // Gate at the input.
36  GF1, // Forget gate at the memory (1-d or looking back 1 timestep).
37  GO, // Gate at the output.
38  GFS, // Forget gate at the memory, looking back in the other dimension.
39 
40  WT_COUNT // Number of WeightTypes.
41  };

Constructor & Destructor Documentation

◆ LSTM()

tesseract::LSTM::LSTM ( const STRING name,
int  num_inputs,
int  num_states,
int  num_outputs,
bool  two_dimensional,
NetworkType  type 
)

Definition at line 99 of file lstm.cpp.

101  : Network(type, name, ni, no),
102  na_(ni + ns),
103  ns_(ns),
104  nf_(0),
105  is_2d_(two_dimensional),
106  softmax_(nullptr),
107  input_width_(0) {
108  if (two_dimensional) na_ += ns_;
109  if (type_ == NT_LSTM || type_ == NT_LSTM_SUMMARY) {
110  nf_ = 0;
111  // networkbuilder ensures this is always true.
112  ASSERT_HOST(no == ns);
113  } else if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) {
114  nf_ = type_ == NT_LSTM_SOFTMAX ? no_ : ceil_log2(no_);
115  softmax_ = new FullyConnected("LSTM Softmax", ns_, no_, NT_SOFTMAX);
116  } else {
117  tprintf("%d is invalid type of LSTM!\n", type);
118  ASSERT_HOST(false);
119  }
120  na_ += nf_;
121 }

◆ ~LSTM()

tesseract::LSTM::~LSTM ( )
override

Definition at line 123 of file lstm.cpp.

123 { delete softmax_; }

Member Function Documentation

◆ Backward()

bool tesseract::LSTM::Backward ( bool  debug,
const NetworkIO fwd_deltas,
NetworkScratch scratch,
NetworkIO back_deltas 
)
overridevirtual

Implements tesseract::Network.

Definition at line 441 of file lstm.cpp.

443  {
444  if (debug) DisplayBackward(fwd_deltas);
445  back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_, ni_);
446  // ======Scratch space.======
447  // Output errors from deltas with recurrence from sourceerr.
448  NetworkScratch::FloatVec outputerr;
449  outputerr.Init(ns_, scratch);
450  // Recurrent error in the state/source.
451  NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
452  curr_stateerr.Init(ns_, scratch);
453  curr_sourceerr.Init(na_, scratch);
454  ZeroVector<double>(ns_, curr_stateerr);
455  ZeroVector<double>(na_, curr_sourceerr);
456  // Errors in the gates.
457  NetworkScratch::FloatVec gate_errors[WT_COUNT];
458  for (auto & gate_error : gate_errors) gate_error.Init(ns_, scratch);
459  // Rotating buffers of width buf_width allow storage of the recurrent time-
460  // steps used only for true 2-D. Stores one full strip of the major direction.
461  int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
462  GenericVector<NetworkScratch::FloatVec> stateerr, sourceerr;
463  if (Is2D()) {
464  stateerr.init_to_size(buf_width, NetworkScratch::FloatVec());
465  sourceerr.init_to_size(buf_width, NetworkScratch::FloatVec());
466  for (int t = 0; t < buf_width; ++t) {
467  stateerr[t].Init(ns_, scratch);
468  sourceerr[t].Init(na_, scratch);
469  ZeroVector<double>(ns_, stateerr[t]);
470  ZeroVector<double>(na_, sourceerr[t]);
471  }
472  }
473  // Parallel-generated sourceerr from each of the gates.
474  NetworkScratch::FloatVec sourceerr_temps[WT_COUNT];
475  for (auto & sourceerr_temp : sourceerr_temps)
476  sourceerr_temp.Init(na_, scratch);
477  int width = input_width_;
478  // Transposed gate errors stored over all timesteps for sum outer.
479  NetworkScratch::GradientStore gate_errors_t[WT_COUNT];
480  for (auto & w : gate_errors_t) {
481  w.Init(ns_, width, scratch);
482  }
483  // Used only if softmax_ != nullptr.
484  NetworkScratch::FloatVec softmax_errors;
485  NetworkScratch::GradientStore softmax_errors_t;
486  if (softmax_ != nullptr) {
487  softmax_errors.Init(no_, scratch);
488  softmax_errors_t.Init(no_, width, scratch);
489  }
490  double state_clip = Is2D() ? 9.0 : 4.0;
491 #if DEBUG_DETAIL > 1
492  tprintf("fwd_deltas:%s\n", name_.string());
493  fwd_deltas.Print(10);
494 #endif
495  StrideMap::Index dest_index(input_map_);
496  dest_index.InitToLast();
497  // Used only by NT_LSTM_SUMMARY.
498  StrideMap::Index src_index(fwd_deltas.stride_map());
499  src_index.InitToLast();
500  do {
501  int t = dest_index.t();
502  bool at_last_x = dest_index.IsLast(FD_WIDTH);
503  // up_pos is the 2-D back step, down_pos is the 2-D fwd step, and are only
504  // valid if >= 0, which is true if 2d and not on the top/bottom.
505  int up_pos = -1;
506  int down_pos = -1;
507  if (Is2D()) {
508  if (dest_index.index(FD_HEIGHT) > 0) {
509  StrideMap::Index up_index(dest_index);
510  if (up_index.AddOffset(-1, FD_HEIGHT)) up_pos = up_index.t();
511  }
512  if (!dest_index.IsLast(FD_HEIGHT)) {
513  StrideMap::Index down_index(dest_index);
514  if (down_index.AddOffset(1, FD_HEIGHT)) down_pos = down_index.t();
515  }
516  }
517  // Index of the 2-D revolving buffers (sourceerr, stateerr).
518  int mod_t = Modulo(t, buf_width); // Current timestep.
519  // Zero the state in the major direction only at the end of every row.
520  if (at_last_x) {
521  ZeroVector<double>(na_, curr_sourceerr);
522  ZeroVector<double>(ns_, curr_stateerr);
523  }
524  // Setup the outputerr.
525  if (type_ == NT_LSTM_SUMMARY) {
526  if (dest_index.IsLast(FD_WIDTH)) {
527  fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
528  src_index.Decrement();
529  } else {
530  ZeroVector<double>(ns_, outputerr);
531  }
532  } else if (softmax_ == nullptr) {
533  fwd_deltas.ReadTimeStep(t, outputerr);
534  } else {
535  softmax_->BackwardTimeStep(fwd_deltas, t, softmax_errors,
536  softmax_errors_t.get(), outputerr);
537  }
538  if (!at_last_x)
539  AccumulateVector(ns_, curr_sourceerr + ni_ + nf_, outputerr);
540  if (down_pos >= 0)
541  AccumulateVector(ns_, sourceerr[mod_t] + ni_ + nf_ + ns_, outputerr);
542  // Apply the 1-d forget gates.
543  if (!at_last_x) {
544  const float* next_node_gf1 = node_values_[GF1].f(t + 1);
545  for (int i = 0; i < ns_; ++i) {
546  curr_stateerr[i] *= next_node_gf1[i];
547  }
548  }
549  if (Is2D() && t + 1 < width) {
550  for (int i = 0; i < ns_; ++i) {
551  if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
552  }
553  if (down_pos >= 0) {
554  const float* right_node_gfs = node_values_[GFS].f(down_pos);
555  const double* right_stateerr = stateerr[mod_t];
556  for (int i = 0; i < ns_; ++i) {
557  if (which_fg_[down_pos][i] == 2) {
558  curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
559  }
560  }
561  }
562  }
563  state_.FuncMultiply3Add<HPrime>(node_values_[GO], t, outputerr,
564  curr_stateerr);
565  // Clip stateerr_ to a sane range.
566  ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
567 #if DEBUG_DETAIL > 1
568  if (t + 10 > width) {
569  tprintf("t=%d, stateerr=", t);
570  for (int i = 0; i < ns_; ++i)
571  tprintf(" %g,%g,%g", curr_stateerr[i], outputerr[i],
572  curr_sourceerr[ni_ + nf_ + i]);
573  tprintf("\n");
574  }
575 #endif
576  // Matrix multiply to get the source errors.
578 
579  // Cell inputs.
580  node_values_[CI].FuncMultiply3<GPrime>(t, node_values_[GI], t,
581  curr_stateerr, gate_errors[CI]);
582  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[CI].get());
583  gate_weights_[CI].VectorDotMatrix(gate_errors[CI], sourceerr_temps[CI]);
584  gate_errors_t[CI].get()->WriteStrided(t, gate_errors[CI]);
585 
587  // Input Gates.
588  node_values_[GI].FuncMultiply3<FPrime>(t, node_values_[CI], t,
589  curr_stateerr, gate_errors[GI]);
590  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GI].get());
591  gate_weights_[GI].VectorDotMatrix(gate_errors[GI], sourceerr_temps[GI]);
592  gate_errors_t[GI].get()->WriteStrided(t, gate_errors[GI]);
593 
595  // 1-D forget Gates.
596  if (t > 0) {
597  node_values_[GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr,
598  gate_errors[GF1]);
599  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GF1].get());
600  gate_weights_[GF1].VectorDotMatrix(gate_errors[GF1],
601  sourceerr_temps[GF1]);
602  } else {
603  memset(gate_errors[GF1], 0, ns_ * sizeof(gate_errors[GF1][0]));
604  memset(sourceerr_temps[GF1], 0, na_ * sizeof(*sourceerr_temps[GF1]));
605  }
606  gate_errors_t[GF1].get()->WriteStrided(t, gate_errors[GF1]);
607 
608  // 2-D forget Gates.
609  if (up_pos >= 0) {
610  node_values_[GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr,
611  gate_errors[GFS]);
612  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GFS].get());
613  gate_weights_[GFS].VectorDotMatrix(gate_errors[GFS],
614  sourceerr_temps[GFS]);
615  } else {
616  memset(gate_errors[GFS], 0, ns_ * sizeof(gate_errors[GFS][0]));
617  memset(sourceerr_temps[GFS], 0, na_ * sizeof(*sourceerr_temps[GFS]));
618  }
619  if (Is2D()) gate_errors_t[GFS].get()->WriteStrided(t, gate_errors[GFS]);
620 
622  // Output gates.
623  state_.Func2Multiply3<HFunc, FPrime>(node_values_[GO], t, outputerr,
624  gate_errors[GO]);
625  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GO].get());
626  gate_weights_[GO].VectorDotMatrix(gate_errors[GO], sourceerr_temps[GO]);
627  gate_errors_t[GO].get()->WriteStrided(t, gate_errors[GO]);
629 
630  SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI],
631  sourceerr_temps[GF1], sourceerr_temps[GO], sourceerr_temps[GFS],
632  curr_sourceerr);
633  back_deltas->WriteTimeStep(t, curr_sourceerr);
634  // Save states for use by the 2nd dimension only if needed.
635  if (Is2D()) {
636  CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
637  CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
638  }
639  } while (dest_index.Decrement());
640 #if DEBUG_DETAIL > 2
641  for (int w = 0; w < WT_COUNT; ++w) {
642  tprintf("%s gate errors[%d]\n", name_.string(), w);
643  gate_errors_t[w].get()->PrintUnTransposed(10);
644  }
645 #endif
646  // Transposed source_ used to speed-up SumOuter.
647  NetworkScratch::GradientStore source_t, state_t;
648  source_t.Init(na_, width, scratch);
649  source_.Transpose(source_t.get());
650  state_t.Init(ns_, width, scratch);
651  state_.Transpose(state_t.get());
652 #ifdef _OPENMP
653 #pragma omp parallel for num_threads(GFS) if (!Is2D())
654 #endif
655  for (int w = 0; w < WT_COUNT; ++w) {
656  if (w == GFS && !Is2D()) continue;
657  gate_weights_[w].SumOuterTransposed(*gate_errors_t[w], *source_t, false);
658  }
659  if (softmax_ != nullptr) {
660  softmax_->FinishBackward(*softmax_errors_t);
661  }
662  return needs_to_backprop_;
663 }

◆ ConvertToInt()

void tesseract::LSTM::ConvertToInt ( )
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 183 of file lstm.cpp.

183  {
184  for (int w = 0; w < WT_COUNT; ++w) {
185  if (w == GFS && !Is2D()) continue;
186  gate_weights_[w].ConvertToInt();
187  }
188  if (softmax_ != nullptr) {
189  softmax_->ConvertToInt();
190  }
191 }

◆ CountAlternators()

void tesseract::LSTM::CountAlternators ( const Network other,
double *  same,
double *  changed 
) const
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 687 of file lstm.cpp.

688  {
689  ASSERT_HOST(other.type() == type_);
690  const LSTM* lstm = static_cast<const LSTM*>(&other);
691  for (int w = 0; w < WT_COUNT; ++w) {
692  if (w == GFS && !Is2D()) continue;
693  gate_weights_[w].CountAlternators(lstm->gate_weights_[w], same, changed);
694  }
695  if (softmax_ != nullptr) {
696  softmax_->CountAlternators(*lstm->softmax_, same, changed);
697  }
698 }

◆ DebugWeights()

void tesseract::LSTM::DebugWeights ( )
overridevirtual

Implements tesseract::Network.

Definition at line 194 of file lstm.cpp.

194  {
195  for (int w = 0; w < WT_COUNT; ++w) {
196  if (w == GFS && !Is2D()) continue;
197  STRING msg = name_;
198  msg.add_str_int(" Gate weights ", w);
199  gate_weights_[w].Debug2D(msg.string());
200  }
201  if (softmax_ != nullptr) {
202  softmax_->DebugWeights();
203  }
204 }

◆ DeSerialize()

bool tesseract::LSTM::DeSerialize ( TFile fp)
overridevirtual

Implements tesseract::Network.

Definition at line 220 of file lstm.cpp.

220  {
221  if (!fp->DeSerialize(&na_)) return false;
222  if (type_ == NT_LSTM_SOFTMAX) {
223  nf_ = no_;
224  } else if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
225  nf_ = ceil_log2(no_);
226  } else {
227  nf_ = 0;
228  }
229  is_2d_ = false;
230  for (int w = 0; w < WT_COUNT; ++w) {
231  if (w == GFS && !Is2D()) continue;
232  if (!gate_weights_[w].DeSerialize(IsTraining(), fp)) return false;
233  if (w == CI) {
234  ns_ = gate_weights_[CI].NumOutputs();
235  is_2d_ = na_ - nf_ == ni_ + 2 * ns_;
236  }
237  }
238  delete softmax_;
240  softmax_ = static_cast<FullyConnected*>(Network::CreateFromFile(fp));
241  if (softmax_ == nullptr) return false;
242  } else {
243  softmax_ = nullptr;
244  }
245  return true;
246 }

◆ Forward()

void tesseract::LSTM::Forward ( bool  debug,
const NetworkIO input,
const TransposedArray input_transpose,
NetworkScratch scratch,
NetworkIO output 
)
overridevirtual

Implements tesseract::Network.

Definition at line 250 of file lstm.cpp.

252  {
253  input_map_ = input.stride_map();
254  input_width_ = input.Width();
255  if (softmax_ != nullptr)
256  output->ResizeFloat(input, no_);
257  else if (type_ == NT_LSTM_SUMMARY)
258  output->ResizeXTo1(input, no_);
259  else
260  output->Resize(input, no_);
261  ResizeForward(input);
262  // Temporary storage of forward computation for each gate.
263  NetworkScratch::FloatVec temp_lines[WT_COUNT];
264  for (auto & temp_line : temp_lines) temp_line.Init(ns_, scratch);
265  // Single timestep buffers for the current/recurrent output and state.
266  NetworkScratch::FloatVec curr_state, curr_output;
267  curr_state.Init(ns_, scratch);
268  ZeroVector<double>(ns_, curr_state);
269  curr_output.Init(ns_, scratch);
270  ZeroVector<double>(ns_, curr_output);
271  // Rotating buffers of width buf_width allow storage of the state and output
272  // for the other dimension, used only when working in true 2D mode. The width
273  // is enough to hold an entire strip of the major direction.
274  int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
276  if (Is2D()) {
277  states.init_to_size(buf_width, NetworkScratch::FloatVec());
278  outputs.init_to_size(buf_width, NetworkScratch::FloatVec());
279  for (int i = 0; i < buf_width; ++i) {
280  states[i].Init(ns_, scratch);
281  ZeroVector<double>(ns_, states[i]);
282  outputs[i].Init(ns_, scratch);
283  ZeroVector<double>(ns_, outputs[i]);
284  }
285  }
286  // Used only if a softmax LSTM.
287  NetworkScratch::FloatVec softmax_output;
288  NetworkScratch::IO int_output;
289  if (softmax_ != nullptr) {
290  softmax_output.Init(no_, scratch);
291  ZeroVector<double>(no_, softmax_output);
292  int rounded_softmax_inputs = gate_weights_[CI].RoundInputs(ns_);
293  if (input.int_mode())
294  int_output.Resize2d(true, 1, rounded_softmax_inputs, scratch);
295  softmax_->SetupForward(input, nullptr);
296  }
297  NetworkScratch::FloatVec curr_input;
298  curr_input.Init(na_, scratch);
299  StrideMap::Index src_index(input_map_);
300  // Used only by NT_LSTM_SUMMARY.
301  StrideMap::Index dest_index(output->stride_map());
302  do {
303  int t = src_index.t();
304  // True if there is a valid old state for the 2nd dimension.
305  bool valid_2d = Is2D();
306  if (valid_2d) {
307  StrideMap::Index dim_index(src_index);
308  if (!dim_index.AddOffset(-1, FD_HEIGHT)) valid_2d = false;
309  }
310  // Index of the 2-D revolving buffers (outputs, states).
311  int mod_t = Modulo(t, buf_width); // Current timestep.
312  // Setup the padded input in source.
313  source_.CopyTimeStepGeneral(t, 0, ni_, input, t, 0);
314  if (softmax_ != nullptr) {
315  source_.WriteTimeStepPart(t, ni_, nf_, softmax_output);
316  }
317  source_.WriteTimeStepPart(t, ni_ + nf_, ns_, curr_output);
318  if (Is2D())
319  source_.WriteTimeStepPart(t, ni_ + nf_ + ns_, ns_, outputs[mod_t]);
320  if (!source_.int_mode()) source_.ReadTimeStep(t, curr_input);
321  // Matrix multiply the inputs with the source.
323  // It looks inefficient to create the threads on each t iteration, but the
324  // alternative of putting the parallel outside the t loop, a single around
325  // the t-loop and then tasks in place of the sections is a *lot* slower.
326  // Cell inputs.
327  if (source_.int_mode())
328  gate_weights_[CI].MatrixDotVector(source_.i(t), temp_lines[CI]);
329  else
330  gate_weights_[CI].MatrixDotVector(curr_input, temp_lines[CI]);
331  FuncInplace<GFunc>(ns_, temp_lines[CI]);
332 
334  // Input Gates.
335  if (source_.int_mode())
336  gate_weights_[GI].MatrixDotVector(source_.i(t), temp_lines[GI]);
337  else
338  gate_weights_[GI].MatrixDotVector(curr_input, temp_lines[GI]);
339  FuncInplace<FFunc>(ns_, temp_lines[GI]);
340 
342  // 1-D forget gates.
343  if (source_.int_mode())
344  gate_weights_[GF1].MatrixDotVector(source_.i(t), temp_lines[GF1]);
345  else
346  gate_weights_[GF1].MatrixDotVector(curr_input, temp_lines[GF1]);
347  FuncInplace<FFunc>(ns_, temp_lines[GF1]);
348 
349  // 2-D forget gates.
350  if (Is2D()) {
351  if (source_.int_mode())
352  gate_weights_[GFS].MatrixDotVector(source_.i(t), temp_lines[GFS]);
353  else
354  gate_weights_[GFS].MatrixDotVector(curr_input, temp_lines[GFS]);
355  FuncInplace<FFunc>(ns_, temp_lines[GFS]);
356  }
357 
359  // Output gates.
360  if (source_.int_mode())
361  gate_weights_[GO].MatrixDotVector(source_.i(t), temp_lines[GO]);
362  else
363  gate_weights_[GO].MatrixDotVector(curr_input, temp_lines[GO]);
364  FuncInplace<FFunc>(ns_, temp_lines[GO]);
366 
367  // Apply forget gate to state.
368  MultiplyVectorsInPlace(ns_, temp_lines[GF1], curr_state);
369  if (Is2D()) {
370  // Max-pool the forget gates (in 2-d) instead of blindly adding.
371  int8_t* which_fg_col = which_fg_[t];
372  memset(which_fg_col, 1, ns_ * sizeof(which_fg_col[0]));
373  if (valid_2d) {
374  const double* stepped_state = states[mod_t];
375  for (int i = 0; i < ns_; ++i) {
376  if (temp_lines[GF1][i] < temp_lines[GFS][i]) {
377  curr_state[i] = temp_lines[GFS][i] * stepped_state[i];
378  which_fg_col[i] = 2;
379  }
380  }
381  }
382  }
383  MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state);
384  // Clip curr_state to a sane range.
385  ClipVector<double>(ns_, -kStateClip, kStateClip, curr_state);
386  if (IsTraining()) {
387  // Save the gate node values.
388  node_values_[CI].WriteTimeStep(t, temp_lines[CI]);
389  node_values_[GI].WriteTimeStep(t, temp_lines[GI]);
390  node_values_[GF1].WriteTimeStep(t, temp_lines[GF1]);
391  node_values_[GO].WriteTimeStep(t, temp_lines[GO]);
392  if (Is2D()) node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]);
393  }
394  FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output);
395  if (IsTraining()) state_.WriteTimeStep(t, curr_state);
396  if (softmax_ != nullptr) {
397  if (input.int_mode()) {
398  int_output->WriteTimeStepPart(0, 0, ns_, curr_output);
399  softmax_->ForwardTimeStep(int_output->i(0), t, softmax_output);
400  } else {
401  softmax_->ForwardTimeStep(curr_output, t, softmax_output);
402  }
403  output->WriteTimeStep(t, softmax_output);
405  CodeInBinary(no_, nf_, softmax_output);
406  }
407  } else if (type_ == NT_LSTM_SUMMARY) {
408  // Output only at the end of a row.
409  if (src_index.IsLast(FD_WIDTH)) {
410  output->WriteTimeStep(dest_index.t(), curr_output);
411  dest_index.Increment();
412  }
413  } else {
414  output->WriteTimeStep(t, curr_output);
415  }
416  // Save states for use by the 2nd dimension only if needed.
417  if (Is2D()) {
418  CopyVector(ns_, curr_state, states[mod_t]);
419  CopyVector(ns_, curr_output, outputs[mod_t]);
420  }
421  // Always zero the states at the end of every row, but only for the major
422  // direction. The 2-D state remains intact.
423  if (src_index.IsLast(FD_WIDTH)) {
424  ZeroVector<double>(ns_, curr_state);
425  ZeroVector<double>(ns_, curr_output);
426  }
427  } while (src_index.Increment());
428 #if DEBUG_DETAIL > 0
429  tprintf("Source:%s\n", name_.string());
430  source_.Print(10);
431  tprintf("State:%s\n", name_.string());
432  state_.Print(10);
433  tprintf("Output:%s\n", name_.string());
434  output->Print(10);
435 #endif
436  if (debug) DisplayForward(*output);
437 }

◆ InitWeights()

int tesseract::LSTM::InitWeights ( float  range,
TRand randomizer 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 158 of file lstm.cpp.

158  {
159  Network::SetRandomizer(randomizer);
160  num_weights_ = 0;
161  for (int w = 0; w < WT_COUNT; ++w) {
162  if (w == GFS && !Is2D()) continue;
163  num_weights_ += gate_weights_[w].InitWeightsFloat(
164  ns_, na_ + 1, TestFlag(NF_ADAM), range, randomizer);
165  }
166  if (softmax_ != nullptr) {
167  num_weights_ += softmax_->InitWeights(range, randomizer);
168  }
169  return num_weights_;
170 }

◆ Is2D()

bool tesseract::LSTM::Is2D ( ) const
inline

Definition at line 119 of file lstm.h.

119  {
120  return is_2d_;
121  }

◆ OutputShape()

StaticShape tesseract::LSTM::OutputShape ( const StaticShape input_shape) const
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 127 of file lstm.cpp.

127  {
128  StaticShape result = input_shape;
129  result.set_depth(no_);
130  if (type_ == NT_LSTM_SUMMARY) result.set_width(1);
131  if (softmax_ != nullptr) return softmax_->OutputShape(result);
132  return result;
133 }

◆ PrintDW()

void tesseract::LSTM::PrintDW ( )

Definition at line 727 of file lstm.cpp.

727  {
728  tprintf("Delta state:%s\n", name_.string());
729  for (int w = 0; w < WT_COUNT; ++w) {
730  if (w == GFS && !Is2D()) continue;
731  tprintf("Gate %d, inputs\n", w);
732  for (int i = 0; i < ni_; ++i) {
733  tprintf("Row %d:", i);
734  for (int s = 0; s < ns_; ++s)
735  tprintf(" %g", gate_weights_[w].GetDW(s, i));
736  tprintf("\n");
737  }
738  tprintf("Gate %d, outputs\n", w);
739  for (int i = ni_; i < ni_ + ns_; ++i) {
740  tprintf("Row %d:", i - ni_);
741  for (int s = 0; s < ns_; ++s)
742  tprintf(" %g", gate_weights_[w].GetDW(s, i));
743  tprintf("\n");
744  }
745  tprintf("Gate %d, bias\n", w);
746  for (int s = 0; s < ns_; ++s)
747  tprintf(" %g", gate_weights_[w].GetDW(s, na_));
748  tprintf("\n");
749  }
750 }

◆ PrintW()

void tesseract::LSTM::PrintW ( )

Definition at line 701 of file lstm.cpp.

701  {
702  tprintf("Weight state:%s\n", name_.string());
703  for (int w = 0; w < WT_COUNT; ++w) {
704  if (w == GFS && !Is2D()) continue;
705  tprintf("Gate %d, inputs\n", w);
706  for (int i = 0; i < ni_; ++i) {
707  tprintf("Row %d:", i);
708  for (int s = 0; s < ns_; ++s)
709  tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
710  tprintf("\n");
711  }
712  tprintf("Gate %d, outputs\n", w);
713  for (int i = ni_; i < ni_ + ns_; ++i) {
714  tprintf("Row %d:", i - ni_);
715  for (int s = 0; s < ns_; ++s)
716  tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
717  tprintf("\n");
718  }
719  tprintf("Gate %d, bias\n", w);
720  for (int s = 0; s < ns_; ++s)
721  tprintf(" %g", gate_weights_[w].GetWeights(s)[na_]);
722  tprintf("\n");
723  }
724 }

◆ RemapOutputs()

int tesseract::LSTM::RemapOutputs ( int  old_no,
const std::vector< int > &  code_map 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 174 of file lstm.cpp.

174  {
175  if (softmax_ != nullptr) {
176  num_weights_ -= softmax_->num_weights();
177  num_weights_ += softmax_->RemapOutputs(old_no, code_map);
178  }
179  return num_weights_;
180 }

◆ Serialize()

bool tesseract::LSTM::Serialize ( TFile fp) const
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 207 of file lstm.cpp.

207  {
208  if (!Network::Serialize(fp)) return false;
209  if (!fp->Serialize(&na_)) return false;
210  for (int w = 0; w < WT_COUNT; ++w) {
211  if (w == GFS && !Is2D()) continue;
212  if (!gate_weights_[w].Serialize(IsTraining(), fp)) return false;
213  }
214  if (softmax_ != nullptr && !softmax_->Serialize(fp)) return false;
215  return true;
216 }

◆ SetEnableTraining()

void tesseract::LSTM::SetEnableTraining ( TrainingState  state)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 137 of file lstm.cpp.

137  {
138  if (state == TS_RE_ENABLE) {
139  // Enable only from temp disabled.
141  } else if (state == TS_TEMP_DISABLE) {
142  // Temp disable only from enabled.
143  if (training_ == TS_ENABLED) training_ = state;
144  } else {
145  if (state == TS_ENABLED && training_ != TS_ENABLED) {
146  for (int w = 0; w < WT_COUNT; ++w) {
147  if (w == GFS && !Is2D()) continue;
148  gate_weights_[w].InitBackward();
149  }
150  }
151  training_ = state;
152  }
153  if (softmax_ != nullptr) softmax_->SetEnableTraining(state);
154 }

◆ spec()

STRING tesseract::LSTM::spec ( ) const
inlineoverridevirtual

Reimplemented from tesseract::Network.

Definition at line 58 of file lstm.h.

58  {
59  STRING spec;
60  if (type_ == NT_LSTM)
61  spec.add_str_int("Lfx", ns_);
62  else if (type_ == NT_LSTM_SUMMARY)
63  spec.add_str_int("Lfxs", ns_);
64  else if (type_ == NT_LSTM_SOFTMAX)
65  spec.add_str_int("LS", ns_);
66  else if (type_ == NT_LSTM_SOFTMAX_ENCODED)
67  spec.add_str_int("LE", ns_);
68  if (softmax_ != nullptr) spec += softmax_->spec();
69  return spec;
70  }

◆ Update()

void tesseract::LSTM::Update ( float  learning_rate,
float  momentum,
float  adam_beta,
int  num_samples 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 667 of file lstm.cpp.

668  {
669 #if DEBUG_DETAIL > 3
670  PrintW();
671 #endif
672  for (int w = 0; w < WT_COUNT; ++w) {
673  if (w == GFS && !Is2D()) continue;
674  gate_weights_[w].Update(learning_rate, momentum, adam_beta, num_samples);
675  }
676  if (softmax_ != nullptr) {
677  softmax_->Update(learning_rate, momentum, adam_beta, num_samples);
678  }
679 #if DEBUG_DETAIL > 3
680  PrintDW();
681 #endif
682 }

The documentation for this class was generated from the following files:
tesseract::WeightMatrix::MatrixDotVector
void MatrixDotVector(const double *u, double *v) const
Definition: weightmatrix.cpp:243
GenericVector::init_to_size
void init_to_size(int size, const T &t)
Definition: genericvector.h:744
tesseract::NT_SOFTMAX
@ NT_SOFTMAX
Definition: network.h:68
tesseract::Network::training_
TrainingState training_
Definition: network.h:294
tesseract::Network::type_
NetworkType type_
Definition: network.h:293
STRING::string
const char * string() const
Definition: strngs.cpp:194
tesseract::TS_ENABLED
@ TS_ENABLED
Definition: network.h:95
tesseract::FullyConnected::DebugWeights
void DebugWeights() override
Definition: fullyconnected.cpp:101
tesseract::FullyConnected::FinishBackward
void FinishBackward(const TransposedArray &errors_t)
Definition: fullyconnected.cpp:289
tesseract::LSTM::DeSerialize
bool DeSerialize(TFile *fp) override
Definition: lstm.cpp:220
tesseract::WeightMatrix::ConvertToInt
void ConvertToInt()
Definition: weightmatrix.cpp:125
tesseract::WeightMatrix::InitWeightsFloat
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
Definition: weightmatrix.cpp:76
tesseract::CopyVector
void CopyVector(int n, const double *src, double *dest)
Definition: functions.h:169
STRING::add_str_int
void add_str_int(const char *str, int number)
Definition: strngs.cpp:377
tesseract::LSTM::GF1
@ GF1
Definition: lstm.h:36
tesseract::Network::Serialize
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:151
tesseract::LSTM::spec
STRING spec() const override
Definition: lstm.h:58
tesseract::MultiplyVectorsInPlace
void MultiplyVectorsInPlace(int n, const double *src, double *inout)
Definition: functions.h:179
tesseract::FullyConnected::ForwardTimeStep
void ForwardTimeStep(int t, double *output_line)
Definition: fullyconnected.cpp:185
tesseract::LSTM::Is2D
bool Is2D() const
Definition: lstm.h:119
tesseract::MultiplyAccumulate
void MultiplyAccumulate(int n, const double *u, const double *v, double *out)
Definition: functions.h:184
tesseract::NT_LSTM
@ NT_LSTM
Definition: network.h:60
tesseract::Network::num_weights_
int32_t num_weights_
Definition: network.h:299
tesseract::FullyConnected::BackwardTimeStep
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
Definition: fullyconnected.cpp:265
tesseract::NetworkIO::int_mode
bool int_mode() const
Definition: networkio.h:127
tesseract::Network::CreateFromFile
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:187
GenericVector
Definition: baseapi.h:37
tesseract::NetworkIO::ReadTimeStep
void ReadTimeStep(int t, double *output) const
Definition: networkio.cpp:598
tesseract::FullyConnected::ConvertToInt
void ConvertToInt() override
Definition: fullyconnected.cpp:96
tesseract::WeightMatrix::CountAlternators
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
Definition: weightmatrix.cpp:346
tesseract::Network::no_
int32_t no_
Definition: network.h:298
tesseract::CodeInBinary
void CodeInBinary(int n, int nf, double *vec)
Definition: functions.h:214
tesseract::Network::name
const STRING & name() const
Definition: network.h:138
tprintf
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:35
tesseract::WeightMatrix::NumOutputs
int NumOutputs() const
Definition: weightmatrix.h:101
tesseract::WeightMatrix::VectorDotMatrix
void VectorDotMatrix(const double *u, double *v) const
Definition: weightmatrix.cpp:274
tesseract::kErrClip
const double kErrClip
Definition: lstm.cpp:72
ASSERT_HOST
#define ASSERT_HOST(x)
Definition: errcode.h:88
tesseract::FullyConnected::Update
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
Definition: fullyconnected.cpp:298
tesseract::NetworkIO::CopyTimeStepGeneral
void CopyTimeStepGeneral(int dest_t, int dest_offset, int num_features, const NetworkIO &src, int src_t, int src_offset)
Definition: networkio.cpp:393
tesseract::Network::ni_
int32_t ni_
Definition: network.h:297
tesseract::Network::num_weights
int num_weights() const
Definition: network.h:119
tesseract::LSTM::GI
@ GI
Definition: lstm.h:35
tesseract::LSTM::WT_COUNT
@ WT_COUNT
Definition: lstm.h:40
tesseract::FullyConnected::InitWeights
int InitWeights(float range, TRand *randomizer) override
Definition: fullyconnected.cpp:77
tesseract::NF_ADAM
@ NF_ADAM
Definition: network.h:88
tesseract::NetworkIO::Print
void Print(int num) const
Definition: networkio.cpp:366
tesseract::FullyConnected::Serialize
bool Serialize(TFile *fp) const override
Definition: fullyconnected.cpp:106
tesseract::LSTM::LSTM
LSTM(const STRING &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
Definition: lstm.cpp:99
tesseract::FullyConnected::SetupForward
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
Definition: fullyconnected.cpp:173
tesseract::TS_RE_ENABLE
@ TS_RE_ENABLE
Definition: network.h:99
tesseract::FD_WIDTH
@ FD_WIDTH
Definition: stridemap.h:35
tesseract::LSTM::PrintW
void PrintW()
Definition: lstm.cpp:701
tesseract::ClipVector
void ClipVector(int n, T lower, T upper, T *vec)
Definition: functions.h:208
tesseract::AccumulateVector
void AccumulateVector(int n, const double *src, double *dest)
Definition: functions.h:174
PARALLEL_IF_OPENMP
#define PARALLEL_IF_OPENMP(__num_threads)
Definition: lstm.cpp:60
tesseract::Network::IsTraining
bool IsTraining() const
Definition: network.h:115
tesseract::Network::DisplayBackward
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:299
tesseract::WeightMatrix::Debug2D
void Debug2D(const char *msg)
Definition: weightmatrix.cpp:377
tesseract::NT_LSTM_SOFTMAX_ENCODED
@ NT_LSTM_SOFTMAX_ENCODED
Definition: network.h:76
tesseract::Network::SetRandomizer
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:138
tesseract::Network::DisplayForward
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:288
tesseract::Network::type
NetworkType type() const
Definition: network.h:112
tesseract::LSTM::Serialize
bool Serialize(TFile *fp) const override
Definition: lstm.cpp:207
tesseract::WeightMatrix::InitBackward
void InitBackward()
Definition: weightmatrix.cpp:153
tesseract::LSTM::GFS
@ GFS
Definition: lstm.h:38
tesseract::TS_TEMP_DISABLE
@ TS_TEMP_DISABLE
Definition: network.h:97
tesseract::Network::needs_to_backprop_
bool needs_to_backprop_
Definition: network.h:295
tesseract::Network::Network
Network()
Definition: network.cpp:76
tesseract::WeightMatrix::RoundInputs
int RoundInputs(int size) const
Definition: weightmatrix.h:92
tesseract::StrideMap::Size
int Size(FlexDimensions dimension) const
Definition: stridemap.h:114
tesseract::Network::TestFlag
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
tesseract::WeightMatrix::SumOuterTransposed
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
Definition: weightmatrix.cpp:284
tesseract::LSTM::CI
@ CI
Definition: lstm.h:34
tesseract::kStateClip
const double kStateClip
Definition: lstm.cpp:70
SECTION_IF_OPENMP
#define SECTION_IF_OPENMP
Definition: lstm.cpp:61
tesseract::FD_HEIGHT
@ FD_HEIGHT
Definition: stridemap.h:34
tesseract::NetworkIO::FuncMultiply3Add
void FuncMultiply3Add(const NetworkIO &v_io, int t, const double *w, double *product) const
Definition: networkio.h:299
tesseract::FullyConnected::CountAlternators
void CountAlternators(const Network &other, double *same, double *changed) const override
Definition: fullyconnected.cpp:306
tesseract::WeightMatrix::Update
void Update(double learning_rate, double momentum, double adam_beta, int num_samples)
Definition: weightmatrix.cpp:314
Modulo
int Modulo(int a, int b)
Definition: helpers.h:158
tesseract::NT_LSTM_SOFTMAX
@ NT_LSTM_SOFTMAX
Definition: network.h:75
tesseract::FullyConnected::SetEnableTraining
void SetEnableTraining(TrainingState state) override
Definition: fullyconnected.cpp:61
tesseract::NT_LSTM_SUMMARY
@ NT_LSTM_SUMMARY
Definition: network.h:61
tesseract::LSTM::GO
@ GO
Definition: lstm.h:37
tesseract::Network::name_
STRING name_
Definition: network.h:300
tesseract::FullyConnected::RemapOutputs
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
Definition: fullyconnected.cpp:87
tesseract::FullyConnected::OutputShape
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: fullyconnected.cpp:46
STRING
Definition: strngs.h:45
tesseract::FullyConnected::spec
STRING spec() const override
Definition: fullyconnected.h:37
tesseract::NetworkIO::WriteTimeStep
void WriteTimeStep(int t, const double *input)
Definition: networkio.cpp:645
tesseract::SumVectors
void SumVectors(int n, const double *v1, const double *v2, const double *v3, const double *v4, const double *v5, double *sum)
Definition: functions.h:192
END_PARALLEL_IF_OPENMP
#define END_PARALLEL_IF_OPENMP
Definition: lstm.cpp:62
tesseract::NetworkIO::f
float * f(int t)
Definition: networkio.h:115
tesseract::NetworkIO::WriteTimeStepPart
void WriteTimeStepPart(int t, int offset, int num_features, const double *input)
Definition: networkio.cpp:651
tesseract::LSTM::PrintDW
void PrintDW()
Definition: lstm.cpp:727
tesseract::NetworkIO::Transpose
void Transpose(TransposedArray *dest) const
Definition: networkio.cpp:964
tesseract::NetworkIO::i
const int8_t * i(int t) const
Definition: networkio.h:123