21 #include <fst/flags.h> 23 #include <fst/extensions/far/far-class.h> 25 #include <fst/script/fst-class.h> 38 using fst::script::FarReaderClass;
39 using fst::script::FstClass;
40 using fst::script::MutableFstClass;
42 std::string usage =
"Trains a WFST model\n\n Usage: ";
44 usage +=
" input.f(ar|st) output.far model.fst [out.fst]\n";
46 SET_FLAGS(usage.c_str(), &argc, &argv,
true);
48 if (argc < 4 || argc > 5) {
53 const std::string input_name = strcmp(argv[1],
"-") != 0 ? argv[1] :
"";
54 const std::string output_name = strcmp(argv[2],
"-") != 0 ? argv[2] :
"";
55 const std::string model_name = strcmp(argv[3],
"-") != 0 ? argv[3] :
"";
56 const std::string out_name = argc > 4 ? argv[4] :
"";
58 if (input_name.empty() && (output_name.empty() || model_name.empty())) {
59 LOG(ERROR) << argv[0] <<
": Can't take more than one input from standard " 63 if (output_name.empty() && model_name.empty()) {
64 LOG(ERROR) << argv[0] <<
": Can't take more than one input from standard " 69 const std::unique_ptr<FarReaderClass> input(FarReaderClass::Open(input_name));
72 const std::unique_ptr<FarReaderClass> output(
73 FarReaderClass::Open(output_name));
74 if (!output)
return 1;
76 const std::unique_ptr<MutableFstClass> model(
77 MutableFstClass::Read(model_name));
80 const TrainOptions opts(
86 s::Train(*input, *output, model.get(), FST_FLAGS_normalize_ilabel,
90 FSTERROR() <<
"Error reading FAR: " << input_name;
93 if (output->Error()) {
94 FSTERROR() <<
"Error reading FAR: " << output_name;
98 return !model->Write(out_name);
Arc::Weight Train(FarReader< Arc > &input, FarReader< Arc > &output, MutableFst< Arc > *model, const TrainOptions &opts=TrainOptions())
DECLARE_int32(batch_size)
DECLARE_bool(normalize_ilabel)
int baumwelchtrain_main(int argc, char **argv)