BaumWelch  baumwelch-0.3.8
OpenGrm-BaumWelch library
baumwelchtrain-main.cc
Go to the documentation of this file.
1 // Licensed under the Apache License, Version 2.0 (the "License");
2 // you may not use this file except in compliance with the License.
3 // You may obtain a copy of the License at
4 //
5 // http://www.apache.org/licenses/LICENSE-2.0
6 //
7 // Unless required by applicable law or agreed to in writing, software
8 // distributed under the License is distributed on an "AS IS" BASIS,
9 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 // See the License for the specific language governing permissions and
11 // limitations under the License.
12 //
13 // Copyright 2017 and onwards Google, Inc.
14 
15 // Trains Baum-Welch model.
16 
17 #include <cstring>
18 #include <memory>
19 #include <string>
20 
21 #include <fst/flags.h>
22 
23 #include <fst/extensions/far/far-class.h>
24 #include <fst/util.h>
25 #include <fst/script/fst-class.h>
26 #include <baumwelch/train.h>
27 #include <baumwelch/trainscript.h>
28 
29 DECLARE_int32(batch_size);
30 DECLARE_double(delta);
31 DECLARE_double(alpha);
32 DECLARE_int32(max_iters);
33 DECLARE_bool(normalize_ilabel);
34 
35 int baumwelchtrain_main(int argc, char **argv) {
36  namespace s = fst::script;
37  using fst::TrainOptions;
38  using fst::script::FarReaderClass;
39  using fst::script::FstClass;
40  using fst::script::MutableFstClass;
41 
42  std::string usage = "Trains a WFST model\n\n Usage: ";
43  usage += argv[0];
44  usage += " input.f(ar|st) output.far model.fst [out.fst]\n";
45 
46  SET_FLAGS(usage.c_str(), &argc, &argv, true);
47 
48  if (argc < 4 || argc > 5) {
49  ShowUsage();
50  return 1;
51  }
52 
53  const std::string input_name = strcmp(argv[1], "-") != 0 ? argv[1] : "";
54  const std::string output_name = strcmp(argv[2], "-") != 0 ? argv[2] : "";
55  const std::string model_name = strcmp(argv[3], "-") != 0 ? argv[3] : "";
56  const std::string out_name = argc > 4 ? argv[4] : "";
57 
58  if (input_name.empty() && (output_name.empty() || model_name.empty())) {
59  LOG(ERROR) << argv[0] << ": Can't take more than one input from standard "
60  << "input";
61  return 1;
62  }
63  if (output_name.empty() && model_name.empty()) {
64  LOG(ERROR) << argv[0] << ": Can't take more than one input from standard "
65  << "input";
66  return 1;
67  }
68 
69  const std::unique_ptr<FarReaderClass> input(FarReaderClass::Open(input_name));
70  if (!input) return 1;
71 
72  const std::unique_ptr<FarReaderClass> output(
73  FarReaderClass::Open(output_name));
74  if (!output) return 1;
75 
76  const std::unique_ptr<MutableFstClass> model(
77  MutableFstClass::Read(model_name));
78  if (!model) return 1;
79 
80  const TrainOptions opts(
81  /*max_iters=*/FST_FLAGS_max_iters,
82  /*alpha=*/FST_FLAGS_alpha,
83  /*batch_size=*/FST_FLAGS_batch_size,
84  /*delta=*/FST_FLAGS_delta);
85 
86  s::Train(*input, *output, model.get(), FST_FLAGS_normalize_ilabel,
87  opts);
88 
89  if (input->Error()) {
90  FSTERROR() << "Error reading FAR: " << input_name;
91  return 1;
92  }
93  if (output->Error()) {
94  FSTERROR() << "Error reading FAR: " << output_name;
95  return 1;
96  }
97 
98  return !model->Write(out_name);
99 }
100 
Arc::Weight Train(FarReader< Arc > &input, FarReader< Arc > &output, MutableFst< Arc > *model, const TrainOptions &opts=TrainOptions())
Definition: train.h:304
DECLARE_int32(batch_size)
DECLARE_bool(normalize_ilabel)
int baumwelchtrain_main(int argc, char **argv)
DECLARE_double(delta)