BaumWelch  baumwelch-0.3.8
OpenGrm-BaumWelch library
baumwelchdecode-main.cc
Go to the documentation of this file.
1 // Licensed under the Apache License, Version 2.0 (the "License");
2 // you may not use this file except in compliance with the License.
3 // You may obtain a copy of the License at
4 //
5 // http://www.apache.org/licenses/LICENSE-2.0
6 //
7 // Unless required by applicable law or agreed to in writing, software
8 // distributed under the License is distributed on an "AS IS" BASIS,
9 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 // See the License for the specific language governing permissions and
11 // limitations under the License.
12 //
13 // Copyright 2017 and onwards Google, Inc.
14 
15 // Decodes Baum-Welch model.
16 
17 #include <cstring>
18 #include <memory>
19 #include <string>
20 
21 #include <fst/flags.h>
22 
23 #include <fst/extensions/far/far-class.h>
24 #include <fst/extensions/far/far.h>
25 #include <fst/util.h>
26 #include <fst/script/fst-class.h>
27 #include <baumwelch/decodescript.h>
28 
29 int baumwelchdecode_main(int argc, char **argv) {
30  namespace s = fst::script;
31  using fst::FarType;
32  using fst::script::FarReaderClass;
33  using fst::script::FarWriterClass;
34  using fst::script::FstClass;
35 
36  std::string usage = "Decodes a WFST model\n\n Usage: ";
37  usage += argv[0];
38  usage += " input.f(ar|st) output.far model.fst [out.far]\n";
39 
40  SET_FLAGS(usage.c_str(), &argc, &argv, true);
41 
42  if (argc < 4 || argc > 5) {
43  ShowUsage();
44  return 1;
45  }
46 
47  const std::string input_name = strcmp(argv[1], "-") != 0 ? argv[1] : "";
48  const std::string output_name = strcmp(argv[2], "-") != 0 ? argv[2] : "";
49  const std::string model_name = strcmp(argv[3], "-") != 0 ? argv[3] : "";
50  const std::string hypotext_name = argc > 4 ? argv[4] : "";
51 
52  if (input_name.empty() && (output_name.empty() || model_name.empty())) {
53  LOG(ERROR) << argv[0] << ": Can't take more than one input from standard "
54  << "input";
55  return 1;
56  }
57  if (output_name.empty() && model_name.empty()) {
58  LOG(ERROR) << argv[0] << ": Can't take more than one input from standard "
59  << "input";
60  return 1;
61  }
62 
63  const std::unique_ptr<FarReaderClass> input(FarReaderClass::Open(input_name));
64  if (!input) return 1;
65 
66  const std::unique_ptr<FarReaderClass> output(
67  FarReaderClass::Open(output_name));
68  if (!output) return 1;
69 
70  const std::unique_ptr<const FstClass> model(FstClass::Read(model_name));
71  if (!model) return 1;
72 
73  const std::unique_ptr<FarWriterClass> hypotext(
74  FarWriterClass::Create(hypotext_name, output->ArcType()));
75  if (!hypotext) return 1;
76 
77  s::Decode(*input, *output, *model, *hypotext);
78 
79  if (input->Error()) {
80  FSTERROR() << "Error reading FAR: " << input_name;
81  return 1;
82  }
83  if (output->Error()) {
84  FSTERROR() << "Error reading FAR: " << output_name;
85  return 1;
86  }
87  if (hypotext->Error()) {
88  FSTERROR() << "Error writing FAR: " << hypotext_name;
89  return 1;
90  }
91 
92  return 0;
93 }
94 
void Decode(FarReader< Arc > &input, FarReader< Arc > &output, const Fst< Arc > &model, FarWriter< Arc > &hypotext)
Definition: decode.h:89
int baumwelchdecode_main(int argc, char **argv)