BaumWelch  baumwelch-0.3.8
OpenGrm-BaumWelch library
randomize.h
Go to the documentation of this file.
1 // Licensed under the Apache License, Version 2.0 (the "License");
2 // you may not use this file except in compliance with the License.
3 // You may obtain a copy of the License at
4 //
5 // http://www.apache.org/licenses/LICENSE-2.0
6 //
7 // Unless required by applicable law or agreed to in writing, software
8 // distributed under the License is distributed on an "AS IS" BASIS,
9 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 // See the License for the specific language governing permissions and
11 // limitations under the License.
12 //
13 // Copyright 2017 and onwards Google, Inc.
14 
15 #ifndef NLP_GRM2_BAUMWELCH_RANDOMIZE_H_
16 #define NLP_GRM2_BAUMWELCH_RANDOMIZE_H_
17 
18 #include <cmath>
19 #include <cstdint>
20 #include <random>
21 
22 #include <fst/fst.h>
23 #include <fst/mutable-fst.h>
24 #include <fst/weight.h>
25 
26 namespace fst {
27 namespace internal {
28 
29 // Random weight generator in the (real) interval [kDelta, 1).
30 template <class Weight>
32  public:
33  using ValueType = typename Weight::ValueType;
34 
35  explicit LogUniformGenerator(uint64_t seed = std::random_device()())
36  : rand_(seed), dist_(kDelta, 1.0) {}
37 
38  Weight operator()() const { return Weight(-std::log(dist_(rand_))); }
39 
40  private:
41  mutable std::mt19937_64 rand_;
42  mutable std::uniform_real_distribution<ValueType> dist_;
43 };
44 
45 } // namespace internal
46 
47 template <class Arc>
48 void Randomize(MutableFst<Arc> *fst, uint64_t seed = std::random_device()()) {
49  using Weight = typename Arc::Weight;
50  const internal::LogUniformGenerator<Weight> generator(seed);
51  for (StateIterator<MutableFst<Arc>> siter(*fst); !siter.Done();
52  siter.Next()) {
53  const auto state = siter.Value();
54  // Arcs leaving this state.
55  for (MutableArcIterator<MutableFst<Arc>> aiter(fst, state); !aiter.Done();
56  aiter.Next()) {
57  auto arc = aiter.Value();
58  arc.weight = generator();
59  aiter.SetValue(arc);
60  }
61  // Final weight.
62  if (fst->Final(state) != Weight::Zero()) fst->SetFinal(state, generator());
63  }
64 }
65 
66 } // namespace fst
67 
68 #endif // NLP_GRM2_BAUMWELCH_RANDOMIZE_H_
69 
typename Weight::ValueType ValueType
Definition: randomize.h:33
LogUniformGenerator(uint64_t seed=std::random_device()())
Definition: randomize.h:35
Definition: a-star.h:30
void Randomize(MutableFst< Arc > *fst, uint64_t seed=std::random_device()())
Definition: randomize.h:48