BaumWelch  baumwelch-0.3.4
OpenGrm-BaumWelch library
cascade.h
Go to the documentation of this file.
1 // Licensed under the Apache License, Version 2.0 (the "License");
2 // you may not use this file except in compliance with the License.
3 // You may obtain a copy of the License at
4 //
5 // http://www.apache.org/licenses/LICENSE-2.0
6 //
7 // Unless required by applicable law or agreed to in writing, software
8 // distributed under the License is distributed on an "AS IS" BASIS,
9 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 // See the License for the specific language governing permissions and
11 // limitations under the License.
12 //
13 // Copyright 2017 and onwards Google, Inc.
14 
15 #ifndef BAUMWELCH_CASCADE_H_
16 #define BAUMWELCH_CASCADE_H_
17 
18 #include <fst/compose.h>
19 #include <fst/fst-decl.h>
20 
21 // Cascade objects used during the E-step.
22 
23 namespace fst {
24 
25 // Struct holding two cache options structs.
27  CascadeOptions(const CascadeOptions &) = default;
28 
29  explicit CascadeOptions(
30  const CacheOptions &co_cache_options = CacheOptions(),
31  const CacheOptions &ico_cache_options = CacheOptions())
34 
35  const CacheOptions co_cache_options;
36  const CacheOptions ico_cache_options;
37 };
38 
39 // Cascade objects represent the composition of a input WFSA (usually a
40 // string or an LM), the model, and a output string FSA. They minimally have the
41 // following interface:
42 //
43 // template <class Arc>
44 // class CascadeInterface {
45 // public:
46 // using StateId = typename Arc::StateId;
47 //
48 // // Required constructor, which builds the cascade.
49 // CascadeInterface(const Fst<Arc> &input, const Fst<Arc> &output,
50 // const Fst<Arc> &model);
51 //
52 // // Returns reference to the cascade.
53 // const ComposeFst<Arc> &GetFst() const;
54 // };
55 
56 // Simple cascade object.
57 template <class Arc>
59  public:
60  using StateId = typename Arc::StateId;
61 
62  SimpleCascade(const Fst<Arc> &input, const Fst<Arc> &output,
63  const Fst<Arc> &model,
64  const CascadeOptions &opts = CascadeOptions())
65  : co_options_(opts.co_cache_options),
66  co_(model, output, co_options_),
67  ico_options_(opts.ico_cache_options),
68  ico_(input, co_, ico_options_) {}
69 
70  const ComposeFst<Arc> &GetFst() const { return ico_; }
71 
72  private:
73  SimpleCascade(const SimpleCascade &) = delete;
74  SimpleCascade &operator=(const SimpleCascade &) = delete;
75 
76  const ComposeFstOptions<Arc> co_options_;
77  const ComposeFst<Arc> co_;
78  const ComposeFstOptions<Arc> ico_options_;
79  const ComposeFst<Arc> ico_;
80 };
81 
82 // Cascade object that also keeps track of state IDs in the original model.
83 template <class Arc, class M = Matcher<Fst<Arc>>,
84  class Filter = SequenceComposeFilter<M>,
85  class StateTable =
86  GenericComposeStateTable<Arc, typename Filter::FilterState>>
88  public:
89  using StateId = typename Arc::StateId;
90 
91  ChannelStateCascade(const Fst<Arc> &input, const Fst<Arc> &output,
92  const Fst<Arc> &model,
93  const CascadeOptions &opts = CascadeOptions())
94  : co_options_(opts.co_cache_options, nullptr, nullptr, nullptr,
95  new StateTable(model, output)),
96  co_(model, output, co_options_),
97  ico_options_(opts.ico_cache_options, nullptr, nullptr, nullptr,
98  new StateTable(input, co_)),
99  ico_(input, co_, ico_options_) {}
100 
101  const ComposeFst<Arc> &GetFst() const { return ico_; }
102 
103  StateId ChannelState(StateId ico_state) const {
104  const auto ic_state = InputChannelState(ico_state);
105  return co_options_.state_table->Tuple(ic_state).StateId1();
106  }
107 
108  private:
109  ChannelStateCascade(const ChannelStateCascade &) = delete;
110  ChannelStateCascade &operator=(const ChannelStateCascade &) = delete;
111 
112  StateId InputChannelState(StateId ico_state) const {
113  return ico_options_.state_table->Tuple(ico_state).StateId2();
114  }
115 
116  const ComposeFstOptions<Arc> co_options_;
117  const ComposeFst<Arc> co_;
118  const ComposeFstOptions<Arc> ico_options_;
119  const ComposeFst<Arc> ico_;
120 };
121 
122 } // namespace fst
123 
124 #endif // BAUMWELCH_CASCADE_H_
125 
StateId ChannelState(StateId ico_state) const
Definition: cascade.h:103
const CacheOptions co_cache_options
Definition: cascade.h:35
Definition: a-star.h:27
SimpleCascade(const Fst< Arc > &input, const Fst< Arc > &output, const Fst< Arc > &model, const CascadeOptions &opts=CascadeOptions())
Definition: cascade.h:62
typename Arc::StateId StateId
Definition: cascade.h:60
ChannelStateCascade(const Fst< Arc > &input, const Fst< Arc > &output, const Fst< Arc > &model, const CascadeOptions &opts=CascadeOptions())
Definition: cascade.h:91
const ComposeFst< Arc > & GetFst() const
Definition: cascade.h:101
const ComposeFst< Arc > & GetFst() const
Definition: cascade.h:70
CascadeOptions(const CacheOptions &co_cache_options=CacheOptions(), const CacheOptions &ico_cache_options=CacheOptions())
Definition: cascade.h:29
const CacheOptions ico_cache_options
Definition: cascade.h:36
typename Arc::StateId StateId
Definition: cascade.h:89
CascadeOptions(const CascadeOptions &)=default