NumPy 入门
NumPy 是 Python 数据处理的基础库。它最核心的对象是 ndarray,也就是多维数组。
我把这部分放进主线,不是因为一开始就要做复杂科学计算,而是因为后面接 Pandas、机器学习、数值计算时,很多概念都会回到数组、形状和向量化。
创建数组
import numpy as np
a = np.array([1, 2, 3])
b = np.array([[1, 2], [3, 4]])
常见创建方式:
np.zeros((2, 3))
np.ones((2, 3))
np.arange(0, 10, 2)
np.linspace(0, 1, 5)
先看懂三个属性
arr = np.array([[1, 2, 3], [4, 5, 6]])
print(arr.shape) # (2, 3)
print(arr.ndim) # 2
print(arr.dtype) # int64 或 int32,取决于平台
shape:形状ndim:维度数dtype:元素类型
索引与切片
arr = np.array([[1, 2, 3], [4, 5, 6]])
print(arr[0, 1]) # 2
print(arr[:, 1]) # [2 5]
print(arr[0:2, 1:])
和普通 Python 序列一样,NumPy 也支持切片,但多维数组的索引表达力更强。
向量化
NumPy 最有价值的地方之一,是很多运算不需要自己手写 for 循环。
arr = np.array([1, 2, 3, 4])
print(arr * 2) # [2 4 6 8]
print(arr + 10) # [11 12 13 14]
print(arr ** 2) # [1 4 9 16]
数组之间也可以直接按元素运算:
x = np.array([1, 2, 3])
y = np.array([4, 5, 6])
print(x + y)
print(x * y)
常见统计操作
arr = np.array([1, 2, 3, 4, 5])
print(arr.sum())
print(arr.mean())
print(arr.max())
print(arr.min())
print(arr.std())
reshape 与展平
arr = np.arange(6)
matrix = arr.reshape(2, 3)
flat = matrix.flatten()
当数据形状要从一维改成二维、从二维改成一维时,这两个操作很常见。
布尔索引
这是我实际用得非常多的能力:
arr = np.array([1, 2, 3, 4, 5, 6])
mask = arr % 2 == 0
print(arr[mask]) # [2 4 6]
配合 np.where() 也很常见:
arr = np.array([1, 2, 3, 4])
result = np.where(arr % 2 == 0, arr, 0)
两个常见坑
1. 切片通常返回视图
NumPy 的切片很多时候不是拷贝,而是视图。改切片可能会影响原数组。
2. dtype 会影响结果
整数数组做整数运算、浮点数组做浮点运算时,结果表现会不同。遇到精度问题时,先确认 dtype。