BaumWelch  baumwelch-0.3.8
OpenGrm-BaumWelch library
cascade.h
Go to the documentation of this file.
1 // Licensed under the Apache License, Version 2.0 (the "License");
2 // you may not use this file except in compliance with the License.
3 // You may obtain a copy of the License at
4 //
5 // http://www.apache.org/licenses/LICENSE-2.0
6 //
7 // Unless required by applicable law or agreed to in writing, software
8 // distributed under the License is distributed on an "AS IS" BASIS,
9 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 // See the License for the specific language governing permissions and
11 // limitations under the License.
12 //
13 // Copyright 2017 and onwards Google, Inc.
14 
15 #ifndef NLP_GRM2_BAUMWELCH_CASCADE_H_
16 #define NLP_GRM2_BAUMWELCH_CASCADE_H_
17 
18 #include <fst/cache.h>
19 #include <fst/compose.h>
20 #include <fst/fst-decl.h>
21 #include <fst/state-table.h>
22 
23 // Cascade objects used during the E-step.
24 
25 namespace fst {
26 
27 // Struct holding two cache options structs.
29  CascadeOptions(const CascadeOptions &) = default;
30 
31  explicit CascadeOptions(
32  const CacheOptions &co_cache_options = CacheOptions(),
33  const CacheOptions &ico_cache_options = CacheOptions())
36 
37  const CacheOptions co_cache_options;
38  const CacheOptions ico_cache_options;
39 };
40 
41 // Cascade objects represent the composition of a input WFSA (usually a
42 // string or an LM), the model, and a output string FSA. They minimally have the
43 // following interface:
44 //
45 // template <class Arc>
46 // class CascadeInterface {
47 // public:
48 // using StateId = typename Arc::StateId;
49 //
50 // // Required constructor, which builds the cascade.
51 // CascadeInterface(const Fst<Arc> &input, const Fst<Arc> &output,
52 // const Fst<Arc> &model);
53 //
54 // // Returns reference to the cascade.
55 // const ComposeFst<Arc> &GetFst() const;
56 // };
57 
58 // Simple cascade object.
59 template <class Arc>
61  public:
62  using StateId = typename Arc::StateId;
63 
64  SimpleCascade(const Fst<Arc> &input, const Fst<Arc> &output,
65  const Fst<Arc> &model,
66  const CascadeOptions &opts = CascadeOptions())
67  : co_options_(opts.co_cache_options),
68  co_(model, output, co_options_),
69  ico_options_(opts.ico_cache_options),
70  ico_(input, co_, ico_options_) {}
71 
72  const ComposeFst<Arc> &GetFst() const { return ico_; }
73 
74  private:
75  SimpleCascade(const SimpleCascade &) = delete;
76  SimpleCascade &operator=(const SimpleCascade &) = delete;
77 
78  const ComposeFstOptions<Arc> co_options_;
79  const ComposeFst<Arc> co_;
80  const ComposeFstOptions<Arc> ico_options_;
81  const ComposeFst<Arc> ico_;
82 };
83 
84 // Cascade object that also keeps track of state IDs in the original model.
85 template <class Arc, class M = Matcher<Fst<Arc>>,
86  class Filter = SequenceComposeFilter<M>,
87  class StateTable =
88  GenericComposeStateTable<Arc, typename Filter::FilterState>>
90  public:
91  using StateId = typename Arc::StateId;
92 
93  ChannelStateCascade(const Fst<Arc> &input, const Fst<Arc> &output,
94  const Fst<Arc> &model,
95  const CascadeOptions &opts = CascadeOptions())
96  : co_options_(opts.co_cache_options, nullptr, nullptr, nullptr,
97  new StateTable(model, output)),
98  co_(model, output, co_options_),
99  ico_options_(opts.ico_cache_options, nullptr, nullptr, nullptr,
100  new StateTable(input, co_)),
101  ico_(input, co_, ico_options_) {}
102 
103  const ComposeFst<Arc> &GetFst() const { return ico_; }
104 
105  StateId ChannelState(StateId ico_state) const {
106  const auto ic_state = InputChannelState(ico_state);
107  return co_options_.state_table->Tuple(ic_state).StateId1();
108  }
109 
110  private:
111  ChannelStateCascade(const ChannelStateCascade &) = delete;
112  ChannelStateCascade &operator=(const ChannelStateCascade &) = delete;
113 
114  StateId InputChannelState(StateId ico_state) const {
115  return ico_options_.state_table->Tuple(ico_state).StateId2();
116  }
117 
118  const ComposeFstOptions<Arc> co_options_;
119  const ComposeFst<Arc> co_;
120  const ComposeFstOptions<Arc> ico_options_;
121  const ComposeFst<Arc> ico_;
122 };
123 
124 } // namespace fst
125 
126 #endif // NLP_GRM2_BAUMWELCH_CASCADE_H_
127 
StateId ChannelState(StateId ico_state) const
Definition: cascade.h:105
const CacheOptions co_cache_options
Definition: cascade.h:37
Definition: a-star.h:30
SimpleCascade(const Fst< Arc > &input, const Fst< Arc > &output, const Fst< Arc > &model, const CascadeOptions &opts=CascadeOptions())
Definition: cascade.h:64
typename Arc::StateId StateId
Definition: cascade.h:62
ChannelStateCascade(const Fst< Arc > &input, const Fst< Arc > &output, const Fst< Arc > &model, const CascadeOptions &opts=CascadeOptions())
Definition: cascade.h:93
const ComposeFst< Arc > & GetFst() const
Definition: cascade.h:103
const ComposeFst< Arc > & GetFst() const
Definition: cascade.h:72
CascadeOptions(const CacheOptions &co_cache_options=CacheOptions(), const CacheOptions &ico_cache_options=CacheOptions())
Definition: cascade.h:31
const CacheOptions ico_cache_options
Definition: cascade.h:38
typename Arc::StateId StateId
Definition: cascade.h:91
CascadeOptions(const CascadeOptions &)=default