blob: 8c69cf5d152a0925240781096a537c24dcd93ecc [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.task.vision.segmenter;
import static org.tensorflow.lite.DataType.FLOAT32;
import static org.tensorflow.lite.DataType.UINT8;
import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument;
import static org.tensorflow.lite.support.image.ColorSpaceType.GRAYSCALE;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
/**
* Output mask type. This allows specifying the type of post-processing to perform on the raw model
* results.
*/
public enum OutputType {
/**
* Gives a single output mask where each pixel represents the class which the pixel in the
* original image was predicted to belong to.
*/
CATEGORY_MASK(0) {
/**
* {@inheritDoc}
*
* @throws IllegalArgumentException if more than one {@link TensorImage} are provided, or if
* the
* color space of the {@link TensorImage} is not {@link ColorSpaceType#GRAYSCALE}
*/
@Override
void assertMasksMatchColoredLabels(
List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
checkArgument(masks.size() == 1,
"CATRGORY_MASK only allows one TensorImage in the list, providing "
+ masks.size());
TensorImage mask = masks.get(0);
checkArgument(mask.getColorSpaceType() == GRAYSCALE,
"CATRGORY_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing "
+ mask.getColorSpaceType());
}
/**
* {@inheritDoc}
*
* @throws IllegalArgumentException if more than one {@link ByteBuffer} are provided in the
* list
*/
@Override
List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) {
checkArgument(buffers.size() == 1,
"CATRGORY_MASK only allows one mask in the buffer list, providing "
+ buffers.size());
List<TensorImage> masks = new ArrayList<>();
TensorBuffer tensorBuffer = TensorBuffer.createDynamic(UINT8);
tensorBuffer.loadBuffer(buffers.get(0), maskShape);
TensorImage tensorImage = new TensorImage(UINT8);
tensorImage.load(tensorBuffer, GRAYSCALE);
masks.add(tensorImage);
return masks;
}
},
/**
* Gives a list of output masks where, for each mask, each pixel represents the prediction
* confidence, usually in the [0, 1] range.
*/
CONFIDENCE_MASK(1) {
/**
* {@inheritDoc}
*
* @throws IllegalArgumentException if more the size of the masks list does not match the
* size
* of the coloredlabels list, or if the color space type of the any mask is not {@link
* ColorSpaceType#GRAYSCALE}
*/
@Override
void assertMasksMatchColoredLabels(
List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
checkArgument(masks.size() == coloredLabels.size(),
String.format(
"When using CONFIDENCE_MASK, the number of masks (%d) should match the number of"
+ " coloredLabels (%d).",
masks.size(), coloredLabels.size()));
for (TensorImage mask : masks) {
checkArgument(mask.getColorSpaceType() == GRAYSCALE,
"CONFIDENCE_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing "
+ mask.getColorSpaceType());
}
}
@Override
List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) {
List<TensorImage> masks = new ArrayList<>();
for (ByteBuffer buffer : buffers) {
TensorBuffer tensorBuffer = TensorBuffer.createDynamic(FLOAT32);
tensorBuffer.loadBuffer(buffer, maskShape);
TensorImage tensorImage = new TensorImage(FLOAT32);
tensorImage.load(tensorBuffer, GRAYSCALE);
masks.add(tensorImage);
}
return masks;
}
};
public int getValue() {
return value;
}
/**
* Verifies that the given list of masks matches the list of colored labels.
*
* @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the
* output type
*/
abstract void assertMasksMatchColoredLabels(
List<TensorImage> masks, List<ColoredLabel> coloredLabels);
/** Creates the masks in {@link TensorImage} based on the data in {@link ByteBuffer}. */
abstract List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape);
private final int value;
private OutputType(int value) {
this.value = value;
}
}