BaumWelch  baumwelch-0.3.8
OpenGrm-BaumWelch library
train.h
Go to the documentation of this file.
1 // Licensed under the Apache License, Version 2.0 (the "License");
2 // you may not use this file except in compliance with the License.
3 // You may obtain a copy of the License at
4 //
5 // http://www.apache.org/licenses/LICENSE-2.0
6 //
7 // Unless required by applicable law or agreed to in writing, software
8 // distributed under the License is distributed on an "AS IS" BASIS,
9 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 // See the License for the specific language governing permissions and
11 // limitations under the License.
12 //
13 // Copyright 2017 and onwards Google, Inc.
14 
15 #ifndef NLP_GRM2_BAUMWELCH_TRAIN_H_
16 #define NLP_GRM2_BAUMWELCH_TRAIN_H_
17 
18 #include <cmath>
19 #include <cstdint>
20 #include <vector>
21 
22 #include <fst/log.h>
23 #include <fst/extensions/far/far.h>
24 #include <fst/arcfilter.h>
25 #include <fst/compose.h>
26 #include <fst/fst.h>
27 #include <fst/mutable-fst.h>
28 #include <fst/queue.h>
29 #include <fst/shortest-distance.h>
30 #include <fst/weight.h>
31 #include <baumwelch/cascade.h>
33 #include <baumwelch/log-adder.h>
34 #include <baumwelch/util.h>
35 
36 namespace fst {
37 
38 // Some defaults.
39 constexpr float kAlpha = 1.;
40 constexpr int kMaxIters = 50;
41 
42 // Helper for training options. If batch_size is 0, or larger than the data,
43 // full-batch training is performed.
44 struct TrainOptions {
45  explicit TrainOptions(int max_iters = kMaxIters, float alpha = kAlpha,
46  int batch_size = 0, float delta = kDelta,
49  alpha(alpha),
51  delta(delta),
52  copts(copts) {
53  if (alpha == 0.0) {
54  // When alpha is 0, we mandatorily use full-batch training.
55  batch_size = 0;
56  }
57  }
58 
59  // Maximum number of iterations to perform.
60  int max_iters;
61  // Step size reduction power. When non-zero, step size is (k + 2)^{-alpha},
62  // where k is the step.
63  float alpha;
64  // Maximum size of a batch. If set to 0, full-batch training occurs.
66  // Comparison/quantization delta used to determine convergence.
67  float delta;
68  // Options passed to the trainer.
70 };
71 
72 namespace internal {
73 
74 // Class storing forward and backwards weights
75 //
76 // For idempotent semirings, this uses A* search during the alpha computation.
77 // If a state is not visited during search the estimate is taken to be
78 // semiring zero. This estimate of alpha for a state has the true value as an
79 // upper bound, since some states not visited during the search will
80 // have true non-zero values because search terminates once the shortest path is
81 // found (due to first_path=true).
82 template <class Arc>
84  public:
85  using StateId = typename Arc::StateId;
86  using Weight = typename Arc::Weight;
87 
88  explicit ForwardBackward(const ComposeFst<Arc> &ico) {
89  ShortestDistance(ico, &beta_, /*reverse=*/true);
90  if constexpr (IsIdempotent<typename Arc::Weight>::value) {
91  // Computes alpha using an A* approximation.
92  using StateId = typename Arc::StateId;
93  using Weight = typename Arc::Weight;
94  using MyEstimate = NaturalAStarEstimate<StateId, Weight>;
95  using MyQueue = NaturalAStarQueue<StateId, Weight, MyEstimate>;
96  using MyArcFilter = AnyArcFilter<Arc>;
97  using MyShortestDistanceOptions =
98  ShortestDistanceOptions<Arc, MyQueue, MyArcFilter>;
99  const MyEstimate estimate(beta_);
100  MyQueue queue(alpha_, estimate);
101  static constexpr MyArcFilter arc_filter;
102  const MyShortestDistanceOptions opts(
103  &queue, arc_filter,
104  /*source=*/kNoStateId, // Default.
105  /*delta=*/kShortestDelta, // Default.
106  /*first_path=*/true); // Heuristic is admissible.
107  ShortestDistance(ico, &alpha_, opts);
108  VLOG(1) << ExploredStates<Weight>(alpha_) << " alpha states explored";
109  } else {
110  ShortestDistance(ico, &alpha_, /*reverse=*/false);
111  }
112  }
113 
114  const Weight &Alpha(StateId s) const {
115  return ForwardBackward::WeightOrZero(s, alpha_);
116  }
117 
118  const Weight &Beta(StateId s) const {
119  return ForwardBackward::WeightOrZero(s, beta_);
120  }
121 
122  private:
123  static constexpr Weight kZero = Weight::Zero();
124 
125  // Returns the shortest distance weight, or semiring zero if the state was
126  // not visited during the respective shortest distance computation.
127  static const Weight &WeightOrZero(StateId s,
128  const std::vector<Weight> &weights) {
129  return (s < weights.size()) ? weights[s] : kZero;
130  }
131 
132  std::vector<Weight> alpha_;
133  std::vector<Weight> beta_;
134 };
135 
136 // Object which holds all necessary information for stepwise or minibatch
137 // training. It stores the (initial) learning rate and the step counter. For
138 // more information, see the "sEM" pseudocode (p. 613) in:
139 //
140 // Liang, P., and Klein, D. 2009. Online EM for unsupervised models. In
141 // Proceedings of Human Language Technologies: The 2009 Annual Conference of
142 // the North American Chapter of the Association for Computational Linguistics,
143 // pages 611-619.
144 template <class Arc, class ExpectationTable>
146  public:
147  using Weight = typename Arc::Weight;
149 
150  // Valid values of alpha are usually between [.5, 1.0].
152  float alpha = kAlpha, int batch_size = 0,
153  const CascadeOptions &opts = CascadeOptions())
154  : alpha_(alpha), batch_size_(batch_size), opts_(opts), step_(0) {}
155 
156  // Performs a batch of training returning the likelihood. Semiring Zero is
157  // returned in the case of composition failure.
158  Weight Batch(FarReader<Arc> &input, FarReader<Arc> &output,
159  MutableFst<Arc> *model) {
160  ExpectationTable table(*model);
161  Sum likelihood; // Tracks batch likelihood.
162  int batch_idx = 0; // Tracks actual batch size.
163  for (; !input.Done() && !output.Done() &&
164  (!batch_size_ || batch_idx < batch_size_);
165  ++batch_idx) {
166  likelihood.Add(
167  Forward(*input.GetFst(), *output.GetFst(), *model, &table));
168  if (input.Type() != FarType::FST) input.Next();
169  output.Next();
170  }
171  Backward(table, model);
172  ++step_;
173  const auto batch_likelihood = likelihood.Sum();
174  LOG(INFO) << "Step " << step_ << " (batch size " << batch_idx
175  << ") average likelihood: "
176  << batch_likelihood.Value() / batch_idx;
177  return batch_likelihood;
178  }
179 
180  // Repeatedly do the stepwise computation.
181  Weight Train(FarReader<Arc> &input, FarReader<Arc> &output,
182  MutableFst<Arc> *model) {
183  Sum likelihood; // Tracks iteration likelihood.
184  while (!input.Done() && !output.Done()) {
185  likelihood.Add(Batch(input, output, model));
186  }
187  Normalize(model);
188  return likelihood.Sum();
189  }
190 
191  // Normalizes the model.
192  void Normalize(MutableFst<Arc> *model) {
193  ExpectationTable table(*model);
194  StateIterator<MutableFst<Arc>> siter(*model);
195  for (; !siter.Done(); siter.Next()) {
196  const auto state = siter.Value();
197  for (ArcIterator<MutableFst<Arc>> aiter(*model, state); !aiter.Done();
198  aiter.Next()) {
199  const auto &arc = aiter.Value();
200  table.Forward(state, arc.ilabel, arc.olabel, arc.weight, arc.nextstate);
201  }
202  const auto weight = model->Final(state);
203  if (weight == Weight::Zero()) continue;
204  table.Forward(state, weight);
205  }
206  for (siter.Reset(); !siter.Done(); siter.Next()) {
207  const auto state = siter.Value();
208  for (MutableArcIterator<MutableFst<Arc>> aiter(model, state);
209  !aiter.Done(); aiter.Next()) {
210  auto arc = aiter.Value();
211  arc.weight = table.Backward(state, arc);
212  aiter.SetValue(arc);
213  }
214  model->SetFinal(state, table.Backward(state));
215  }
216  }
217 
218  private:
219  Weight Forward(const Fst<Arc> &input, const Fst<Arc> &output,
220  const Fst<Arc> &model, ExpectationTable *table) {
221  const ChannelStateCascade<Arc> cascade(input, output, model, opts_);
222  const auto &ico = cascade.GetFst();
223  const auto start = ico.Start();
224  if (start == kNoStateId) {
225  VLOG(1) << "Empty lattice";
226  return false;
227  }
228  const ForwardBackward<Arc> fb(ico);
229  const auto &likelihood = fb.Beta(start);
230  if (likelihood == Weight::Zero()) {
231  VLOG(1) << "Start state not coaccessible";
232  return Weight::Zero();
233  }
234  for (StateIterator<ComposeFst<Arc>> siter(ico); !siter.Done();
235  siter.Next()) {
236  const auto state = siter.Value();
237  // Non-coaccessible source state.
238  if (fb.Beta(state) == Weight::Zero()) continue;
239  const auto ch_state = cascade.ChannelState(state);
240  const auto &alpha = fb.Alpha(state);
241  for (ArcIterator<ComposeFst<Arc>> aiter(ico, state); !aiter.Done();
242  aiter.Next()) {
243  const auto &arc = aiter.Value();
244  const auto &beta = fb.Beta(arc.nextstate);
245  // Non-coaccessible destination state.
246  if (beta == Weight::Zero()) continue;
247  // The arc expectation is the product of the current weight, alpha,
248  // and beta, divided by the overall observation likelihood.
249  table->Forward(
250  ch_state, arc.ilabel, arc.olabel,
251  Divide(Times(Times(alpha, arc.weight), beta), likelihood),
252  cascade.ChannelState(arc.nextstate));
253  }
254  const auto weight = ico.Final(state);
255  if (weight == Weight::Zero()) continue;
256  // The final state expectation is the product of the current weight and
257  // alpha, divided by the overall observation likelihood.
258  table->Forward(ch_state, Divide(Times(alpha, weight), likelihood));
259  }
260  return likelihood;
261  }
262 
263  // TODO(kbg): Add a way to disable interpolation.
264  static Weight Interpolate(const Weight &old_weight, const Weight &new_weight,
265  double nu_k) {
266  if (nu_k == 1.0) {
267  // Contribution of old_weight is 0, so just returns the new_weight.
268  // This corresponds to standard full-batch EM.
269  return new_weight;
270  }
271  const auto old_term = Times(1 - nu_k, old_weight);
272  const auto new_term = Times(nu_k, new_weight);
273  Sum plus(old_term);
274  plus.Add(new_term);
275  return plus.Sum();
276  }
277 
278  void Backward(const ExpectationTable &table, MutableFst<Arc> *model) {
279  const double nu_k = alpha_ == 0.0 ? 1.0 : std::pow(step_ + 2, -alpha_);
280  for (StateIterator<MutableFst<Arc>> siter(*model); !siter.Done();
281  siter.Next()) {
282  const auto state = siter.Value();
283  // Sets new arc weights.
284  for (MutableArcIterator<MutableFst<Arc>> aiter(model, state);
285  !aiter.Done(); aiter.Next()) {
286  auto arc = aiter.Value();
287  arc.weight = Interpolate(arc.weight, table.Backward(state, arc), nu_k);
288  aiter.SetValue(arc);
289  }
290  // Sets new final weights.
291  model->SetFinal(
292  state, Interpolate(model->Final(state), table.Backward(state), nu_k));
293  }
294  }
295 
296  const float alpha_; // Batch size reduction power.
297  const int batch_size_; // Batch size hyperparameter.
298  const CascadeOptions opts_;
299  uint64_t step_; // Iteration/step number.
300 };
301 
302 // Full training setup, templated on expectation table.
303 template <class Arc, class ExpectationTable>
304 typename Arc::Weight Train(FarReader<Arc> &input, FarReader<Arc> &output,
305  MutableFst<Arc> *model,
306  const TrainOptions &opts = TrainOptions()) {
307  using Weight = typename Arc::Weight;
308  auto last_likelihood = Weight::Zero();
310  opts.alpha, opts.batch_size, opts.copts);
311  trainer.Normalize(model);
312  for (int iteration = 0; iteration < opts.max_iters; ++iteration) {
313  input.Reset();
314  output.Reset();
315  const auto total_likelihood = trainer.Train(input, output, model);
316  LOG(INFO) << "Iteration " << iteration + 1
317  << " total likelihood: " << total_likelihood;
318  if (ApproxEqual(last_likelihood, total_likelihood, opts.delta)) {
319  return total_likelihood;
320  }
321  last_likelihood = total_likelihood;
322  }
323  return last_likelihood;
324 }
325 
326 } // namespace internal
327 
328 // Full training setup.
329 template <class Arc>
330 typename Arc::Weight Train(FarReader<Arc> &input, FarReader<Arc> &output,
331  MutableFst<Arc> *model, bool normalize_ilabel = true,
332  const TrainOptions &opts = TrainOptions()) {
333  if (normalize_ilabel) {
334  return internal::Train<Arc, StateILabelExpectationTable<Arc>>(input, output,
335  model, opts);
336  } else {
337  return internal::Train<Arc, StateExpectationTable<Arc>>(input, output,
338  model, opts);
339  }
340 }
341 
342 } // namespace fst
343 
344 #endif // NLP_GRM2_BAUMWELCH_TRAIN_H_
345 
typename Arc::StateId StateId
Definition: train.h:85
void Add(const Weight &weight)
Definition: log-adder.h:42
StateId ChannelState(StateId ico_state) const
Definition: cascade.h:105
typename Arc::Weight Weight
Definition: train.h:86
CascadeOptions copts
Definition: train.h:69
void Normalize(MutableFst< Arc > *model)
Definition: train.h:192
Definition: a-star.h:30
TrainOptions(int max_iters=kMaxIters, float alpha=kAlpha, int batch_size=0, float delta=kDelta, const CascadeOptions &copts=CascadeOptions())
Definition: train.h:45
StepwiseBaumWelchTrainer(float alpha=kAlpha, int batch_size=0, const CascadeOptions &opts=CascadeOptions())
Definition: train.h:151
Weight Train(FarReader< Arc > &input, FarReader< Arc > &output, MutableFst< Arc > *model)
Definition: train.h:181
float alpha
Definition: train.h:63
const Weight & Alpha(StateId s) const
Definition: train.h:114
ForwardBackward(const ComposeFst< Arc > &ico)
Definition: train.h:88
Arc::Weight Train(FarReader< Arc > &input, FarReader< Arc > &output, MutableFst< Arc > *model, bool normalize_ilabel=true, const TrainOptions &opts=TrainOptions())
Definition: train.h:330
const ComposeFst< Arc > & GetFst() const
Definition: cascade.h:103
Weight Sum() const
Definition: log-adder.h:45
float delta
Definition: train.h:67
int batch_size
Definition: train.h:65
const Weight & Beta(StateId s) const
Definition: train.h:118
Weight Batch(FarReader< Arc > &input, FarReader< Arc > &output, MutableFst< Arc > *model)
Definition: train.h:158
constexpr float kAlpha
Definition: train.h:39
constexpr int kMaxIters
Definition: train.h:40
typename Arc::Weight Weight
Definition: train.h:147