BaumWelch  baumwelch-0.3.8
OpenGrm-BaumWelch library
trainscript.h
Go to the documentation of this file.
1 // Licensed under the Apache License, Version 2.0 (the "License");
2 // you may not use this file except in compliance with the License.
3 // You may obtain a copy of the License at
4 //
5 // http://www.apache.org/licenses/LICENSE-2.0
6 //
7 // Unless required by applicable law or agreed to in writing, software
8 // distributed under the License is distributed on an "AS IS" BASIS,
9 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 // See the License for the specific language governing permissions and
11 // limitations under the License.
12 //
13 // Copyright 2017 and onwards Google, Inc.
14 
15 #ifndef NLP_GRM2_BAUMWELCH_TRAINSCRIPT_H_
16 #define NLP_GRM2_BAUMWELCH_TRAINSCRIPT_H_
17 
18 #include <tuple>
19 
20 #include <fst/extensions/far/far-class.h>
21 #include <fst/extensions/far/far.h>
22 #include <fst/mutable-fst.h>
23 #include <fst/script/fst-class.h>
24 #include <baumwelch/train.h>
25 
26 namespace fst {
27 namespace script {
28 
29 using BaumWelchTrainArgs =
30  std::tuple<FarReaderClass &, FarReaderClass &, MutableFstClass *, bool,
31  const TrainOptions &>;
32 
33 template <class Arc>
35  FarReader<Arc> &input = *std::get<0>(*args).GetFarReader<Arc>();
36  FarReader<Arc> &output = *std::get<1>(*args).GetFarReader<Arc>();
37  MutableFst<Arc> *model = std::get<2>(*args)->GetMutableFst<Arc>();
38  Train(input, output, model, std::get<3>(*args), std::get<4>(*args));
39 }
40 
41 void Train(FarReaderClass &input, FarReaderClass &output,
42  MutableFstClass *model, bool normalize_ilabel = true,
43  const TrainOptions &opts = TrainOptions());
44 
45 } // namespace script
46 } // namespace fst
47 
48 #endif // NLP_GRM2_BAUMWELCH_TRAINSCRIPT_H_
49 
std::tuple< FarReaderClass &, FarReaderClass &, MutableFstClass *, bool, const TrainOptions & > BaumWelchTrainArgs
Definition: trainscript.h:31
Definition: a-star.h:30
void Train(BaumWelchTrainArgs *args)
Definition: trainscript.h:34