#include <lstmtrainer.h>
|
| LSTMTrainer () |
|
| LSTMTrainer (FileReader file_reader, FileWriter file_writer, CheckPointReader checkpoint_reader, CheckPointWriter checkpoint_writer, const char *model_base, const char *checkpoint_name, int debug_interval, int64_t max_memory) |
|
virtual | ~LSTMTrainer () |
|
bool | TryLoadingCheckpoint (const char *filename, const char *old_traineddata) |
|
void | InitCharSet (const std::string &traineddata_path) |
|
void | InitCharSet (const TessdataManager &mgr) |
|
bool | InitNetwork (const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta) |
|
int | InitTensorFlowNetwork (const std::string &tf_proto) |
|
void | InitIterations () |
|
double | ActivationError () const |
|
double | CharError () const |
|
const double * | error_rates () const |
|
double | best_error_rate () const |
|
int | best_iteration () const |
|
int | learning_iteration () const |
|
int32_t | improvement_steps () const |
|
void | set_perfect_delay (int delay) |
|
const GenericVector< char > & | best_trainer () const |
|
double | NewSingleError (ErrorTypes type) const |
|
double | LastSingleError (ErrorTypes type) const |
|
const DocumentCache & | training_data () const |
|
DocumentCache * | mutable_training_data () |
|
Trainability | GridSearchDictParams (const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, STRING *results) |
|
void | DebugNetwork () |
|
bool | LoadAllTrainingData (const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, bool randomly_rotate) |
|
bool | MaintainCheckpoints (TestCallback tester, STRING *log_msg) |
|
bool | MaintainCheckpointsSpecific (int iteration, const GenericVector< char > *train_model, const GenericVector< char > *rec_model, TestCallback tester, STRING *log_msg) |
|
void | PrepareLogMsg (STRING *log_msg) const |
|
void | LogIterations (const char *intro_str, STRING *log_msg) const |
|
bool | TransitionTrainingStage (float error_threshold) |
|
int | CurrentTrainingStage () const |
|
bool | Serialize (SerializeAmount serialize_amount, const TessdataManager *mgr, TFile *fp) const |
|
bool | DeSerialize (const TessdataManager *mgr, TFile *fp) |
|
void | StartSubtrainer (STRING *log_msg) |
|
SubTrainerResult | UpdateSubtrainer (STRING *log_msg) |
|
void | ReduceLearningRates (LSTMTrainer *samples_trainer, STRING *log_msg) |
|
int | ReduceLayerLearningRates (double factor, int num_samples, LSTMTrainer *samples_trainer) |
|
bool | EncodeString (const STRING &str, GenericVector< int > *labels) const |
|
const ImageData * | TrainOnLine (LSTMTrainer *samples_trainer, bool batch) |
|
Trainability | TrainOnLine (const ImageData *trainingdata, bool batch) |
|
Trainability | PrepareForBackward (const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets) |
|
bool | SaveTrainingDump (SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const |
|
bool | ReadTrainingDump (const GenericVector< char > &data, LSTMTrainer *trainer) const |
|
bool | ReadSizedTrainingDump (const char *data, int size, LSTMTrainer *trainer) const |
|
bool | ReadLocalTrainingDump (const TessdataManager *mgr, const char *data, int size) |
|
void | SetupCheckpointInfo () |
|
bool | SaveTraineddata (const STRING &filename) |
|
void | SaveRecognitionDump (GenericVector< char > *data) const |
|
STRING | DumpFilename () const |
|
void | FillErrorBuffer (double new_error, ErrorTypes type) |
|
std::vector< int > | MapRecoder (const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const |
|
| LSTMRecognizer () |
|
| LSTMRecognizer (const STRING language_data_path_prefix) |
|
| ~LSTMRecognizer () |
|
int | NumOutputs () const |
|
int | training_iteration () const |
|
int | sample_iteration () const |
|
double | learning_rate () const |
|
LossType | OutputLossType () const |
|
bool | SimpleTextOutput () const |
|
bool | IsIntMode () const |
|
bool | IsRecoding () const |
|
bool | IsTensorFlow () const |
|
GenericVector< STRING > | EnumerateLayers () const |
|
Network * | GetLayer (const STRING &id) const |
|
float | GetLayerLearningRate (const STRING &id) const |
|
void | ScaleLearningRate (double factor) |
|
void | ScaleLayerLearningRate (const STRING &id, double factor) |
|
void | ConvertToInt () |
|
const UNICHARSET & | GetUnicharset () const |
|
const UnicharCompress & | GetRecoder () const |
|
const Dict * | GetDict () const |
|
void | SetIteration (int iteration) |
|
int | NumInputs () const |
|
int | null_char () const |
|
bool | Load (const ParamsVectors *params, const char *lang, TessdataManager *mgr) |
|
bool | Serialize (const TessdataManager *mgr, TFile *fp) const |
|
bool | DeSerialize (const TessdataManager *mgr, TFile *fp) |
|
bool | LoadCharsets (const TessdataManager *mgr) |
|
bool | LoadRecoder (TFile *fp) |
|
bool | LoadDictionary (const ParamsVectors *params, const char *lang, TessdataManager *mgr) |
|
void | RecognizeLine (const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0) |
|
void | OutputStats (const NetworkIO &outputs, float *min_output, float *mean_output, float *sd) |
|
bool | RecognizeLine (const ImageData &image_data, bool invert, bool debug, bool re_invert, bool upside_down, float *scale_factor, NetworkIO *inputs, NetworkIO *outputs) |
|
STRING | DecodeLabels (const GenericVector< int > &labels) |
|
void | DisplayForward (const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window) |
|
void | LabelsFromOutputs (const NetworkIO &outputs, GenericVector< int > *labels, GenericVector< int > *xcoords) |
|
|
void | InitCharSet () |
|
void | SetNullChar () |
|
void | EmptyConstructor () |
|
bool | DebugLSTMTraining (const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs) |
|
void | DisplayTargets (const NetworkIO &targets, const char *window_name, ScrollView **window) |
|
bool | ComputeTextTargets (const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets) |
|
bool | ComputeCTCTargets (const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets) |
|
double | ComputeErrorRates (const NetworkIO &deltas, double char_error, double word_error) |
|
double | ComputeRMSError (const NetworkIO &deltas) |
|
double | ComputeWinnerError (const NetworkIO &deltas) |
|
double | ComputeCharError (const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str) |
|
double | ComputeWordError (STRING *truth_str, STRING *ocr_str) |
|
void | UpdateErrorBuffer (double new_error, ErrorTypes type) |
|
void | RollErrorBuffers () |
|
STRING | UpdateErrorGraph (int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester) |
|
void | SetRandomSeed () |
|
void | DisplayLSTMOutput (const GenericVector< int > &labels, const GenericVector< int > &xcoords, int height, ScrollView *window) |
|
void | DebugActivationPath (const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords) |
|
void | DebugActivationRange (const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end) |
|
void | LabelsViaReEncode (const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords) |
|
void | LabelsViaSimpleText (const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords) |
|
const char * | DecodeLabel (const GenericVector< int > &labels, int start, int *end, int *decoded) |
|
const char * | DecodeSingleLabel (int label) |
|
Definition at line 89 of file lstmtrainer.h.
◆ LSTMTrainer() [1/2]
tesseract::LSTMTrainer::LSTMTrainer |
( |
| ) |
|
◆ LSTMTrainer() [2/2]
◆ ~LSTMTrainer()
tesseract::LSTMTrainer::~LSTMTrainer |
( |
| ) |
|
|
virtual |
◆ ActivationError()
double tesseract::LSTMTrainer::ActivationError |
( |
| ) |
const |
|
inline |
◆ best_error_rate()
double tesseract::LSTMTrainer::best_error_rate |
( |
| ) |
const |
|
inline |
◆ best_iteration()
int tesseract::LSTMTrainer::best_iteration |
( |
| ) |
const |
|
inline |
◆ best_trainer()
const GenericVector<char>& tesseract::LSTMTrainer::best_trainer |
( |
| ) |
const |
|
inline |
◆ CharError()
double tesseract::LSTMTrainer::CharError |
( |
| ) |
const |
|
inline |
◆ ComputeCharError()
Definition at line 1191 of file lstmtrainer.cpp.
1196 for (
int i = 0; i < truth_str.
size(); ++i) {
1198 ++label_counts[truth_str[i]];
1202 for (
int i = 0; i < ocr_str.
size(); ++i) {
1204 --label_counts[ocr_str[i]];
1207 int char_errors = 0;
1208 for (
int i = 0; i < label_counts.
size(); ++i) {
1209 char_errors += abs(label_counts[i]);
1211 if (truth_size == 0) {
1212 return (char_errors == 0) ? 0.0 : 1.0;
1214 return static_cast<double>(char_errors) / truth_size;
◆ ComputeCTCTargets()
◆ ComputeErrorRates()
double tesseract::LSTMTrainer::ComputeErrorRates |
( |
const NetworkIO & |
deltas, |
|
|
double |
char_error, |
|
|
double |
word_error |
|
) |
| |
|
protected |
◆ ComputeRMSError()
double tesseract::LSTMTrainer::ComputeRMSError |
( |
const NetworkIO & |
deltas | ) |
|
|
protected |
Definition at line 1154 of file lstmtrainer.cpp.
1155 double total_error = 0.0;
1156 int width = deltas.Width();
1157 int num_classes = deltas.NumFeatures();
1158 for (
int t = 0; t < width; ++t) {
1159 const float* class_errs = deltas.f(t);
1160 for (
int c = 0; c < num_classes; ++c) {
1161 double error = class_errs[c];
1162 total_error += error * error;
1165 return sqrt(total_error / (width * num_classes));
◆ ComputeTextTargets()
Definition at line 1103 of file lstmtrainer.cpp.
1106 if (truth_labels.
size() > targets->Width()) {
1107 tprintf(
"Error: transcription %s too long to fit into target of width %d\n",
1108 DecodeLabels(truth_labels).
string(), targets->Width());
1111 for (
int i = 0; i < truth_labels.
size() && i < targets->Width(); ++i) {
1112 targets->SetActivations(i, truth_labels[i], 1.0);
1114 for (
int i = truth_labels.
size(); i < targets->Width(); ++i) {
◆ ComputeWinnerError()
double tesseract::LSTMTrainer::ComputeWinnerError |
( |
const NetworkIO & |
deltas | ) |
|
|
protected |
Definition at line 1173 of file lstmtrainer.cpp.
1175 int width = deltas.Width();
1176 int num_classes = deltas.NumFeatures();
1177 for (
int t = 0; t < width; ++t) {
1178 const float* class_errs = deltas.f(t);
1179 for (
int c = 0; c < num_classes; ++c) {
1180 float abs_delta = fabs(class_errs[c]);
1183 if (0.5 <= abs_delta)
1187 return static_cast<double>(num_errors) / width;
◆ ComputeWordError()
double tesseract::LSTMTrainer::ComputeWordError |
( |
STRING * |
truth_str, |
|
|
STRING * |
ocr_str |
|
) |
| |
|
protected |
Definition at line 1219 of file lstmtrainer.cpp.
1220 using StrMap = std::unordered_map<std::string, int, std::hash<std::string>>;
1222 truth_str->
split(
' ', &truth_words);
1223 if (truth_words.
empty())
return 0.0;
1224 ocr_str->
split(
' ', &ocr_words);
1226 for (
int i = 0; i < truth_words.
size(); ++i) {
1227 std::string truth_word(truth_words[i].
string());
1228 auto it = word_counts.find(truth_word);
1229 if (it == word_counts.end())
1230 word_counts.insert(std::make_pair(truth_word, 1));
1234 for (
int i = 0; i < ocr_words.
size(); ++i) {
1235 std::string ocr_word(ocr_words[i].
string());
1236 auto it = word_counts.find(ocr_word);
1237 if (it == word_counts.end())
1238 word_counts.insert(std::make_pair(ocr_word, -1));
1242 int word_recall_errs = 0;
1243 for (StrMap::const_iterator it = word_counts.begin(); it != word_counts.end();
1245 if (it->second > 0) word_recall_errs += it->second;
1247 return static_cast<double>(word_recall_errs) / truth_words.
size();
◆ CurrentTrainingStage()
int tesseract::LSTMTrainer::CurrentTrainingStage |
( |
| ) |
const |
|
inline |
◆ DebugLSTMTraining()
Definition at line 1029 of file lstmtrainer.cpp.
1035 if (truth_text.
string() ==
nullptr || truth_text.
length() <= 0) {
1036 tprintf(
"Empty truth string at decode time!\n");
1045 tprintf(
"Iteration %d: GROUND TRUTH : %s\n",
1047 if (truth_text != text) {
1048 tprintf(
"Iteration %d: ALIGNED TRUTH : %s\n",
1052 tprintf(
"TRAINING activation path for truth string %s\n",
◆ DebugNetwork()
void tesseract::LSTMTrainer::DebugNetwork |
( |
| ) |
|
◆ DeSerialize()
Definition at line 466 of file lstmtrainer.cpp.
472 tprintf(
"Warning: LSTMTrainer deserialized an LSTMRecognizer!\n");
481 if (!error_buffer.DeSerialize(fp))
return false;
486 if (!fp->DeSerialize(&amount))
return false;
487 if (amount ==
LIGHT)
return true;
501 if (sub_data.
empty()) {
◆ DisplayTargets()
void tesseract::LSTMTrainer::DisplayTargets |
( |
const NetworkIO & |
targets, |
|
|
const char * |
window_name, |
|
|
ScrollView ** |
window |
|
) |
| |
|
protected |
Definition at line 1066 of file lstmtrainer.cpp.
1068 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics.
1069 int width = targets.Width();
1070 int num_features = targets.NumFeatures();
1073 for (
int c = 0; c < num_features; ++c) {
1077 for (
int t = 0; t < width; ++t) {
1078 double target = targets.f(t)[c];
1082 (*window)->SetCursor(t - 1, 0);
1085 (*window)->DrawTo(t, target);
1086 }
else if (start_t >= 0) {
1087 (*window)->DrawTo(t, 0);
1088 (*window)->DrawTo(start_t - 1, 0);
1093 (*window)->DrawTo(width, 0);
1094 (*window)->DrawTo(start_t - 1, 0);
1097 (*window)->Update();
1098 #endif // GRAPHICS_DISABLED
◆ DumpFilename()
STRING tesseract::LSTMTrainer::DumpFilename |
( |
| ) |
const |
◆ EmptyConstructor()
void tesseract::LSTMTrainer::EmptyConstructor |
( |
| ) |
|
|
protected |
◆ EncodeString() [1/2]
Definition at line 716 of file lstmtrainer.cpp.
720 tprintf(
"Empty truth string!\n");
728 if (unicharset.
encode_string(cleaned.c_str(),
true, &internal_labels,
nullptr,
731 for (
int i = 0; i < internal_labels.
size(); ++i) {
732 if (recoder !=
nullptr) {
735 int len = recoder->EncodeUnichar(internal_labels[i], &code);
737 for (
int j = 0; j < len; ++j) {
751 if (success)
return true;
753 tprintf(
"Encoding of string failed! Failure bytes:");
754 while (err_index < cleaned.size()) {
755 tprintf(
" %x", cleaned[err_index++]);
◆ EncodeString() [2/2]
bool tesseract::LSTMTrainer::EncodeString |
( |
const STRING & |
str, |
|
|
GenericVector< int > * |
labels |
|
) |
| const |
|
inline |
◆ error_rates()
const double* tesseract::LSTMTrainer::error_rates |
( |
| ) |
const |
|
inline |
◆ FillErrorBuffer()
void tesseract::LSTMTrainer::FillErrorBuffer |
( |
double |
new_error, |
|
|
ErrorTypes |
type |
|
) |
| |
◆ GridSearchDictParams()
Trainability tesseract::LSTMTrainer::GridSearchDictParams |
( |
const ImageData * |
trainingdata, |
|
|
int |
iteration, |
|
|
double |
min_dict_ratio, |
|
|
double |
dict_ratio_step, |
|
|
double |
max_dict_ratio, |
|
|
double |
min_cert_offset, |
|
|
double |
cert_offset_step, |
|
|
double |
max_cert_offset, |
|
|
STRING * |
results |
|
) |
| |
Definition at line 241 of file lstmtrainer.cpp.
246 NetworkIO fwd_outputs, targets;
259 base_search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
266 for (
double r = min_dict_ratio; r < max_dict_ratio; r += dict_ratio_step) {
267 for (
double c = min_cert_offset; c < max_cert_offset;
268 c += cert_offset_step) {
270 search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
275 if ((r == min_dict_ratio && c == min_cert_offset) ||
276 !std::isfinite(word_error)) {
279 tprintf(
"r=%g, c=%g, truth=%s, ocr=%s, wderr=%g, truth[0]=%d\n", r, c,
◆ improvement_steps()
int32_t tesseract::LSTMTrainer::improvement_steps |
( |
| ) |
const |
|
inline |
◆ InitCharSet() [1/3]
void tesseract::LSTMTrainer::InitCharSet |
( |
| ) |
|
|
protected |
Definition at line 992 of file lstmtrainer.cpp.
998 "Must provide a traineddata containing lstm_unicharset and"
999 " lstm_recoder!\n" !=
nullptr);
◆ InitCharSet() [2/3]
void tesseract::LSTMTrainer::InitCharSet |
( |
const std::string & |
traineddata_path | ) |
|
|
inline |
◆ InitCharSet() [3/3]
◆ InitIterations()
void tesseract::LSTMTrainer::InitIterations |
( |
| ) |
|
◆ InitNetwork()
bool tesseract::LSTMTrainer::InitNetwork |
( |
const STRING & |
network_spec, |
|
|
int |
append_index, |
|
|
int |
net_flags, |
|
|
float |
weight_range, |
|
|
float |
learning_rate, |
|
|
float |
momentum, |
|
|
float |
adam_beta |
|
) |
| |
Definition at line 172 of file lstmtrainer.cpp.
182 append_index, net_flags, weight_range,
187 tprintf(
"Built network:%s from request %s\n",
190 "Training parameters:\n Debug interval = %d,"
191 " weights = %g, learning rate = %g, momentum=%g\n",
◆ InitTensorFlowNetwork()
int tesseract::LSTMTrainer::InitTensorFlowNetwork |
( |
const std::string & |
tf_proto | ) |
|
◆ LastSingleError()
double tesseract::LSTMTrainer::LastSingleError |
( |
ErrorTypes |
type | ) |
const |
|
inline |
◆ learning_iteration()
int tesseract::LSTMTrainer::learning_iteration |
( |
| ) |
const |
|
inline |
◆ LoadAllTrainingData()
◆ LogIterations()
void tesseract::LSTMTrainer::LogIterations |
( |
const char * |
intro_str, |
|
|
STRING * |
log_msg |
|
) |
| const |
◆ MaintainCheckpoints()
bool tesseract::LSTMTrainer::MaintainCheckpoints |
( |
TestCallback |
tester, |
|
|
STRING * |
log_msg |
|
) |
| |
Definition at line 310 of file lstmtrainer.cpp.
336 *log_msg +=
UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
349 *log_msg +=
" failed to write best model:";
351 *log_msg +=
" wrote best model:";
354 *log_msg += best_model_name;
359 *log_msg +=
UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
363 *log_msg +=
"\nDivergence! ";
380 result = sub_trainer_result !=
STR_NONE;
388 *log_msg +=
" failed to write checkpoint.";
390 *log_msg +=
" wrote checkpoint.";
◆ MaintainCheckpointsSpecific()
◆ MapRecoder()
std::vector< int > tesseract::LSTMTrainer::MapRecoder |
( |
const UNICHARSET & |
old_chset, |
|
|
const UnicharCompress & |
old_recoder |
|
) |
| const |
Definition at line 957 of file lstmtrainer.cpp.
961 std::vector<int> code_map(num_new_codes, -1);
962 for (
int c = 0; c < num_new_codes; ++c) {
966 for (
int uid = 0; uid <= num_new_unichars; ++uid) {
970 while (code_index < length && codes(code_index) != c) ++code_index;
971 if (code_index == length)
continue;
974 uid < num_new_unichars
976 : old_chset.
size() - 1;
977 if (old_uid == INVALID_UNICHAR_ID)
continue;
979 RecodedCharID old_codes;
980 if (code_index < old_recoder.EncodeUnichar(old_uid, &old_codes)) {
981 old_code = old_codes(code_index);
985 code_map[c] = old_code;
◆ mutable_training_data()
◆ NewSingleError()
double tesseract::LSTMTrainer::NewSingleError |
( |
ErrorTypes |
type | ) |
const |
|
inline |
◆ PrepareForBackward()
Definition at line 796 of file lstmtrainer.cpp.
799 if (trainingdata ==
nullptr) {
800 tprintf(
"Null trainingdata.\n");
807 if (!
EncodeString(trainingdata->transcription(), &truth_labels)) {
808 tprintf(
"Can't encode transcription: '%s' in language '%s'\n",
809 trainingdata->transcription().string(),
810 trainingdata->language().string());
813 bool upside_down =
false;
823 for (
int c = 0; c < truth_labels.
size(); ++c) {
831 while (w < truth_labels.
size() &&
834 if (w == truth_labels.
size()) {
835 tprintf(
"Blank transcription: %s\n",
836 trainingdata->transcription().string());
841 bool invert = trainingdata->boxes().empty();
842 if (!
RecognizeLine(*trainingdata, invert, debug, invert, upside_down,
843 &image_scale, &inputs, fwd_outputs)) {
844 tprintf(
"Image not trainable\n");
851 tprintf(
"Compute simple targets failed!\n");
854 }
else if (loss_type ==
LT_CTC) {
856 tprintf(
"Compute CTC targets failed!\n");
860 tprintf(
"Logistic outputs not implemented yet!\n");
867 if (loss_type !=
LT_CTC) {
872 tprintf(
"Input width was %d\n", inputs.Width());
877 targets->SubtractAllFromFloat(*fwd_outputs);
879 if (truth_text != ocr_text) {
880 tprintf(
"Iteration %d: BEST OCR TEXT : %s\n",
888 tprintf(
"File %s line %d %s:\n", trainingdata->imagefilename().string(),
889 trainingdata->page_number(), delta_error == 0.0 ?
"(Perfect)" :
"");
891 if (delta_error == 0.0)
return PERFECT;
◆ PrepareLogMsg()
void tesseract::LSTMTrainer::PrepareLogMsg |
( |
STRING * |
log_msg | ) |
const |
◆ ReadLocalTrainingDump()
bool tesseract::LSTMTrainer::ReadLocalTrainingDump |
( |
const TessdataManager * |
mgr, |
|
|
const char * |
data, |
|
|
int |
size |
|
) |
| |
Definition at line 909 of file lstmtrainer.cpp.
912 tprintf(
"Warning: data size is 0 in LSTMTrainer::ReadLocalTrainingDump\n");
◆ ReadSizedTrainingDump()
bool tesseract::LSTMTrainer::ReadSizedTrainingDump |
( |
const char * |
data, |
|
|
int |
size, |
|
|
LSTMTrainer * |
trainer |
|
) |
| const |
|
inline |
Definition at line 296 of file lstmtrainer.h.
298 return trainer->ReadLocalTrainingDump(&
mgr_, data, size);
◆ ReadTrainingDump()
◆ ReduceLayerLearningRates()
int tesseract::LSTMTrainer::ReduceLayerLearningRates |
( |
double |
factor, |
|
|
int |
num_samples, |
|
|
LSTMTrainer * |
samples_trainer |
|
) |
| |
Definition at line 607 of file lstmtrainer.cpp.
615 int num_layers = layers.
size();
620 for (
int i = 0; i < LR_COUNT; ++i) {
624 double momentum_factor = 1.0 / (1.0 -
momentum_);
626 samples_trainer->SaveTrainingDump(
LIGHT,
this, &orig_trainer);
627 for (
int i = 0; i < num_layers; ++i) {
628 Network* layer =
GetLayer(layers[i]);
629 num_weights[i] = layer->IsTraining() ? layer->num_weights() : 0;
632 for (
int s = 0; s < num_samples; ++s) {
634 for (
int ww = 0; ww < LR_COUNT; ++ww) {
636 float ww_factor = momentum_factor;
637 if (ww == LR_DOWN) ww_factor *= factor;
640 samples_trainer->ReadTrainingDump(orig_trainer, ©_trainer);
642 copy_trainer.network_->Update(0.0, 0.0, 0.0, 0);
644 for (
int i = 0; i < num_layers; ++i) {
645 if (num_weights[i] == 0)
continue;
646 copy_trainer.ScaleLayerLearningRate(layers[i], ww_factor);
648 copy_trainer.SetIteration(iteration);
651 const ImageData* trainingdata =
652 copy_trainer.TrainOnLine(samples_trainer,
true);
653 if (trainingdata ==
nullptr)
continue;
656 samples_trainer->SaveTrainingDump(
LIGHT, ©_trainer, &updated_trainer);
657 for (
int i = 0; i < num_layers; ++i) {
658 if (num_weights[i] == 0)
continue;
660 samples_trainer->ReadTrainingDump(updated_trainer, &layer_trainer);
661 Network* layer = layer_trainer.GetLayer(layers[i]);
664 layer_trainer.training_iteration_ + 1);
666 layer->Update(0.0, 0.0, 0.0, 0);
668 layer_trainer.TrainOnLine(trainingdata,
true);
670 float before_bad = bad_sums[ww][i];
671 float before_ok = ok_sums[ww][i];
672 layer->CountAlternators(*copy_trainer.GetLayer(layers[i]),
673 &ok_sums[ww][i], &bad_sums[ww][i]);
675 bad_sums[ww][i] + ok_sums[ww][i] - before_bad - before_ok;
677 bad_frac = (bad_sums[ww][i] - before_bad) / bad_frac;
683 for (
int i = 0; i < num_layers; ++i) {
684 if (num_weights[i] == 0)
continue;
685 Network* layer =
GetLayer(layers[i]);
687 double total_down = bad_sums[LR_DOWN][i] + ok_sums[LR_DOWN][i];
688 double total_same = bad_sums[LR_SAME][i] + ok_sums[LR_SAME][i];
689 double frac_down = bad_sums[LR_DOWN][i] / total_down;
690 double frac_same = bad_sums[LR_SAME][i] / total_same;
691 tprintf(
"Layer %d=%s: lr %g->%g%%, lr %g->%g%%", i, layer->name().string(),
692 lr * factor, 100.0 * frac_down, lr, 100.0 * frac_same);
701 if (num_lowered == 0) {
703 for (
int i = 0; i < num_layers; ++i) {
704 if (num_weights[i] > 0) {
◆ ReduceLearningRates()
void tesseract::LSTMTrainer::ReduceLearningRates |
( |
LSTMTrainer * |
samples_trainer, |
|
|
STRING * |
log_msg |
|
) |
| |
◆ RollErrorBuffers()
void tesseract::LSTMTrainer::RollErrorBuffers |
( |
| ) |
|
|
protected |
◆ SaveRecognitionDump()
void tesseract::LSTMTrainer::SaveRecognitionDump |
( |
GenericVector< char > * |
data | ) |
const |
◆ SaveTraineddata()
bool tesseract::LSTMTrainer::SaveTraineddata |
( |
const STRING & |
filename | ) |
|
◆ SaveTrainingDump()
◆ Serialize()
Definition at line 429 of file lstmtrainer.cpp.
437 if (!error_buffer.Serialize(fp))
return false;
441 uint8_t amount = serialize_amount;
442 if (!fp->Serialize(&amount))
return false;
443 if (serialize_amount ==
LIGHT)
return true;
458 if (!sub_data.
Serialize(fp))
return false;
◆ set_perfect_delay()
void tesseract::LSTMTrainer::set_perfect_delay |
( |
int |
delay | ) |
|
|
inline |
◆ SetNullChar()
void tesseract::LSTMTrainer::SetNullChar |
( |
| ) |
|
|
protected |
◆ SetupCheckpointInfo()
void tesseract::LSTMTrainer::SetupCheckpointInfo |
( |
| ) |
|
◆ StartSubtrainer()
void tesseract::LSTMTrainer::StartSubtrainer |
( |
STRING * |
log_msg | ) |
|
Definition at line 515 of file lstmtrainer.cpp.
519 *log_msg +=
" Failed to revert to previous best for trial!";
523 log_msg->
add_str_int(
" Trial sub_trainer_ from iteration ",
◆ training_data()
const DocumentCache& tesseract::LSTMTrainer::training_data |
( |
| ) |
const |
|
inline |
◆ TrainOnLine() [1/2]
Definition at line 763 of file lstmtrainer.cpp.
765 NetworkIO fwd_outputs, targets;
784 #ifndef GRAPHICS_DISABLED
788 #endif // GRAPHICS_DISABLED
◆ TrainOnLine() [2/2]
Definition at line 259 of file lstmtrainer.h.
261 const ImageData* image =
262 samples_trainer->training_data_.GetPageBySerial(sample_index);
263 if (image !=
nullptr) {
◆ TransitionTrainingStage()
bool tesseract::LSTMTrainer::TransitionTrainingStage |
( |
float |
error_threshold | ) |
|
◆ TryLoadingCheckpoint()
bool tesseract::LSTMTrainer::TryLoadingCheckpoint |
( |
const char * |
filename, |
|
|
const char * |
old_traineddata |
|
) |
| |
Definition at line 129 of file lstmtrainer.cpp.
133 tprintf(
"Loaded file %s, unpacking...\n", filename);
136 if (((old_traineddata ==
nullptr || *old_traineddata ==
'\0') &&
138 filename == old_traineddata) {
143 if (old_traineddata ==
nullptr || *old_traineddata ==
'\0') {
144 tprintf(
"Must supply the old traineddata for code conversion!\n");
147 TessdataManager old_mgr;
154 UnicharCompress old_recoder;
155 if (!old_recoder.DeSerialize(&fp))
return false;
156 std::vector<int> code_map =
MapRecoder(old_chset, old_recoder);
◆ UpdateErrorBuffer()
void tesseract::LSTMTrainer::UpdateErrorBuffer |
( |
double |
new_error, |
|
|
ErrorTypes |
type |
|
) |
| |
|
protected |
Definition at line 1252 of file lstmtrainer.cpp.
1257 double buffer_sum = 0.0;
1259 double mean = buffer_sum / mean_count;
◆ UpdateErrorGraph()
Definition at line 1284 of file lstmtrainer.cpp.
1321 double two_percent_more = error_rate + 2.0;
1328 tprintf(
"2 Percent improvement time=%d, best error was %g @ %d\n",
1333 if (tester !=
nullptr) {
◆ UpdateSubtrainer()
Definition at line 545 of file lstmtrainer.cpp.
548 double sub_margin = (training_error - sub_error) / sub_error;
557 int target_iteration =
562 STRING batch_log =
"Sub:";
566 *log_msg += batch_log;
568 sub_margin = (training_error - sub_error) / sub_error;
576 log_msg->
add_str_int(
" Sub trainer wins at iteration ",
◆ align_win_
◆ best_error_history_
GenericVector<double> tesseract::LSTMTrainer::best_error_history_ |
|
protected |
◆ best_error_iterations_
GenericVector<int> tesseract::LSTMTrainer::best_error_iterations_ |
|
protected |
◆ best_error_rate_
double tesseract::LSTMTrainer::best_error_rate_ |
|
protected |
◆ best_error_rates_
double tesseract::LSTMTrainer::best_error_rates_[ET_COUNT] |
|
protected |
◆ best_iteration_
int tesseract::LSTMTrainer::best_iteration_ |
|
protected |
◆ best_model_data_
◆ best_model_name_
STRING tesseract::LSTMTrainer::best_model_name_ |
|
protected |
◆ best_trainer_
◆ checkpoint_iteration_
int tesseract::LSTMTrainer::checkpoint_iteration_ |
|
protected |
◆ checkpoint_name_
STRING tesseract::LSTMTrainer::checkpoint_name_ |
|
protected |
◆ checkpoint_reader_
◆ checkpoint_writer_
◆ ctc_win_
◆ debug_interval_
int tesseract::LSTMTrainer::debug_interval_ |
|
protected |
◆ error_buffers_
◆ error_rate_of_last_saved_best_
float tesseract::LSTMTrainer::error_rate_of_last_saved_best_ |
|
protected |
◆ error_rates_
double tesseract::LSTMTrainer::error_rates_[ET_COUNT] |
|
protected |
◆ file_reader_
◆ file_writer_
◆ improvement_steps_
int32_t tesseract::LSTMTrainer::improvement_steps_ |
|
protected |
◆ kRollingBufferSize_
const int tesseract::LSTMTrainer::kRollingBufferSize_ = 1000 |
|
staticprotected |
◆ last_perfect_training_iteration_
int tesseract::LSTMTrainer::last_perfect_training_iteration_ |
|
protected |
◆ learning_iteration_
int tesseract::LSTMTrainer::learning_iteration_ |
|
protected |
◆ mgr_
◆ model_base_
STRING tesseract::LSTMTrainer::model_base_ |
|
protected |
◆ num_training_stages_
int tesseract::LSTMTrainer::num_training_stages_ |
|
protected |
◆ perfect_delay_
int tesseract::LSTMTrainer::perfect_delay_ |
|
protected |
◆ prev_sample_iteration_
int tesseract::LSTMTrainer::prev_sample_iteration_ |
|
protected |
◆ randomly_rotate_
bool tesseract::LSTMTrainer::randomly_rotate_ |
|
protected |
◆ recon_win_
◆ stall_iteration_
int tesseract::LSTMTrainer::stall_iteration_ |
|
protected |
◆ sub_trainer_
◆ target_win_
◆ training_data_
◆ training_stage_
int tesseract::LSTMTrainer::training_stage_ |
|
protected |
◆ worst_error_rate_
double tesseract::LSTMTrainer::worst_error_rate_ |
|
protected |
◆ worst_error_rates_
double tesseract::LSTMTrainer::worst_error_rates_[ET_COUNT] |
|
protected |
◆ worst_iteration_
int tesseract::LSTMTrainer::worst_iteration_ |
|
protected |
◆ worst_model_data_
The documentation for this class was generated from the following files:
int IntCastRounded(double x)
void init_to_size(int size, const T &t)
const double kMinDivergenceRate
virtual R Run(A1, A2, A3)=0
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0)
static std::string CleanupString(const char *utf8_str)
int CurrentTrainingStage() const
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer) const
bool LoadDataFromFile(const char *filename, GenericVector< char > *data)
const char * string() const
NetworkScratch scratch_space_
void SaveRecognitionDump(GenericVector< char > *data) const
const int kNumAdjustmentIterations
void UpdateErrorBuffer(double new_error, ErrorTypes type)
void add_str_int(const char *str, int number)
LossType OutputLossType() const
bool encode_string(const char *str, bool give_up_on_failure, GenericVector< UNICHAR_ID > *encoding, GenericVector< char > *lengths, int *encoded_length) const
constexpr size_t countof(T const (&)[N]) noexcept
static constexpr float kMinCertainty
double NewSingleError(ErrorTypes type) const
double ComputeRMSError(const NetworkIO &deltas)
double learning_rate() const
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
bool TransitionTrainingStage(float error_threshold)
void DisplayForward(const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
GenericVector< char > best_trainer_
SVEvent * AwaitEvent(SVEventType type)
virtual StaticShape InputShape() const
virtual void SetEnableTraining(TrainingState state)
int last_perfect_training_iteration_
CheckPointReader checkpoint_reader_
void SetVersionString(const std::string &v_str)
SubTrainerResult UpdateSubtrainer(STRING *log_msg)
bool SaveFile(const STRING &filename, FileWriter writer) const
virtual STRING spec() const
int sample_iteration() const
const int kMinStartedErrorRate
void ScaleLayerLearningRate(const STRING &id, double factor)
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
virtual void Update(float learning_rate, float momentum, float adam_beta, int num_samples)
STRING UpdateErrorGraph(int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
void LogIterations(const char *intro_str, STRING *log_msg) const
CheckPointWriter checkpoint_writer_
void LabelsFromOutputs(const NetworkIO &outputs, GenericVector< int > *labels, GenericVector< int > *xcoords)
DLLSYM void tprintf(const char *format,...)
double error_rates_[ET_COUNT]
UNICHAR_ID unichar_to_id(const char *const unichar_repr) const
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
static bool ComputeCTCTargets(const GenericVector< int > &truth_labels, int null_char, const GENERIC_2D_ARRAY< float > &outputs, NetworkIO *targets)
void StartSubtrainer(STRING *log_msg)
double ComputeCharError(const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
int checkpoint_iteration_
void DebugActivationPath(const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
bool ComputeTextTargets(const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
bool LoadDocuments(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, FileReader reader)
bool Serialize(FILE *fp) const
double worst_error_rates_[ET_COUNT]
const double kLearningRateDecay
_ConstTessMemberResultCallback_5_0< false, R, T1, P1, P2, P3, P4, P5 >::base * NewPermanentTessCallback(const T1 *obj, R(T2::*member)(P1, P2, P3, P4, P5) const, typename Identity< P1 >::type p1, typename Identity< P2 >::type p2, typename Identity< P3 >::type p3, typename Identity< P4 >::type p4, typename Identity< P5 >::type p5)
DocumentCache training_data_
LIST search(LIST list, void *key, int_compare is_equal)
std::vector< int > MapRecoder(const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
GenericVector< double > best_error_history_
GenericVector< double > error_buffers_[ET_COUNT]
int prev_sample_iteration_
bool has_special_codes() const
float error_rate_of_last_saved_best_
int learning_iteration() const
virtual void DebugWeights()=0
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
int32_t improvement_steps_
GenericVector< char > worst_model_data_
const int kNumPagesPerBatch
const double kImprovementFraction
std::string VersionString() const
const double kStageTransitionThreshold
LSTMTrainer * sub_trainer_
int training_iteration() const
bool DeSerialize(bool swap, FILE *fp)
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
double ComputeWinnerError(const NetworkIO &deltas)
double SignedRand(double range)
static void NormalizeProbs(NetworkIO *probs)
bool SimpleTextOutput() const
bool TestFlag(NetworkFlags flag) const
Network * GetLayer(const STRING &id) const
GenericVector< STRING > EnumerateLayers() const
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
bool load_from_file(const char *const filename, bool skip_fragments)
void PrepareLogMsg(STRING *log_msg) const
virtual int RemapOutputs(int old_no, const std::vector< int > &code_map)
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)
const int kErrorGraphInterval
bool ReadSizedTrainingDump(const char *data, int size, LSTMTrainer *trainer) const
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)=0
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
void add_str_double(const char *str, double number)
virtual StaticShape OutputShape(const StaticShape &input_shape) const
const double kBestCheckpointFraction
bool ComputeCTCTargets(const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
const UNICHARSET & GetUnicharset() const
float GetLayerLearningRate(const STRING &id) const
bool LoadCharsets(const TessdataManager *mgr)
GenericVector< char > best_model_data_
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
static const int kRollingBufferSize_
int EncodeUnichar(int unichar_id, RecodedCharID *code) const
int32_t sample_iteration_
bool Init(const char *data_file_name)
const int kMinStallIterations
int ReduceLayerLearningRates(double factor, int num_samples, LSTMTrainer *samples_trainer)
bool Serialize(const TessdataManager *mgr, TFile *fp) const
GenericVector< int > best_error_iterations_
STRING DecodeLabels(const GenericVector< int > &labels)
void ScaleLearningRate(double factor)
double best_error_rates_[ET_COUNT]
void OverwriteEntry(TessdataType type, const char *data, int size)
const double kHighConfidence
int32_t training_iteration_
STRING DumpFilename() const
const char * c_str() const
bool SaveDataToFile(const GenericVector< char > &data, const STRING &filename)
void split(char c, GenericVector< STRING > *splited)
@ TESSDATA_LSTM_UNICHARSET
static bool InitNetwork(int num_outputs, STRING network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)
const double kSubTrainerMarginFraction
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)