Python Numpy.where() 函数

Sohaib Atiq 2023年1月30日 2020年6月17日
  1. numpy.where() 语法
  2. 示例代码:numpy.where(), 没有 [x, y] 输入
  3. 示例代码:numpy.where() 与 1-D 数组的关系
  4. 示例代码:numpy.where() 与二维数组的关系
  5. 示例代码:numpy.where() 有多个条件
Python Numpy.where() 函数

numpy.where() 函数在没有给出 xy 的情况下,生成符合输入条件的数组索引;或者根据给定的条件从 xy 中生成数组元素。

numpy.where() 语法

numpy.where(condition,[x,y])

参数

condition array_like, TrueFalse
如果条件是 True,则输出包含 x 的元素,否则,输出包含 y 的元素
x,y 返回值产生的数组
可以同时传递 (x, y) 或不传递

返回值

它返回一个数组。如果条件为 True,结果包含 x 元素,如果条件为 False,结果包含 y 元素。

如果没有给定 x, y,则返回数组的索引。

示例代码:numpy.where(), 没有 [x, y] 输入

import numpy as np

m = np.array([1,2,3,4,5])

n = np.where(m > 3)

print(n)

输出:

(array([3, 4], dtype=int64),)

它返回 m 的索引,其中元素大于 3 - a > 3

如果你需要的是元素而不是索引。

示例代码:numpy.where() 与 1-D 数组的关系

import numpy as np

m = np.where([True, False, True], [1,2,3], [4, 5, 6])

print(m)

输出:

[1 5 3]

当条件是一个 1-D 数组时,Numpy.where() 函数对条件数组进行迭代,如果条件元素是 True,则从 x 中选择元素,如果条件元素是 False,则从 y 中选择元素。

Numpy 其中 1-D 数组

示例代码:numpy.where() 与二维数组的关系

import numpy as np

x = np.array([[10, 20, 30],
               [3, 50, 5]])
y = np.array([[70, 80, 90],
             [100, 110, 120]])
condition = np.where(x>20,x,y)

print("Input array :")
print(x)
print(y)
print("Output array with condition applied:")
print(condition)

输出:

Input array :
[[10 20 30]
[ 3 50  5]]
[[ 70  80  90]
[100 110 120]]
Output array with condition applied:
[[ 70  80  30]
[100  50 120]]

它将 x>20 的条件应用于 x 的所有元素,如果是 True,则输出 x 的元素,如果是 False,则输出 y 的元素。

我们做一个简化的例子来说明它的工作原理。

import numpy as np

m = np.where([[True, False, True],
              [False, True, False]],
             [[1,2,3],
              [4, 5, 6]],
             [[7,8,9],
              [10, 11, 12]])

print(m)

输出:

[[ 1  8  3]
 [10  5 12]]

Numpy 其中 1-D 数组

示例代码:numpy.where() 有多个条件

我们也可以在 numpy.where() 函数中应用两个或多个条件。

import numpy as np

m = np.array([1,2,3,4,5])

n = np.where((m > 1) & (m < 5), m, 0)

print(n)

输出:

[0 2 3 4 0]

它应用 m > 1m < 5 这两个多重条件,如果元素满足这两个条件,则返回元素。

多重条件之间的逻辑不限于 AND(&),也接受 OR(|)。

import numpy as np

m = np.array([1,2,3,4,5])

n = np.where((m < 2) | (m > 4), m, 0)

print(n)

输出:

[1 0 0 0 5]