blob: 9f0a331e3f990bef16c855141a5c028cdd42010c [file] [log] [blame]
/*
* Copyright (C) 2012 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LATINIME_TYPING_WEIGHTING_H
#define LATINIME_TYPING_WEIGHTING_H
#include "defines.h"
#include "suggest/core/dicnode/dic_node_utils.h"
#include "suggest/core/layout/touch_position_correction_utils.h"
#include "suggest/core/policy/weighting.h"
#include "suggest/core/session/dic_traverse_session.h"
#include "suggest/policyimpl/typing/scoring_params.h"
#include "utils/char_utils.h"
namespace latinime {
class DicNode;
struct DicNode_InputStateG;
class MultiBigramMap;
class TypingWeighting : public Weighting {
public:
static const TypingWeighting *getInstance() { return &sInstance; }
protected:
float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const {
float cost = 0.0f;
if (dicNode->hasMultipleWords()) {
cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST;
}
if (dicNode->getProximityCorrectionCount() > 0) {
cost += ScoringParams::HAS_PROXIMITY_TERMINAL_COST;
}
if (dicNode->getEditCorrectionCount() > 0) {
cost += ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST;
}
return cost;
}
float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const {
const bool isZeroCostOmission = parentDicNode->isZeroCostOmission();
const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode);
// If the traversal omitted the first letter then the dicNode should now be on the second.
const bool isFirstLetterOmission = dicNode->getNodeCodePointCount() == 2;
float cost = 0.0f;
if (isZeroCostOmission) {
cost = 0.0f;
} else if (isFirstLetterOmission) {
cost = ScoringParams::OMISSION_COST_FIRST_CHAR;
} else {
cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR
: ScoringParams::OMISSION_COST;
}
return cost;
}
float getMatchedCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const {
const int pointIndex = dicNode->getInputIndex(0);
// Note: min() required since length can be MAX_POINT_TO_KEY_LENGTH for characters not on
// the keyboard (like accented letters)
const float normalizedSquaredLength = traverseSession->getProximityInfoState(0)
->getPointToKeyLength(pointIndex,
CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()));
const float normalizedDistance = TouchPositionCorrectionUtils::getSweetSpotFactor(
traverseSession->isTouchPositionCorrectionEnabled(), normalizedSquaredLength);
const float weightedDistance = ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance;
const bool isFirstChar = pointIndex == 0;
const bool isProximity = isProximityDicNode(traverseSession, dicNode);
float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_CHAR_PROXIMITY_COST
: ScoringParams::PROXIMITY_COST) : 0.0f;
if (isProximity && dicNode->getProximityCorrectionCount() == 0) {
cost += ScoringParams::FIRST_PROXIMITY_COST;
}
if (dicNode->getNodeCodePointCount() == 2) {
// At the second character of the current word, we check if the first char is uppercase
// and the word is a second or later word of a multiple word suggestion. We demote it
// if so.
const bool isSecondOrLaterWordFirstCharUppercase =
dicNode->hasMultipleWords() && dicNode->isFirstCharUppercase();
if (isSecondOrLaterWordFirstCharUppercase) {
cost += ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE;
}
}
return weightedDistance + cost;
}
bool isProximityDicNode(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const {
const int pointIndex = dicNode->getInputIndex(0);
const int primaryCodePoint = CharUtils::toBaseLowerCase(
traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex));
const int dicNodeChar = CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint());
return primaryCodePoint != dicNodeChar;
}
float getTranspositionCost(const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const {
const int16_t parentPointIndex = parentDicNode->getInputIndex(0);
const int prevCodePoint = parentDicNode->getNodeCodePoint();
const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
parentPointIndex + 1, CharUtils::toBaseLowerCase(prevCodePoint));
const int codePoint = dicNode->getNodeCodePoint();
const float distance2 = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
parentPointIndex, CharUtils::toBaseLowerCase(codePoint));
const float distance = distance1 + distance2;
const float weightedLengthDistance =
distance * ScoringParams::DISTANCE_WEIGHT_LENGTH;
return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance;
}
float getInsertionCost(const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const {
const int16_t insertedPointIndex = parentDicNode->getInputIndex(0);
const int prevCodePoint = traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(
insertedPointIndex);
const int currentCodePoint = dicNode->getNodeCodePoint();
const bool sameCodePoint = prevCodePoint == currentCodePoint;
const bool existsAdjacentProximityChars = traverseSession->getProximityInfoState(0)
->existsAdjacentProximityChars(insertedPointIndex);
const float dist = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
insertedPointIndex + 1, CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()));
const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LENGTH;
const bool singleChar = dicNode->getNodeCodePointCount() == 1;
float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f);
if (sameCodePoint) {
cost += ScoringParams::INSERTION_COST_SAME_CHAR;
} else if (existsAdjacentProximityChars) {
cost += ScoringParams::INSERTION_COST_PROXIMITY_CHAR;
} else {
cost += ScoringParams::INSERTION_COST;
}
return cost + weightedDistance;
}
float getNewWordSpatialCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const {
return ScoringParams::COST_NEW_WORD * traverseSession->getMultiWordCostMultiplier();
}
float getNewWordBigramLanguageCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode,
MultiBigramMap *const multiBigramMap) const {
return DicNodeUtils::getBigramNodeImprobability(
traverseSession->getDictionaryStructurePolicy(),
dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
}
float getCompletionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const {
// The auto completion starts when the input index is same as the input size
const bool firstCompletion = dicNode->getInputIndex(0)
== traverseSession->getInputSize();
// TODO: Change the cost for the first completion for the gesture?
const float cost = firstCompletion ? ScoringParams::COST_FIRST_LOOKAHEAD
: ScoringParams::COST_LOOKAHEAD;
return cost;
}
float getTerminalLanguageCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, const float dicNodeLanguageImprobability) const {
return dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
}
float getTerminalInsertionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const {
const int inputIndex = dicNode->getInputIndex(0);
const int inputSize = traverseSession->getInputSize();
ASSERT(inputIndex < inputSize);
// TODO: Implement more efficient logic
return ScoringParams::TERMINAL_INSERTION_COST * (inputSize - inputIndex);
}
AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const {
return false;
}
AK_FORCE_INLINE float getAdditionalProximityCost() const {
return ScoringParams::ADDITIONAL_PROXIMITY_COST;
}
AK_FORCE_INLINE float getSubstitutionCost() const {
return ScoringParams::SUBSTITUTION_COST;
}
AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const {
const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + ScoringParams::COST_NEW_WORD;
return cost * traverseSession->getMultiWordCostMultiplier();
}
ErrorType getErrorType(const CorrectionType correctionType,
const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const;
private:
DISALLOW_COPY_AND_ASSIGN(TypingWeighting);
static const TypingWeighting sInstance;
TypingWeighting() {}
~TypingWeighting() {}
};
} // namespace latinime
#endif // LATINIME_TYPING_WEIGHTING_H