Time-based resampling
Currently, we are relying on the number of samples in order to do certain
actions, like checking for hold or for running the rejection model.
Before this CL, the resampling could end up producing several samples at
once. This would cause the actions above to potentially be skipped,
because we performed the check of the number of samples before the
resampling code ran. An exact match for certain number of samples was
performed, so sometimes, palm rejection model would not get executed.
In this CL, we change the approach to rely on the time instead. We check
whether a specific time threshold is crossed with each sample received
from the touchscreen. If it's crossed, then the action is taken.
Additional changes here:
1) All samples are now stored in unmodified manner.
2) It's possible to exceed max_sample_count when resampling is enabled
3) samples_seen() is replaced with Duration()
4) When a resampled value is needed, we get getSampleAt(time). The
resampled value is never stored.
5) Old variables such as last_sample and old function Resample were
deleted
6) We are now storing the first sample's timestamp in first_sample_time.
It's needed to calculate the duration of the stroke, since at some
point, we will drop the first sample
After this change, any references to samples_seen and samples().size()
should consider the resampled and non-resampled cases.
Therefore, the "Filter" function should not directly check the number of
samples anywhere.
Bug: b/240168494
Change-Id: I983aa93ce7ae352aca748d4bbab8fd2d1af28df5
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3857630
Commit-Queue: Xinglong Luan <alanlxl@chromium.org>
Reviewed-by: Rob Schonberger <robsc@chromium.org>
Auto-Submit: Siarhei Vishniakou <svv@google.com>
Reviewed-by: Xinglong Luan <alanlxl@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1049397}
diff --git a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc
index 586b30c..1365bbc 100644
--- a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc
+++ b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.cc
@@ -33,20 +33,49 @@
bool IsEarlyStageSample(
const PalmFilterStroke& stroke,
const NeuralStylusPalmDetectionFilterModelConfig& config) {
- return config.early_stage_sample_counts.find(stroke.samples_seen()) !=
- config.early_stage_sample_counts.end();
+ if (!config.resample_period) {
+ return config.early_stage_sample_counts.find(stroke.samples_seen()) !=
+ config.early_stage_sample_counts.end();
+ }
+ // Duration is not well-defined for sample_count <= 1, so we handle
+ // it separately.
+ if (stroke.samples().empty()) {
+ return false;
+ }
+ if (stroke.samples().size() == 1) {
+ return config.early_stage_sample_counts.find(1) !=
+ config.early_stage_sample_counts.end();
+ }
+ for (const uint32_t sample_count : config.early_stage_sample_counts) {
+ const base::TimeDelta duration = config.GetEquivalentDuration(sample_count);
+ // Previous sample must not have passed the 'duration' threshold, but the
+ // current sample must pass the threshold
+ if (stroke.LastSampleCrossed(duration)) {
+ return true;
+ }
+ }
+ return false;
}
bool HasDecidedStroke(
const PalmFilterStroke& stroke,
const NeuralStylusPalmDetectionFilterModelConfig& config) {
- return stroke.samples_seen() >= config.max_sample_count;
+ if (!config.resample_period) {
+ return stroke.samples_seen() >= config.max_sample_count;
+ }
+ const base::TimeDelta max_duration =
+ config.GetEquivalentDuration(config.max_sample_count);
+ return stroke.Duration() >= max_duration;
}
bool IsVeryShortStroke(
const PalmFilterStroke& stroke,
const NeuralStylusPalmDetectionFilterModelConfig& config) {
- return stroke.samples_seen() < config.min_sample_count;
+ if (!config.resample_period) {
+ return stroke.samples_seen() < config.min_sample_count;
+ }
+ return stroke.Duration() <
+ config.GetEquivalentDuration(config.min_sample_count);
}
/**
@@ -54,9 +83,15 @@
* being evaluated. The parameter 'neighbor_min_sample_count' might be different
* from the config, depending on the specific usage in the caller.
*/
-bool HasInsufficientDataAsNeighbor(const PalmFilterStroke& neighbor_stroke,
- size_t neighbor_min_sample_count) {
- return neighbor_stroke.samples().size() < neighbor_min_sample_count;
+bool HasInsufficientDataAsNeighbor(
+ const PalmFilterStroke& neighbor_stroke,
+ size_t neighbor_min_sample_count,
+ const NeuralStylusPalmDetectionFilterModelConfig& config) {
+ if (!config.resample_period) {
+ return neighbor_stroke.samples().size() < neighbor_min_sample_count;
+ }
+ return neighbor_stroke.Duration() <
+ config.GetEquivalentDuration(neighbor_min_sample_count);
}
} // namespace
@@ -91,7 +126,8 @@
if (neighbor.tracking_id() == stroke.tracking_id()) {
continue;
}
- if (HasInsufficientDataAsNeighbor(neighbor, neighbor_min_sample_count)) {
+ if (HasInsufficientDataAsNeighbor(neighbor, neighbor_min_sample_count,
+ model_->config())) {
continue;
}
float distance =
@@ -131,7 +167,8 @@
if (neighbor.tracking_id() == stroke.tracking_id()) {
continue;
}
- if (HasInsufficientDataAsNeighbor(neighbor, neighbor_min_sample_count)) {
+ if (HasInsufficientDataAsNeighbor(neighbor, neighbor_min_sample_count,
+ model_->config())) {
continue;
}
float distance =
@@ -286,15 +323,29 @@
const PalmFilterStroke& stroke) const {
const NeuralStylusPalmDetectionFilterModelConfig& config = model_->config();
// Inference only executed once per stroke
- return stroke.samples_seen() == config.max_sample_count;
+ if (!config.resample_period) {
+ return stroke.samples_seen() == config.max_sample_count;
+ }
+ return stroke.LastSampleCrossed(
+ config.GetEquivalentDuration(config.max_sample_count));
}
bool NeuralStylusPalmDetectionFilter::IsHeuristicPalmStroke(
const PalmFilterStroke& stroke) const {
const auto& config = model_->config();
- if (stroke.samples().size() >= config.max_sample_count) {
- LOG(DFATAL) << "Should not call this method on long strokes.";
- return false;
+ if (config.resample_period) {
+ if (stroke.Duration() >
+ config.GetEquivalentDuration(config.max_sample_count)) {
+ LOG(DFATAL)
+ << "Should not call this method on long strokes. Got duration = "
+ << stroke.Duration();
+ return false;
+ }
+ } else {
+ if (stroke.samples().size() >= config.max_sample_count) {
+ LOG(DFATAL) << "Should not call this method on long strokes.";
+ return false;
+ }
}
if (config.heuristic_palm_touch_limit > 0.0) {
@@ -387,6 +438,9 @@
void NeuralStylusPalmDetectionFilter::AppendFeatures(
const PalmFilterStroke& stroke,
std::vector<float>* features) const {
+ if (model_->config().resample_period) {
+ return AppendResampledFeatures(stroke, features);
+ }
const int size = stroke.samples().size();
for (int i = 0; i < size; ++i) {
const PalmFilterSample& sample = stroke.samples()[i];
@@ -421,6 +475,59 @@
features->push_back(samples_seen - model_->config().max_sample_count);
}
}
+
+/**
+ * The flow here is similar to 'AppendFeatures' above, but we rely on the
+ * timing of the samples rather than on the explicit number / position of
+ * samples.
+ */
+void NeuralStylusPalmDetectionFilter::AppendResampledFeatures(
+ const PalmFilterStroke& stroke,
+ std::vector<float>* features) const {
+ size_t sample_count = 0;
+ const base::TimeTicks& first_time = stroke.samples()[0].time;
+ const base::TimeDelta& resample_period = *model_->config().resample_period;
+ const base::TimeDelta max_duration =
+ model_->config().GetEquivalentDuration(model_->config().max_sample_count);
+ for (auto time = first_time; (time - first_time) <= max_duration &&
+ time <= stroke.samples().back().time;
+ time += resample_period) {
+ sample_count++;
+ const PalmFilterSample& sample = stroke.GetSampleAt(time);
+ features->push_back(sample.major_radius);
+ features->push_back(sample.minor_radius <= 0.0 ? sample.major_radius
+ : sample.minor_radius);
+ float distance = 0;
+ if (time != first_time) {
+ distance = EuclideanDistance(
+ stroke.GetSampleAt(time - resample_period).point, sample.point);
+ }
+ features->push_back(distance);
+ features->push_back(sample.edge);
+ features->push_back(1.0); // existence.
+ }
+ const int padding = model_->config().max_sample_count - sample_count;
+ DCHECK_GE(padding, 0);
+
+ for (int i = 0; i < padding * kFeaturesPerSample; ++i) {
+ features->push_back(0.0);
+ }
+ // "fill proportion."
+ features->push_back(static_cast<float>(sample_count) /
+ model_->config().max_sample_count);
+ features->push_back(EuclideanDistance(stroke.samples().front().point,
+ stroke.samples().back().point));
+
+ // Start sequence number. 0 is min.
+ uint32_t samples_seen =
+ (stroke.Duration() / (*model_->config().resample_period)) + 1;
+ if (samples_seen < model_->config().max_sample_count) {
+ features->push_back(0);
+ } else {
+ features->push_back(samples_seen - model_->config().max_sample_count);
+ }
+}
+
void NeuralStylusPalmDetectionFilter::AppendFeaturesAsNeighbor(
const PalmFilterStroke& stroke,
float distance,
diff --git a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.h b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.h
index 9b5ecae..7e54c0a 100644
--- a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.h
+++ b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.h
@@ -81,6 +81,8 @@
std::vector<float> ExtractFeatures(int tracking_id) const;
void AppendFeatures(const PalmFilterStroke& stroke,
std::vector<float>* features) const;
+ void AppendResampledFeatures(const PalmFilterStroke& stroke,
+ std::vector<float>* features) const;
void AppendFeaturesAsNeighbor(const PalmFilterStroke& stroke,
float distance,
std::vector<float>* features) const;
diff --git a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_model.cc b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_model.cc
index c1d9dfea..dcf8ae5 100644
--- a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_model.cc
+++ b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_model.cc
@@ -4,6 +4,8 @@
#include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_model.h"
+#include "base/logging.h"
+
namespace ui {
NeuralStylusPalmDetectionFilterModelConfig::
@@ -15,4 +17,19 @@
NeuralStylusPalmDetectionFilterModelConfig::
~NeuralStylusPalmDetectionFilterModelConfig() = default;
+
+base::TimeDelta
+NeuralStylusPalmDetectionFilterModelConfig::GetEquivalentDuration(
+ uint32_t sample_count) const {
+ if (!resample_period) {
+ LOG(DFATAL) << __func__
+ << " should only be called if resampling is enabled";
+ return base::Microseconds(0);
+ }
+ if (sample_count <= 1) {
+ return base::Microseconds(0);
+ }
+ return (sample_count - 1) * (*resample_period);
+}
+
} // namespace ui
diff --git a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_model.h b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_model.h
index 76f8cb23..adcae3f 100644
--- a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_model.h
+++ b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_model.h
@@ -38,6 +38,10 @@
// Maximum sample count.
uint32_t max_sample_count = 0;
+ // Convert the provided 'sample_count' to an equivalent time duration.
+ // Should only be called when resampling is enabled.
+ base::TimeDelta GetEquivalentDuration(uint32_t sample_count) const;
+
// Minimum count of samples for a stroke to be considered as a neighbor.
uint32_t neighbor_min_sample_count = 0;
diff --git a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_unittest.cc b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_unittest.cc
index 5f6fb94..c9d75c2 100644
--- a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_unittest.cc
+++ b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_unittest.cc
@@ -26,7 +26,8 @@
const NeuralStylusPalmDetectionFilterModelConfig&());
};
-class NeuralStylusPalmDetectionFilterTest : public testing::Test {
+class NeuralStylusPalmDetectionFilterTest
+ : public testing::TestWithParam<float> {
public:
NeuralStylusPalmDetectionFilterTest() = default;
@@ -45,6 +46,11 @@
model_config_.heuristic_palm_touch_limit = 40;
model_config_.heuristic_palm_area_limit = 1000;
model_config_.max_dead_neighbor_time = base::Milliseconds(100);
+ const float resample_period = GetParam();
+ if (resample_period != 0.0) {
+ model_config_.resample_period = base::Milliseconds(resample_period);
+ sample_period_ = *model_config_.resample_period;
+ }
EXPECT_CALL(*model_, config())
.Times(testing::AnyNumber())
.WillRepeatedly(testing::ReturnRef(model_config_));
@@ -68,12 +74,18 @@
raw_ptr<MockNeuralModel> model_;
NeuralStylusPalmDetectionFilterModelConfig model_config_;
std::unique_ptr<PalmDetectionFilter> palm_detection_filter_;
+ base::TimeDelta sample_period_ = base::Milliseconds(8);
};
-class NeuralStylusPalmDetectionFilterDeathTest
- : public NeuralStylusPalmDetectionFilterTest {};
+INSTANTIATE_TEST_SUITE_P(ParametricFilterTest,
+ NeuralStylusPalmDetectionFilterTest,
+ ::testing::Values(0, 8.0),
+ [](const auto& paramInfo) {
+ return paramInfo.param != 0.0 ? "ResamplingEnabled"
+ : "ResamplingDisabled";
+ });
-TEST_F(NeuralStylusPalmDetectionFilterTest, EventDeviceSimpleTest) {
+TEST_P(NeuralStylusPalmDetectionFilterTest, EventDeviceSimpleTest) {
EventDeviceInfo devinfo;
std::vector<std::pair<DeviceCapabilities, bool>> devices = {
{kNocturneTouchScreen, true},
@@ -103,23 +115,25 @@
}
}
-TEST_F(NeuralStylusPalmDetectionFilterDeathTest, EventDeviceConstructionDeath) {
+TEST(NeuralStylusPalmDetectionFilterDeathTest, EventDeviceConstructionDeath) {
EventDeviceInfo bad_devinfo;
EXPECT_TRUE(CapabilitiesToDeviceInfo(kNocturneStylus, &bad_devinfo));
+ std::unique_ptr<NeuralStylusPalmDetectionFilterModel> model_(
+ new testing::StrictMock<MockNeuralModel>);
+ std::unique_ptr<SharedPalmDetectionFilterState> shared_palm_state =
+ std::make_unique<SharedPalmDetectionFilterState>();
EXPECT_DCHECK_DEATH({
- NeuralStylusPalmDetectionFilter f(
- bad_devinfo,
- std::unique_ptr<NeuralStylusPalmDetectionFilterModel>(model_),
- shared_palm_state.get());
+ NeuralStylusPalmDetectionFilter f(bad_devinfo, std::move(model_),
+ shared_palm_state.get());
});
}
-TEST_F(NeuralStylusPalmDetectionFilterTest, NameTest) {
+TEST_P(NeuralStylusPalmDetectionFilterTest, NameTest) {
EXPECT_EQ("NeuralStylusPalmDetectionFilter",
palm_detection_filter_->FilterNameForTesting());
}
-TEST_F(NeuralStylusPalmDetectionFilterTest, ShortTouchPalmAreaTest) {
+TEST_P(NeuralStylusPalmDetectionFilterTest, ShortTouchPalmAreaTest) {
std::bitset<kNumTouchEvdevSlots> actual_held, actual_cancelled,
expected_cancelled;
touch_[0].touching = true;
@@ -146,7 +160,7 @@
EXPECT_EQ(expected_cancelled, actual_cancelled);
}
-TEST_F(NeuralStylusPalmDetectionFilterTest, ShortTouchPalmSizeTest) {
+TEST_P(NeuralStylusPalmDetectionFilterTest, ShortTouchPalmSizeTest) {
std::bitset<kNumTouchEvdevSlots> actual_held, actual_cancelled;
touch_[0].touching = true;
touch_[0].tracking_id = 600;
@@ -164,7 +178,7 @@
touch_[0].was_touching = true;
touch_[0].touching = false;
touch_[0].tracking_id = -1;
- touch_time += base::Milliseconds(8.0f);
+ touch_time += sample_period_;
palm_detection_filter_->Filter(touch_, touch_time, &actual_held,
&actual_cancelled);
EXPECT_TRUE(actual_held.none());
@@ -182,7 +196,7 @@
touch_[0].was_touching = true;
touch_[0].touching = false;
touch_[0].tracking_id = -1;
- touch_time += base::Milliseconds(8.0f);
+ touch_time += sample_period_;
palm_detection_filter_->Filter(touch_, touch_time, &actual_held,
&actual_cancelled);
EXPECT_TRUE(actual_held.none());
@@ -191,7 +205,7 @@
EXPECT_TRUE(actual_cancelled.none());
}
-TEST_F(NeuralStylusPalmDetectionFilterTest, CallFilterTest) {
+TEST_P(NeuralStylusPalmDetectionFilterTest, CallFilterTest) {
// Set up 3 touches as touching:
// Touch 0 starts off and is sent twice
// Touch 1 and 2 are then added on: 2 is far away, 1 is nearby. We expect a
@@ -238,7 +252,7 @@
touch_[2].tracking_id = 502;
touch_[2].slot = 2;
- touch_time += base::Milliseconds(8.0f);
+ touch_time += sample_period_;
palm_detection_filter_->Filter(touch_, touch_time, &actual_held,
&actual_cancelled);
EXPECT_TRUE(actual_held.none());
@@ -263,7 +277,7 @@
Inference(testing::Pointwise(testing::FloatEq(), features)))
.Times(1)
.WillOnce(testing::Return(0.5));
- touch_time += base::Milliseconds(8.0f);
+ touch_time += sample_period_;
palm_detection_filter_->Filter(touch_, touch_time, &actual_held,
&actual_cancelled);
EXPECT_TRUE(actual_held.none());
@@ -285,7 +299,7 @@
Inference(testing::Pointwise(testing::FloatEq(), features)))
.Times(1)
.WillOnce(testing::Return(0.5));
- touch_time += base::Milliseconds(8.0f);
+ touch_time += sample_period_;
palm_detection_filter_->Filter(touch_, touch_time, &actual_held,
&actual_cancelled);
EXPECT_TRUE(actual_held.none());
@@ -293,7 +307,7 @@
EXPECT_EQ(actual_cancelled, expected_cancelled);
}
-TEST_F(NeuralStylusPalmDetectionFilterTest, CallFilterTestWithAdaptiveHold) {
+TEST_P(NeuralStylusPalmDetectionFilterTest, CallFilterTestWithAdaptiveHold) {
std::bitset<kNumTouchEvdevSlots> actual_held, actual_cancelled;
std::bitset<kNumTouchEvdevSlots> expected_held, expected_cancelled;
@@ -348,7 +362,7 @@
Inference(testing::Pointwise(testing::FloatEq(), features)))
.Times(1)
.WillOnce(testing::Return(0.5));
- touch_time += base::Milliseconds(8.0f);
+ touch_time += sample_period_;
palm_detection_filter_->Filter(touch_, touch_time, &actual_held,
&actual_cancelled);
// Slot 0 is held.
@@ -404,7 +418,7 @@
.Times(1)
.WillOnce(testing::Return(0.5));
- touch_time += base::Milliseconds(8.0f);
+ touch_time += sample_period_;
palm_detection_filter_->Filter(touch_, touch_time, &actual_held,
&actual_cancelled);
@@ -443,7 +457,7 @@
Inference(testing::Pointwise(testing::FloatEq(), features)))
.Times(1)
.WillOnce(testing::Return(0.5));
- touch_time += base::Milliseconds(8.0f);
+ touch_time += sample_period_;
palm_detection_filter_->Filter(touch_, touch_time, &actual_held,
&actual_cancelled);
@@ -462,7 +476,7 @@
EXPECT_EQ(actual_cancelled, expected_cancelled);
}
-TEST_F(NeuralStylusPalmDetectionFilterTest, InferenceOnceNotPalm) {
+TEST_P(NeuralStylusPalmDetectionFilterTest, InferenceOnceNotPalm) {
std::bitset<kNumTouchEvdevSlots> actual_held, actual_cancelled;
base::TimeTicks touch_time =
base::TimeTicks::UnixEpoch() + base::Milliseconds(10.0);
@@ -480,7 +494,7 @@
if (i != 0) {
touch_[0].was_touching = true;
}
- touch_time += base::Milliseconds(8.0f);
+ touch_time += sample_period_;
palm_detection_filter_->Filter(touch_, touch_time, &actual_held,
&actual_cancelled);
ASSERT_TRUE(actual_held.none()) << " Failed at " << i;
@@ -488,7 +502,7 @@
}
}
-TEST_F(NeuralStylusPalmDetectionFilterTest, InferenceOncePalm) {
+TEST_P(NeuralStylusPalmDetectionFilterTest, InferenceOncePalm) {
std::bitset<kNumTouchEvdevSlots> actual_held, actual_cancelled;
std::bitset<kNumTouchEvdevSlots> expected_cancelled;
base::TimeTicks touch_time =
@@ -512,7 +526,7 @@
if (i != 0) {
touch_[0].was_touching = true;
}
- touch_time += base::Milliseconds(8.0f);
+ touch_time += sample_period_;
palm_detection_filter_->Filter(touch_, touch_time, &actual_held,
&actual_cancelled);
ASSERT_EQ(original_finger_time,
@@ -532,7 +546,7 @@
}
}
-TEST_F(NeuralStylusPalmDetectionFilterTest, DelayShortFingerTouch) {
+TEST_P(NeuralStylusPalmDetectionFilterTest, DelayShortFingerTouch) {
std::bitset<kNumTouchEvdevSlots> actual_held, actual_cancelled;
std::bitset<kNumTouchEvdevSlots> expected_held, expected_cancelled;
model_config_.heuristic_delay_start_if_palm = true;
@@ -552,7 +566,7 @@
EXPECT_EQ(expected_cancelled, actual_cancelled);
}
-TEST_F(NeuralStylusPalmDetectionFilterTest, DelayShortPalmTouch) {
+TEST_P(NeuralStylusPalmDetectionFilterTest, DelayShortPalmTouch) {
std::bitset<kNumTouchEvdevSlots> actual_held, actual_cancelled;
std::bitset<kNumTouchEvdevSlots> expected_held, expected_cancelled;
model_config_.heuristic_delay_start_if_palm = true;
@@ -575,7 +589,7 @@
// Delay continues even afterwards, until inference time: then it's off.
for (uint32_t i = 1; i < model_config_.max_sample_count - 1; ++i) {
touch_[0].was_touching = true;
- touch_time += base::Milliseconds(10.0);
+ touch_time += sample_period_;
touch_[0].major = 15;
touch_[0].minor = 15;
touch_[0].x += 1;
@@ -589,7 +603,7 @@
EXPECT_CALL(*model_, Inference(testing::_))
.Times(1)
.WillOnce(testing::Return(-0.1));
- touch_time = base::TimeTicks::UnixEpoch() + base::Milliseconds(10.0);
+ touch_time += sample_period_;
palm_detection_filter_->Filter(touch_, touch_time, &actual_held,
&actual_cancelled);
expected_held.set(0, false);
diff --git a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_util.cc b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_util.cc
index 89ff01e..b0c7fc8d 100644
--- a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_util.cc
+++ b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_util.cc
@@ -4,6 +4,7 @@
#include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_util.h"
+#include <base/logging.h>
#include <algorithm>
namespace ui {
@@ -77,7 +78,7 @@
* not interpolated, the values are taken from the 'after' sample unless the
* requested time is very close to the 'before' sample.
*/
-PalmFilterSample getSampleAtTime(base::TimeTicks time,
+PalmFilterSample GetSampleAtTime(base::TimeTicks time,
const PalmFilterSample& before,
const PalmFilterSample& after) {
// Use the newest sample as the base, except when the requested time is very
@@ -150,16 +151,30 @@
void PalmFilterStroke::ProcessSample(const PalmFilterSample& sample) {
DCHECK_EQ(tracking_id_, sample.tracking_id);
- if (resample_period_.has_value()) {
- Resample(sample);
- return;
+ if (samples_seen_ == 0) {
+ first_sample_time_ = sample.time;
}
AddSample(sample);
- while (samples_.size() > max_sample_count_) {
- AddToUnscaledCentroid(-samples_.front().point.OffsetFromOrigin());
- samples_.pop_front();
+ if (resample_period_.has_value()) {
+ // Prune based on time
+ const base::TimeDelta max_duration =
+ (*resample_period_) * (max_sample_count_ - 1);
+ while (samples_.size() > 2 &&
+ samples_.back().time - samples_[1].time >= max_duration) {
+ // We can only discard the sample if after it's discarded, we still cover
+ // the entire range. If we don't, we need to keep this sample for
+ // calculating resampled values.
+ AddToUnscaledCentroid(-samples_.front().point.OffsetFromOrigin());
+ samples_.pop_front();
+ }
+ } else {
+ // Prune based on number of samples
+ while (samples_.size() > max_sample_count_) {
+ AddToUnscaledCentroid(-samples_.front().point.OffsetFromOrigin());
+ samples_.pop_front();
+ }
}
}
@@ -169,36 +184,6 @@
samples_seen_++;
}
-/**
- * When resampling is enabled, we don't store all samples. Only the resampled
- * values are stored into samples_. In addition, the last real event is stored
- * into last_sample_, which is used to calculate the resampled values.
- */
-void PalmFilterStroke::Resample(const PalmFilterSample& sample) {
- if (samples_seen_ == 0) {
- AddSample(sample);
- last_sample_ = sample;
- return;
- }
-
- // We already have a valid last sample here.
- DCHECK_LE(last_sample_.time, sample.time);
- // Generate resampled values
- base::TimeTicks next_sample_time = samples_.back().time + *resample_period_;
- while (next_sample_time <= sample.time) {
- AddSample(getSampleAtTime(next_sample_time, last_sample_, sample));
- next_sample_time = samples_.back().time + (*resample_period_);
- }
- last_sample_ = sample;
-
- // Prune the resampled collection
- while ((samples_.back().time - samples_.front().time) >=
- (*resample_period_) * max_sample_count_) {
- AddToUnscaledCentroid(-samples_.front().point.OffsetFromOrigin());
- samples_.pop_front();
- }
-}
-
void PalmFilterStroke::AddToUnscaledCentroid(const gfx::Vector2dF point) {
const gfx::Vector2dF corrected_point = point - unscaled_centroid_sum_error_;
const gfx::PointF new_unscaled_centroid =
@@ -223,6 +208,47 @@
return tracking_id_;
}
+base::TimeDelta PalmFilterStroke::Duration() const {
+ if (samples_.empty()) {
+ LOG(DFATAL) << "No samples available";
+ return base::Milliseconds(0);
+ }
+ return samples_.back().time - first_sample_time_;
+}
+
+base::TimeDelta PalmFilterStroke::PreviousDuration() const {
+ if (samples_.size() <= 1) {
+ LOG(DFATAL) << "Not enough samples";
+ return base::Milliseconds(0);
+ }
+ const PalmFilterSample& secondToLastSample = samples_.rbegin()[1];
+ return secondToLastSample.time - first_sample_time_;
+}
+
+bool PalmFilterStroke::LastSampleCrossed(base::TimeDelta duration) const {
+ if (samples_.size() <= 1) {
+ // If there's only 1 sample, stroke just started and Duration() is zero.
+ return false;
+ }
+ return PreviousDuration() < duration && duration <= Duration();
+}
+
+PalmFilterSample PalmFilterStroke::GetSampleAt(base::TimeTicks time) const {
+ size_t i = 0;
+ for (; i < samples_.size() && samples_[i].time < time; ++i) {
+ }
+
+ if (i < samples_.size() && !samples_.empty() && samples_[i].time == time) {
+ return samples_[i];
+ }
+ if (i == 0 || i == samples_.size()) {
+ LOG(DFATAL) << "Invalid index: " << i
+ << ", can't interpolate for time: " << time;
+ return {};
+ }
+ return GetSampleAtTime(time, samples_[i - 1], samples_[i]);
+}
+
uint64_t PalmFilterStroke::samples_seen() const {
return samples_seen_;
}
diff --git a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_util.h b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_util.h
index 72c34f7..ac72463 100644
--- a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_util.h
+++ b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_util.h
@@ -69,6 +69,25 @@
float BiggestSize() const;
// If no elements in stroke, returns 0.0;
float MaxMajorRadius() const;
+ /**
+ * Return the time duration of this stroke.
+ */
+ base::TimeDelta Duration() const;
+ /**
+ * Provide a (potentially resampled) sample at the requested time.
+ * Only interpolation is allowed.
+ * The requested time must be within the window at which the gesture occurred.
+ */
+ PalmFilterSample GetSampleAt(base::TimeTicks time) const;
+
+ /**
+ * Return true if the provided duration is between the duration of the
+ * previous sample and the current sample. In other words, if the addition of
+ * the last sample caused the total stroke duration to exceed the provided
+ * duration. Return false otherwise.
+ */
+ bool LastSampleCrossed(base::TimeDelta duration) const;
+
const std::deque<PalmFilterSample>& samples() const;
uint64_t samples_seen() const;
int tracking_id() const;
@@ -76,10 +95,8 @@
private:
void AddToUnscaledCentroid(const gfx::Vector2dF point);
void AddSample(const PalmFilterSample& sample);
- /**
- * Process the sample. Potentially store the resampled sample into samples_.
- */
- void Resample(const PalmFilterSample& sample);
+
+ base::TimeDelta PreviousDuration() const;
std::deque<PalmFilterSample> samples_;
const int tracking_id_;
@@ -92,13 +109,9 @@
* number of times 'AddSample' has been called.
*/
uint64_t samples_seen_ = 0;
- /**
- * The last sample seen by the model. Used when resampling is enabled in order
- * to compute the resampled value.
- */
- PalmFilterSample last_sample_;
const uint64_t max_sample_count_;
+ base::TimeTicks first_sample_time_;
const absl::optional<base::TimeDelta> resample_period_;
gfx::PointF unscaled_centroid_ = gfx::PointF(0., 0.);
diff --git a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_util_unittest.cc b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_util_unittest.cc
index 4d7a82e..6116736 100644
--- a/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_util_unittest.cc
+++ b/ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_util_unittest.cc
@@ -293,23 +293,24 @@
ASSERT_THAT(stroke.samples(), ElementsAre(SampleTime(down_time)));
ASSERT_EQ(1u, stroke.samples_seen());
- // Add second sample at time = T + 2ms. It's not yet time for the new frame,
- // so no new sample should be generated.
+ // Add second sample at time = T + 4ms. All samples are stored, even if it's
+ // before the next resample time.
base::TimeTicks time = down_time + base::Milliseconds(4);
sample = CreatePalmFilterSample(touch_, time, model_config_, device_info);
stroke.ProcessSample(sample);
- ASSERT_THAT(stroke.samples(), ElementsAre(SampleTime(down_time)));
- ASSERT_EQ(1u, stroke.samples_seen());
+ ASSERT_THAT(stroke.samples(),
+ ElementsAre(SampleTime(down_time), SampleTime(time)));
+ ASSERT_EQ(2u, stroke.samples_seen());
- // Add third sample at time = T + 10ms. An event at time = T + 8ms should be
- // generated.
+ // Add third sample at time = T + 10ms.
time = down_time + base::Milliseconds(10);
sample = CreatePalmFilterSample(touch_, time, model_config_, device_info);
stroke.ProcessSample(sample);
ASSERT_THAT(stroke.samples(),
ElementsAre(SampleTime(down_time),
- SampleTime(down_time + base::Milliseconds(8))));
- ASSERT_EQ(2u, stroke.samples_seen());
+ SampleTime(down_time + base::Milliseconds(4)),
+ SampleTime(down_time + base::Milliseconds(10))));
+ ASSERT_EQ(3u, stroke.samples_seen());
}
TEST(PalmFilterStrokeTest, ResamplingTest) {
@@ -346,7 +347,7 @@
CreatePalmFilterSample(touch_, time, model_config_, device_info);
stroke.ProcessSample(sample2);
// The samples should remain unchanged
- ASSERT_THAT(stroke.samples(), ElementsAre(sample1));
+ ASSERT_THAT(stroke.samples(), ElementsAre(sample1, sample2));
// Add third sample at time = T + 12ms. A resampled event at time = T + 8ms
// should be generated.
@@ -358,16 +359,16 @@
PalmFilterSample sample3 =
CreatePalmFilterSample(touch_, time, model_config_, device_info);
stroke.ProcessSample(sample3);
- ASSERT_THAT(
- stroke.samples(),
- ElementsAre(sample1, SampleTime(down_time + base::Milliseconds(8))));
-
- EXPECT_EQ(150, stroke.samples().back().point.x());
- EXPECT_EQ(22, stroke.samples().back().point.y());
- EXPECT_EQ(14, stroke.samples().back().major_radius);
- EXPECT_EQ(13, stroke.samples().back().minor_radius);
+ ASSERT_THAT(stroke.samples(), ElementsAre(sample1, sample2, sample3));
}
+/**
+ * There should always be at least (max_sample_count - 1) * resample_period
+ * worth of samples. However, that's not sufficient. In the cases where the gap
+ * between samples is very large (larger than the sample horizon), there needs
+ * to be another sample in order to calculate resampled values throughout the
+ * entire range.
+ */
TEST(PalmFilterStrokeTest, MultipleResampledValues) {
NeuralStylusPalmDetectionFilterModelConfig model_config_;
model_config_.max_sample_count = 3;
@@ -391,8 +392,7 @@
// First sample should go in as is
ASSERT_THAT(stroke.samples(), ElementsAre(sample1));
- // Add second sample at time = T + 20ms. Two resampled values should be
- // generated: 1) at time = T+8ms 2) at time = T+16ms
+ // Add second sample at time = T + 20ms.
base::TimeTicks time = down_time + base::Milliseconds(20);
touch_.x = 20;
touch_.y = 30;
@@ -401,22 +401,23 @@
PalmFilterSample sample2 =
CreatePalmFilterSample(touch_, time, model_config_, device_info);
stroke.ProcessSample(sample2);
- ASSERT_THAT(stroke.samples(),
- ElementsAre(SampleTime(down_time),
- SampleTime(down_time + base::Milliseconds(8)),
- SampleTime(down_time + base::Milliseconds(16))));
- // First sample : time = T + 8ms
- EXPECT_EQ(8, stroke.samples()[1].point.x());
- EXPECT_EQ(18, stroke.samples()[1].point.y());
- EXPECT_EQ(220, stroke.samples()[1].major_radius);
- EXPECT_EQ(120, stroke.samples()[1].minor_radius);
+ ASSERT_THAT(stroke.samples(), ElementsAre(sample1, sample2));
- // Second sample : time = T + 16ms
- EXPECT_EQ(16, stroke.samples().back().point.x());
- EXPECT_EQ(26, stroke.samples().back().point.y());
- EXPECT_EQ(220, stroke.samples().back().major_radius);
- EXPECT_EQ(120, stroke.samples().back().minor_radius);
+ // Verify resampled sample : time = T + 8ms
+ PalmFilterSample resampled_sample =
+ stroke.GetSampleAt(down_time + base::Milliseconds(8));
+ EXPECT_EQ(8, resampled_sample.point.x());
+ EXPECT_EQ(18, resampled_sample.point.y());
+ EXPECT_EQ(220, resampled_sample.major_radius);
+ EXPECT_EQ(120, resampled_sample.minor_radius);
+
+ // Verify resampled sample : time = T + 16ms
+ resampled_sample = stroke.GetSampleAt(down_time + base::Milliseconds(16));
+ EXPECT_EQ(16, resampled_sample.point.x());
+ EXPECT_EQ(26, resampled_sample.point.y());
+ EXPECT_EQ(220, resampled_sample.major_radius);
+ EXPECT_EQ(120, resampled_sample.minor_radius);
}
} // namespace ui