使用numpy.where函数出现的问题与思考
前段时间在做一些数据处理分析的时候用到了np.where函数,错把这个函数当成了逻辑运算和条件判断语句,导致一开始出现错误的时候还莫名其妙,明明没有用错,输入也正确,为什么还是报错?其实这也是我们的惯性思维导致的,在此记录下,作为一种提醒。
我们还是先看看np.where函数的功能,下面是官方文档的解释:
numpy.where(condition, [x, y, ]/)
Return elements chosen from x or y depending on condition.
Note:
When only condition is provided, this function is a shorthand fornp.asarray(condition).nonzero(). Using nonzero directly should be preferred, as it behaves correctly for subclasses. The rest of this documentation covers only the case where all three arguments are provided.
Parameters:
condition: array_like, bool
Where True, yield x, otherwise yield y.
x, y: array_like
Values from which to choose. x, y and condition need to be broadcastable to some shape.
Returns:
out: ndarray
An array with elements from x where condition is True, and elements from y elsewhere.
Notes:
If all the arrays are 1-D,whereis equivalent to:
[xv if c else yv for c, xv, yv in zip(condition, x, y)]
这个函数的功能很简单,就是根据条件参数condition的值从x或y中选择返回的元素,即在condition为True的位置,返回x对应位置的值,其他位置则返回y对应位置的值,这里x, y 和condition的数据类型都是数组或类数组的,且它们的shape必须相同或者可以广播到同样的shape。
这里也来看看官网的一些简单例子,能更好理解这个函数的功能和用法:
# 一维数组
In [2]: a = np.arange(10)
In [3]: a
Out[3]: array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
In [4]: np.where(a < 5, a, 10*a)
Out[4]: array([ 0,  1,  2,  3,  4, 50, 60, 70, 80, 90])
# 多维数组
In [5]: np.where([[True, False], [True, True]],
   ...:          [[1, 2], [3, 4]],
   ...:          [[9, 8], [7, 6]])
Out[5]:
array([[1, 8],
       [3, 4]])
# 传递的参数需要广播
In [6]: x, y = np.ogrid[:3, :4]
In [7]: x
Out[7]:
array([[0],
       [1],
       [2]])
In [8]: y
Out[8]: array([[0, 1, 2, 3]])
In [9]: x < y
Out[9]:
array([[False,  True,  True,  True],
       [False, False,  True,  True],
       [False, False, False,  True]])
In [10]: np.where(x < y, x, 10 + y)  # both x and 10+y are broadcast
Out[10]:
array([[10,  0,  0,  0],
       [10, 11,  1,  1],
       [10, 11, 12,  2]])
In [11]: a = np.array([[0, 1, 2],
    ...:               [0, 2, 4],
    ...:               [0, 3, 6]])
In [12]: np.where(a < 4, a, -1)  # -1 is broadcast
Out[12]:
array([[ 0,  1,  2],
       [ 0,  2, -1],
       [ 0,  3, -1]])
回到开头说的情况,当时是把np.where函数和... if ... else ...条件判断语句搞混淆了,以为前者和后者的执行逻辑是一样的,所以在执行下面的语句时就报错了:

出现了IndexError: list index out of range的错误,当时就在想明明条件len(lst)==0的结果为True,应该返回空串,不会执行后面的lst[0]语句了啊,怎么还会报错呢?而使用条件判断语句就没有问题,条件为True时返回空串,条件不满足时才执行语句lst[0],报错:

接着我干脆直接把第一个参数设置为True和False看看执行结果,结果都报错:

说明对于np.where,不管条件是否满足,后面的两个语句都会执行。这时我才意识到np.where是个函数,而不是类似... if ... else ...的条件判断语句。对于if ... else ... 语句,它会根据条件是否满足选择执行某个分支的代码,其他分支的代码不会执行。而对于函数,在你调用函数的时候,传递给函数每个参数的表达式都会先被执行,再把得到的值作为实参传递给函数。所以这也能够解释通为什么执行代码np.where(len(lst)==0, '', lst[0])时都会报错了,报错信息来自于语句lst[0],因为在调用函数时,不管其他参数值是什么,它都会被执行。
总结一句,编写程序代码时,如果只看函数的功能,有时可能会因为惯性思维导致一些自己觉得莫名其妙的错误,就像在这里,因为np.where函数的功能和if else语句的功能类似,所以也自以为其执行的逻辑也是一样的,而忘记了函数本身的执行逻辑,结果就出现错误了。所以,惯性思维是把双刃剑,有时能助你快速解决问题,有时也会给你带来一些麻烦,正确认识这一点,才能有效避免这一误区。