跳到主要内容

NumPy 入门

NumPy 是 Python 数据处理的基础库。它最核心的对象是 ndarray,也就是多维数组。

我把这部分放进主线,不是因为一开始就要做复杂科学计算,而是因为后面接 Pandas、机器学习、数值计算时,很多概念都会回到数组、形状和向量化。

创建数组

import numpy as np

a = np.array([1, 2, 3])
b = np.array([[1, 2], [3, 4]])

常见创建方式:

np.zeros((2, 3))
np.ones((2, 3))
np.arange(0, 10, 2)
np.linspace(0, 1, 5)

先看懂三个属性

arr = np.array([[1, 2, 3], [4, 5, 6]])

print(arr.shape) # (2, 3)
print(arr.ndim) # 2
print(arr.dtype) # int64 或 int32,取决于平台
  • shape:形状
  • ndim:维度数
  • dtype:元素类型

索引与切片

arr = np.array([[1, 2, 3], [4, 5, 6]])

print(arr[0, 1]) # 2
print(arr[:, 1]) # [2 5]
print(arr[0:2, 1:])

和普通 Python 序列一样,NumPy 也支持切片,但多维数组的索引表达力更强。

向量化

NumPy 最有价值的地方之一,是很多运算不需要自己手写 for 循环。

arr = np.array([1, 2, 3, 4])

print(arr * 2) # [2 4 6 8]
print(arr + 10) # [11 12 13 14]
print(arr ** 2) # [1 4 9 16]

数组之间也可以直接按元素运算:

x = np.array([1, 2, 3])
y = np.array([4, 5, 6])

print(x + y)
print(x * y)

常见统计操作

arr = np.array([1, 2, 3, 4, 5])

print(arr.sum())
print(arr.mean())
print(arr.max())
print(arr.min())
print(arr.std())

reshape 与展平

arr = np.arange(6)
matrix = arr.reshape(2, 3)
flat = matrix.flatten()

当数据形状要从一维改成二维、从二维改成一维时,这两个操作很常见。

布尔索引

这是我实际用得非常多的能力:

arr = np.array([1, 2, 3, 4, 5, 6])
mask = arr % 2 == 0

print(arr[mask]) # [2 4 6]

配合 np.where() 也很常见:

arr = np.array([1, 2, 3, 4])
result = np.where(arr % 2 == 0, arr, 0)

两个常见坑

1. 切片通常返回视图

NumPy 的切片很多时候不是拷贝,而是视图。改切片可能会影响原数组。

2. dtype 会影响结果

整数数组做整数运算、浮点数组做浮点运算时,结果表现会不同。遇到精度问题时,先确认 dtype

我更关心的学习重点

  • 能看懂 shape
  • 能熟练做索引和切片
  • 明白什么叫向量化
  • 先用现成数组运算,再考虑自己写循环

关联阅读