NumPy 中的数据类型提升#
当混合两种不同的数据类型时,NumPy 必须确定操作结果的适当 dtype.此步骤称为提升或查找公共 dtype.
在典型情况下,用户无需担心提升的细节,因为提升步骤通常确保结果将匹配或超过输入的精度.
例如,当输入具有相同的 dtype 时,结果的 dtype 与输入的 dtype 匹配:
>>> np.int8(1) + np.int8(1)
np.int8(2)
混合两种不同的 dtype 通常会产生一个结果,其 dtype 具有较高精度的输入:
>>> np.int8(4) + np.int64(8) # 64 > 8
np.int64(12)
>>> np.float32(3) + np.float16(3) # 32 > 16
np.float32(6.0)
在典型情况下,这不会导致意外.但是,如果您使用非默认的 dtype(如无符号整数和低精度浮点数),或者如果混合 NumPy 整数,NumPy 浮点数和 Python 标量,则 NumPy 提升规则的某些细节可能相关.请注意,这些详细规则并不总是与其他语言的规则匹配 [1] .
数值 dtype 分为四种"类型",具有自然的层次结构.
无符号整数 (
uint)有符号整数 (
int)浮点数 (
float)复数 (
complex)
除了类型之外,NumPy 数值 dtype 还具有关联的精度,以位为单位指定.类型和精度共同指定 dtype.例如, uint8 是使用 8 位存储的无符号整数.
操作的结果将始终是任何输入的相同或更高类型.此外,结果将始终具有大于或等于输入的精度.已经,这可能导致一些可能出乎意料的例子:
当混合浮点数和整数时,整数的精度可能会强制结果为更高精度的浮点数.例如,涉及
int64和float16的操作的结果是float64.当混合具有相同精度的无符号和有符号整数时,结果将具有比任何一个输入更高的精度.此外,如果其中一个已经具有 64 位精度,则没有更高的精度整数可用,例如,涉及
int64和uint64的操作给出float64.
请参阅下面的"数值提升"部分和图像,以获取有关两者的详细信息.
Python 标量的详细行为#
自从 NumPy 2.0 [2] 以来,我们提升规则中的一个重要点是,虽然涉及两个 NumPy dtype 的操作永远不会丢失精度,但涉及 NumPy dtype 和 Python 标量( int , float 或 complex )的操作可能会丢失精度.例如,Python 整数和 NumPy 整数之间的操作结果应该是 NumPy 整数,这可能是很直观的.但是,Python 整数具有任意精度,而所有 NumPy dtype 具有固定精度,因此 Python 整数的任意精度无法保留.
更一般地,NumPy 在确定结果 dtype 时会考虑 Python 标量的"种类",但忽略其精度.这通常很方便.例如,当使用低精度 dtype 的数组时,通常希望与 Python 标量的简单操作保留 dtype.
>>> arr_float32 = np.array([1, 2.5, 2.1], dtype="float32")
>>> arr_float32 + 10.0 # undesirable to promote to float64
array([11. , 12.5, 12.1], dtype=float32)
>>> arr_int16 = np.array([3, 5, 7], dtype="int16")
>>> arr_int16 + 10 # undesirable to promote to int64
array([13, 15, 17], dtype=int16)
在这两种情况下,结果精度都由 NumPy dtype 决定.因此, arr_float32 + 3.0 的行为与 arr_float32 + np.float32(3.0) 相同, arr_int16 + 10 的行为与 arr_int16 + np.int16(10.) 相同.
作为另一个例子,当将 NumPy 整数与 Python float 或 complex 混合时,结果始终具有类型 float64 或 complex128 :
>> np.int16(1) + 1.0 np.float64(2.0)
但是,当使用低精度 dtype 时,这些规则也可能导致令人惊讶的行为.
首先,由于 Python 值在执行操作之前会被转换为 NumPy 值,因此当结果看起来很明显时,操作可能会失败并出现错误.例如, np.int8(1) + 1000 无法继续,因为 1000 超出了 int8 的最大值.当 Python 标量无法强制转换为 NumPy dtype 时,会引发错误:
>>> np.int8(1) + 1000
Traceback (most recent call last):
...
OverflowError: Python integer 1000 out of bounds for int8
>>> np.int64(1) * 10**100
Traceback (most recent call last):
...
OverflowError: Python int too large to convert to C long
>>> np.float32(1) + 1e300
np.float32(inf)
... RuntimeWarning: overflow encountered in cast
其次,由于 Python 浮点数或整数的精度总是被忽略,因此低精度的 NumPy 标量将继续使用其较低精度,除非显式转换为更高精度的 NumPy dtype 或 Python 标量(例如,通过 int() , float() 或 scalar.item() ).这种较低的精度可能不利于某些计算或导致不正确的结果,尤其是在整数溢出的情况下:
>>> np.int8(100) + 100 # the result exceeds the capacity of int8
np.int8(-56)
... RuntimeWarning: overflow encountered in scalar add
请注意,当标量发生溢出时,NumPy 会发出警告,但数组则不会;例如, np.array(100, dtype="uint8") + 100 不会发出警告.
数值提升#
下图显示了数值提升规则,其中种类在纵轴上,精度在横轴上.
具有较高种类的输入 dtype 决定了结果 dtype 的种类.结果 dtype 的精度尽可能低,且不会出现在图中任何输入 dtype 的左侧.
请注意以下具体规则和观察结果:
当 Python
float或complex与 NumPy 整数交互时,结果将为float64或complex128(黄色边框).NumPy 布尔值也会被转换为默认整数 [3] . 当额外涉及 NumPy 浮点值时,这无关紧要.精度被绘制成
float16 < int16 < uint16,因为大的uint16不适合int16,并且大的int16在存储在float16中时会失去精度.然而,这种模式被打破了,因为 NumPy 总是认为float64和complex128对于任何整数值都是可接受的提升结果.一个特殊情况是,NumPy 将有符号整数和无符号整数的许多组合提升为
float64.这里使用了更高的种类,因为没有有符号整数 dtype 具有足够的精度来保存uint64.
通用提升规则的例外情况#
在NumPy中,提升是指特定函数对结果的处理方式,在某些情况下,这意味着NumPy可能会偏离 np.result_type 给出的结果.
sum 和 prod 的行为#
当对整数值(或布尔值)求和时, np.sum 和 np.prod 将始终返回默认整数类型.这通常是一个 int64 .这样做的原因是整数求和非常容易溢出并给出令人困惑的结果.此规则也适用于底层 np.add.reduce 和 np.multiply.reduce .
使用 NumPy 或 Python 整数标量时的显著行为#
NumPy 提升指的是结果 dtype 和操作精度,但操作有时会决定该结果.除法总是返回浮点值,比较总是返回布尔值.
这导致了可能出现的规则"例外":
NumPy 与 Python 整数或混合精度整数的比较总是返回正确的结果.输入永远不会以失去精度的方式转换.
无法提升的类型之间的相等比较将被认为是全部
False(相等)或全部True(不相等).像
np.sin这样总是返回浮点值的单值数学函数,通过将其转换为float64来接受任何 Python 整数输入.除法总是返回浮点值,因此也允许任何 NumPy 整数与任何 Python 整数值之间的除法,方法是将两者都转换为
float64.
原则上,其中一些例外情况可能对其他函数有意义.如果您认为情况确实如此,请提出问题.
Python 内置类型类别的显著行为#
当组合 Python 的内置标量类型(即, float , int 或 complex ,而不是标量值)时,提升规则可能会显得令人惊讶:
>>> np.result_type(7, np.array([1], np.float32))
dtype('float32') # The scalar value '7' does not impact type promotion
>>> np.result_type(type(7), np.array([1], np.float32))
dtype('float64') # The *type* of the scalar value '7' does impact promotion
# Similar situations happen with Python's float and complex types
这种行为的原因是 NumPy 将 int 转换为其默认整数类型,并使用该类型进行提升:
>>> np.result_type(int)
dtype('int64')
另请参见 内置 Python 类型 了解更多详细信息.
非数值数据类型的提升#
NumPy 将提升扩展到非数值类型,尽管在许多情况下,提升没有明确定义,并会被简单地拒绝.
以下规则适用:
NumPy 字节字符串 (
np.bytes_) 可以被提升为 Unicode 字符串 (np.str_).但是,对于非 ASCII 字符,将字节转换为 Unicode 将会失败.出于某些目的,NumPy 会将几乎任何其他数据类型提升为字符串.这适用于数组创建或连接.
当没有可行的提升时,像
np.array()这样的数组构造函数将使用objectdtype.当结构化 dtype 的字段名称和顺序匹配时,可以进行提升. 在这种情况下,所有字段都会被单独提升.
NumPy
timedelta在某些情况下可以与整数进行提升.
备注
其中一些规则有些令人惊讶,并且正在考虑在未来进行更改. 但是,任何向后不兼容的更改都必须权衡打破现有代码的风险. 如果您对提升应该如何工作有任何想法,请提出问题.
提升的 dtype 实例的详细信息#
上面的讨论主要涉及混合不同 DType 类时的行为. 附加到数组的 dtype 实例可以携带其他信息,例如字节顺序,元数据,字符串长度或精确的结构化 dtype 布局.
虽然结构化 dtype 的字符串长度或字段名称很重要,但 NumPy 认为字节顺序,元数据和结构化 dtype 的确切布局是存储细节.
在提升期间,NumPy 不会考虑这些存储细节:
字节顺序被转换为本机字节顺序.
附加到 dtype 的元数据可能会或可能不会被保留.
生成的结构化 dtype 将被打包(但如果输入已对齐,则也会对齐).
对于大多数程序来说,这种行为是最好的,因为存储细节与最终结果无关,并且使用不正确的字节顺序可能会大大降低计算速度.