15 #ifndef NLP_GRM2_BAUMWELCH_TRAIN_H_ 16 #define NLP_GRM2_BAUMWELCH_TRAIN_H_ 23 #include <fst/extensions/far/far.h> 24 #include <fst/arcfilter.h> 25 #include <fst/compose.h> 27 #include <fst/mutable-fst.h> 28 #include <fst/queue.h> 29 #include <fst/shortest-distance.h> 30 #include <fst/weight.h> 89 ShortestDistance(ico, &beta_,
true);
90 if constexpr (IsIdempotent<typename Arc::Weight>::value) {
92 using StateId =
typename Arc::StateId;
93 using Weight =
typename Arc::Weight;
94 using MyEstimate = NaturalAStarEstimate<StateId, Weight>;
95 using MyQueue = NaturalAStarQueue<StateId, Weight, MyEstimate>;
96 using MyArcFilter = AnyArcFilter<Arc>;
97 using MyShortestDistanceOptions =
98 ShortestDistanceOptions<Arc, MyQueue, MyArcFilter>;
99 const MyEstimate estimate(beta_);
100 MyQueue queue(alpha_, estimate);
101 static constexpr MyArcFilter arc_filter;
102 const MyShortestDistanceOptions opts(
107 ShortestDistance(ico, &alpha_, opts);
108 VLOG(1) << ExploredStates<Weight>(alpha_) <<
" alpha states explored";
110 ShortestDistance(ico, &alpha_,
false);
115 return ForwardBackward::WeightOrZero(s, alpha_);
119 return ForwardBackward::WeightOrZero(s, beta_);
123 static constexpr
Weight kZero = Weight::Zero();
128 const std::vector<Weight> &weights) {
129 return (s < weights.size()) ? weights[s] : kZero;
132 std::vector<Weight> alpha_;
133 std::vector<Weight> beta_;
144 template <
class Arc,
class ExpectationTable>
159 MutableFst<Arc> *model) {
160 ExpectationTable table(*model);
163 for (; !input.Done() && !output.Done() &&
164 (!batch_size_ || batch_idx < batch_size_);
167 Forward(*input.GetFst(), *output.GetFst(), *model, &table));
168 if (input.Type() != FarType::FST) input.Next();
171 Backward(table, model);
173 const auto batch_likelihood = likelihood.
Sum();
174 LOG(INFO) <<
"Step " << step_ <<
" (batch size " << batch_idx
175 <<
") average likelihood: " 176 << batch_likelihood.Value() / batch_idx;
177 return batch_likelihood;
182 MutableFst<Arc> *model) {
184 while (!input.Done() && !output.Done()) {
185 likelihood.
Add(Batch(input, output, model));
188 return likelihood.
Sum();
193 ExpectationTable table(*model);
194 StateIterator<MutableFst<Arc>> siter(*model);
195 for (; !siter.Done(); siter.Next()) {
196 const auto state = siter.Value();
197 for (ArcIterator<MutableFst<Arc>> aiter(*model, state); !aiter.Done();
199 const auto &arc = aiter.Value();
200 table.Forward(state, arc.ilabel, arc.olabel, arc.weight, arc.nextstate);
202 const auto weight = model->Final(state);
203 if (weight == Weight::Zero())
continue;
204 table.Forward(state, weight);
206 for (siter.Reset(); !siter.Done(); siter.Next()) {
207 const auto state = siter.Value();
208 for (MutableArcIterator<MutableFst<Arc>> aiter(model, state);
209 !aiter.Done(); aiter.Next()) {
210 auto arc = aiter.Value();
211 arc.weight = table.Backward(state, arc);
214 model->SetFinal(state, table.Backward(state));
219 Weight Forward(
const Fst<Arc> &input,
const Fst<Arc> &output,
220 const Fst<Arc> &model, ExpectationTable *table) {
222 const auto &ico = cascade.
GetFst();
223 const auto start = ico.Start();
224 if (start == kNoStateId) {
225 VLOG(1) <<
"Empty lattice";
229 const auto &likelihood = fb.
Beta(start);
230 if (likelihood == Weight::Zero()) {
231 VLOG(1) <<
"Start state not coaccessible";
232 return Weight::Zero();
234 for (StateIterator<ComposeFst<Arc>> siter(ico); !siter.Done();
236 const auto state = siter.Value();
238 if (fb.
Beta(state) == Weight::Zero())
continue;
241 for (ArcIterator<ComposeFst<Arc>> aiter(ico, state); !aiter.Done();
243 const auto &arc = aiter.Value();
244 const auto &beta = fb.
Beta(arc.nextstate);
246 if (beta == Weight::Zero())
continue;
250 ch_state, arc.ilabel, arc.olabel,
251 Divide(Times(Times(
alpha, arc.weight), beta), likelihood),
254 const auto weight = ico.Final(state);
255 if (weight == Weight::Zero())
continue;
258 table->Forward(ch_state, Divide(Times(
alpha, weight), likelihood));
271 const auto old_term = Times(1 - nu_k, old_weight);
272 const auto new_term = Times(nu_k, new_weight);
278 void Backward(
const ExpectationTable &table, MutableFst<Arc> *model) {
279 const double nu_k = alpha_ == 0.0 ? 1.0 : std::pow(step_ + 2, -alpha_);
280 for (StateIterator<MutableFst<Arc>> siter(*model); !siter.Done();
282 const auto state = siter.Value();
284 for (MutableArcIterator<MutableFst<Arc>> aiter(model, state);
285 !aiter.Done(); aiter.Next()) {
286 auto arc = aiter.Value();
287 arc.weight = Interpolate(arc.weight, table.Backward(state, arc), nu_k);
292 state, Interpolate(model->Final(state), table.Backward(state), nu_k));
297 const int batch_size_;
303 template <
class Arc,
class ExpectationTable>
304 typename Arc::Weight
Train(FarReader<Arc> &input, FarReader<Arc> &output,
305 MutableFst<Arc> *model,
307 using Weight =
typename Arc::Weight;
308 auto last_likelihood = Weight::Zero();
310 opts.alpha, opts.batch_size, opts.copts);
312 for (
int iteration = 0; iteration < opts.max_iters; ++iteration) {
315 const auto total_likelihood = trainer.
Train(input, output, model);
316 LOG(INFO) <<
"Iteration " << iteration + 1
317 <<
" total likelihood: " << total_likelihood;
318 if (ApproxEqual(last_likelihood, total_likelihood, opts.delta)) {
319 return total_likelihood;
321 last_likelihood = total_likelihood;
323 return last_likelihood;
330 typename Arc::Weight
Train(FarReader<Arc> &input, FarReader<Arc> &output,
331 MutableFst<Arc> *model,
bool normalize_ilabel =
true,
333 if (normalize_ilabel) {
334 return internal::Train<Arc, StateILabelExpectationTable<Arc>>(input, output,
337 return internal::Train<Arc, StateExpectationTable<Arc>>(input, output,
344 #endif // NLP_GRM2_BAUMWELCH_TRAIN_H_ typename Arc::StateId StateId
void Add(const Weight &weight)
StateId ChannelState(StateId ico_state) const
typename Arc::Weight Weight
void Normalize(MutableFst< Arc > *model)
TrainOptions(int max_iters=kMaxIters, float alpha=kAlpha, int batch_size=0, float delta=kDelta, const CascadeOptions &copts=CascadeOptions())
StepwiseBaumWelchTrainer(float alpha=kAlpha, int batch_size=0, const CascadeOptions &opts=CascadeOptions())
Weight Train(FarReader< Arc > &input, FarReader< Arc > &output, MutableFst< Arc > *model)
const Weight & Alpha(StateId s) const
ForwardBackward(const ComposeFst< Arc > &ico)
Arc::Weight Train(FarReader< Arc > &input, FarReader< Arc > &output, MutableFst< Arc > *model, bool normalize_ilabel=true, const TrainOptions &opts=TrainOptions())
const ComposeFst< Arc > & GetFst() const
const Weight & Beta(StateId s) const
Weight Batch(FarReader< Arc > &input, FarReader< Arc > &output, MutableFst< Arc > *model)
typename Arc::Weight Weight