BaumWelch  baumwelch-0.3.8
OpenGrm-BaumWelch library
decode.h
Go to the documentation of this file.
1 // Licensed under the Apache License, Version 2.0 (the "License");
2 // you may not use this file except in compliance with the License.
3 // You may obtain a copy of the License at
4 //
5 // http://www.apache.org/licenses/LICENSE-2.0
6 //
7 // Unless required by applicable law or agreed to in writing, software
8 // distributed under the License is distributed on an "AS IS" BASIS,
9 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 // See the License for the specific language governing permissions and
11 // limitations under the License.
12 //
13 // Copyright 2017 and onwards Google, Inc.
14 
15 #ifndef NLP_GRM2_BAUMWELCH_DECODE_H_
16 #define NLP_GRM2_BAUMWELCH_DECODE_H_
17 
18 // This header defines functions for decoding outputs. Both FST and
19 // FAR inputs/outputs are supported.
20 //
21 // By convention, operations in this library that work with FarReader input
22 // reset the FAR to its initial position upon completion.
23 
24 #include <numeric>
25 #include <vector>
26 
27 #include <fst/extensions/far/far.h>
28 #include <fst/compact-fst.h>
29 #include <fst/fst-decl.h>
30 #include <fst/project.h>
31 #include <fst/rmepsilon.h>
32 #include <fst/shortest-path.h>
33 #include <fst/statesort.h>
34 #include <fst/vector-fst.h>
35 #include <fst/weight.h>
36 #include <baumwelch/a-star.h>
37 #include <baumwelch/cascade.h>
38 
39 namespace fst {
40 namespace internal {
41 
42 // Helper function that reverses the state numbering of the output of single
43 // ShortestPath. This is necessary because the single shortest path produces
44 // an automaton with descendingly ordered state numbering, and this in turn
45 // is not compatible with the definition of the kString property assumed by
46 // string compactor.
47 template <class Arc>
48 void ReverseStateNumbering(MutableFst<Arc> *fst) {
49  using StateId = typename Arc::StateId;
50  // TODO(kbg): use std::ranges::iota_view once C++20 is more widely available.
51  std::vector<StateId> order(fst->NumStates());
52  std::iota(order.rbegin(), order.rend(), 0);
53  StateSort(fst, order);
54 }
55 
56 // Pair decoding.
57 template <class Arc>
58 VectorFst<Arc> DecodePair(const Fst<Arc> &ifst) {
59  VectorFst<Arc> ofst;
60  if constexpr (IsPath<typename Arc::Weight>::value) {
61  ShortestPath(ifst, &ofst);
62  } else {
63  AStarSingleShortestString(ifst, &ofst);
64  }
65  ReverseStateNumbering(&ofst);
66  return ofst;
67 }
68 
69 // Decipherment decoding.
70 template <class Arc>
71 CompactWeightedStringFst<Arc> DecodeDecipherment(const Fst<Arc> &ifst) {
72  VectorFst<Arc> lattice;
73  Project(ifst, &lattice, ProjectType::INPUT);
74  RmEpsilon(&lattice);
75  VectorFst<Arc> ofst;
76  if constexpr (IsPath<typename Arc::Weight>::value) {
77  ShortestPath(lattice, &ofst);
78  } else {
79  AStarSingleShortestString(lattice, &ofst);
80  }
81  ReverseStateNumbering(&ofst);
82  return CompactWeightedStringFst<Arc>(ofst);
83 }
84 
85 } // namespace internal
86 
87 // Full decipherment setup.
88 template <class Arc>
89 void Decode(FarReader<Arc> &input, FarReader<Arc> &output,
90  const Fst<Arc> &model, FarWriter<Arc> &hypotext) {
91  while (!input.Done() && !output.Done()) {
92  const SimpleCascade<Arc> cascade(*input.GetFst(), *output.GetFst(), model);
93  if (input.Type() == FarType::FST) {
94  hypotext.Add(output.GetKey(),
95  internal::DecodeDecipherment(cascade.GetFst()));
96  } else {
97  hypotext.Add(input.GetKey() + "_" + output.GetKey(),
98  internal::DecodePair(cascade.GetFst()));
99  input.Next();
100  }
101  output.Next();
102  }
103 }
104 
105 } // namespace fst
106 
107 #endif // NLP_GRM2_BAUMWELCH_DECODE_H_
108 
Definition: a-star.h:30
CompactWeightedStringFst< Arc > DecodeDecipherment(const Fst< Arc > &ifst)
Definition: decode.h:71
VectorFst< Arc > DecodePair(const Fst< Arc > &ifst)
Definition: decode.h:58
void Decode(FarReader< Arc > &input, FarReader< Arc > &output, const Fst< Arc > &model, FarWriter< Arc > &hypotext)
Definition: decode.h:89
void ReverseStateNumbering(MutableFst< Arc > *fst)
Definition: decode.h:48
void AStarSingleShortestString(const Fst< Arc > &ifst, MutableFst< Arc > *ofst, float delta=kShortestDelta)
Definition: a-star.h:42