Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 628170507
  • Loading branch information
kylebgorman authored and copybara-github committed May 16, 2024
1 parent bb4fd3e commit 03dc28e
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 28 deletions.
5 changes: 2 additions & 3 deletions mozolm/models/ngram_char_fst_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,13 @@ class NGramCharFstModel : public NGramFstModel {
// Computes negative log probability for observing the supplied label in a
// given state.
fst::StdArc::Weight LabelCostInState(fst::StdArc::StateId state,
fst::StdArc::Label label) const;
fst::StdArc::Label label) const;

private:
fst::StdArc::Label SymLabel(int utf8_sym) const;

// Returns negative log probability of the end-of-string at the given state.
fst::StdArc::Weight FinalCostInState(
fst::StdArc::StateId state) const;
fst::StdArc::Weight FinalCostInState(fst::StdArc::StateId state) const;
};

} // namespace models
Expand Down
2 changes: 1 addition & 1 deletion mozolm/models/ngram_char_fst_model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
#include "nisaba/port/utf8_util.h"
#include "nisaba/port/test_utils.h"

using nisaba::testing::TestFilePath;
using fst::StdArc;
using nisaba::testing::TestFilePath;

namespace mozolm {
namespace models {
Expand Down
8 changes: 3 additions & 5 deletions mozolm/models/ngram_fst_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ class NGramFstModel : public LanguageModel {

// Returns the next state reached by arc labeled with label from state s.
// If the label is out-of-vocabulary, it will return the unigram state.
fst::StdArc::StateId NextModelState(
fst::StdArc::StateId current_state,
fst::StdArc::Label label) const;
fst::StdArc::StateId NextModelState(fst::StdArc::StateId current_state,
fst::StdArc::Label label) const;

// Language model represented by vector FST.
std::unique_ptr<const fst::StdVectorFst> fst_;
Expand All @@ -66,8 +65,7 @@ class NGramFstModel : public LanguageModel {

// Checks the current state and sets it to the unigram state if less than
// zero.
fst::StdArc::StateId CheckCurrentState(
fst::StdArc::StateId state) const;
fst::StdArc::StateId CheckCurrentState(fst::StdArc::StateId state) const;

private:
// Performs model sanity check.
Expand Down
4 changes: 2 additions & 2 deletions mozolm/models/ngram_word_fst_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class NGramImplicitStates {
public:
NGramImplicitStates() = default;

NGramImplicitStates(const fst::StdVectorFst& fst,
int first_char_begin_index, int first_char_end_index);
NGramImplicitStates(const fst::StdVectorFst& fst, int first_char_begin_index,
int first_char_end_index);

// Returns the state if already exists, creates it otherwise.
absl::StatusOr<int> GetState(int model_state, int prefix_length,
Expand Down
2 changes: 1 addition & 1 deletion mozolm/models/ngram_word_fst_model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@
#include "nisaba/port/test_utils.h"
#include "nisaba/port/utf8_util.h"

using ::nisaba::testing::TestFilePath;
using ::fst::ArcSort;
using ::fst::ILabelCompare;
using ::fst::StdArc;
using ::fst::StdVectorFst;
using ::fst::SymbolTable;
using ::nisaba::testing::TestFilePath;

namespace mozolm {
namespace models {
Expand Down
19 changes: 8 additions & 11 deletions mozolm/models/ppm_as_fst_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,9 @@ class PpmAsFstModel : public LanguageModel {
// Fills in cache vectors of negative log probabilities and destination states
// for each item in the vocabulary, matching indices with the symbol table. By
// convention, index 0 is for final cost.
absl::Status UpdateCacheAtNonEmptyState(
fst::StdArc::StateId s, fst::StdArc::StateId backoff_state,
const PpmStateCache& backoff_cache);
absl::Status UpdateCacheAtNonEmptyState(fst::StdArc::StateId s,
fst::StdArc::StateId backoff_state,
const PpmStateCache& backoff_cache);

// Checks if lower order state caches have updated more recently.
bool LowerOrderCacheUpdated(fst::StdArc::StateId s) const;
Expand All @@ -277,16 +277,14 @@ class PpmAsFstModel : public LanguageModel {
absl::StatusOr<int> AddNewState(fst::StdArc::StateId backoff_dest_state);

// Returns origin state of arc with symbol from state s.
absl::StatusOr<int> GetArcOriginState(fst::StdArc::StateId s,
int sym_index);
absl::StatusOr<int> GetArcOriginState(fst::StdArc::StateId s, int sym_index);

// Returns destination state of arc with symbol from state s.
absl::StatusOr<int> GetDestinationState(fst::StdArc::StateId s,
int sym_index);

// Returns probability of symbol leaving the current state.
absl::StatusOr<double> GetNegLogProb(fst::StdArc::StateId s,
int sym_index);
absl::StatusOr<double> GetNegLogProb(fst::StdArc::StateId s, int sym_index);

// Returns normalization value at the current state.
absl::StatusOr<double> GetNormalization(fst::StdArc::StateId s);
Expand All @@ -303,13 +301,12 @@ class PpmAsFstModel : public LanguageModel {

// Updates model with an observation of the sym_index at curr_state.
absl::StatusOr<fst::StdArc::StateId> UpdateModel(
fst::StdArc::StateId curr_state,
fst::StdArc::StateId highest_found_state, int sym_index);
fst::StdArc::StateId curr_state, fst::StdArc::StateId highest_found_state,
int sym_index);

// Converts input string into linear FST at the character level, replacing
// characters not in possible_characters_ set (if non-empty) with kOovSymbol.
absl::StatusOr<fst::StdVectorFst> String2Fst(
const std::string& input_string);
absl::StatusOr<fst::StdVectorFst> String2Fst(const std::string& input_string);

// Adds a single unigram count to every character.
absl::Status AddPriorCounts();
Expand Down
4 changes: 2 additions & 2 deletions mozolm/models/ppm_as_fst_model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ namespace {
constexpr float kFloatDelta = 0.00001; // Delta for float comparisons.
constexpr char kVocabFileName[] = "vocab.txt";

using ::nisaba::file::WriteTempTextFile;
using ::nisaba::utf8::DecodeSingleUnicodeChar;
using ::fst::ArcSort;
using ::fst::ILabelCompare;
using ::fst::Isomorphic;
using ::fst::StdArc;
using ::fst::StdVectorFst;
using ::fst::SymbolTable;
using ::nisaba::file::WriteTempTextFile;
using ::nisaba::utf8::DecodeSingleUnicodeChar;
using ::testing::DoubleEq;
using ::testing::Each;

Expand Down
3 changes: 1 addition & 2 deletions mozolm/utils/ngram_fst_relabel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ absl::Status CheckProperties(const StdVectorFst &fst) {
} // namespace

absl::Status RelabelWithCodepoints(
const std::vector<std::string> &keep_symbols_vec,
fst::StdVectorFst *fst) {
const std::vector<std::string> &keep_symbols_vec, fst::StdVectorFst *fst) {
RETURN_IF_ERROR(CheckProperties(*fst));
GOOGLE_LOG(INFO) << "Building input/output mappings and relabeling ...";
const absl::flat_hash_set<std::string> keep_symbols(keep_symbols_vec.begin(),
Expand Down
2 changes: 1 addition & 1 deletion mozolm/utils/ngram_fst_relabel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@

using fst::ArcIterator;
using fst::FstCompiler;
using fst::kError;
using fst::StateIterator;
using fst::StdArc;
using fst::StdVectorFst;
using fst::SymbolTable;
using fst::kError;

namespace mozolm {
namespace {
Expand Down

0 comments on commit 03dc28e

Please sign in to comment.