NumPy 中的数据类型提升#

当混合两种不同的数据类型时,NumPy 必须确定操作结果的适当 dtype.此步骤称为提升或查找公共 dtype.

在典型情况下,用户无需担心提升的细节,因为提升步骤通常确保结果将匹配或超过输入的精度.

例如,当输入具有相同的 dtype 时,结果的 dtype 与输入的 dtype 匹配:

>>> np.int8(1) + np.int8(1)
np.int8(2)

混合两种不同的 dtype 通常会产生具有更高精度输入的 dtype 的结果:

>>> np.int8(4) + np.int64(8)  # 64 > 8
np.int64(12)
>>> np.float32(3) + np.float16(3)  # 32 > 16
np.float32(6.0)

在典型情况下,这不会导致意外.但是,如果你使用非默认 dtype(如无符号整数和低精度浮点数),或者你混合 NumPy 整数,NumPy 浮点数和 Python 标量,则 NumPy 提升规则的某些细节可能相关.请注意,这些详细规则并不总是与其他语言的规则匹配 [1] .

数值 dtype 分为四种“类型”,具有自然的层次结构.

  1. 无符号整数 ( uint )

  2. 有符号整数 ( int )

  3. 浮点数 ( float )

  4. 复数 ( complex )

除了类型之外,NumPy 数值 dtype 还有一个相关联的精度,以位为单位指定.类型和精度一起指定 dtype.例如, uint8 是使用 8 位存储的无符号整数.

操作的结果将始终是任何输入的相等或更高类型.此外,结果将始终具有大于或等于输入的精度.已经,这可能会导致一些可能出乎意料的示例:

  1. 当混合浮点数和整数时,整数的精度可能会强制结果为更高精度的浮点数.例如,涉及 int64float16 的操作的结果是 float64 .

  2. 当混合具有相同精度的无符号和有符号整数时,结果将具有比任何输入更高的精度.此外,如果其中一个已经具有 64 位精度,则没有更高精度的整数可用,例如,涉及 int64uint64 的操作会给出 float64 .

有关详细信息,请参阅下面的“数值提升”部分和图片.

Python 标量的详细行为#

自从 NumPy 2.0 [2] 以来,我们的提升规则中的一个重要点是,虽然涉及两个 NumPy dtype 的操作永远不会损失精度,但涉及 NumPy dtype 和 Python 标量( int , floatcomplex )的操作可能会损失精度.例如,Python 整数和 NumPy 整数之间的运算结果应该是 NumPy 整数,这可能是很直观的.但是,Python 整数具有任意精度,而所有 NumPy dtype 具有固定精度,因此 Python 整数的任意精度无法保留.

更广泛地说,NumPy会考虑Python标量的“种类”,但在确定结果数据类型时会忽略它们的精度.这通常很方便.例如,当处理低精度数据类型的数组时,通常希望与Python标量的简单操作能够保留数据类型.

>>> arr_float32 = np.array([1, 2.5, 2.1], dtype="float32")
>>> arr_float32 + 10.0  # undesirable to promote to float64
array([11. , 12.5, 12.1], dtype=float32)
>>> arr_int16 = np.array([3, 5, 7], dtype="int16")
>>> arr_int16 + 10  # undesirable to promote to int64
array([13, 15, 17], dtype=int16)

在这两种情况下,结果精度都由NumPy数据类型决定.因此, arr_float32 + 3.0 的行为与 arr_float32 + np.float32(3.0) 相同, arr_int16 + 10 的行为与 arr_int16 + np.int16(10.) 相同.

另一个例子是,当将NumPy整数与Python floatcomplex 混合时,结果始终具有 float64complex128 类型:

>> np.int16(1) + 1.0 np.float64(2.0)

但是,当使用低精度数据类型时,这些规则也可能导致令人惊讶的行为.

首先,由于Python值在执行操作之前会转换为NumPy值,因此当结果看起来很明显时,操作可能会因错误而失败.例如, np.int8(1) + 1000 无法继续,因为 1000 超过了 int8 的最大值.当Python标量无法强制转换为NumPy数据类型时,会引发错误:

>>> np.int8(1) + 1000
Traceback (most recent call last):
  ...
OverflowError: Python integer 1000 out of bounds for int8
>>> np.int64(1) * 10**100
Traceback (most recent call last):
...
OverflowError: Python int too large to convert to C long
>>> np.float32(1) + 1e300
np.float32(inf)
... RuntimeWarning: overflow encountered in cast

其次,由于Python浮点数或整数的精度始终被忽略,因此低精度NumPy标量将继续使用其较低精度,除非显式转换为更高精度的NumPy数据类型或Python标量(例如,通过 int() , float()scalar.item() ).这种较低的精度可能不利于某些计算,或者导致不正确的结果,尤其是在整数溢出的情况下:

>>> np.int8(100) + 100  # the result exceeds the capacity of int8
np.int8(-56)
... RuntimeWarning: overflow encountered in scalar add

请注意,NumPy会在标量发生溢出时发出警告,但不会在数组发生溢出时发出警告; 例如, np.array(100, dtype="uint8") + 100 不会发出警告.

数值提升#

下图显示了数值提升规则,纵轴表示种类,横轴表示精度.

../_images/nep-0050-promotion-no-fonts.svg

具有较高种类的数据类型决定了结果数据类型的种类.结果数据类型的精度尽可能低,且在图中不会出现在任何一个输入数据类型的左侧.

请注意以下特定规则和观察结果:

  1. 当Python floatcomplex 与NumPy整数交互时,结果将是 float64complex128 (黄色边框). NumPy布尔值也将被转换为默认整数 [3] . 当另外涉及NumPy浮点数值时,这不相关.

  2. 精度的绘制方式使得 float16 < int16 < uint16 ,因为大的 uint16 不适合 int16 ,而大的 int16 在存储在 float16 中时会损失精度.然而,这种模式被打破了,因为NumPy始终认为 float64complex128 是任何整数值的可接受的提升结果.

  3. 一个特殊的例子是,NumPy将有符号和无符号整数的许多组合提升为 float64 .这里使用更高的种类,因为没有有符号整数数据类型具有足够的精度来容纳 uint64 .

通用提升规则的例外情况#

在NumPy中,提升指的是特定函数对结果的处理方式,在某些情况下,这意味着NumPy可能会偏离 np.result_type 给出的结果.

sumprod 的行为#

np.sumnp.prod 在对整数值(或布尔值)求和时,始终返回デフォルトの整数类型.这通常是 int64 . 这样做的原因是,整数求和否则非常容易溢出,并给出令人困惑的结果.此规则也适用于底层的 np.add.reducenp.multiply.reduce .

NumPy或Python整数标量的值得注意的行为#

NumPy 的类型提升指的是结果数据类型和运算精度,但运算有时会决定这个结果.除法总是返回浮点数值,比较运算总是返回布尔值.

这导致了一些可能看起来像是规则“例外”的情况:

  • NumPy 与 Python 整数或混合精度整数进行比较总是返回正确的结果.输入永远不会以丢失精度的方式进行转换.

  • 无法提升的类型之间的相等性比较将被视为全部 False (相等) 或全部 True (不相等).

  • np.sin 这样总是返回浮点值的单目数学函数,接受任何 Python 整数输入,并通过将其转换为 float64 .

  • 除法总是返回浮点数值,因此也允许任何 NumPy 整数与任何 Python 整数之间进行除法,方法是将两者都转换为 float64 .

原则上,对于其他函数而言,其中一些例外可能是有意义的.如果您认为情况确实如此,请提出问题.

Python 内置类型类的显著行为#

当组合 Python 的内置标量类型(即, float , intcomplex ,而不是标量值)时,类型提升规则可能会显得令人惊讶:

>>> np.result_type(7, np.array([1], np.float32))
dtype('float32')  # The scalar value '7' does not impact type promotion
>>> np.result_type(type(7), np.array([1], np.float32))
dtype('float64')  # The *type* of the scalar value '7' does impact promotion
# Similar situations happen with Python's float and complex types

这种行为的原因是 NumPy 将 int 转换为其默认整数类型,并使用该类型进行类型提升:

>>> np.result_type(int)
dtype('int64')

另请参阅 内置 Python 类型 以获取更多详细信息.

非数值数据类型的类型提升#

NumPy 将类型提升扩展到非数值类型,尽管在许多情况下,类型提升没有明确定义,并被直接拒绝.

以下规则适用:

  • NumPy 字节字符串 ( np.bytes_ ) 可以提升为 Unicode 字符串 ( np.str_ ).但是,对于非 ASCII 字符,将字节转换为 Unicode 将失败.

  • 出于某些目的,NumPy 几乎会将任何其他数据类型提升为字符串.这适用于数组创建或连接.

  • 当没有可行的类型提升时,像 np.array() 这样的数组构造器将使用 object 数据类型.

  • 当结构化数据类型匹配其字段名称和顺序时,可以进行类型提升.在这种情况下,所有字段都会进行单独类型提升.

  • NumPy timedelta 在某些情况下可以与整数进行类型提升.

备注

其中一些规则有些令人惊讶,目前正在考虑在将来进行更改.但是,任何向后不兼容的更改都必须权衡打破现有代码的风险. 如果您对类型提升应该如何工作有特别的想法,请提出问题.

被提升的 dtype 实例的详细信息#

以上讨论主要涉及混合不同 DType 类时的行为.附加到数组的 dtype 实例可以携带附加信息,如字节顺序,元数据,字符串长度或精确的结构化数据类型布局.

虽然字符串长度或结构化数据类型的字段名称很重要,但 NumPy 认为字节顺序,元数据和结构化数据类型的精确布局是存储细节.

在类型提升期间,NumPy 不会考虑这些存储细节:

  • 字节顺序被转换为本机字节顺序.

  • 附加到数据类型的元数据可能会或可能不会被保留.

  • 生成的结构化数据类型将被打包(但如果输入已对齐,则也会对齐).

对于大多数程序来说,这种行为是最佳行为,因为存储细节与最终结果无关,并且使用不正确的字节顺序可能会大大降低评估速度.