blob: 0c40168930dd6267516c36796300d99ddf20ccd2 [file] [log] [blame]
/*
* Copyright (C) 2013 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "suggest/core/policy/weighting.h"
#include "defines.h"
#include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/dicnode/dic_node_profiler.h"
#include "suggest/core/dicnode/dic_node_utils.h"
#include "suggest/core/session/dic_traverse_session.h"
namespace latinime {
class MultiBigramMap;
static inline void profile(const CorrectionType correctionType, DicNode *const node) {
#if DEBUG_DICT
switch (correctionType) {
case CT_OMISSION:
PROF_OMISSION(node->mProfiler);
return;
case CT_ADDITIONAL_PROXIMITY:
PROF_ADDITIONAL_PROXIMITY(node->mProfiler);
return;
case CT_SUBSTITUTION:
PROF_SUBSTITUTION(node->mProfiler);
return;
case CT_NEW_WORD_SPACE_OMISSION:
PROF_NEW_WORD(node->mProfiler);
return;
case CT_MATCH:
PROF_MATCH(node->mProfiler);
return;
case CT_COMPLETION:
PROF_COMPLETION(node->mProfiler);
return;
case CT_TERMINAL:
PROF_TERMINAL(node->mProfiler);
return;
case CT_TERMINAL_INSERTION:
PROF_TERMINAL_INSERTION(node->mProfiler);
return;
case CT_NEW_WORD_SPACE_SUBSTITUTION:
PROF_SPACE_SUBSTITUTION(node->mProfiler);
return;
case CT_INSERTION:
PROF_INSERTION(node->mProfiler);
return;
case CT_TRANSPOSITION:
PROF_TRANSPOSITION(node->mProfiler);
return;
default:
// do nothing
return;
}
#else
// do nothing
#endif
}
/* static */ void Weighting::addCostAndForwardInputIndex(const Weighting *const weighting,
const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, DicNode *const dicNode,
MultiBigramMap *const multiBigramMap) {
const int inputSize = traverseSession->getInputSize();
DicNode_InputStateG inputStateG;
inputStateG.mNeedsToUpdateInputStateG = false; // Don't use input info by default
const float spatialCost = Weighting::getSpatialCost(weighting, correctionType,
traverseSession, parentDicNode, dicNode, &inputStateG);
const float languageCost = Weighting::getLanguageCost(weighting, correctionType,
traverseSession, parentDicNode, dicNode, multiBigramMap);
const ErrorType errorType = weighting->getErrorType(correctionType, traverseSession,
parentDicNode, dicNode);
profile(correctionType, dicNode);
if (inputStateG.mNeedsToUpdateInputStateG) {
dicNode->updateInputIndexG(&inputStateG);
} else {
dicNode->forwardInputIndex(0, getForwardInputCount(correctionType),
(correctionType == CT_TRANSPOSITION));
}
dicNode->addCost(spatialCost, languageCost, weighting->needsToNormalizeCompoundDistance(),
inputSize, errorType);
if (CT_NEW_WORD_SPACE_OMISSION == correctionType) {
// When we are on a terminal, we save the current distance for evaluating
// when to auto-commit partial suggestions.
dicNode->saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet();
}
}
/* static */ float Weighting::getSpatialCost(const Weighting *const weighting,
const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode,
DicNode_InputStateG *const inputStateG) {
switch(correctionType) {
case CT_OMISSION:
return weighting->getOmissionCost(parentDicNode, dicNode);
case CT_ADDITIONAL_PROXIMITY:
// only used for typing
return weighting->getAdditionalProximityCost();
case CT_SUBSTITUTION:
// only used for typing
return weighting->getSubstitutionCost();
case CT_NEW_WORD_SPACE_OMISSION:
return weighting->getNewWordSpatialCost(traverseSession, dicNode, inputStateG);
case CT_MATCH:
return weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
case CT_COMPLETION:
return weighting->getCompletionCost(traverseSession, dicNode);
case CT_TERMINAL:
return weighting->getTerminalSpatialCost(traverseSession, dicNode);
case CT_TERMINAL_INSERTION:
return weighting->getTerminalInsertionCost(traverseSession, dicNode);
case CT_NEW_WORD_SPACE_SUBSTITUTION:
return weighting->getSpaceSubstitutionCost(traverseSession, dicNode);
case CT_INSERTION:
return weighting->getInsertionCost(traverseSession, parentDicNode, dicNode);
case CT_TRANSPOSITION:
return weighting->getTranspositionCost(traverseSession, parentDicNode, dicNode);
default:
return 0.0f;
}
}
/* static */ float Weighting::getLanguageCost(const Weighting *const weighting,
const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode,
MultiBigramMap *const multiBigramMap) {
switch(correctionType) {
case CT_OMISSION:
return 0.0f;
case CT_SUBSTITUTION:
return 0.0f;
case CT_NEW_WORD_SPACE_OMISSION:
return weighting->getNewWordBigramLanguageCost(
traverseSession, parentDicNode, multiBigramMap);
case CT_MATCH:
return 0.0f;
case CT_COMPLETION:
return 0.0f;
case CT_TERMINAL: {
const float languageImprobability =
DicNodeUtils::getBigramNodeImprobability(
traverseSession->getDictionaryStructurePolicy(), dicNode, multiBigramMap);
return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability);
}
case CT_TERMINAL_INSERTION:
return 0.0f;
case CT_NEW_WORD_SPACE_SUBSTITUTION:
return weighting->getNewWordBigramLanguageCost(
traverseSession, parentDicNode, multiBigramMap);
case CT_INSERTION:
return 0.0f;
case CT_TRANSPOSITION:
return 0.0f;
default:
return 0.0f;
}
}
/* static */ int Weighting::getForwardInputCount(const CorrectionType correctionType) {
switch(correctionType) {
case CT_OMISSION:
return 0;
case CT_ADDITIONAL_PROXIMITY:
return 0; /* 0 because CT_MATCH will be called */
case CT_SUBSTITUTION:
return 0; /* 0 because CT_MATCH will be called */
case CT_NEW_WORD_SPACE_OMISSION:
return 0;
case CT_MATCH:
return 1;
case CT_COMPLETION:
return 1;
case CT_TERMINAL:
return 0;
case CT_TERMINAL_INSERTION:
return 1;
case CT_NEW_WORD_SPACE_SUBSTITUTION:
return 1;
case CT_INSERTION:
return 2; /* look ahead + skip the current char */
case CT_TRANSPOSITION:
return 2; /* look ahead + skip the current char */
default:
return 0;
}
}
} // namespace latinime