| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for losses.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import numpy as np |
| |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import errors_impl |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import random_seed |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import init_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import random_ops |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops import variables |
| from tensorflow.python.ops.losses import losses |
| from tensorflow.python.ops.losses import util |
| from tensorflow.python.platform import test |
| from tensorflow.python.training import momentum as momentum_lib |
| |
| |
| @test_util.run_deprecated_v1 |
| class AbsoluteDifferenceLossTest(test.TestCase): |
| |
| def setUp(self): |
| super(AbsoluteDifferenceLossTest, self).setUp() |
| self._predictions = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3)) |
| self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) |
| |
| def testValueErrorThrownWhenWeightIsNone(self): |
| with self.cached_session(): |
| with self.assertRaises(ValueError): |
| losses.absolute_difference( |
| self._predictions, self._predictions, weights=None) |
| |
| def testAllCorrectNoLossWeight(self): |
| loss = losses.absolute_difference(self._predictions, self._predictions) |
| with self.cached_session(): |
| self.assertAlmostEqual(0.0, self.evaluate(loss), 3) |
| |
| def testNonZeroLoss(self): |
| loss = losses.absolute_difference(self._labels, self._predictions) |
| with self.cached_session(): |
| self.assertAlmostEqual(5.5, self.evaluate(loss), 3) |
| |
| def testNonZeroLossWithPythonScalarWeight(self): |
| weights = 2.3 |
| loss = losses.absolute_difference(self._labels, self._predictions, weights) |
| with self.cached_session(): |
| self.assertAlmostEqual(5.5 * weights, self.evaluate(loss), 3) |
| |
| def testNonZeroLossWithScalarTensorWeight(self): |
| weights = 2.3 |
| loss = losses.absolute_difference(self._labels, self._predictions, |
| constant_op.constant(weights)) |
| with self.cached_session(): |
| self.assertAlmostEqual(5.5 * weights, self.evaluate(loss), 3) |
| |
| def testNonZeroLossWithOneDimBatchSpecificWeights(self): |
| weights = constant_op.constant((1.2, 0.0), shape=(2, 1)) |
| loss = losses.absolute_difference(self._labels, self._predictions, weights) |
| with self.cached_session(): |
| self.assertAlmostEqual(5.6, self.evaluate(loss), 3) |
| |
| def testNonZeroLossWithTwoDimBatchSpecificWeights(self): |
| weights = constant_op.constant([1.2, 0.0], shape=[2, 1]) |
| loss = losses.absolute_difference(self._labels, self._predictions, weights) |
| with self.cached_session(): |
| self.assertAlmostEqual(5.6, self.evaluate(loss), 3) |
| |
| def testNonZeroLossWithSampleSpecificWeights(self): |
| weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3]) |
| loss = losses.absolute_difference(self._labels, self._predictions, weights) |
| with self.cached_session(): |
| self.assertAlmostEqual(16.6, self.evaluate(loss), 3) |
| |
| def testNonZeroLossWithSampleSpecificWeightsMostZero(self): |
| weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3]) |
| loss = losses.absolute_difference(self._labels, self._predictions, weights) |
| with self.cached_session(): |
| self.assertAlmostEqual(6.0, self.evaluate(loss), 3) |
| |
| def testLossWithSampleSpecificWeightsAllZero(self): |
| weights = array_ops.zeros((2, 3)) |
| loss = losses.absolute_difference(self._labels, self._predictions, weights) |
| with self.cached_session(): |
| self.assertAlmostEqual(0.0, self.evaluate(loss), 3) |
| |
| @test_util.assert_no_new_pyobjects_executing_eagerly |
| def testEagerNoMemoryLeaked(self): |
| # This is a somewhat convoluted way of testing that nothing gets added to |
| # a global collection. |
| predictions = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3)) |
| labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) |
| losses.absolute_difference(labels, predictions) |
| |
| |
| class SoftmaxCrossEntropyLossTest(test.TestCase): |
| |
| def testNoneWeightRaisesValueError(self): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) |
| with self.cached_session(): |
| with self.assertRaises(ValueError): |
| losses.softmax_cross_entropy(labels, logits, weights=None) |
| |
| @test_util.run_deprecated_v1 |
| def testAllCorrect(self): |
| with self.cached_session(): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) |
| loss = losses.softmax_cross_entropy(labels, logits) |
| self.assertEquals('softmax_cross_entropy_loss/value', loss.op.name) |
| self.assertAlmostEqual(loss.eval(), 0.0, 3) |
| |
| @test_util.run_deprecated_v1 |
| def testAllWrong(self): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) |
| |
| with self.cached_session(): |
| loss = losses.softmax_cross_entropy(labels, logits) |
| self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value') |
| self.assertAlmostEqual(loss.eval(), 10.0, 3) |
| |
| @test_util.run_deprecated_v1 |
| def testNonZeroLossWithPythonScalarWeight(self): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) |
| weights = 2.3 |
| with self.cached_session(): |
| loss = losses.softmax_cross_entropy(labels, logits, weights) |
| self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3) |
| |
| @test_util.run_deprecated_v1 |
| def testNonZeroLossWithScalarTensorWeight(self): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) |
| weights = 2.3 |
| with self.cached_session(): |
| loss = losses.softmax_cross_entropy(labels, logits, |
| constant_op.constant(weights)) |
| self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3) |
| |
| def testNonZeroLossWithOneDimBatchSpecificWeights(self): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) |
| weights = constant_op.constant((1.2, 3.4, 5.6)) |
| with self.cached_session(): |
| loss = losses.softmax_cross_entropy(labels, logits, weights) |
| self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, |
| self.evaluate(loss), 3) |
| |
| def testAllWrongAllWeightsMissing(self): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) |
| weights = constant_op.constant([0, 0, 0], shape=[3]) |
| with self.cached_session(): |
| loss = losses.softmax_cross_entropy(labels, logits, weights) |
| self.assertAlmostEqual(0.0, self.evaluate(loss), 3) |
| |
| def testSomeWeightsMissing(self): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) |
| weights = constant_op.constant([1.2, 0, 0], shape=[3]) |
| with self.cached_session(): |
| loss = losses.softmax_cross_entropy(labels, logits, weights) |
| self.assertAlmostEqual(12.0, self.evaluate(loss), 3) |
| |
| def testSoftmaxWithMeasurementSpecificWeightsRaisesException(self): |
| with self.cached_session(): |
| logits = constant_op.constant([[100.0, -100.0, -100.0], |
| [-100.0, 100.0, -100.0], |
| [-100.0, -100.0, 100.0]]) |
| labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) |
| weights = constant_op.constant([[3, 4, 5], [2, 6, 0], [8, 0, 1]]) |
| |
| with self.assertRaises(ValueError): |
| losses.softmax_cross_entropy(labels, logits, weights=weights).eval() |
| |
| @test_util.run_deprecated_v1 |
| def testSoftmaxLabelSmoothing(self): |
| with self.cached_session(): |
| # Softmax Cross Entropy Loss is: |
| # -\sum_i p_i \log q_i |
| # where for a softmax activation |
| # \log q_i = x_i - \log \sum_j \exp x_j |
| # = x_i - x_max - \log \sum_j \exp (x_j - x_max) |
| # For our activations, [100, -100, -100] the log partion function becomes |
| # \log ( exp(0) + exp(-200) + exp(-200) ) = 0 |
| # so our log softmaxes become: [0, -200, -200] |
| # so our cross entropy loss is: |
| # -(1 - L + L/n) * 0 + 400 * L/n = 400 L/n |
| logits = constant_op.constant([[100.0, -100.0, -100.0]]) |
| labels = constant_op.constant([[1, 0, 0]]) |
| label_smoothing = 0.1 |
| loss = losses.softmax_cross_entropy( |
| labels, logits, label_smoothing=label_smoothing) |
| self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value') |
| expected_value = 400.0 * label_smoothing / 3.0 |
| self.assertAlmostEqual(loss.eval(), expected_value, 3) |
| |
| |
| class SparseSoftmaxCrossEntropyLossTest(test.TestCase): |
| |
| def testNoneWeightRaisesValueError(self): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[0], [1], [2]]) |
| with self.cached_session(): |
| with self.assertRaises(ValueError): |
| losses.sparse_softmax_cross_entropy(labels, logits, weights=None) |
| |
| @test_util.run_deprecated_v1 |
| def testAllCorrectInt32Labels(self): |
| with self.cached_session(): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int32) |
| loss = losses.sparse_softmax_cross_entropy(labels, logits) |
| self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') |
| self.assertAlmostEqual(loss.eval(), 0.0, 3) |
| |
| @test_util.assert_no_new_pyobjects_executing_eagerly |
| def testEagerNoMemoryLeaked(self): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int32) |
| losses.sparse_softmax_cross_entropy(labels, logits) |
| |
| @test_util.run_deprecated_v1 |
| def testAllCorrectInt64Labels(self): |
| with self.cached_session(): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int64) |
| loss = losses.sparse_softmax_cross_entropy(labels, logits) |
| self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') |
| self.assertAlmostEqual(loss.eval(), 0.0, 3) |
| |
| @test_util.run_deprecated_v1 |
| def testAllCorrectNonColumnLabels(self): |
| with self.cached_session(): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([0, 1, 2]) |
| loss = losses.sparse_softmax_cross_entropy(labels, logits) |
| self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') |
| self.assertAlmostEqual(loss.eval(), 0.0, 3) |
| |
| @test_util.run_deprecated_v1 |
| def testAllWrongInt32Labels(self): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int32) |
| |
| with self.cached_session(): |
| loss = losses.sparse_softmax_cross_entropy(labels, logits) |
| self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') |
| self.assertAlmostEqual(loss.eval(), 10.0, 3) |
| |
| @test_util.run_deprecated_v1 |
| def testAllWrongInt64Labels(self): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int64) |
| |
| with self.cached_session(): |
| loss = losses.sparse_softmax_cross_entropy(labels, logits) |
| self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') |
| self.assertAlmostEqual(loss.eval(), 10.0, 3) |
| |
| @test_util.run_deprecated_v1 |
| def testAllWrongNonColumnLabels(self): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([2, 0, 1]) |
| |
| with self.cached_session(): |
| loss = losses.sparse_softmax_cross_entropy(labels, logits) |
| self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') |
| self.assertAlmostEqual(loss.eval(), 10.0, 3) |
| |
| @test_util.run_deprecated_v1 |
| def testNonZeroLossWithPythonScalarWeight(self): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[2], [0], [1]]) |
| weights = 2.3 |
| with self.cached_session(): |
| loss = losses.sparse_softmax_cross_entropy(labels, logits, weights) |
| self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3) |
| |
| @test_util.run_deprecated_v1 |
| def testNonZeroLossWithScalarTensorWeight(self): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[2], [0], [1]]) |
| weights = 2.3 |
| with self.cached_session(): |
| loss = losses.sparse_softmax_cross_entropy(labels, logits, |
| constant_op.constant(weights)) |
| self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3) |
| |
| def testNonZeroLossWith1DTensorWeight(self): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[2], [0], [1]]) |
| weights = 2.3 |
| with self.cached_session(): |
| loss = losses.sparse_softmax_cross_entropy( |
| labels, logits, constant_op.constant((weights,))) |
| self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3) |
| |
| @test_util.run_deprecated_v1 |
| def testNonZeroLossWithPlaceholderForWeights(self): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], |
| [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[2], [0], [1]]) |
| weights = array_ops.placeholder(dtypes.float32) |
| with self.cached_session() as sess: |
| loss = losses.sparse_softmax_cross_entropy(labels, logits, weights) |
| loss_val = sess.run(loss, |
| feed_dict={weights: ((1.2,), (3.4,), (5.6,))}) |
| self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss_val, 3) |
| |
| @test_util.run_deprecated_v1 |
| def testUnknownShapePlaceholderForLogitsLabelsButScalarWeights(self): |
| logits = array_ops.placeholder(dtypes.float32) |
| labels = array_ops.placeholder(dtypes.int32) |
| weights = 1.0 |
| with self.cached_session() as sess: |
| loss = losses.sparse_softmax_cross_entropy(labels, logits, weights) |
| loss_val = sess.run(loss, |
| feed_dict={ |
| logits: [[10.0, 0.0, 0.0], |
| [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]], |
| labels: [[2], [0], [1]], |
| }) |
| self.assertAlmostEqual((1.0 + 1.0 + 1.0) * 10.0 / 3.0, loss_val, 3) |
| |
| @test_util.run_deprecated_v1 |
| def testNonZeroLossWithPlaceholderForLogitsLabelsAndWeights(self): |
| logits = array_ops.placeholder(dtypes.float32, shape=(None, 3)) |
| labels = array_ops.placeholder(dtypes.int32, shape=(None, 1)) |
| weights = array_ops.placeholder(dtypes.float32) |
| with self.cached_session() as sess: |
| loss = losses.sparse_softmax_cross_entropy(labels, logits, weights) |
| loss_val = sess.run(loss, |
| feed_dict={ |
| logits: [[10.0, 0.0, 0.0], |
| [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]], |
| labels: [[2], [0], [1]], |
| weights: ((1.2,), (3.4,), (5.6,)), |
| }) |
| self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss_val, 3) |
| |
| def testNonZeroLossWithOneDimBatchSpecificWeights(self): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[2], [0], [1]]) |
| weights = constant_op.constant([1.2, 3.4, 5.6], shape=(3, 1)) |
| with self.cached_session(): |
| loss = losses.sparse_softmax_cross_entropy(labels, logits, weights) |
| self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, |
| self.evaluate(loss), 3) |
| |
| def testNonZeroLossWithColumnWeights(self): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[2], [0], [1]]) |
| weights = constant_op.constant([[1.2], [3.4], [5.6]]) |
| with self.cached_session(): |
| loss = losses.sparse_softmax_cross_entropy(labels, logits, weights) |
| self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, |
| self.evaluate(loss), 3) |
| |
| def testAllWrongAllWeightsMissing(self): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[2], [0], [1]]) |
| weights = constant_op.constant([0, 0, 0], shape=(3, 1)) |
| with self.cached_session(): |
| loss = losses.sparse_softmax_cross_entropy(labels, logits, weights) |
| self.assertAlmostEqual(0.0, self.evaluate(loss), 3) |
| |
| def testSomeWeightsMissing(self): |
| logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], |
| [0.0, 0.0, 10.0]]) |
| labels = constant_op.constant([[2], [0], [1]]) |
| weights = constant_op.constant([1.2, 0, 0], shape=(3, 1)) |
| with self.cached_session(): |
| loss = losses.sparse_softmax_cross_entropy(labels, logits, weights) |
| self.assertAlmostEqual(12.0, self.evaluate(loss), 3) |
| |
| @test_util.run_deprecated_v1 |
| def testMeasurementSpecificWeightsRaisesException(self): |
| with self.cached_session(): |
| logits = constant_op.constant([[100.0, -100.0, -100.0], |
| [-100.0, 100.0, -100.0], |
| [-100.0, -100.0, 100.0]]) |
| labels = constant_op.constant([[0], [1], [2]]) |
| weights = constant_op.constant([[3, 4, 5], [2, 6, 0], [8, 0, 1]]) |
| |
| with self.assertRaises(ValueError): |
| losses.sparse_softmax_cross_entropy( |
| labels, logits, weights=weights).eval() |
| |
| def testInconsistentWeightSizeRaisesException(self): |
| """The weight tensor has incorrect number of elements.""" |
| with self.cached_session(): |
| logits = constant_op.constant([[100.0, -100.0, -100.0], |
| [-100.0, 100.0, -100.0], |
| [-100.0, -100.0, 100.0]]) |
| labels = constant_op.constant([[0], [1], [2]]) |
| weights = constant_op.constant([1.2, 3.4, 5.6, 7.8]) |
| |
| with self.assertRaises(ValueError): |
| losses.sparse_softmax_cross_entropy( |
| labels, logits, weights=weights).eval() |
| |
| def testInconsistentLabelSizeRaisesException(self): |
| """The label tensor has incorrect number of elements.""" |
| with self.cached_session(): |
| logits = constant_op.constant([[100.0, -100.0, -100.0], |
| [-100.0, 100.0, -100.0], |
| [-100.0, -100.0, 100.0]]) |
| labels = constant_op.constant([[0], [1], [2], [3]]) |
| weights = constant_op.constant([1.2, 3.4, 5.6]) |
| |
| with self.assertRaises(ValueError): |
| losses.sparse_softmax_cross_entropy( |
| labels, logits, weights=weights).eval() |
| |
| @test_util.run_deprecated_v1 |
| def testInconsistentWeightShapeRaisesException(self): |
| """The weight tensor has incorrect shape.""" |
| with self.cached_session(): |
| logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0], |
| [-100.0, 100.0, -100.0, -100.0], |
| [-100.0, -100.0, 100.0, -100.0], |
| [-100.0, -100.0, -100.0, 100.0]]) |
| labels = constant_op.constant([[0], [1], [2], [3]]) |
| weights = constant_op.constant([[1.2, 3.4], [5.6, 7.8]]) |
| |
| with self.assertRaises(ValueError): |
| losses.sparse_softmax_cross_entropy( |
| labels, logits, weights=weights).eval() |
| |
| @test_util.run_deprecated_v1 |
| def testInconsistentLabelShapeRaisesException(self): |
| """The label tensor has incorrect shape.""" |
| with self.cached_session(): |
| logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0], |
| [-100.0, 100.0, -100.0, -100.0], |
| [-100.0, -100.0, 100.0, -100.0], |
| [-100.0, -100.0, -100.0, 100.0]]) |
| labels = constant_op.constant([[0, 1], [2, 3]]) |
| weights = constant_op.constant(1.2) |
| |
| with self.assertRaisesRegexp(ValueError, 'dimension'): |
| losses.sparse_softmax_cross_entropy( |
| labels, logits, weights=weights).eval() |
| |
| |
| class SigmoidCrossEntropyLossTest(test.TestCase): |
| |
| @test_util.run_deprecated_v1 |
| def testAllCorrectSigmoid(self): |
| with self.cached_session(): |
| logits = constant_op.constant([[100.0, -100.0, -100.0], |
| [-100.0, 100.0, -100.0], |
| [-100.0, -100.0, 100.0]]) |
| labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) |
| loss = losses.sigmoid_cross_entropy(labels, logits) |
| self.assertEquals(logits.dtype, loss.dtype) |
| self.assertEquals('sigmoid_cross_entropy_loss/value', loss.op.name) |
| self.assertAlmostEqual(0.0, self.evaluate(loss), 3) |
| |
| @test_util.run_deprecated_v1 |
| def testLossWithSingleDimPlaceholderForLogitsAndWeights1(self): |
| logits = array_ops.placeholder(dtypes.float32, shape=(None, 1)) |
| labels = array_ops.placeholder(dtypes.float32, shape=(None, 1)) |
| weights = array_ops.ones_like(logits, dtype=dtypes.float32) |
| |
| loss = losses.sigmoid_cross_entropy(labels, logits, weights) |
| self.assertEquals(logits.dtype, loss.dtype) |
| |
| with self.cached_session() as sess: |
| loss = sess.run(loss, |
| feed_dict={ |
| logits: np.ones((32, 1)), |
| labels: np.ones((32, 1)), |
| }) |
| self.assertAlmostEqual(0.313, loss, 3) |
| |
| @test_util.run_deprecated_v1 |
| def testLossWithSingleDimPlaceholderForLogitsAndWeights2(self): |
| logits = array_ops.placeholder(dtypes.float32, shape=(None, 2)) |
| labels = array_ops.placeholder(dtypes.float32, shape=(None, 2)) |
| weights = array_ops.ones_like(logits, dtype=dtypes.float32) |
| |
| loss = losses.sigmoid_cross_entropy(labels, logits, weights) |
| self.assertEquals(logits.dtype, loss.dtype) |
| |
| with self.cached_session() as sess: |
| loss = sess.run(loss, |
| feed_dict={ |
| logits: np.ones((32, 2)), |
| labels: np.ones((32, 2)), |
| }) |
| self.assertAlmostEqual(0.313, loss, 3) |
| |
| @test_util.run_deprecated_v1 |
| def testAllWrongSigmoid(self): |
| with self.cached_session(): |
| logits = constant_op.constant([[100.0, -100.0, -100.0], |
| [-100.0, 100.0, -100.0], |
| [-100.0, -100.0, 100.0]]) |
| labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) |
| loss = losses.sigmoid_cross_entropy(labels, logits) |
| self.assertEquals(logits.dtype, loss.dtype) |
| self.assertEquals('sigmoid_cross_entropy_loss/value', loss.op.name) |
| self.assertAlmostEqual(loss.eval(), 600.0 / 9.0, 3) |
| |
| @test_util.run_deprecated_v1 |
| def testAllWrongSigmoidWithMeasurementSpecificWeights(self): |
| with self.cached_session(): |
| logits = constant_op.constant([[100.0, -100.0, -100.0], |
| [-100.0, 100.0, -100.0], |
| [-100.0, -100.0, 100.0]]) |
| labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) |
| weights = constant_op.constant([[3, 4, 5], [2, 6, 0], [8, 0, 1]]) |
| loss = losses.sigmoid_cross_entropy(labels, logits, weights) |
| self.assertEquals(logits.dtype, loss.dtype) |
| self.assertEquals('sigmoid_cross_entropy_loss/value', loss.op.name) |
| self.assertAlmostEqual(1700.0 / 7.0, self.evaluate(loss), 3) |
| |
| @test_util.run_deprecated_v1 |
| def testMultiCorrectSigmoid(self): |
| logits = constant_op.constant([[100.0, -100.0, 100.0], |
| [100.0, 100.0, -100.0], |
| [-100.0, 100.0, 100.0]]) |
| labels = constant_op.constant([[1, 0, 1], [1, 1, 0], [0, 1, 1]]) |
| loss = losses.sigmoid_cross_entropy(labels, logits) |
| self.assertEquals(logits.dtype, loss.dtype) |
| self.assertEquals('sigmoid_cross_entropy_loss/value', loss.op.name) |
| |
| with self.cached_session(): |
| self.assertAlmostEqual(0.0, self.evaluate(loss), 3) |
| |
| def testSigmoidFloat64(self): |
| logits = constant_op.constant(( |
| (100.0, -100.0, 100.0), |
| (100.0, -100.0, 100.0), |
| (100.0, 100.0, -100.0) |
| ), dtype=dtypes.float64) |
| labels = constant_op.constant(( |
| (1, 0, 1), (1, 1, 0), (0, 1, 1) |
| ), dtype=dtypes.int64) |
| loss = losses.sigmoid_cross_entropy(labels, logits) |
| self.assertEquals(logits.dtype, loss.dtype) |
| |
| with self.cached_session(): |
| self.assertAlmostEqual(44.444, self.evaluate(loss), 3) |
| |
| def testSigmoidNoReduction(self): |
| logits = constant_op.constant(( |
| (100.0, -100.0, 100.0), |
| (100.0, -100.0, 100.0), |
| (100.0, 100.0, -100.0))) |
| labels = constant_op.constant(((1, 0, 1), (1, 1, 0), (0, 1, 1))) |
| loss = losses.sigmoid_cross_entropy( |
| labels, logits, reduction=losses.Reduction.NONE) |
| self.assertEquals(logits.dtype, loss.dtype) |
| |
| with self.cached_session(): |
| self.assertAllClose(((0., 0., 0.), (0., 100., 100.), (100., 0., 100.)), |
| self.evaluate(loss), 3) |
| |
| @test_util.run_deprecated_v1 |
| def testSigmoidLabelSmoothingCorrect(self): |
| with self.cached_session(): |
| logits = constant_op.constant([[100.0, -100.0, -100.0]]) |
| labels = constant_op.constant([[1, 0, 1]]) |
| # Sigmoid cross entropy loss is: |
| # max(x,0) - x*z + log(1 + exp(-abs(x))) |
| # The new labels are: |
| # z' = z * (1 - L) + 0.5 L |
| # 1 -> 1 - 0.5 L |
| # 0 -> 0.5 L |
| # here we expect: |
| # 1/3 * (100 - 100 * (1 - 0.5 L) + 0 |
| # + 0 + 100 * (0.5 L) + 0 |
| # + 0 + 100 * (1 - 0.5 L) + 0) |
| # = 1/3 * (100 + 50 L) |
| label_smoothing = 0.1 |
| loss = losses.sigmoid_cross_entropy( |
| labels, logits, label_smoothing=label_smoothing) |
| self.assertEquals(logits.dtype, loss.dtype) |
| self.assertEquals('sigmoid_cross_entropy_loss/value', loss.op.name) |
| expected_value = (100.0 + 50.0 * label_smoothing) / 3.0 |
| self.assertAlmostEqual(loss.eval(), expected_value, 3) |
| |
| @test_util.run_deprecated_v1 |
| def testSigmoidLabelSmoothingEqualsSoftmaxTwoLabel(self): |
| with self.cached_session(): |
| label_smoothing = 0.1 |
| sigmoid_logits = constant_op.constant([[100.0, -100.0, -100.0]]) |
| sigmoid_labels = constant_op.constant([[1, 0, 1]]) |
| sigmoid_loss = losses.sigmoid_cross_entropy( |
| sigmoid_labels, sigmoid_logits, label_smoothing=label_smoothing) |
| self.assertEquals(sigmoid_logits.dtype, sigmoid_loss.dtype) |
| |
| softmax_logits = constant_op.constant( |
| [[0.0, 100.0], [100.0, 0.0], [100.0, 0.0]]) |
| softmax_labels = constant_op.constant([[0, 1], [1, 0], [0, 1]]) |
| softmax_loss = losses.softmax_cross_entropy( |
| softmax_labels, softmax_logits, label_smoothing=label_smoothing) |
| self.assertAlmostEqual(sigmoid_loss.eval(), self.evaluate(softmax_loss), |
| 3) |
| |
| |
| @test_util.run_deprecated_v1 |
| class LogLossTest(test.TestCase): |
| |
| def setUp(self): |
| super(LogLossTest, self).setUp() |
| predictions = np.asarray([.9, .2, .2, .8, .4, .6]).reshape((2, 3)) |
| labels = np.asarray([1.0, 0.0, 1.0, 1.0, 0.0, 0.0]).reshape((2, 3)) |
| |
| self._np_predictions = predictions |
| self._np_labels = labels |
| |
| epsilon = 1e-7 |
| self._expected_losses = np.multiply( |
| labels, np.log(predictions + epsilon)) + np.multiply( |
| 1 - labels, np.log(1 - predictions + epsilon)) |
| |
| self._predictions = constant_op.constant(predictions) |
| self._labels = constant_op.constant(labels) |
| |
| def testValueErrorThrownWhenWeightIsNone(self): |
| with self.cached_session(): |
| with self.assertRaises(ValueError): |
| losses.log_loss(self._labels, self._labels, weights=None) |
| |
| def testAllCorrectNoLossWeight(self): |
| loss = losses.log_loss(self._labels, self._labels) |
| with self.cached_session(): |
| self.assertAlmostEqual(0.0, self.evaluate(loss), 3) |
| |
| def testAllCorrectNoLossWeightWithPlaceholder(self): |
| tf_predictions = array_ops.placeholder( |
| dtypes.float32, shape=self._np_labels.shape) |
| loss = losses.log_loss(self._labels, tf_predictions) |
| with self.cached_session(): |
| self.assertAlmostEqual( |
| 0.0, loss.eval(feed_dict={tf_predictions: self._np_labels}), 3) |
| |
| def testNonZeroLoss(self): |
| loss = losses.log_loss(self._labels, self._predictions) |
| with self.cached_session(): |
| self.assertAlmostEqual(-np.sum(self._expected_losses) / 6.0, |
| self.evaluate(loss), 3) |
| |
| def testNonZeroLossWithPythonScalarWeight(self): |
| weights = 2.3 |
| loss = losses.log_loss(self._labels, self._predictions, weights) |
| with self.cached_session(): |
| self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0, |
| self.evaluate(loss), 3) |
| |
| def testNonZeroLossWithScalarTensorWeight(self): |
| weights = 2.3 |
| loss = losses.log_loss(self._labels, self._predictions, |
| constant_op.constant(weights)) |
| with self.cached_session(): |
| self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0, |
| self.evaluate(loss), 3) |
| |
| def testNonZeroLossWithScalarTensorWeightAndPlaceholder(self): |
| tf_predictions = array_ops.placeholder( |
| dtypes.float32, shape=self._np_predictions.shape) |
| weights = 2.3 |
| loss = losses.log_loss(self._labels, tf_predictions, |
| constant_op.constant(weights)) |
| with self.cached_session() as sess: |
| loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions}) |
| self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0, |
| loss, 3) |
| |
| def testNonZeroLossWithScalarTensorWeightAndPlaceholderWithRankOnly(self): |
| tf_predictions = array_ops.placeholder(dtypes.float32, shape=[None, None]) |
| weights = 2.3 |
| loss = losses.log_loss(self._labels, tf_predictions, |
| constant_op.constant(weights)) |
| with self.cached_session() as sess: |
| loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions}) |
| self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0, |
| loss, 3) |
| |
| def testNonZeroLossWithOneDimBatchSpecificWeights(self): |
| weights = constant_op.constant((1.2, 3.4), shape=(2, 1)) |
| expected_losses = np.multiply( |
| self._expected_losses, |
| np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3))) |
| loss = losses.log_loss(self._labels, self._predictions, weights) |
| with self.cached_session(): |
| self.assertAlmostEqual(-np.sum(expected_losses) / 6.0, |
| self.evaluate(loss), 3) |
| |
| def testNonZeroLossWithOneDimBatchSpecificWeightsSomeZero(self): |
| weights = constant_op.constant((1.2, 0), shape=(2, 1)) |
| expected_losses = np.multiply(self._expected_losses, |
| np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape( |
| (2, 3))) |
| loss = losses.log_loss(self._labels, self._predictions, weights) |
| with self.cached_session(): |
| self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, |
| self.evaluate(loss), 3) |
| |
| def testNonZeroLossWithTwoDimBatchSpecificWeightsSomeZero(self): |
| weights = constant_op.constant([1.2, 0], shape=[2, 1]) |
| expected_losses = np.multiply(self._expected_losses, |
| np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape( |
| (2, 3))) |
| loss = losses.log_loss(self._labels, self._predictions, weights) |
| with self.cached_session(): |
| self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, |
| self.evaluate(loss), 3) |
| |
| def testWeightsWithSameNumDimsButWrongShapeThrowsException(self): |
| weights = constant_op.constant(np.random.normal(size=(2, 4)), shape=[2, 4]) |
| with self.cached_session(): |
| with self.assertRaises(ValueError): |
| losses.log_loss(self._labels, self._predictions, weights) |
| |
| def testNonZeroLossWithMeasurementSpecificWeights(self): |
| weights = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3)) |
| expected_losses = np.multiply(self._expected_losses, weights) |
| |
| loss = losses.log_loss( |
| self._labels, |
| self._predictions, |
| constant_op.constant( |
| weights, shape=(2, 3))) |
| with self.cached_session(): |
| self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, |
| self.evaluate(loss), 3) |
| |
| def testNonZeroLossWithMeasurementSpecificWeightsWithPlaceholder(self): |
| weights = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3)) |
| expected_losses = np.multiply(self._expected_losses, weights) |
| |
| tf_predictions = array_ops.placeholder(dtypes.float32, shape=[2, 3]) |
| loss = losses.log_loss( |
| self._labels, |
| tf_predictions, |
| constant_op.constant( |
| weights, shape=(2, 3))) |
| |
| with self.cached_session() as sess: |
| loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions}) |
| self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss, 3) |
| |
| def testNonZeroLossWithSampleSpecificWeightsMostZero(self): |
| weights = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3)) |
| expected_losses = np.multiply(self._expected_losses, weights) |
| |
| loss = losses.log_loss( |
| self._labels, |
| self._predictions, |
| constant_op.constant( |
| weights, shape=(2, 3))) |
| with self.cached_session(): |
| self.assertAlmostEqual(-np.sum(expected_losses), self.evaluate(loss), 3) |
| |
| def testNonZeroLossWithSampleSpecificWeightsMostZeroWithPlaceholder(self): |
| weights = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3)) |
| expected_losses = np.multiply(self._expected_losses, weights) |
| |
| tf_predictions = array_ops.placeholder(dtypes.float32, shape=[2, 3]) |
| tf_weights = constant_op.constant(weights, shape=(2, 3)) |
| loss = losses.log_loss(self._labels, tf_predictions, tf_weights) |
| |
| with self.cached_session() as sess: |
| loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions}) |
| self.assertAlmostEqual(-np.sum(expected_losses), loss, 3) |
| |
| def testLossWithSampleSpecificWeightsAllZero(self): |
| tf_weights = array_ops.zeros(shape=(2, 3)) |
| loss = losses.log_loss(self._labels, self._predictions, tf_weights) |
| with self.cached_session(): |
| self.assertAlmostEqual(0.0, self.evaluate(loss), 3) |
| |
| |
| class HingeLossTest(test.TestCase): |
| |
| def testIncompatibleShapes(self): |
| with self.cached_session(): |
| logits = constant_op.constant([[-1.0], [2.1]]) |
| labels = constant_op.constant([0.0, 1.0]) |
| with self.assertRaises(ValueError): |
| _ = losses.hinge_loss(labels, logits).eval() |
| |
| @test_util.run_deprecated_v1 |
| def testAllOutsideMargin(self): |
| with self.cached_session(): |
| logits = constant_op.constant([1.2, -1.4, -1.0, 2.1]) |
| labels = constant_op.constant([1.0, 0.0, 0.0, 1.0]) |
| loss = losses.hinge_loss(labels, logits) |
| self.assertAllClose(loss.eval(), 0.0, atol=1e-3) |
| |
| @test_util.run_deprecated_v1 |
| def testSomeInsideMargin(self): |
| with self.cached_session(): |
| logits = constant_op.constant([[-0.7], [-1.4], [1.4], [0.6]]) |
| labels = constant_op.constant([[0.0], [0.0], [1.0], [1.0]]) |
| loss = losses.hinge_loss(labels, logits) |
| # Examples 1 and 4 are on the correct side of the hyperplane but within |
| # the margin so they incur some (small) loss. |
| self.assertAllClose(loss.eval(), 0.175, atol=1e-3) |
| |
| @test_util.run_deprecated_v1 |
| def testSomeMisclassified(self): |
| with self.cached_session(): |
| logits = constant_op.constant([[[1.2], [0.4], [-1.0], [-1.1]]]) |
| labels = constant_op.constant([[[1.0], [0.0], [0.0], [1.0]]]) |
| loss = losses.hinge_loss(labels, logits) |
| # Examples 2 and 4 are on the wrong side of the hyperplane so they incur |
| # some (fairly large) loss. |
| self.assertAllClose(loss.eval(), 0.875, atol=1e-3) |
| |
| |
| class HuberLossTest(test.TestCase): |
| |
| def testIncompatibleShapes(self): |
| with self.cached_session(): |
| predictions = constant_op.constant([[-1.0], [2.1]]) |
| labels = constant_op.constant([0.0, 1.0]) |
| with self.assertRaises(ValueError): |
| _ = losses.huber_loss(labels, predictions).eval() |
| |
| @test_util.run_deprecated_v1 |
| def testAllQuadratic(self): |
| with self.cached_session(): |
| predictions = constant_op.constant([1.5, -1.4, -1.0, 0.0]) |
| labels = constant_op.constant([1.0, -1.0, 0.0, 0.5]) |
| loss = losses.huber_loss(labels, predictions) |
| self.assertAllClose(loss.eval(), |
| 0.5 * (0.25 + 0.16 + 1.0 + 0.25) / 4., atol=1e-5) |
| |
| @test_util.run_deprecated_v1 |
| def testAllLinear(self): |
| with self.cached_session(): |
| predictions = constant_op.constant([1.5, -1.4, -1.0, 0.0]) |
| labels = constant_op.constant([0.0, 1.0, 0.0, 1.5]) |
| loss = losses.huber_loss(labels, predictions) |
| self.assertAllClose(loss.eval(), |
| (1.5 + 2.4 + 1.0 + 1.5) / 4. - 0.5, atol=1e-5) |
| |
| @test_util.run_deprecated_v1 |
| def testMixedQuadraticLinear(self): |
| with self.cached_session(): |
| predictions = constant_op.constant([[1.5, -1.4, -1.0, 0.0], |
| [1.5, -1.4, -1.0, 0.0]]) |
| labels = constant_op.constant([[1.0, -1.0, 0.0, 0.5], |
| [0.0, 1.0, 0.0, 1.5]]) |
| loss = losses.huber_loss(labels, predictions) |
| quadratic = 0.5 * (0.25 + 0.16 + 1.0 + 0.25) / 4. |
| linear = (1.5 + 2.4 + 1.0 + 1.5) / 4. - 0.5 |
| expected_loss = (quadratic + linear) / 2. |
| self.assertAllClose(loss.eval(), expected_loss, atol=1e-5) |
| |
| def testAllQuadraticDelta(self): |
| with self.cached_session(): |
| delta = 0.5 |
| predictions = constant_op.constant([1.5, -1.4, -0.5, 0.0]) |
| labels = constant_op.constant([1.0, -1.0, 0.0, 0.5]) |
| expected = 0.5 * np.array([0.5**2, 0.4**2, 0.5**2, 0.5**2]).mean() |
| loss = losses.huber_loss(labels, predictions, delta=delta) |
| self.assertAllClose(expected, self.evaluate(loss), atol=1e-5) |
| |
| def testAllLinearDelta(self): |
| delta = 0.5 |
| predictions = constant_op.constant([1.5, -1.4, -1.0, 0.0]) |
| labels = constant_op.constant([0.0, 1.0, 0.0, 1.5]) |
| expected = delta * np.array([1.5, 2.4, 1.0, 1.5]).mean() |
| expected -= 0.5 * delta**2 |
| loss = losses.huber_loss(labels, predictions, delta=delta) |
| with self.cached_session(): |
| self.assertAllClose(expected, self.evaluate(loss), atol=1e-5) |
| |
| |
| @test_util.run_deprecated_v1 |
| class MeanSquaredErrorTest(test.TestCase): |
| |
| def setUp(self): |
| super(MeanSquaredErrorTest, self).setUp() |
| self._predictions = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3)) |
| self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) |
| |
| def testValueErrorThrownWhenWeightIsNone(self): |
| with self.cached_session(): |
| with self.assertRaises(ValueError): |
| losses.mean_squared_error( |
| self._predictions, self._predictions, weights=None) |
| |
| @test_util.run_deprecated_v1 |
| def testScalar(self): |
| with self.cached_session(): |
| self.assertEqual( |
| 0.0, |
| losses.mean_squared_error(predictions=constant_op.constant(0), |
| labels=constant_op.constant(0)).eval()) |
| |
| @test_util.run_deprecated_v1 |
| def testAllCorrectNoLossWeight(self): |
| loss = losses.mean_squared_error(self._predictions, self._predictions) |
| with self.cached_session(): |
| self.assertAlmostEqual(0.0, self.evaluate(loss), 3) |
| |
| @test_util.run_deprecated_v1 |
| def testNonZeroLoss(self): |
| loss = losses.mean_squared_error(self._labels, self._predictions) |
| with self.cached_session(): |
| self.assertAlmostEqual(49.5, self.evaluate(loss), 3) |
| |
| @test_util.run_deprecated_v1 |
| def testNonZeroLossWithPythonScalarWeight(self): |
| weights = 2.3 |
| loss = losses.mean_squared_error(self._labels, self._predictions, weights) |
| with self.cached_session(): |
| self.assertAlmostEqual(49.5 * weights, self.evaluate(loss), 3) |
| |
| @test_util.run_deprecated_v1 |
| def testNonZeroLossWithScalarTensorWeight(self): |
| weights = 2.3 |
| loss = losses.mean_squared_error(self._labels, self._predictions, |
| constant_op.constant(weights)) |
| with self.cached_session(): |
| self.assertAlmostEqual(49.5 * weights, self.evaluate(loss), 3) |
| |
| def testNonZeroLossWithOneDimBatchSpecificWeights(self): |
| weights = constant_op.constant([1.2, 3.4], shape=(2, 1)) |
| loss = losses.mean_squared_error(self._labels, self._predictions, weights) |
| with self.cached_session(): |
| self.assertAlmostEqual(767.8 / 6.0, self.evaluate(loss), 3) |
| |
| def testNonZeroLossWithTwoDimBatchSpecificWeights(self): |
| weights = constant_op.constant([1.2, 3.4], shape=[2, 1]) |
| loss = losses.mean_squared_error(self._labels, self._predictions, weights) |
| with self.cached_session(): |
| self.assertAlmostEqual(767.8 / 6.0, self.evaluate(loss), 3) |
| |
| def testNonZeroLossWithSampleSpecificWeights(self): |
| weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3]) |
| loss = losses.mean_squared_error(self._labels, self._predictions, weights) |
| with self.cached_session(): |
| self.assertAlmostEqual(587 / 5.0, self.evaluate(loss), 3) |
| |
| def testNonZeroLossWithSampleSpecificWeightsMostZero(self): |
| weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3]) |
| loss = losses.mean_squared_error(self._labels, self._predictions, weights) |
| with self.cached_session(): |
| self.assertAlmostEqual(18.0, self.evaluate(loss), 3) |
| |
| def testLossWithSampleSpecificWeightsAllZero(self): |
| weights = array_ops.zeros((2, 3)) |
| loss = losses.mean_squared_error(self._labels, self._predictions, weights) |
| with self.cached_session(): |
| self.assertAlmostEqual(0.0, self.evaluate(loss), 3) |
| |
| |
| @test_util.run_deprecated_v1 |
| class MeanPairwiseSquaredErrorTest(test.TestCase): |
| |
| def setUp(self): |
| super(MeanPairwiseSquaredErrorTest, self).setUp() |
| self._predictions = np.array([[4, 8, 12], [8, 1, 3]]) |
| self._labels = np.array([[1, 9, 2], [-5, -5, 7]]) |
| |
| batch_size, dims = self._labels.shape |
| |
| # Compute the expected loss 'manually'. |
| total = np.zeros((batch_size,)) |
| for b in range(batch_size): |
| for i in range(dims - 1): |
| for j in range(i + 1, dims): |
| x = self._predictions[b, i].item() - self._predictions[b, j].item() |
| y = self._labels[b, i].item() - self._labels[b, j].item() |
| diff = (x - y) |
| total[b] += (diff * diff) |
| |
| self._expected_losses = np.divide(total, 3.0) |
| |
| def testValueErrorThrownWhenWeightIsNone(self): |
| with self.cached_session(): |
| with self.assertRaises(ValueError): |
| losses.mean_pairwise_squared_error( |
| predictions=constant_op.constant(self._labels), |
| labels=constant_op.constant(self._labels), |
| weights=None) |
| |
| def _test_valid_weights( |
| self, labels, predictions, expected_loss, weights=1.0): |
| with self.cached_session(): |
| static_inputs_op = losses.mean_pairwise_squared_error( |
| predictions=predictions, labels=labels, weights=weights) |
| self.assertAlmostEqual( |
| expected_loss, self.evaluate(static_inputs_op), places=3) |
| |
| predictions_placeholder = array_ops.placeholder( |
| dtypes.float32, shape=np.asarray(predictions.shape)) |
| labels_placeholder = array_ops.placeholder( |
| dtypes.int32, shape=np.asarray(labels.shape)) |
| weights_placeholder = array_ops.placeholder( |
| dtypes.float32, shape=np.asarray(weights).shape) |
| dynamic_inputs_op = losses.mean_pairwise_squared_error( |
| predictions=predictions_placeholder, |
| labels=labels_placeholder, |
| weights=weights_placeholder) |
| feed_dict = { |
| predictions_placeholder: predictions, |
| labels_placeholder: labels, |
| weights_placeholder: weights, |
| } |
| self.assertAlmostEqual( |
| expected_loss, dynamic_inputs_op.eval(feed_dict=feed_dict), places=3) |
| |
| def testAllCorrectNoLossWeight(self): |
| self._test_valid_weights( |
| self._labels, self._labels, expected_loss=0.0) |
| |
| def testNonZeroLoss(self): |
| self._test_valid_weights( |
| self._labels, self._predictions, |
| expected_loss=np.sum(self._expected_losses)) |
| |
| def testGradientWithZeroWeight(self): |
| with ops.Graph().as_default(): |
| random_seed.set_random_seed(0) |
| |
| inputs = array_ops.ones((2, 3)) |
| weights = variable_scope.get_variable( |
| 'weights', |
| shape=[3, 4], |
| initializer=init_ops.truncated_normal_initializer()) |
| predictions = math_ops.matmul(inputs, weights) |
| |
| optimizer = momentum_lib.MomentumOptimizer( |
| learning_rate=0.001, momentum=0.9) |
| loss = losses.mean_pairwise_squared_error(predictions, predictions, 0) |
| |
| gradients_to_variables = optimizer.compute_gradients(loss) |
| |
| init_op = variables.global_variables_initializer() |
| |
| with self.cached_session() as sess: |
| self.evaluate(init_op) |
| for grad, _ in gradients_to_variables: |
| np_grad = self.evaluate(grad) |
| self.assertFalse(np.isnan(np_grad).any()) |
| |
| def testNonZeroLossWithPythonScalarWeight(self): |
| weight = 2.3 |
| self._test_valid_weights( |
| self._labels, self._predictions, |
| expected_loss=weight * np.sum(self._expected_losses), |
| weights=weight) |
| |
| def testNonZeroLossWithScalarTensorWeight(self): |
| weights = 2.3 |
| loss = losses.mean_pairwise_squared_error( |
| predictions=constant_op.constant(self._predictions), |
| labels=constant_op.constant(self._labels), |
| weights=constant_op.constant(weights)) |
| with self.cached_session(): |
| self.assertAlmostEqual(weights * np.sum(self._expected_losses), |
| self.evaluate(loss), 3) |
| |
| def testNonZeroLossWithScalarZeroWeight(self): |
| self._test_valid_weights( |
| self._labels, self._predictions, expected_loss=0.0, weights=0.0) |
| |
| def test3d(self): |
| labels = np.array([ |
| [[1, 9, 2], [12, 11, 10], [9, 8, 7]], |
| [[-5, -5, 7], [6, 5, 4], [3, 2, 1]], |
| ]) |
| predictions = np.array([ |
| [[4, 8, 12], [1, 2, 3], [4, 5, 6]], |
| [[8, 1, 3], [7, 8, 9], [10, 11, 12]], |
| ]) |
| self._test_valid_weights(labels, predictions, expected_loss=137.5) |
| |
| def test3dWeightedScalar(self): |
| labels = np.array([ |
| [[1, 9, 2], [12, 11, 10], [9, 8, 7]], |
| [[-5, -5, 7], [6, 5, 4], [3, 2, 1]], |
| ]) |
| predictions = np.array([ |
| [[4, 8, 12], [1, 2, 3], [4, 5, 6]], |
| [[8, 1, 3], [7, 8, 9], [10, 11, 12]], |
| ]) |
| weight = 3.0 |
| self._test_valid_weights( |
| labels, predictions, expected_loss=weight * 137.5, weights=weight) |
| |
| def _test_invalid_weights( |
| self, labels, predictions, weights=1.0): |
| expected_error_msg = 'weights can not be broadcast to values' |
| |
| # Static check. |
| with self.assertRaisesRegexp(ValueError, expected_error_msg): |
| losses.mean_pairwise_squared_error( |
| predictions=predictions, labels=labels, weights=weights) |
| |
| # Dynamic check. |
| predictions_placeholder = array_ops.placeholder(dtypes.float32) |
| labels_placeholder = array_ops.placeholder(dtypes.int32) |
| weights_placeholder = array_ops.placeholder(dtypes.float32) |
| dynamic_inputs_op = losses.mean_pairwise_squared_error( |
| predictions=predictions_placeholder, |
| labels=labels_placeholder, |
| weights=weights_placeholder) |
| with self.cached_session(): |
| with self.assertRaisesRegexp(errors_impl.OpError, expected_error_msg): |
| dynamic_inputs_op.eval(feed_dict={ |
| predictions_placeholder: predictions, |
| labels_placeholder: labels, |
| weights_placeholder: weights, |
| }) |
| |
| def testInvalid3dWeighted2x0(self): |
| labels = np.array([ |
| [[1, 9, 2], [12, 11, 10], [9, 8, 7]], |
| [[-5, -5, 7], [6, 5, 4], [3, 2, 1]], |
| ]) |
| predictions = np.array([ |
| [[4, 8, 12], [1, 2, 3], [4, 5, 6]], |
| [[8, 1, 3], [7, 8, 9], [10, 11, 12]], |
| ]) |
| self._test_invalid_weights( |
| labels, predictions, weights=np.asarray((1.2, 3.4))) |
| |
| def test3dWeighted2x3x3(self): |
| labels = np.array([ |
| [[1, 9, 2], [12, 11, 10], [9, 8, 7]], |
| [[-5, -5, 7], [6, 5, 4], [3, 2, 1]], |
| ]) |
| predictions = np.array([ |
| [[4, 8, 12], [1, 2, 3], [4, 5, 6]], |
| [[8, 1, 3], [7, 8, 9], [10, 11, 12]], |
| ]) |
| self._test_valid_weights( |
| # TODO(ptucker): This doesn't look right. |
| labels, |
| predictions, |
| expected_loss=9 * 137.5, |
| weights=np.ones((2, 3, 3))) |
| |
| def testLossWithAllZeroBatchSpecificWeights(self): |
| self._test_valid_weights( |
| self._labels, self._predictions, expected_loss=0.0, |
| weights=np.zeros((2, 1))) |
| |
| def testLossIsAssociativeAcrossBatchElements(self): |
| with ops.Graph().as_default(): |
| random_seed.set_random_seed(0) |
| |
| height = 3 |
| width = 4 |
| shape = (1, height, width, 1) |
| |
| labels0 = random_ops.random_uniform( |
| shape, minval=0, maxval=1, dtype=dtypes.float32) |
| predictions0 = random_ops.random_uniform( |
| shape, minval=0, maxval=1, dtype=dtypes.float32) |
| |
| labels1 = random_ops.random_uniform( |
| shape, minval=0, maxval=1, dtype=dtypes.float32) |
| predictions1 = random_ops.random_uniform( |
| shape, minval=0, maxval=1, dtype=dtypes.float32) |
| |
| loss0 = losses.mean_pairwise_squared_error( |
| labels=labels0, |
| predictions=predictions0) |
| loss1 = losses.mean_pairwise_squared_error( |
| labels=labels1, |
| predictions=predictions1) |
| loss0_1 = losses.mean_pairwise_squared_error( |
| labels=array_ops.concat([labels0, labels1], 0), |
| predictions=array_ops.concat([predictions0, predictions1], 0)) |
| |
| with self.cached_session() as session: |
| loss0, loss1, loss0_1 = session.run([loss0, loss1, loss0_1]) |
| |
| self.assertTrue(loss0 > 0) |
| self.assertTrue(loss1 > 0) |
| self.assertAlmostEqual(loss0 + loss1, loss0_1, 5) |
| |
| |
| @test_util.run_deprecated_v1 |
| class CosineDistanceLossTest(test.TestCase): |
| |
| def setUp(self): |
| super(CosineDistanceLossTest, self).setUp() |
| self._predictions = np.asarray([ |
| [1, 0, 0], # Batch 1 |
| [0, 0, -1], |
| [1, 0, 0], # Batch 2 |
| [1, 0, 0], |
| [0, 0, -1], # Batch 3 |
| [1, 0, 0] |
| ]).reshape((3, 2, 3)) |
| |
| self._labels = np.asarray([[1, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], |
| [0, 0, 1], [0, 1, 0]]).reshape((3, 2, 3)) |
| |
| def testValueErrorThrownWhenWeightIsNone(self): |
| with self.cached_session(): |
| with self.assertRaises(ValueError): |
| losses.cosine_distance( |
| predictions=constant_op.constant(self._labels), |
| labels=constant_op.constant(self._labels), |
| dim=2, |
| weights=None) |
| |
| def testAllCorrectNoWeights(self): |
| loss = losses.cosine_distance( |
| predictions=constant_op.constant(self._labels), |
| labels=constant_op.constant(self._labels), |
| dim=2) |
| with self.cached_session(): |
| self.assertAlmostEqual(0, self.evaluate(loss), 5) |
| |
| def testPartiallyCorrectWithIntegerValues(self): |
| loss = losses.cosine_distance( |
| predictions=constant_op.constant(self._predictions), |
| labels=constant_op.constant(self._labels), |
| dim=2) |
| with self.cached_session(): |
| self.assertAlmostEqual(1, self.evaluate(loss), 5) |
| |
| def testPartiallyCorrectFloatingPointValues(self): |
| predictions = np.matrix( |
| ('0.819031913261206 0.567041924552012 0.087465312324590;' |
| '-0.665139432070255 -0.739487441769973 -0.103671883216994;' |
| '0.707106781186548 -0.707106781186548 0')) |
| labels = np.matrix(('0.819031913261206 0.567041924552012 0.087465312324590;' |
| '0.665139432070255 0.739487441769973 0.103671883216994;' |
| '0.707106781186548 0.707106781186548 0')) |
| |
| tf_preds = constant_op.constant( |
| predictions, shape=(3, 1, 3), dtype=dtypes.float32) |
| tf_labels = constant_op.constant( |
| labels, shape=(3, 1, 3), dtype=dtypes.float32) |
| loss = losses.cosine_distance(tf_labels, tf_preds, dim=2) |
| |
| with self.cached_session(): |
| self.assertAlmostEqual(1.0, self.evaluate(loss), 5) |
| |
| def testSampleSpecificWeights(self): |
| loss = losses.cosine_distance( |
| predictions=constant_op.constant(self._predictions), |
| labels=constant_op.constant(self._labels), |
| dim=2, |
| weights=np.asarray((1, 0, 0)).reshape((3, 1, 1))) |
| with self.cached_session(): |
| self.assertEqual(1.0, self.evaluate(loss)) |
| |
| def testMeasurementSpecificWeights(self): |
| loss = losses.cosine_distance( |
| predictions=constant_op.constant(self._predictions), |
| labels=constant_op.constant(self._labels), |
| dim=2, |
| weights=constant_op.constant( |
| [1, 0, 0, 1, 1, 1], shape=(3, 2, 1))) |
| with self.cached_session(): |
| self.assertEqual(3.0 / 4.0, self.evaluate(loss)) |
| |
| def testMeasurementSpecificWeightsWithPlaceholderWithShape(self): |
| tf_predictions = array_ops.placeholder( |
| dtypes.float32, shape=self._labels.shape) |
| loss = losses.cosine_distance( |
| predictions=tf_predictions, |
| labels=constant_op.constant(self._labels), |
| dim=2, |
| weights=constant_op.constant( |
| [1, 0, 0, 1, 1, 1], shape=(3, 2, 1))) |
| with self.cached_session() as sess: |
| loss = sess.run(loss, feed_dict={tf_predictions: self._predictions}) |
| self.assertEqual(3.0 / 4.0, loss) |
| |
| def testZeroLossWhenAllSampleSpecificWeightsAreZero(self): |
| loss = losses.cosine_distance( |
| predictions=constant_op.constant(self._predictions), |
| labels=constant_op.constant(self._labels), |
| dim=2, |
| weights=array_ops.zeros((3, 1, 1))) |
| with self.cached_session(): |
| self.assertEqual(0, self.evaluate(loss)) |
| |
| def testZeroLossWhenAllMeasurementSpecificWeightsAreZero(self): |
| loss = losses.cosine_distance( |
| predictions=constant_op.constant(self._predictions), |
| labels=constant_op.constant(self._labels), |
| dim=2, |
| weights=array_ops.zeros((3, 2, 1))) |
| with self.cached_session(): |
| self.assertEqual(0, self.evaluate(loss)) |
| |
| |
| class AddLossTest(test.TestCase): |
| |
| def testNoCollectLossesBatch2(self): |
| logits = constant_op.constant([[1.2, 0.4, -1.0, -1.1]] * 2) |
| labels = constant_op.constant([[1.0, 0.0, 0.0, 1.0]] * 2) |
| self.assertFalse(util.get_losses()) |
| losses.absolute_difference(logits, labels, loss_collection=None) |
| losses.log_loss(logits, labels, loss_collection=None) |
| losses.mean_squared_error(logits, labels, loss_collection=None) |
| losses.sigmoid_cross_entropy(logits, labels, loss_collection=None) |
| losses.softmax_cross_entropy(logits, labels, loss_collection=None) |
| self.assertFalse(util.get_losses()) |
| |
| |
| class ComputeWeightedLossTest(test.TestCase): |
| |
| def setUp(self): |
| super(ComputeWeightedLossTest, self).setUp() |
| self._shape = (3, 2, 4) |
| raw_losses = np.zeros(self._shape) |
| next_loss = 0.0 |
| for i in range(self._shape[0]): |
| for j in range(self._shape[1]): |
| for k in range(self._shape[2]): |
| raw_losses[i][j][k] = next_loss |
| next_loss += 1.0 |
| raw_losses.setflags(write=False) |
| self._raw_losses = raw_losses |
| |
| def testUnweighted(self): |
| for reduction in losses.Reduction.all(): |
| with ops.Graph().as_default() as g: |
| self.assertEqual(0, len(util.get_losses())) |
| raw_losses = self._raw_losses |
| unweighted_losses = ( |
| losses.compute_weighted_loss(raw_losses, reduction=reduction), |
| losses.compute_weighted_loss( |
| raw_losses, weights=np.ones((1, 1, 1)), reduction=reduction), |
| losses.compute_weighted_loss( |
| raw_losses, weights=np.ones((1, 1, 4)), reduction=reduction), |
| losses.compute_weighted_loss( |
| raw_losses, weights=np.ones((1, 2, 1)), reduction=reduction), |
| losses.compute_weighted_loss( |
| raw_losses, weights=np.ones((1, 2, 4)), reduction=reduction), |
| losses.compute_weighted_loss( |
| raw_losses, weights=np.ones((3, 1, 1)), reduction=reduction), |
| losses.compute_weighted_loss( |
| raw_losses, weights=np.ones((3, 1, 4)), reduction=reduction), |
| losses.compute_weighted_loss( |
| raw_losses, weights=np.ones((3, 2, 1)), reduction=reduction), |
| losses.compute_weighted_loss( |
| raw_losses, weights=np.ones(self._shape), reduction=reduction) |
| ) |
| self.assertEqual(9, len(util.get_losses())) |
| with self.session(g): |
| for unweighted_loss in unweighted_losses: |
| if reduction == losses.Reduction.NONE: |
| self.assertAllClose(self._raw_losses, |
| self.evaluate(unweighted_loss)) |
| elif reduction == losses.Reduction.SUM: |
| self.assertAllClose( |
| np.sum(self._raw_losses), self.evaluate(unweighted_loss)) |
| else: |
| # reduction one of MEAN, SUM_OVER_NONZERO_WEIGHTS, |
| # SUM_BY_NONZERO_WEIGHTS or SUM_OVER_BATCH_SIZE. |
| self.assertAllClose( |
| np.mean(self._raw_losses), self.evaluate(unweighted_loss)) |
| |
| def testUnweightedFromPlaceholder(self): |
| for reduction in losses.Reduction.all(): |
| with ops.Graph().as_default() as g: |
| self.assertEqual(0, len(util.get_losses())) |
| raw_losses = array_ops.placeholder(dtype=dtypes.float32) |
| feed_dict = {raw_losses: self._raw_losses} |
| unweighted_losses = ( |
| losses.compute_weighted_loss(raw_losses, reduction=reduction), |
| losses.compute_weighted_loss( |
| raw_losses, weights=np.ones((1, 1, 1)), reduction=reduction), |
| losses.compute_weighted_loss( |
| raw_losses, weights=np.ones((1, 1, 4)), reduction=reduction), |
| ) |
| self.assertEqual(3, len(util.get_losses())) |
| with self.session(g): |
| for unweighted_loss in unweighted_losses: |
| if reduction == losses.Reduction.NONE: |
| self.assertAllClose( |
| self._raw_losses, unweighted_loss.eval(feed_dict)) |
| elif reduction == losses.Reduction.SUM: |
| self.assertAllClose( |
| np.sum(self._raw_losses), unweighted_loss.eval(feed_dict)) |
| else: |
| # reduction one of MEAN, SUM_OVER_NONZERO_WEIGHTS, |
| # SUM_BY_NONZERO_WEIGHTS or SUM_OVER_BATCH_SIZE. |
| self.assertAllClose( |
| np.mean(self._raw_losses), unweighted_loss.eval(feed_dict)) |
| |
| def testScalarWeight(self): |
| with ops.Graph().as_default(): |
| self.assertEqual(0, len(util.get_losses())) |
| weight = 17.0 |
| weighted_loss = losses.compute_weighted_loss( |
| self._raw_losses, weights=weight) |
| self.assertEqual(1, len(util.get_losses())) |
| with self.cached_session(): |
| self.assertAllClose( |
| np.mean(weight * self._raw_losses), self.evaluate(weighted_loss)) |
| |
| def _test_invalid_weights(self, weights): |
| with ops.Graph().as_default(): |
| self.assertEqual(0, len(util.get_losses())) |
| expected_error_msg = 'weights can not be broadcast to values' |
| |
| # Static check. |
| with self.assertRaisesRegexp(ValueError, expected_error_msg): |
| losses.compute_weighted_loss(self._raw_losses, weights=weights) |
| |
| # Dynamic check. |
| weights_placeholder = array_ops.placeholder(dtypes.float32) |
| weighted_loss = losses.compute_weighted_loss( |
| self._raw_losses, weights=weights_placeholder) |
| self.assertEqual(1, len(util.get_losses())) |
| with self.cached_session(): |
| with self.assertRaisesRegexp(errors_impl.OpError, expected_error_msg): |
| weighted_loss.eval(feed_dict={weights_placeholder: weights}) |
| |
| def testInvalidWeightTooManyDims(self): |
| self._test_invalid_weights(np.zeros(shape=(2, 2, 2, 2))) |
| |
| def testInvalidWeightMismatchedDim(self): |
| with ops.Graph().as_default(): |
| raw_losses = array_ops.reshape(self._raw_losses, shape=(3, 2, 4, 1)) |
| weights = np.ones(shape=(3, 2, 4, 2)) |
| expected_error_msg = 'weights can not be broadcast to values' |
| self.assertEqual(0, len(util.get_losses())) |
| |
| # Static check. |
| with self.assertRaisesRegexp(ValueError, expected_error_msg): |
| losses.compute_weighted_loss(raw_losses, weights=weights) |
| |
| # Dynamic check. |
| weights_placeholder = array_ops.placeholder(dtypes.float32) |
| weighted_loss = losses.compute_weighted_loss( |
| raw_losses, weights=weights_placeholder) |
| self.assertEqual(1, len(util.get_losses())) |
| with self.cached_session(): |
| with self.assertRaisesRegexp(errors_impl.OpError, expected_error_msg): |
| weighted_loss.eval(feed_dict={weights_placeholder: weights}) |
| |
| def testInvalid3Weight(self): |
| self._test_invalid_weights((17.0, 5.0, 2.0)) |
| |
| def testInvalid3x1Weight(self): |
| self._test_invalid_weights(((17.0,), (5.0,), (2.0,),)) |
| |
| def testInvalid3x2Weight(self): |
| self._test_invalid_weights(( |
| (17.0, 3.0), |
| (5.0, 31.0), |
| (2.0, 7.0),)) |
| |
| def testInvalid1x2Weight(self): |
| self._test_invalid_weights((17.0, 3.0,),) |
| |
| def testInvalidScalar1DWeight(self): |
| self._test_invalid_weights((17.0,),) |
| |
| def _test_valid_weights(self, weights): |
| for reduction in losses.Reduction.all(): |
| with ops.Graph().as_default() as g: |
| self.assertEqual(0, len(util.get_losses())) |
| weighted_loss = losses.compute_weighted_loss( |
| self._raw_losses, weights=weights, reduction=reduction) |
| self.assertEqual(1, len(util.get_losses())) |
| with self.session(g): |
| weighted_losses = weights * self._raw_losses |
| weighted_sum = np.sum(weighted_losses) |
| if reduction == losses.Reduction.NONE: |
| self.assertAllClose(weighted_losses, self.evaluate(weighted_loss)) |
| elif reduction == losses.Reduction.SUM: |
| self.assertAllClose(weighted_sum, self.evaluate(weighted_loss)) |
| else: |
| broadcast_weights = weights * np.ones_like(self._raw_losses) |
| if reduction == losses.Reduction.MEAN: |
| self.assertAllClose(weighted_sum / np.sum(broadcast_weights), |
| self.evaluate(weighted_loss)) |
| elif (reduction == losses.Reduction.SUM_OVER_NONZERO_WEIGHTS or |
| reduction == losses.Reduction.SUM_BY_NONZERO_WEIGHTS): |
| self.assertAllClose( |
| weighted_sum / np.count_nonzero(broadcast_weights), |
| self.evaluate(weighted_loss)) |
| elif reduction == losses.Reduction.SUM_OVER_BATCH_SIZE: |
| self.assertAllClose(weighted_sum / self._raw_losses.size, |
| self.evaluate(weighted_loss)) |
| |
| def test1x1x1Weight(self): |
| self._test_valid_weights((((17.0,),),)) |
| |
| def test1x2x1Weight(self): |
| self._test_valid_weights((((17.0,), (3.0,),),)) |
| |
| def test1x1x4Weight(self): |
| self._test_valid_weights((((17.0, 0.0, 2.0, 5.0),),)) |
| |
| def test3x1x1Weight(self): |
| self._test_valid_weights((((17.0,),), ((5.0,),), ((2.0,),),)) |
| |
| def test3x2x1Weight(self): |
| self._test_valid_weights(( |
| ((17.0,), (3.0,)), |
| ((5.0,), (31.0,)), |
| ((2.0,), (7.0,)), |
| )) |
| |
| def test3x1x4Weight(self): |
| self._test_valid_weights(( |
| ((17.0, 0.0, 2.0, 5.0),), |
| ((5.0, 31.0, 17.0, 5.0),), |
| ((7.0, 3.0, 11.0, 5.0),), |
| )) |
| |
| def test1x2x4Weight(self): |
| self._test_valid_weights((( |
| (17.0, 0.0, 2.0, 5.0), |
| (3.0, 13.0, 11.0, 2.0), |
| ),)) |
| |
| def test3x2x4Weight(self): |
| self._test_valid_weights(( |
| ((17.0, 0.0, 2.0, 5.0), (3.0, 13.0, 11.0, 2.0),), |
| ((5.0, 31.0, 17.0, 5.0), (13.0, 3.0, 0.0, 11.0),), |
| ((0.0, 3.0, 11.0, 5.0), (13.0, 11.0, 1.0, 7.0),), |
| )) |
| |
| |
| if __name__ == '__main__': |
| test.main() |