tensorflow.python.ops.check_ops 源代码

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=g-short-docstring-punctuation
"""Asserts and Boolean Checks."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections

import numpy as np

from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export

NUMERIC_TYPES = frozenset(
    [dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16, dtypes.int32,
     dtypes.int64, dtypes.uint8, dtypes.qint8, dtypes.qint32, dtypes.quint8,
     dtypes.complex64])

__all__ = [
    'assert_negative',
    'assert_positive',
    'assert_proper_iterable',
    'assert_non_negative',
    'assert_non_positive',
    'assert_equal',
    'assert_none_equal',
    'assert_near',
    'assert_integer',
    'assert_less',
    'assert_less_equal',
    'assert_greater',
    'assert_greater_equal',
    'assert_rank',
    'assert_rank_at_least',
    'assert_rank_in',
    'assert_same_float_dtype',
    'assert_scalar',
    'assert_type',
    'assert_shapes',
    'is_non_decreasing',
    'is_numeric_tensor',
    'is_strictly_increasing',
]


def _maybe_constant_value_string(t):
  if not isinstance(t, ops.Tensor):
    return str(t)
  const_t = tensor_util.constant_value(t)
  if const_t is not None:
    return str(const_t)
  return t


def _assert_static(condition, data):
  """Raises a InvalidArgumentError with as much information as possible."""
  if not condition:
    data_static = [_maybe_constant_value_string(x) for x in data]
    raise errors.InvalidArgumentError(node_def=None, op=None,
                                      message='\n'.join(data_static))


def _shape_and_dtype_str(tensor):
  """Returns a string containing tensor's shape and dtype."""
  return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name)


def _unary_assert_doc(sym, sym_name):
  """Common docstring for assert_* ops that evaluate a unary predicate over every element of a tensor.

  Args:
    sym: Mathematical symbol for the check performed on each element, i.e. "> 0"
    sym_name: English-language name for the op described by sym

  Returns:
    Decorator that adds the appropriate docstring to the function for symbol
    `sym`.
  """

  def _decorator(func):
    """Generated decorator that adds the appropriate docstring to the function for symbol `sym`.

    Args:
      func: Function for a TensorFlow op

    Returns:
      Version of `func` with documentation attached.
    """
    opname = func.__name__
    cap_sym_name = sym_name.capitalize()

    func.__doc__ = """
    Assert the condition `x {sym}` holds element-wise.

    When running in graph mode, you should add a dependency on this operation
    to ensure that it runs. Example of adding a dependency to an operation:

    ```python
    with tf.control_dependencies([tf.debugging.{opname}(x, y)]):
      output = tf.reduce_sum(x)
    ```

    {sym_name} means, for every element `x[i]` of `x`, we have `x[i] {sym}`.
    If `x` is empty this is trivially satisfied.

    Args:
      x:  Numeric `Tensor`.
      data:  The tensors to print out if the condition is False.  Defaults to
        error message and first few entries of `x`.
      summarize: Print this many entries of each tensor.
      message: A string to prefix to the default message.
      name: A name for this operation (optional).  Defaults to "{opname}".

    Returns:
      Op that raises `InvalidArgumentError` if `x {sym}` is False.
      @compatibility(eager)
        returns None
      @end_compatibility

    Raises:
      InvalidArgumentError: if the check can be performed immediately and
        `x {sym}` is False. The check can be performed immediately during 
        eager execution or if `x` is statically known.
    """.format(
        sym=sym, sym_name=cap_sym_name, opname=opname)
    return func

  return _decorator


def _binary_assert_doc(sym):
  """Common docstring for most of the v1 assert_* ops that compare two tensors element-wise.

  Args:
    sym: Binary operation symbol, i.e. "=="

  Returns:
    Decorator that adds the appropriate docstring to the function for
  symbol `sym`.
  """

  def _decorator(func):
    """Generated decorator that adds the appropriate docstring to the function for symbol `sym`.

    Args:
      func: Function for a TensorFlow op

    Returns:
      A version of `func` with documentation attached.
    """
    opname = func.__name__

    func.__doc__ = """
    Assert the condition `x {sym} y` holds element-wise.

    This condition holds if for every pair of (possibly broadcast) elements
    `x[i]`, `y[i]`, we have `x[i] {sym} y[i]`.
    If both `x` and `y` are empty, this is trivially satisfied.

    When running in graph mode, you should add a dependency on this operation
    to ensure that it runs. Example of adding a dependency to an operation:

    ```python
    with tf.control_dependencies([tf.compat.v1.{opname}(x, y)]):
      output = tf.reduce_sum(x)
    ```

    Args:
      x:  Numeric `Tensor`.
      y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
      data:  The tensors to print out if the condition is False.  Defaults to
        error message and first few entries of `x`, `y`.
      summarize: Print this many entries of each tensor.
      message: A string to prefix to the default message.
      name: A name for this operation (optional).  Defaults to "{opname}".

    Returns:
      Op that raises `InvalidArgumentError` if `x {sym} y` is False.
      @compatibility(eager)
        returns None
      @end_compatibility

    Raises:
      InvalidArgumentError: if the check can be performed immediately and
        `x {sym} y` is False. The check can be performed immediately during 
        eager execution or if `x` and `y` are statically known.
    """.format(
        sym=sym, opname=opname)
    return func

  return _decorator


def _make_assert_msg_data(sym, x, y, summarize, test_op):
  """Subroutine of _binary_assert that generates the components of the default error message when running in eager mode.

  Args:
    sym: Mathematical symbol for the test to apply to pairs of tensor elements,
      i.e. "=="
    x: First input to the assertion after applying `convert_to_tensor()`
    y: Second input to the assertion
    summarize: Value of the "summarize" parameter to the original assert_* call;
      tells how many elements of each tensor to print.
    test_op: TensorFlow op that returns a Boolean tensor with True in each
      position where the assertion is satisfied.

  Returns:
    List of tensors and scalars that, when stringified and concatenated,
    will produce the error message string.
  """
  # Prepare a message with first elements of x and y.
  data = []

  data.append('Condition x %s y did not hold.' % sym)

  if summarize > 0:
    if x.shape == y.shape and x.shape.as_list():
      # If the shapes of x and y are the same (and not scalars),
      # Get the values that actually differed and their indices.
      # If shapes are different this information is more confusing
      # than useful.
      mask = math_ops.logical_not(test_op)
      indices = array_ops.where(mask)
      indices_np = indices.numpy()
      x_vals = array_ops.boolean_mask(x, mask)
      y_vals = array_ops.boolean_mask(y, mask)
      num_vals = min(summarize, indices_np.shape[0])
      data.append('Indices of first %d different values:' % num_vals)
      data.append(indices_np[:num_vals])
      data.append('Corresponding x values:')
      data.append(x_vals.numpy().reshape((-1,))[:num_vals])
      data.append('Corresponding y values:')
      data.append(y_vals.numpy().reshape((-1,))[:num_vals])

    # reshape((-1,)) is the fastest way to get a flat array view.
    x_np = x.numpy().reshape((-1,))
    y_np = y.numpy().reshape((-1,))
    x_sum = min(x_np.size, summarize)
    y_sum = min(y_np.size, summarize)
    data.append('First %d elements of x:' % x_sum)
    data.append(x_np[:x_sum])
    data.append('First %d elements of y:' % y_sum)
    data.append(y_np[:y_sum])

  return data


def _pretty_print(data_item, summarize):
  """Format a data item for use in an error message in eager mode.

  Args:
    data_item: One of the items in the "data" argument to an assert_* function.
      Can be a Tensor or a scalar value.
    summarize: How many elements to retain of each tensor-valued entry in data.

  Returns:
    An appropriate string representation of data_item
  """
  if isinstance(data_item, ops.Tensor):
    arr = data_item.numpy()
    if np.isscalar(arr):
      # Tensor.numpy() returns a scalar for zero-dimensional tensors
      return str(arr)
    else:
      flat = arr.reshape((-1,))
      lst = [str(x) for x in flat[:summarize]]
      if len(lst) < flat.size:
        lst.append('...')
      return str(lst)
  else:
    return str(data_item)


def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize,
                   message, name):
  """Generic binary elementwise assertion.

  Implements the behavior described in _binary_assert_doc() above.
  Args:
    sym: Mathematical symbol for the test to apply to pairs of tensor elements,
      i.e. "=="
    opname: Name of the assert op in the public API, i.e. "assert_equal"
    op_func: Function that, if passed the two Tensor inputs to the assertion (x
      and y), will return the test to be passed to reduce_all() i.e.
    static_func: Function that, if passed numpy ndarray versions of the two
      inputs to the assertion, will return a Boolean ndarray with containing
      True in all positions where the assertion PASSES.
      i.e. np.equal for assert_equal()
    x:  Numeric `Tensor`.
    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
    data:  The tensors to print out if the condition is False.  Defaults to
      error message and first few entries of `x`, `y`.
    summarize: Print this many entries of each tensor.
    message: A string to prefix to the default message.
    name: A name for this operation (optional).  Defaults to the value of
      `opname`.

  Returns:
    See docstring template in _binary_assert_doc().
  """
  with ops.name_scope(name, opname, [x, y, data]):
    x = ops.convert_to_tensor(x, name='x')
    y = ops.convert_to_tensor(y, name='y')

    if context.executing_eagerly():
      test_op = op_func(x, y)
      condition = math_ops.reduce_all(test_op)
      if condition:
        return

      # If we get here, the assertion has failed.
      # Default to printing 3 elements like control_flow_ops.Assert (used
      # by graph mode) does. Also treat negative values as "print
      # everything" for consistency with Tensor::SummarizeValue().
      if summarize is None:
        summarize = 3
      elif summarize < 0:
        summarize = 1e9  # Code below will find exact size of x and y.

      if data is None:
        data = _make_assert_msg_data(sym, x, y, summarize, test_op)

      if message is not None:
        data = [message] + list(data)

      raise errors.InvalidArgumentError(
          node_def=None,
          op=None,
          message=('\n'.join(_pretty_print(d, summarize) for d in data)))

    else:  # not context.executing_eagerly()
      if data is None:
        data = [
            'Condition x %s y did not hold element-wise:' % sym,
            'x (%s) = ' % x.name, x,
            'y (%s) = ' % y.name, y
        ]
      if message is not None:
        data = [message] + list(data)
      condition = math_ops.reduce_all(op_func(x, y))
      x_static = tensor_util.constant_value(x)
      y_static = tensor_util.constant_value(y)
      if x_static is not None and y_static is not None:
        condition_static = np.all(static_func(x_static, y_static))
        _assert_static(condition_static, data)
      return control_flow_ops.Assert(condition, data, summarize=summarize)


@tf_export(
    'debugging.assert_proper_iterable',
    v1=['debugging.assert_proper_iterable', 'assert_proper_iterable'])
@deprecation.deprecated_endpoints('assert_proper_iterable')
def assert_proper_iterable(values):
  """Static assert that values is a "proper" iterable.

  `Ops` that expect iterables of `Tensor` can call this to validate input.
  Useful since `Tensor`, `ndarray`, byte/text type are all iterables themselves.

  Args:
    values:  Object to be checked.

  Raises:
    TypeError:  If `values` is not iterable or is one of
      `Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`.
  """
  unintentional_iterables = (
      (ops.Tensor, sparse_tensor.SparseTensor, np.ndarray)
      + compat.bytes_or_text_types
  )
  if isinstance(values, unintentional_iterables):
    raise TypeError(
        'Expected argument "values" to be a "proper" iterable.  Found: %s' %
        type(values))

  if not hasattr(values, '__iter__'):
    raise TypeError(
        'Expected argument "values" to be iterable.  Found: %s' % type(values))


@tf_export('debugging.assert_negative', v1=[])
def assert_negative_v2(x, message=None, summarize=None, name=None):
  """Assert the condition `x < 0` holds element-wise.

  This Op checks that `x[i] < 0` holds for every element of `x`. If `x` is
  empty, this is trivially satisfied.

  If `x` is not negative everywhere, `message`, as well as the first `summarize`
  entries of `x` are printed, and `InvalidArgumentError` is raised.

  Args:
    x:  Numeric `Tensor`.
    message: A string to prefix to the default message.
    summarize: Print this many entries of each tensor.
    name: A name for this operation (optional).  Defaults to "assert_negative".

  Returns:
    Op raising `InvalidArgumentError` unless `x` is all negative. This can be
      used with `tf.control_dependencies` inside of `tf.function`s to block
      followup computation until the check has executed.
    @compatibility(eager)
    returns None
    @end_compatibility

  Raises:
    InvalidArgumentError: if the check can be performed immediately and
      `x[i] < 0` is False. The check can be performed immediately during eager
      execution or if `x` is statically known.
  """
  return assert_negative(x=x, message=message, summarize=summarize, name=name)


@tf_export(v1=['debugging.assert_negative', 'assert_negative'])
@deprecation.deprecated_endpoints('assert_negative')
@_unary_assert_doc('< 0', 'negative')
def assert_negative(x, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
  message = message or ''
  with ops.name_scope(name, 'assert_negative', [x, data]):
    x = ops.convert_to_tensor(x, name='x')
    if data is None:
      if context.executing_eagerly():
        name = _shape_and_dtype_str(x)
      else:
        name = x.name
      data = [
          message,
          'Condition x < 0 did not hold element-wise:',
          'x (%s) = ' % name, x]
    zero = ops.convert_to_tensor(0, dtype=x.dtype)
    return assert_less(x, zero, data=data, summarize=summarize)


@tf_export('debugging.assert_positive', v1=[])
def assert_positive_v2(x, message=None, summarize=None, name=None):
  """Assert the condition `x > 0` holds element-wise.

  This Op checks that `x[i] > 0` holds for every element of `x`. If `x` is
  empty, this is trivially satisfied.

  If `x` is not positive everywhere, `message`, as well as the first `summarize`
  entries of `x` are printed, and `InvalidArgumentError` is raised.

  Args:
    x:  Numeric `Tensor`.
    message: A string to prefix to the default message.
    summarize: Print this many entries of each tensor.
    name: A name for this operation (optional). Defaults to "assert_positive".

  Returns:
    Op raising `InvalidArgumentError` unless `x` is all positive. This can be
      used with `tf.control_dependencies` inside of `tf.function`s to block
      followup computation until the check has executed.
    @compatibility(eager)
    returns None
    @end_compatibility

  Raises:
    InvalidArgumentError: if the check can be performed immediately and
      `x[i] > 0` is False. The check can be performed immediately during eager
      execution or if `x` is statically known.
  """
  return assert_positive(x=x, summarize=summarize, message=message, name=name)


@tf_export(v1=['debugging.assert_positive', 'assert_positive'])
@deprecation.deprecated_endpoints('assert_positive')
@_unary_assert_doc('> 0', 'positive')
def assert_positive(x, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
  message = message or ''
  with ops.name_scope(name, 'assert_positive', [x, data]):
    x = ops.convert_to_tensor(x, name='x')
    if data is None:
      if context.executing_eagerly():
        name = _shape_and_dtype_str(x)
      else:
        name = x.name
      data = [
          message, 'Condition x > 0 did not hold element-wise:',
          'x (%s) = ' % name, x]
    zero = ops.convert_to_tensor(0, dtype=x.dtype)
    return assert_less(zero, x, data=data, summarize=summarize)


@tf_export('debugging.assert_non_negative', v1=[])
def assert_non_negative_v2(x, message=None, summarize=None, name=None):
  """Assert the condition `x >= 0` holds element-wise.

  This Op checks that `x[i] >= 0` holds for every element of `x`. If `x` is
  empty, this is trivially satisfied.

  If `x` is not >= 0 everywhere, `message`, as well as the first `summarize`
  entries of `x` are printed, and `InvalidArgumentError` is raised.

  Args:
    x:  Numeric `Tensor`.
    message: A string to prefix to the default message.
    summarize: Print this many entries of each tensor.
    name: A name for this operation (optional).  Defaults to
      "assert_non_negative".

  Returns:
    Op raising `InvalidArgumentError` unless `x` is all non-negative. This can
      be used with `tf.control_dependencies` inside of `tf.function`s to block
      followup computation until the check has executed.
    @compatibility(eager)
    returns None
    @end_compatibility

  Raises:
    InvalidArgumentError: if the check can be performed immediately and
      `x[i] >= 0` is False. The check can be performed immediately during eager
      execution or if `x` is statically known.
  """
  return assert_non_negative(x=x, summarize=summarize, message=message,
                             name=name)


@tf_export(v1=['debugging.assert_non_negative', 'assert_non_negative'])
@deprecation.deprecated_endpoints('assert_non_negative')
@_unary_assert_doc('>= 0', 'non-negative')
def assert_non_negative(x, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
  message = message or ''
  with ops.name_scope(name, 'assert_non_negative', [x, data]):
    x = ops.convert_to_tensor(x, name='x')
    if data is None:
      if context.executing_eagerly():
        name = _shape_and_dtype_str(x)
      else:
        name = x.name
      data = [
          message,
          'Condition x >= 0 did not hold element-wise:',
          'x (%s) = ' % name, x]
    zero = ops.convert_to_tensor(0, dtype=x.dtype)
    return assert_less_equal(zero, x, data=data, summarize=summarize)


@tf_export('debugging.assert_non_positive', v1=[])
def assert_non_positive_v2(x, message=None, summarize=None, name=None):
  """Assert the condition `x <= 0` holds element-wise.

  This Op checks that `x[i] <= 0` holds for every element of `x`. If `x` is
  empty, this is trivially satisfied.

  If `x` is not <= 0 everywhere, `message`, as well as the first `summarize`
  entries of `x` are printed, and `InvalidArgumentError` is raised.

  Args:
    x:  Numeric `Tensor`.
    message: A string to prefix to the default message.
    summarize: Print this many entries of each tensor.
    name: A name for this operation (optional).  Defaults to
      "assert_non_positive".

  Returns:
    Op raising `InvalidArgumentError` unless `x` is all non-positive. This can
      be used with `tf.control_dependencies` inside of `tf.function`s to block
      followup computation until the check has executed.
    @compatibility(eager)
    returns None
    @end_compatibility

  Raises:
    InvalidArgumentError: if the check can be performed immediately and
      `x[i] <= 0` is False. The check can be performed immediately during eager
      execution or if `x` is statically known.
  """
  return assert_non_positive(x=x, summarize=summarize, message=message,
                             name=name)


@tf_export(v1=['debugging.assert_non_positive', 'assert_non_positive'])
@deprecation.deprecated_endpoints('assert_non_positive')
@_unary_assert_doc('<= 0', 'non-positive')
def assert_non_positive(x, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
  message = message or ''
  with ops.name_scope(name, 'assert_non_positive', [x, data]):
    x = ops.convert_to_tensor(x, name='x')
    if data is None:
      if context.executing_eagerly():
        name = _shape_and_dtype_str(x)
      else:
        name = x.name
      data = [
          message,
          'Condition x <= 0 did not hold element-wise:'
          'x (%s) = ' % name, x]
    zero = ops.convert_to_tensor(0, dtype=x.dtype)
    return assert_less_equal(x, zero, data=data, summarize=summarize)


@tf_export('debugging.assert_equal', 'assert_equal', v1=[])
def assert_equal_v2(x, y, message=None, summarize=None, name=None):
  """Assert the condition `x == y` holds element-wise.

  This Op checks that `x[i] == y[i]` holds for every pair of (possibly
  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
  trivially satisfied.

  If `x` and `y` are not equal, `message`, as well as the first `summarize`
  entries of `x` and `y` are printed, and `InvalidArgumentError` is raised.

  Args:
    x:  Numeric `Tensor`.
    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
    message: A string to prefix to the default message.
    summarize: Print this many entries of each tensor.
    name: A name for this operation (optional).  Defaults to "assert_equal".

  Returns:
    Op that raises `InvalidArgumentError` if `x == y` is False. This can be
      used with `tf.control_dependencies` inside of `tf.function`s to block
      followup computation until the check has executed.
    @compatibility(eager)
    returns None
    @end_compatibility

  Raises:
    InvalidArgumentError: if the check can be performed immediately and
      `x == y` is False. The check can be performed immediately during eager
      execution or if `x` and `y` are statically known.
  """
  return assert_equal(x=x, y=y, summarize=summarize, message=message, name=name)


[文档]@tf_export(v1=['debugging.assert_equal', 'assert_equal']) @_binary_assert_doc('==') def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring with ops.name_scope(name, 'assert_equal', [x, y, data]): # Short-circuit if x and y are the same tensor. if x is y: return None if context.executing_eagerly() else control_flow_ops.no_op() return _binary_assert('==', 'assert_equal', math_ops.equal, np.equal, x, y, data, summarize, message, name)
@tf_export('debugging.assert_none_equal', v1=[]) def assert_none_equal_v2(x, y, summarize=None, message=None, name=None): """Assert the condition `x != y` holds for all elements. This Op checks that `x[i] != y[i]` holds for every pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is trivially satisfied. If any elements of `x` and `y` are equal, `message`, as well as the first `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is raised. Args: x: Numeric `Tensor`. y: Numeric `Tensor`, same dtype as and broadcastable to `x`. summarize: Print this many entries of each tensor. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_none_equal". Returns: Op that raises `InvalidArgumentError` if `x != y` is ever False. This can be used with `tf.control_dependencies` inside of `tf.function`s to block followup computation until the check has executed. @compatibility(eager) returns None @end_compatibility Raises: InvalidArgumentError: if the check can be performed immediately and `x != y` is False for any pair of elements in `x` and `y`. The check can be performed immediately during eager execution or if `x` and `y` are statically known. """ return assert_none_equal(x=x, y=y, summarize=summarize, message=message, name=name) @tf_export(v1=['debugging.assert_none_equal', 'assert_none_equal']) @deprecation.deprecated_endpoints('assert_none_equal') @_binary_assert_doc('!=') def assert_none_equal( x, y, data=None, summarize=None, message=None, name=None): return _binary_assert('!=', 'assert_none_equal', math_ops.not_equal, np.not_equal, x, y, data, summarize, message, name) @tf_export('debugging.assert_near', v1=[]) def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None, name=None): """Assert the condition `x` and `y` are close element-wise. This Op checks that `x[i] - y[i] < atol + rtol * tf.abs(y[i])` holds for every pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is trivially satisfied. If any elements of `x` and `y` are not close, `message`, as well as the first `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is raised. The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest representable positive number such that `1 + eps != 1`. This is about `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`. See `numpy.finfo`. Args: x: Float or complex `Tensor`. y: Float or complex `Tensor`, same dtype as and broadcastable to `x`. rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. The relative tolerance. Default is `10 * eps`. atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. The absolute tolerance. Default is `10 * eps`. message: A string to prefix to the default message. summarize: Print this many entries of each tensor. name: A name for this operation (optional). Defaults to "assert_near". Returns: Op that raises `InvalidArgumentError` if `x` and `y` are not close enough. This can be used with `tf.control_dependencies` inside of `tf.function`s to block followup computation until the check has executed. @compatibility(eager) returns None @end_compatibility Raises: InvalidArgumentError: if the check can be performed immediately and `x != y` is False for any pair of elements in `x` and `y`. The check can be performed immediately during eager execution or if `x` and `y` are statically known. @compatibility(numpy) Similar to `numpy.assert_allclose`, except tolerance depends on data type. This is due to the fact that `TensorFlow` is often used with `32bit`, `64bit`, and even `16bit` data. @end_compatibility """ return assert_near(x=x, y=y, rtol=rtol, atol=atol, summarize=summarize, message=message, name=name) @tf_export(v1=['debugging.assert_near', 'assert_near']) @deprecation.deprecated_endpoints('assert_near') def assert_near( x, y, rtol=None, atol=None, data=None, summarize=None, message=None, name=None): """Assert the condition `x` and `y` are close element-wise. Example of adding a dependency to an operation: ```python with tf.control_dependencies([tf.compat.v1.assert_near(x, y)]): output = tf.reduce_sum(x) ``` This condition holds if for every pair of (possibly broadcast) elements `x[i]`, `y[i]`, we have ```tf.abs(x[i] - y[i]) <= atol + rtol * tf.abs(y[i])```. If both `x` and `y` are empty, this is trivially satisfied. The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest representable positive number such that `1 + eps != 1`. This is about `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`. See `numpy.finfo`. Args: x: Float or complex `Tensor`. y: Float or complex `Tensor`, same `dtype` as, and broadcastable to, `x`. rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. The relative tolerance. Default is `10 * eps`. atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. The absolute tolerance. Default is `10 * eps`. data: The tensors to print out if the condition is False. Defaults to error message and first few entries of `x`, `y`. summarize: Print this many entries of each tensor. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_near". Returns: Op that raises `InvalidArgumentError` if `x` and `y` are not close enough. @compatibility(numpy) Similar to `numpy.assert_allclose`, except tolerance depends on data type. This is due to the fact that `TensorFlow` is often used with `32bit`, `64bit`, and even `16bit` data. @end_compatibility """ message = message or '' with ops.name_scope(name, 'assert_near', [x, y, rtol, atol, data]): x = ops.convert_to_tensor(x, name='x') y = ops.convert_to_tensor(y, name='y', dtype=x.dtype) eps = np.finfo(x.dtype.as_numpy_dtype).eps rtol = 10 * eps if rtol is None else rtol atol = 10 * eps if atol is None else atol rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=x.dtype) atol = ops.convert_to_tensor(atol, name='atol', dtype=x.dtype) if context.executing_eagerly(): x_name = _shape_and_dtype_str(x) y_name = _shape_and_dtype_str(y) else: x_name = x.name y_name = y.name if data is None: data = [ message, 'x and y not equal to tolerance rtol = %s, atol = %s' % (rtol, atol), 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y ] tol = atol + rtol * math_ops.abs(y) diff = math_ops.abs(x - y) condition = math_ops.reduce_all(math_ops.less(diff, tol)) return control_flow_ops.Assert(condition, data, summarize=summarize) @tf_export('debugging.assert_less', 'assert_less', v1=[]) def assert_less_v2(x, y, message=None, summarize=None, name=None): """Assert the condition `x < y` holds element-wise. This Op checks that `x[i] < y[i]` holds for every pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is trivially satisfied. If `x` is not less than `y` element-wise, `message`, as well as the first `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is raised. Args: x: Numeric `Tensor`. y: Numeric `Tensor`, same dtype as and broadcastable to `x`. message: A string to prefix to the default message. summarize: Print this many entries of each tensor. name: A name for this operation (optional). Defaults to "assert_less". Returns: Op that raises `InvalidArgumentError` if `x < y` is False. This can be used with `tf.control_dependencies` inside of `tf.function`s to block followup computation until the check has executed. @compatibility(eager) returns None @end_compatibility Raises: InvalidArgumentError: if the check can be performed immediately and `x < y` is False. The check can be performed immediately during eager execution or if `x` and `y` are statically known. """ return assert_less(x=x, y=y, summarize=summarize, message=message, name=name)
[文档]@tf_export(v1=['debugging.assert_less', 'assert_less']) @_binary_assert_doc('<') def assert_less(x, y, data=None, summarize=None, message=None, name=None): return _binary_assert('<', 'assert_less', math_ops.less, np.less, x, y, data, summarize, message, name)
@tf_export('debugging.assert_less_equal', v1=[]) def assert_less_equal_v2(x, y, message=None, summarize=None, name=None): """Assert the condition `x <= y` holds element-wise. This Op checks that `x[i] <= y[i]` holds for every pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is trivially satisfied. If `x` is not less or equal than `y` element-wise, `message`, as well as the first `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is raised. Args: x: Numeric `Tensor`. y: Numeric `Tensor`, same dtype as and broadcastable to `x`. message: A string to prefix to the default message. summarize: Print this many entries of each tensor. name: A name for this operation (optional). Defaults to "assert_less_equal". Returns: Op that raises `InvalidArgumentError` if `x <= y` is False. This can be used with `tf.control_dependencies` inside of `tf.function`s to block followup computation until the check has executed. @compatibility(eager) returns None @end_compatibility Raises: InvalidArgumentError: if the check can be performed immediately and `x <= y` is False. The check can be performed immediately during eager execution or if `x` and `y` are statically known. """ return assert_less_equal(x=x, y=y, summarize=summarize, message=message, name=name) @tf_export(v1=['debugging.assert_less_equal', 'assert_less_equal']) @deprecation.deprecated_endpoints('assert_less_equal') @_binary_assert_doc('<=') def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None): return _binary_assert('<=', 'assert_less_equal', math_ops.less_equal, np.less_equal, x, y, data, summarize, message, name) @tf_export('debugging.assert_greater', 'assert_greater', v1=[]) def assert_greater_v2(x, y, message=None, summarize=None, name=None): """Assert the condition `x > y` holds element-wise. This Op checks that `x[i] > y[i]` holds for every pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is trivially satisfied. If `x` is not greater than `y` element-wise, `message`, as well as the first `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is raised. Args: x: Numeric `Tensor`. y: Numeric `Tensor`, same dtype as and broadcastable to `x`. message: A string to prefix to the default message. summarize: Print this many entries of each tensor. name: A name for this operation (optional). Defaults to "assert_greater". Returns: Op that raises `InvalidArgumentError` if `x > y` is False. This can be used with `tf.control_dependencies` inside of `tf.function`s to block followup computation until the check has executed. @compatibility(eager) returns None @end_compatibility Raises: InvalidArgumentError: if the check can be performed immediately and `x > y` is False. The check can be performed immediately during eager execution or if `x` and `y` are statically known. """ return assert_greater(x=x, y=y, summarize=summarize, message=message, name=name)
[文档]@tf_export(v1=['debugging.assert_greater', 'assert_greater']) @_binary_assert_doc('>') def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring return _binary_assert('>', 'assert_greater', math_ops.greater, np.greater, x, y, data, summarize, message, name)
@tf_export('debugging.assert_greater_equal', v1=[]) def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None): """Assert the condition `x >= y` holds element-wise. This Op checks that `x[i] >= y[i]` holds for every pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is trivially satisfied. If `x` is not greater or equal to `y` element-wise, `message`, as well as the first `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is raised. Args: x: Numeric `Tensor`. y: Numeric `Tensor`, same dtype as and broadcastable to `x`. message: A string to prefix to the default message. summarize: Print this many entries of each tensor. name: A name for this operation (optional). Defaults to "assert_greater_equal". Returns: Op that raises `InvalidArgumentError` if `x >= y` is False. This can be used with `tf.control_dependencies` inside of `tf.function`s to block followup computation until the check has executed. @compatibility(eager) returns None @end_compatibility Raises: InvalidArgumentError: if the check can be performed immediately and `x >= y` is False. The check can be performed immediately during eager execution or if `x` and `y` are statically known. """ return assert_greater_equal(x=x, y=y, summarize=summarize, message=message, name=name) @tf_export(v1=['debugging.assert_greater_equal', 'assert_greater_equal']) @deprecation.deprecated_endpoints('assert_greater_equal') @_binary_assert_doc('>=') def assert_greater_equal(x, y, data=None, summarize=None, message=None, name=None): return _binary_assert('>=', 'assert_greater_equal', math_ops.greater_equal, np.greater_equal, x, y, data, summarize, message, name) def _assert_rank_condition( x, rank, static_condition, dynamic_condition, data, summarize): """Assert `x` has a rank that satisfies a given condition. Args: x: Numeric `Tensor`. rank: Scalar `Tensor`. static_condition: A python function that takes `[actual_rank, given_rank]` and returns `True` if the condition is satisfied, `False` otherwise. dynamic_condition: An `op` that takes [actual_rank, given_rank] and return `True` if the condition is satisfied, `False` otherwise. data: The tensors to print out if the condition is false. Defaults to error message and first few entries of `x`. summarize: Print this many entries of each tensor. Returns: Op raising `InvalidArgumentError` if `x` fails dynamic_condition. Raises: ValueError: If static checks determine `x` fails static_condition. """ assert_type(rank, dtypes.int32) # Attempt to statically defined rank. rank_static = tensor_util.constant_value(rank) if rank_static is not None: if rank_static.ndim != 0: raise ValueError('Rank must be a scalar.') x_rank_static = x.get_shape().ndims if x_rank_static is not None: if not static_condition(x_rank_static, rank_static): raise ValueError( 'Static rank condition failed', x_rank_static, rank_static) return control_flow_ops.no_op(name='static_checks_determined_all_ok') condition = dynamic_condition(array_ops.rank(x), rank) # Add the condition that `rank` must have rank zero. Prevents the bug where # someone does assert_rank(x, [n]), rather than assert_rank(x, n). if rank_static is None: this_data = ['Rank must be a scalar. Received rank: ', rank] rank_check = assert_rank(rank, 0, data=this_data) condition = control_flow_ops.with_dependencies([rank_check], condition) return control_flow_ops.Assert(condition, data, summarize=summarize) @tf_export('debugging.assert_rank', 'assert_rank', v1=[]) def assert_rank_v2(x, rank, message=None, name=None): """Assert that `x` has rank equal to `rank`. This Op checks that the rank of `x` is equal to `rank`. If `x` has a different rank, `message`, as well as the shape of `x` are printed, and `InvalidArgumentError` is raised. Args: x: `Tensor`. rank: Scalar integer `Tensor`. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_rank". Returns: Op raising `InvalidArgumentError` unless `x` has specified rank. If static checks determine `x` has correct rank, a `no_op` is returned. This can be used with `tf.control_dependencies` inside of `tf.function`s to block followup computation until the check has executed. @compatibility(eager) returns None @end_compatibility Raises: InvalidArgumentError: if the check can be performed immediately and `x` does not have rank `rank`. The check can be performed immediately during eager execution or if the shape of `x` is statically known. """ return assert_rank(x=x, rank=rank, message=message, name=name)
[文档]@tf_export(v1=['debugging.assert_rank', 'assert_rank']) def assert_rank(x, rank, data=None, summarize=None, message=None, name=None): """Assert `x` has rank equal to `rank`. Example of adding a dependency to an operation: ```python with tf.control_dependencies([tf.compat.v1.assert_rank(x, 2)]): output = tf.reduce_sum(x) ``` Args: x: Numeric `Tensor`. rank: Scalar integer `Tensor`. data: The tensors to print out if the condition is False. Defaults to error message and the shape of `x`. summarize: Print this many entries of each tensor. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_rank". Returns: Op raising `InvalidArgumentError` unless `x` has specified rank. If static checks determine `x` has correct rank, a `no_op` is returned. Raises: ValueError: If static checks determine `x` has wrong rank. """ with ops.name_scope(name, 'assert_rank', (x, rank) + tuple(data or [])): x = ops.convert_to_tensor(x, name='x') rank = ops.convert_to_tensor(rank, name='rank') message = message or '' static_condition = lambda actual_rank, given_rank: actual_rank == given_rank dynamic_condition = math_ops.equal if context.executing_eagerly(): name = '' else: name = x.name if data is None: data = [ message, 'Tensor %s must have rank' % name, rank, 'Received shape: ', array_ops.shape(x) ] try: assert_op = _assert_rank_condition(x, rank, static_condition, dynamic_condition, data, summarize) except ValueError as e: if e.args[0] == 'Static rank condition failed': raise ValueError( '%s. Tensor %s must have rank %d. Received rank %d, shape %s' % (message, name, e.args[2], e.args[1], x.get_shape())) else: raise return assert_op
@tf_export('debugging.assert_rank_at_least', v1=[]) def assert_rank_at_least_v2(x, rank, message=None, name=None): """Assert that `x` has rank of at least `rank`. This Op checks that the rank of `x` is greater or equal to `rank`. If `x` has a rank lower than `rank`, `message`, as well as the shape of `x` are printed, and `InvalidArgumentError` is raised. Args: x: `Tensor`. rank: Scalar integer `Tensor`. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_rank_at_least". Returns: Op raising `InvalidArgumentError` unless `x` has specified rank or higher. If static checks determine `x` has correct rank, a `no_op` is returned. This can be used with `tf.control_dependencies` inside of `tf.function`s to block followup computation until the check has executed. @compatibility(eager) returns None @end_compatibility Raises: InvalidArgumentError: `x` does not have rank at least `rank`, but the rank cannot be statically determined. ValueError: If static checks determine `x` has mismatched rank. """ return assert_rank_at_least(x=x, rank=rank, message=message, name=name) @tf_export(v1=['debugging.assert_rank_at_least', 'assert_rank_at_least']) @deprecation.deprecated_endpoints('assert_rank_at_least') def assert_rank_at_least( x, rank, data=None, summarize=None, message=None, name=None): """Assert `x` has rank equal to `rank` or higher. Example of adding a dependency to an operation: ```python with tf.control_dependencies([tf.compat.v1.assert_rank_at_least(x, 2)]): output = tf.reduce_sum(x) ``` Args: x: Numeric `Tensor`. rank: Scalar `Tensor`. data: The tensors to print out if the condition is False. Defaults to error message and first few entries of `x`. summarize: Print this many entries of each tensor. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_rank_at_least". Returns: Op raising `InvalidArgumentError` unless `x` has specified rank or higher. If static checks determine `x` has correct rank, a `no_op` is returned. Raises: ValueError: If static checks determine `x` has wrong rank. """ with ops.name_scope( name, 'assert_rank_at_least', (x, rank) + tuple(data or [])): x = ops.convert_to_tensor(x, name='x') rank = ops.convert_to_tensor(rank, name='rank') message = message or '' static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank dynamic_condition = math_ops.greater_equal if context.executing_eagerly(): name = '' else: name = x.name if data is None: data = [ message, 'Tensor %s must have rank at least' % name, rank, 'Received shape: ', array_ops.shape(x) ] try: assert_op = _assert_rank_condition(x, rank, static_condition, dynamic_condition, data, summarize) except ValueError as e: if e.args[0] == 'Static rank condition failed': raise ValueError( '%s. Tensor %s must have rank at least %d. Received rank %d, ' 'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape())) else: raise return assert_op def _static_rank_in(actual_rank, given_ranks): return actual_rank in given_ranks def _dynamic_rank_in(actual_rank, given_ranks): if len(given_ranks) < 1: return ops.convert_to_tensor(False) result = math_ops.equal(given_ranks[0], actual_rank) for given_rank in given_ranks[1:]: result = math_ops.logical_or( result, math_ops.equal(given_rank, actual_rank)) return result def _assert_ranks_condition( x, ranks, static_condition, dynamic_condition, data, summarize): """Assert `x` has a rank that satisfies a given condition. Args: x: Numeric `Tensor`. ranks: Scalar `Tensor`. static_condition: A python function that takes `[actual_rank, given_ranks]` and returns `True` if the condition is satisfied, `False` otherwise. dynamic_condition: An `op` that takes [actual_rank, given_ranks] and return `True` if the condition is satisfied, `False` otherwise. data: The tensors to print out if the condition is false. Defaults to error message and first few entries of `x`. summarize: Print this many entries of each tensor. Returns: Op raising `InvalidArgumentError` if `x` fails dynamic_condition. Raises: ValueError: If static checks determine `x` fails static_condition. """ for rank in ranks: assert_type(rank, dtypes.int32) # Attempt to statically defined rank. ranks_static = tuple([tensor_util.constant_value(rank) for rank in ranks]) if not any(r is None for r in ranks_static): for rank_static in ranks_static: if rank_static.ndim != 0: raise ValueError('Rank must be a scalar.') x_rank_static = x.get_shape().ndims if x_rank_static is not None: if not static_condition(x_rank_static, ranks_static): raise ValueError( 'Static rank condition failed', x_rank_static, ranks_static) return control_flow_ops.no_op(name='static_checks_determined_all_ok') condition = dynamic_condition(array_ops.rank(x), ranks) # Add the condition that `rank` must have rank zero. Prevents the bug where # someone does assert_rank(x, [n]), rather than assert_rank(x, n). for rank, rank_static in zip(ranks, ranks_static): if rank_static is None: this_data = ['Rank must be a scalar. Received rank: ', rank] rank_check = assert_rank(rank, 0, data=this_data) condition = control_flow_ops.with_dependencies([rank_check], condition) return control_flow_ops.Assert(condition, data, summarize=summarize) @tf_export('debugging.assert_rank_in', v1=[]) def assert_rank_in_v2(x, ranks, message=None, name=None): """Assert that `x` has a rank in `ranks`. This Op checks that the rank of `x` is in `ranks`. If `x` has a different rank, `message`, as well as the shape of `x` are printed, and `InvalidArgumentError` is raised. Args: x: `Tensor`. ranks: `Iterable` of scalar `Tensor` objects. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_rank_in". Returns: Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`. If static checks determine `x` has matching rank, a `no_op` is returned. This can be used with `tf.control_dependencies` inside of `tf.function`s to block followup computation until the check has executed. @compatibility(eager) returns None @end_compatibility Raises: InvalidArgumentError: `x` does not have rank in `ranks`, but the rank cannot be statically determined. ValueError: If static checks determine `x` has mismatched rank. """ return assert_rank_in(x=x, ranks=ranks, message=message, name=name) @tf_export(v1=['debugging.assert_rank_in', 'assert_rank_in']) @deprecation.deprecated_endpoints('assert_rank_in') def assert_rank_in( x, ranks, data=None, summarize=None, message=None, name=None): """Assert `x` has rank in `ranks`. Example of adding a dependency to an operation: ```python with tf.control_dependencies([tf.compat.v1.assert_rank_in(x, (2, 4))]): output = tf.reduce_sum(x) ``` Args: x: Numeric `Tensor`. ranks: Iterable of scalar `Tensor` objects. data: The tensors to print out if the condition is False. Defaults to error message and first few entries of `x`. summarize: Print this many entries of each tensor. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_rank_in". Returns: Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`. If static checks determine `x` has matching rank, a `no_op` is returned. Raises: ValueError: If static checks determine `x` has mismatched rank. """ with ops.name_scope( name, 'assert_rank_in', (x,) + tuple(ranks) + tuple(data or [])): x = ops.convert_to_tensor(x, name='x') ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks]) message = message or '' if context.executing_eagerly(): name = '' else: name = x.name if data is None: data = [ message, 'Tensor %s must have rank in' % name ] + list(ranks) + [ 'Received shape: ', array_ops.shape(x) ] try: assert_op = _assert_ranks_condition(x, ranks, _static_rank_in, _dynamic_rank_in, data, summarize) except ValueError as e: if e.args[0] == 'Static rank condition failed': raise ValueError( '%s. Tensor %s must have rank in %s. Received rank %d, ' 'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape())) else: raise return assert_op @tf_export('debugging.assert_integer', v1=[]) def assert_integer_v2(x, message=None, name=None): """Assert that `x` is of integer dtype. If `x` has a non-integer type, `message`, as well as the dtype of `x` are printed, and `InvalidArgumentError` is raised. This can always be checked statically, so this method returns nothing. Args: x: A `Tensor`. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_integer". Raises: TypeError: If `x.dtype` is not a non-quantized integer type. """ assert_integer(x=x, message=message, name=name) @tf_export(v1=['debugging.assert_integer', 'assert_integer']) @deprecation.deprecated_endpoints('assert_integer') def assert_integer(x, message=None, name=None): """Assert that `x` is of integer dtype. Example of adding a dependency to an operation: ```python with tf.control_dependencies([tf.compat.v1.assert_integer(x)]): output = tf.reduce_sum(x) ``` Args: x: `Tensor` whose basetype is integer and is not quantized. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_integer". Raises: TypeError: If `x.dtype` is anything other than non-quantized integer. Returns: A `no_op` that does nothing. Type can be determined statically. """ message = message or '' with ops.name_scope(name, 'assert_integer', [x]): x = ops.convert_to_tensor(x, name='x') if not x.dtype.is_integer: if context.executing_eagerly(): name = 'tensor' else: name = x.name err_msg = ( '%s Expected "x" to be integer type. Found: %s of dtype %s' % (message, name, x.dtype)) raise TypeError(err_msg) return control_flow_ops.no_op('statically_determined_was_integer') @tf_export('debugging.assert_type', v1=[]) def assert_type_v2(tensor, tf_type, message=None, name=None): """Asserts that the given `Tensor` is of the specified type. This can always be checked statically, so this method returns nothing. Args: tensor: A `Tensor`. tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`, etc). message: A string to prefix to the default message. name: A name for this operation. Defaults to "assert_type" Raises: TypeError: If the tensor's data type doesn't match `tf_type`. """ assert_type(tensor=tensor, tf_type=tf_type, message=message, name=name) @tf_export(v1=['debugging.assert_type', 'assert_type']) @deprecation.deprecated_endpoints('assert_type') def assert_type(tensor, tf_type, message=None, name=None): """Statically asserts that the given `Tensor` is of the specified type. Args: tensor: A `Tensor`. tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`, etc). message: A string to prefix to the default message. name: A name to give this `Op`. Defaults to "assert_type" Raises: TypeError: If the tensors data type doesn't match `tf_type`. Returns: A `no_op` that does nothing. Type can be determined statically. """ message = message or '' with ops.name_scope(name, 'assert_type', [tensor]): tensor = ops.convert_to_tensor(tensor, name='tensor') if tensor.dtype != tf_type: if context.executing_eagerly(): raise TypeError('%s tensor must be of type %s' % (message, tf_type)) else: raise TypeError('%s %s must be of type %s' % (message, tensor.name, tf_type)) return control_flow_ops.no_op('statically_determined_correct_type') def _dimension_sizes(x): """Gets the dimension sizes of a tensor `x`. If a size can be determined statically it is returned as an integer, otherwise as a tensor. If `x` is a scalar it is treated as rank 1 size 1. Args: x: A `Tensor`. Returns: Dimension sizes. """ dynamic_shape = array_ops.shape(x) rank = x.get_shape().rank rank_is_known = rank is not None if rank_is_known and rank == 0: return tuple([1]) if rank_is_known and rank > 0: static_shape = x.get_shape().as_list() sizes = [ int(size) if size is not None else dynamic_shape[i] for i, size in enumerate(static_shape) ] return sizes has_rank_zero = math_ops.equal(array_ops.rank(x), 0) return control_flow_ops.cond( has_rank_zero, lambda: array_ops.constant([1]), lambda: dynamic_shape) def _symbolic_dimension_sizes(symbolic_shape): # If len(symbolic_shape) == 0 construct a tuple if not symbolic_shape: return tuple([1]) return symbolic_shape def _has_known_value(dimension_size): not_none = dimension_size is not None try: int(dimension_size) can_be_parsed_as_int = True except (ValueError, TypeError): can_be_parsed_as_int = False return not_none and can_be_parsed_as_int def _is_symbol_for_any_size(symbol): return symbol in [None, '.'] _TensorDimSizes = collections.namedtuple( '_TensorDimSizes', ['x', 'unspecified_dim', 'actual_sizes', 'symbolic_sizes']) @tf_export('debugging.assert_shapes', v1=[]) def assert_shapes_v2(shapes, data=None, summarize=None, message=None, name=None): """Assert tensor shapes and dimension size relationships between tensors. This Op checks that a collection of tensors shape relationships satisfies given constraints. Example: >>> n = 10 >>> q = 3 >>> d = 7 >>> x = tf.zeros([n,q]) >>> y = tf.ones([n,d]) >>> param = tf.Variable([1.0, 2.0, 3.0]) >>> scalar = 1.0 >>> tf.debugging.assert_shapes([ ... (x, ('N', 'Q')), ... (y, ('N', 'D')), ... (param, ('Q',)), ... (scalar, ()), ... ]) >>> tf.debugging.assert_shapes([ ... (x, ('N', 'D')), ... (y, ('N', 'D')) ... ]) Traceback (most recent call last): ... ValueError: ... If `x`, `y`, `param` or `scalar` does not have a shape that satisfies all specified constraints, `message`, as well as the first `summarize` entries of the first encountered violating tensor are printed, and `InvalidArgumentError` is raised. Size entries in the specified shapes are checked against other entries by their __hash__, except: - a size entry is interpreted as an explicit size if it can be parsed as an integer primitive. - a size entry is interpreted as *any* size if it is None or '.'. If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates a variable number of outer dimensions of unspecified size, i.e. the constraint applies to the inner-most dimensions only. Scalar tensors and specified shapes of length zero (excluding the 'inner-most' prefix) are both treated as having a single dimension of size one. Args: shapes: dictionary with (`Tensor` to shape) items. A shape must be an iterable. data: The tensors to print out if the condition is False. Defaults to error message and first few entries of the violating tensor. summarize: Print this many entries of the tensor. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_shapes". Raises: ValueError: If static checks determine any shape constraint is violated. """ assert_shapes( shapes, data=data, summarize=summarize, message=message, name=name) @tf_export(v1=['debugging.assert_shapes']) def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): """Assert tensor shapes and dimension size relationships between tensors. This Op checks that a collection of tensors shape relationships satisfies given constraints. Example: ```python tf.assert_shapes([ (x, ('N', 'Q')), (y, ('N', 'D')), (param, ('Q',)), (scalar, ()) ]) ``` Example of adding a dependency to an operation: ```python with tf.control_dependencies([tf.assert_shapes(shapes)]): output = tf.matmul(x, y, transpose_a=True) ``` If `x`, `y`, `param` or `scalar` does not have a shape that satisfies all specified constraints, `message`, as well as the first `summarize` entries of the first encountered violating tensor are printed, and `InvalidArgumentError` is raised. Size entries in the specified shapes are checked against other entries by their __hash__, except: - a size entry is interpreted as an explicit size if it can be parsed as an integer primitive. - a size entry is interpreted as *any* size if it is None or '.'. If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates a variable number of outer dimensions of unspecified size, i.e. the constraint applies to the inner-most dimensions only. Scalar tensors and specified shapes of length zero (excluding the 'inner-most' prefix) are both treated as having a single dimension of size one. Args: shapes: dictionary with (`Tensor` to shape) items. A shape must be an iterable. data: The tensors to print out if the condition is False. Defaults to error message and first few entries of the violating tensor. summarize: Print this many entries of the tensor. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_shapes". Returns: Op raising `InvalidArgumentError` unless all shape constraints are satisfied. If static checks determine all constraints are satisfied, a `no_op` is returned. Raises: ValueError: If static checks determine any shape constraint is violated. """ # If the user manages to assemble a dict containing tensors (possible in # Graph mode only), make sure we still accept that. if isinstance(shapes, dict): shapes = shapes.items() message = message or '' with ops.name_scope(name, 'assert_shapes', [shapes, data]): # Shape specified as None implies no constraint shape_constraints = [ (ops.convert_to_tensor(x), s) for x, s in shapes if s is not None ] executing_eagerly = context.executing_eagerly() def tensor_name(x): if executing_eagerly: return _shape_and_dtype_str(x) return x.name tensor_dim_sizes = [] for tensor, symbolic_shape in shape_constraints: is_iterable = ( hasattr(symbolic_shape, '__iter__') or hasattr(symbolic_shape, '__getitem__') # For Python 2 compat. ) if not is_iterable: raise ValueError( '%s. ' 'Tensor %s. Specified shape must be an iterable. ' 'An iterable has the attribute `__iter__` or `__getitem__`. ' 'Received specified shape: %s' % (message, tensor_name(tensor), symbolic_shape)) # We convert this into a tuple to handle strings, lists and numpy arrays symbolic_shape_tuple = tuple(symbolic_shape) tensors_specified_innermost = False for i, symbol in enumerate(symbolic_shape_tuple): if symbol not in [Ellipsis, '*']: continue if i != 0: raise ValueError( '%s. ' 'Tensor %s specified shape index %d. ' 'Symbol `...` or `*` for a variable number of ' 'unspecified dimensions is only allowed as the first entry' % (message, tensor_name(tensor), i)) tensors_specified_innermost = True # Only include the size of the specified dimensions since the 0th symbol # is either ellipsis or * tensor_dim_sizes.append( _TensorDimSizes( tensor, tensors_specified_innermost, _dimension_sizes(tensor), _symbolic_dimension_sizes( symbolic_shape_tuple[1:] if tensors_specified_innermost else symbolic_shape_tuple))) rank_assertions = [] for sizes in tensor_dim_sizes: rank = len(sizes.symbolic_sizes) rank_zero_or_one = rank in [0, 1] if sizes.unspecified_dim: if rank_zero_or_one: # No assertion of rank needed as `x` only need to have rank at least # 0. See elif rank_zero_or_one case comment. continue assertion = assert_rank_at_least( x=sizes.x, rank=rank, data=data, summarize=summarize, message=message, name=name) elif rank_zero_or_one: # Rank 0 is treated as rank 1 size 1, i.e. there is # no distinction between the two in terms of rank. # See _dimension_sizes. assertion = assert_rank_in( x=sizes.x, ranks=[0, 1], data=data, summarize=summarize, message=message, name=name) else: assertion = assert_rank( x=sizes.x, rank=rank, data=data, summarize=summarize, message=message, name=name) rank_assertions.append(assertion) size_assertions = [] size_specifications = {} for sizes in tensor_dim_sizes: for i, size_symbol in enumerate(sizes.symbolic_sizes): if _is_symbol_for_any_size(size_symbol): # Size specified as any implies no constraint continue if sizes.unspecified_dim: tensor_dim = i - len(sizes.symbolic_sizes) else: tensor_dim = i if size_symbol in size_specifications or _has_known_value(size_symbol): if _has_known_value(size_symbol): specified_size = int(size_symbol) size_check_message = 'Specified explicitly' else: specified_size, specified_by_y, specified_at_dim = \ size_specifications[size_symbol] size_check_message = ( 'Specified by tensor %s dimension %d' % (tensor_name(specified_by_y), specified_at_dim)) actual_size = sizes.actual_sizes[tensor_dim] if _has_known_value(actual_size) and _has_known_value(specified_size): if int(actual_size) != int(specified_size): raise ValueError( '%s. %s. Tensor %s dimension %s must have size %d. ' 'Received size %d, shape %s' % (message, size_check_message, tensor_name(sizes.x), tensor_dim, specified_size, actual_size, sizes.x.get_shape())) # No dynamic assertion needed continue condition = math_ops.equal( ops.convert_to_tensor(actual_size), ops.convert_to_tensor(specified_size)) data_ = data if data is None: data_ = [ message, size_check_message, 'Tensor %s dimension' % tensor_name(sizes.x), tensor_dim, 'must have size', specified_size, 'Received shape: ', array_ops.shape(sizes.x) ] size_assertions.append( control_flow_ops.Assert(condition, data_, summarize=summarize)) else: size = sizes.actual_sizes[tensor_dim] size_specifications[size_symbol] = (size, sizes.x, tensor_dim) with ops.control_dependencies(rank_assertions): shapes_assertion = control_flow_ops.group(size_assertions) return shapes_assertion # pylint: disable=line-too-long def _get_diff_for_monotonic_comparison(x): """Gets the difference x[1:] - x[:-1].""" x = array_ops.reshape(x, [-1]) if not is_numeric_tensor(x): raise TypeError('Expected x to be numeric, instead found: %s' % x) # If x has less than 2 elements, there is nothing to compare. So return []. is_shorter_than_two = math_ops.less(array_ops.size(x), 2) short_result = lambda: ops.convert_to_tensor([], dtype=x.dtype) # With 2 or more elements, return x[1:] - x[:-1] s_len = array_ops.shape(x) - 1 diff = lambda: array_ops.strided_slice(x, [1], [1] + s_len)- array_ops.strided_slice(x, [0], s_len) return control_flow_ops.cond(is_shorter_than_two, short_result, diff) @tf_export( 'debugging.is_numeric_tensor', v1=['debugging.is_numeric_tensor', 'is_numeric_tensor']) @deprecation.deprecated_endpoints('is_numeric_tensor') def is_numeric_tensor(tensor): """Returns `True` if the elements of `tensor` are numbers. Specifically, returns `True` if the dtype of `tensor` is one of the following: * `tf.float32` * `tf.float64` * `tf.int8` * `tf.int16` * `tf.int32` * `tf.int64` * `tf.uint8` * `tf.qint8` * `tf.qint32` * `tf.quint8` * `tf.complex64` Returns `False` if `tensor` is of a non-numeric type or if `tensor` is not a `tf.Tensor` object. """ return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES @tf_export( 'math.is_non_decreasing', v1=[ 'math.is_non_decreasing', 'debugging.is_non_decreasing', 'is_non_decreasing' ]) @deprecation.deprecated_endpoints('debugging.is_non_decreasing', 'is_non_decreasing') def is_non_decreasing(x, name=None): """Returns `True` if `x` is non-decreasing. Elements of `x` are compared in row-major order. The tensor `[x[0],...]` is non-decreasing if for every adjacent pair we have `x[i] <= x[i+1]`. If `x` has less than two elements, it is trivially non-decreasing. See also: `is_strictly_increasing` >>> x1 = tf.constant([1.0, 1.0, 3.0]) >>> tf.math.is_non_decreasing(x1) <tf.Tensor: shape=(), dtype=bool, numpy=True> >>> x2 = tf.constant([3.0, 1.0, 2.0]) >>> tf.math.is_non_decreasing(x2) <tf.Tensor: shape=(), dtype=bool, numpy=False> Args: x: Numeric `Tensor`. name: A name for this operation (optional). Defaults to "is_non_decreasing" Returns: Boolean `Tensor`, equal to `True` iff `x` is non-decreasing. Raises: TypeError: if `x` is not a numeric tensor. """ with ops.name_scope(name, 'is_non_decreasing', [x]): diff = _get_diff_for_monotonic_comparison(x) # When len(x) = 1, diff = [], less_equal = [], and reduce_all([]) = True. zero = ops.convert_to_tensor(0, dtype=diff.dtype) return math_ops.reduce_all(math_ops.less_equal(zero, diff)) @tf_export( 'math.is_strictly_increasing', v1=[ 'math.is_strictly_increasing', 'debugging.is_strictly_increasing', 'is_strictly_increasing' ]) @deprecation.deprecated_endpoints('debugging.is_strictly_increasing', 'is_strictly_increasing') def is_strictly_increasing(x, name=None): """Returns `True` if `x` is strictly increasing. Elements of `x` are compared in row-major order. The tensor `[x[0],...]` is strictly increasing if for every adjacent pair we have `x[i] < x[i+1]`. If `x` has less than two elements, it is trivially strictly increasing. See also: `is_non_decreasing` >>> x1 = tf.constant([1.0, 2.0, 3.0]) >>> tf.math.is_strictly_increasing(x1) <tf.Tensor: shape=(), dtype=bool, numpy=True> >>> x2 = tf.constant([3.0, 1.0, 2.0]) >>> tf.math.is_strictly_increasing(x2) <tf.Tensor: shape=(), dtype=bool, numpy=False> Args: x: Numeric `Tensor`. name: A name for this operation (optional). Defaults to "is_strictly_increasing" Returns: Boolean `Tensor`, equal to `True` iff `x` is strictly increasing. Raises: TypeError: if `x` is not a numeric tensor. """ with ops.name_scope(name, 'is_strictly_increasing', [x]): diff = _get_diff_for_monotonic_comparison(x) # When len(x) = 1, diff = [], less = [], and reduce_all([]) = True. zero = ops.convert_to_tensor(0, dtype=diff.dtype) return math_ops.reduce_all(math_ops.less(zero, diff)) def _assert_same_base_type(items, expected_type=None): r"""Asserts all items are of the same base type. Args: items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`, `Operation`, or `IndexedSlices`). Can include `None` elements, which will be ignored. expected_type: Expected type. If not specified, assert all items are of the same base type. Returns: Validated type, or none if neither expected_type nor items provided. Raises: ValueError: If any types do not match. """ original_expected_type = expected_type mismatch = False for item in items: if item is not None: item_type = item.dtype.base_dtype if not expected_type: expected_type = item_type elif expected_type != item_type: mismatch = True break if mismatch: # Loop back through and build up an informative error message (this is very # slow, so we don't do it unless we found an error above). expected_type = original_expected_type original_item_str = None for item in items: if item is not None: item_type = item.dtype.base_dtype if not expected_type: expected_type = item_type original_item_str = item.name if hasattr(item, 'name') else str(item) elif expected_type != item_type: raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % ( item.name if hasattr(item, 'name') else str(item), item_type, expected_type, (' as %s' % original_item_str) if original_item_str else '')) return expected_type # Should be unreachable else: return expected_type @tf_export( 'debugging.assert_same_float_dtype', v1=['debugging.assert_same_float_dtype', 'assert_same_float_dtype']) @deprecation.deprecated_endpoints('assert_same_float_dtype') def assert_same_float_dtype(tensors=None, dtype=None): """Validate and return float type based on `tensors` and `dtype`. For ops such as matrix multiplication, inputs and weights must be of the same float type. This function validates that all `tensors` are the same type, validates that type is `dtype` (if supplied), and returns the type. Type must be a floating point type. If neither `tensors` nor `dtype` is supplied, the function will return `dtypes.float32`. Args: tensors: Tensors of input values. Can include `None` elements, which will be ignored. dtype: Expected type. Returns: Validated type. Raises: ValueError: if neither `tensors` nor `dtype` is supplied, or result is not float, or the common type of the inputs is not a floating point type. """ if tensors: dtype = _assert_same_base_type(tensors, dtype) if not dtype: dtype = dtypes.float32 elif not dtype.is_floating: raise ValueError('Expected floating point type, got %s.' % dtype) return dtype @tf_export('debugging.assert_scalar', v1=[]) def assert_scalar_v2(tensor, message=None, name=None): """Asserts that the given `tensor` is a scalar. This function raises `ValueError` unless it can be certain that the given `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is unknown. This is always checked statically, so this method returns nothing. Args: tensor: A `Tensor`. message: A string to prefix to the default message. name: A name for this operation. Defaults to "assert_scalar" Raises: ValueError: If the tensor is not scalar (rank 0), or if its shape is unknown. """ assert_scalar(tensor=tensor, message=message, name=name) @tf_export(v1=['debugging.assert_scalar', 'assert_scalar']) @deprecation.deprecated_endpoints('assert_scalar') def assert_scalar(tensor, name=None, message=None): """Asserts that the given `tensor` is a scalar (i.e. zero-dimensional). This function raises `ValueError` unless it can be certain that the given `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is unknown. Args: tensor: A `Tensor`. name: A name for this operation. Defaults to "assert_scalar" message: A string to prefix to the default message. Returns: The input tensor (potentially converted to a `Tensor`). Raises: ValueError: If the tensor is not scalar (rank 0), or if its shape is unknown. """ with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope: tensor = ops.convert_to_tensor(tensor, name=name_scope) shape = tensor.get_shape() if shape.ndims != 0: if context.executing_eagerly(): raise ValueError('%sExpected scalar shape, saw shape: %s.' % (message or '', shape,)) else: raise ValueError('%sExpected scalar shape for %s, saw shape: %s.' % (message or '', tensor.name, shape)) return tensor
[文档]@tf_export('ensure_shape') def ensure_shape(x, shape, name=None): """Updates the shape of a tensor and checks at runtime that the shape holds. For example: ```python x = tf.compat.v1.placeholder(tf.int32) print(x.shape) ==> TensorShape(None) y = x * 2 print(y.shape) ==> TensorShape(None) y = tf.ensure_shape(y, (None, 3, 3)) print(y.shape) ==> TensorShape([Dimension(None), Dimension(3), Dimension(3)]) with tf.compat.v1.Session() as sess: # Raises tf.errors.InvalidArgumentError, because the shape (3,) is not # compatible with the shape (None, 3, 3) sess.run(y, feed_dict={x: [1, 2, 3]}) ``` NOTE: This differs from `Tensor.set_shape` in that it sets the static shape of the resulting tensor and enforces it at runtime, raising an error if the tensor's runtime shape is incompatible with the specified shape. `Tensor.set_shape` sets the static shape of the tensor without enforcing it at runtime, which may result in inconsistencies between the statically-known shape of tensors and the runtime value of tensors. Args: x: A `Tensor`. shape: A `TensorShape` representing the shape of this tensor, a `TensorShapeProto`, a list, a tuple, or None. name: A name for this operation (optional). Defaults to "EnsureShape". Returns: A `Tensor`. Has the same type and contents as `x`. At runtime, raises a `tf.errors.InvalidArgumentError` if `shape` is incompatible with the shape of `x`. """ if not isinstance(shape, tensor_shape.TensorShape): shape = tensor_shape.TensorShape(shape) return array_ops.ensure_shape(x, shape, name=name)
@ops.RegisterGradient('EnsureShape') def _ensure_shape_grad(op, grad): del op # Unused. return grad