# python程序性能优化实战1

# 参考

# 前言

有许多工具和库可以帮助我们写出更加高效的Python代码。但是在我们深入这些第三方方案之前,先来看一看如何编写更加高效的纯Python代码,无论是在计算方面还是IO方面。虽然不能说所有,但确实有许多问题可以通过对Python能力与局限的理解来规避。

为了展示Python自己的性能优化工具,我们假设一些问题来作演示。假设你是一名数据工程师,准备对全球气候数据进行分析。数据来源是这里 (opens new window),归属美国国家海洋与大气管理局。时间很紧,而且只能用标准Python,受限于预算,也没法购买更多的计算资源。数据一个月之后就能就绪,你要利用这段时间来提升代码性能。你的任务是寻找最需要优化的地方,提升它们的性能。

你的第一个任务是度量现有的数据处理代码性能。尽管已经知道代码比较慢,但还是需要为性能瓶颈找到响应的依据。度量的步骤是十分重要的,让你能够以严格系统化的方式来定位代码中瓶颈的所在。而大家平时最常见的拍脑袋方法反而不行,许多性能问题点都是反直觉的。

对于纯Python代码的优化是最直接的,而这个领域也是问题比较多的,所以在这个领域的优化收益颇丰。本篇中,我们会看到Python自带的开箱即用的工具如何帮助我们开发性能更好的代码。我们从代码度量开始,使用集中不同的工具来检测问题区域。然后我们聚焦Python的基础数据结构:列表、集合和字典。我们的目标是提升这些数据结构的效率,优化内存分配的性能。最后,我们会看到现代Python的延迟求值技术是如何改善数据处理流程的性能的。

本篇中只看纯Python代码,不涉及第三方库,但会用到一些外部工具来优化性能和访问数据。我们会使用Snakeviz对度量结果可视化,还会用到line_profiler来逐行测量性能。最后使用request库来下载数据。

如果你使用Docker,直接用默认的镜像即可。

# 对应用的IO和计算负载进行度量

我们的第一个目标是从气象站下载数据,计算某个年份的最低温度。NOAA网站上有CSV格式的数据,按照年份与气象站切分。比如al-hourly/access/2021/01494099999.csv (opens new window)这个文件包含了2021年01494099999气象站的所有记录。这些记录中有温度、气压等信息,而且一天可能有多条记录。

让我们写一个脚本来下载一段时间内某一批气象站的数据,然后计算每个气象站的最低温度。

# 下载数据并计算最低温度

我们的脚本通过命令行运行,接受一个站点列表,开始和结束的年份。这里是处理输入的代码:

import collections
import csv
import datetime
import sys
import requests

stations = sys.argv[1].split(",")
years = [int(year) for year in sys.argv[2].split("-")]
start_year = years[0]
end_year = years[1]

我们使用requests库来获取文件,如下:

TEMPLATE_URL = "https://www.ncei.noaa.gov/data/global-hourly/access/{year}/{station}.csv"
TEMPLATE_FILE = "station_{station}_{year}.csv"

def download_data(station, year):
  my_url = TEMPLATE_URL.format(station=station, year=year)
  req = requests.get(my_url)
  if req.status_code != 200:
    return # not found
  w = open(TEMPLATE_FILE.format(station=station, year=year), "wt")
  w.write(req.text)
  w.close()

def download_all_data(stations, start_year, end_year):
  for station in stations:
    for year in range(start_year, end_year + 1):
      download_data(station, year)

这部分代码会将每个文件都写入本地磁盘。接下来我们获取单个文件中的所有温度:

def get_file_temperatures(file_name):
    with open(file_name, "rt") as f:
        reader = csv.reader(f)
        header = next(reader)
        for row in reader:
            station = row[header.index("STATION")]
            # date = datetime.datetime.fromisoformat(row[header.index('DATE')])
            tmp = row[header.index("TMP")]
            temperature, status = tmp.split(",")
            if status != "1":
                continue
            temperature = int(temperature) / 10
            yield temperature

接下来获取每个站点的所有温度,得到最小值:

def get_all_temperatures(stations, start_year, end_year):
    temperatures = collections.defaultdict(list)
    for station in stations:
        for year in range(start_year, end_year + 1):
            for temperature in get_file_temperatures(TEMPLATE_FILE.format(station=station, year=year)):
                temperatures[station].append(temperature)
    return temperatures


def get_min_temperatures(all_temperatures):
    return {station: min(temperatures) for station, temperatures in all_temperatures.items()}

最后把所有的部分组合起来:下载数据、获取温度、计算最小值和输出结果。

python load.py 01044099999,02293099999 2021-2021

输出的结果是:{'01044099999': -10.0, '02293099999': -27.6}

现在好戏才刚刚开始。我们的目标是继续下载更多站点在更多年份的更多数据。为了处理这么多的数据,我们需要让代码尽可能地高效。而提升代码效率的第一步便是度量其效率,找到性能的瓶颈。为此我们会用到Python内置的度量机制。

# Python内置的度量工具

我们的第一步是对代码进行度量,检查每个函数的时间消耗。为此,我们需要通过Python的cProfile模块来运行代码。这个模块内置于Python之中,帮助我们从代码中获取执行信息。这里要注意,我们用的不是profile模块,这货要慢几个数量级,除非你要开发自己的测量工具时才用得上。

我们可以通过下面的命令来运行:

python -m cProfile -s cumulative load.py 01044099999,02293099999 2021-2021 > profile.txt

这里我们通过-m参数运行Python,执行cProfile模块。Python官方推荐使用这个模块来收集运行信息。我们根据累计执行时间对数据进行排序。执行的结果部分展示如下:

这里的结果是根据在函数上累计执行的时间来排序的。另一种方式是根据每个函数的调用次数来排序。可以看到,这里只调用了1次download_all_data,但是其累计执行时间却和整个脚本差不多。这里有两列名字都叫做percall,第一列代表的是排除子调用之后的时间开销,第二列代表的是包含子调用之后的时间开销。这里download_all_data的时间明显都花在了子调用上。

在许多IO密集的场景中,比如现在这个例子,很有可能是IO占据了主要的时间开销。在我们的例子中,既有网络IO(下载数据),又有磁盘IO(写入文件)。网络开销差异巨大,收到许多因素的影响,一般是最大的时间消耗所在,让我们尝试来缓解这个问题。

# 使用本地缓存来减少网络使用

为了减少网络通信,我们可以在首次下载文件的时候保存一个副本给未来使用。我们会构建一个本地数据缓存。我们使用和之前一样的代码,但是会检查文件是否存在,存在的话就不会重复下载。

def download_all_data(stations, start_year, end_year):
    for station in stations:
        for year in range(start_year, end_year + 1):
            if not os.path.exists(TEMPLATE_FILE.format(station=station, year=year)):
                download_data(station, year)

第一次运行的时候还是和之前一样慢,但是第二次就不需要进行网络访问了。当前这个例子中有数量级的优化。可以通过下面的命令执行:

python -m cProfile -s cumulative load_cache.py 01044099999,02293099999 2021-2021 > profile_cache.txt

结果如下:

尽管时间下降了一个数量级,IO仍然名列前茅。现在问题不在网络,而在于磁盘访问了,当然也是因为计算量相对较小。

注意:本例中展示的缓存虽然可以将速度提升一个数量级,但是缓存的管理是很难的,也是常见的bug来源。本例中的文件不会随着时间而变化,但大部分情况不是这样的,缓存管理代码需要识别这个问题。我们也会在后续的内容中回头再看这个问题。

接下来我们看看CPU是限制因素的话怎么办。

# 度量代码以寻找性能瓶颈

这里我们用同样的数据但是换一个任务,主要考验CPU。我们使用NOAA所有的站点,计算站点之间的距离,复杂度位N²。

我们提前准备好了数据,下面的计算的代码。

def get_locations():
    with open("locations.csv", "rt") as f:
        reader = csv.reader(f)
        header = next(reader)
        for row in reader:
            station = row[header.index("STATION")]
            lat = float(row[header.index("LATITUDE")])
            lon = float(row[header.index("LONGITUDE")])
            yield station, (lat, lon)


def get_distance(p1, p2):
    lat1, lon1 = p1
    lat2, lon2 = p2

    lat_dist = math.radians(lat2 - lat1)
    lon_dist = math.radians(lon2 - lon1)
    a = (
        math.sin(lat_dist / 2) * math.sin(lat_dist / 2) +
        math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) *
        math.sin(lon_dist / 2) * math.sin(lon_dist / 2)
    )
    c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
    earth_radius = 6371
    dist = earth_radius * c

    return dist


def get_distances(stations, locations):
    distances = {}
    for first_i in range(len(stations) - 1):
        first_station = stations[first_i]
        first_location = locations[first_station]      
        for second_i in range(first_i, len(stations)):
            second_station = stations[second_i]
            second_location = locations[second_station]
            distances[(first_station, second_station)] = get_distance(
                first_location, second_location)
    return distances


locations = {station: (lat, lon) for station, (lat, lon) in get_locations()}
stations = sorted(locations.keys())
distances = get_distances(stations, locations)

这段代码跑起来要花不少时间,还会消耗不少内存。如果你内存不够,可以少处理一些站点。接下来我们使用Python的工具来看看大部分时间花在了哪里。

# 将度量信息可视化

我们再一次使用度量工具来寻找影响执行速度的代码。为了更好地理解跟踪情况,我们会使用一个外部的工具,SnakeViz (opens new window)

我们先保存一个度量文件:

python -m cProfile -o distance_cache.prof distance_cache.py

这里的-o参数会指定度量数据保存的位置,其它的调用照旧。

注意:Python提供了pstats模块来分析写入了磁盘的跟踪记录。你可以用python -m pstats distance_cache.prof命令来分析程序的执行成本。你可以在文档中找到更多的细节,第五篇也会涉及。

为了分析这些信息,我们会使用基于web的可视化工具,直接通过snakeviz distance_cache.prof即可,会跳出一个交互式浏览器界面(图2.1)

大部分时间都花在了get_distance上,但究竟是哪里呢?我们可以看到部分数学函数的开销,但是Python的度量工具没法提供函数内部更细粒度的视角。我们看到的只是聚合视角。是的,math.sin确实花了不少时间,但是我们在好多地方都用到了,到底是哪里有问题呢?为此我们需要引入逐行度量的工具。

# 逐行度量

尽管内置的度量工具可以帮我们找到哪一段代码引起的性能问题,但是局限也不少。这里我们会讨论这些局限,并引入逐行度量工具来进一步寻找性能瓶颈。

##为了了解get_distance函数中每一行代码的代价,我们会使用line_profiler库 (opens new window)。具体用法很简单,只需要加个标注就行了:

@profile
def get_distance(p1, p2):

你可能会发现,我们并没有从任何地方导入profile注解。这是因为我们将使用line_profile包中的便捷脚本kernprof来处理。让我们用下面的方式来进行逐行测量:

kernprof -l lprofile_distance_cache.py

逐行度量会让执行速度慢几个数量级,完整跑完可能要几个小时。执行完毕之后可以使用下面的命令来查看结果。

python -m line_profiler lprofile_distance_cache.py.lprof

如果你看看图2.2中的结果就会发现去多调用的耗时都不少,也是我们可能需要优化的地方。现阶段我们度量完毕之后就此打住,但后续第六篇中我们还会进行优化。

图2.2

可以看到,逐行度量的结果比内置的工具更容易让人看明白。

# 心得:度量代码性能

我们一开始尝试的内置度量工具也帮了不少忙,而且运行起来要比逐行度量工具更快。但是逐行度量工具深入函数内部,提供的信息量更多。相反,Python内置的工具只提供函数的累计值,以及在子调用上花费的时间。某些场景中可能会知道子调用是什么,但是一般情况下是不清楚的。一个完整的度量策略需要把这些都考虑在内。

我们在这里采取的是一种比较理性的策略:先尝试内置的cProfile模块,因为它跑的快,还能提供一些概要信息。如果这些不够,那就使用逐行度量,信息更多,运行也更慢。我们这里的主要目标是寻找性能瓶颈;后续的篇章中 会进行优化。有时候仅仅对现有方案做局部调整还不够,甚至需要整体重新架构。

其它度量工具
度量代码的时候,有一个工具不得不提,那就是timeit模块。这可能是新手用的最多的模块,网上的例子也很多。结合使用最简单的是IPython和Jupyter Notebook的场景,只要加上%timeit%标记,就可以对任何东西进行度量。比如在IPython中:
In [1]: %timeit list(range(1000000))
27.4 ms ± 72.5 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
你可以在timeit的文档中找到更多细节,但是一般来说上面这种用法就够了。

到这里你应该已经熟悉度量的代码和工具了,下一步我们看看如何优化Python数据结构的使用。

# 优化基础数据结构的速度:列表、集合、字典

接下来我们要寻找基础数据结构用的不好的地方,用更加高效的方式重写。这里我们还是使用NOAA的例子,但任务变成了一个站点在特定周期内是否出现过某些气温。下面的代码会读取站点01044099999在2005到2021年间的数据:

stations = ['01044099999']
start_year = 2005
end_year = 2021
download_all_data(stations, start_year, end_year)
all_temperatures = get_all_temperatures(stations, start_year, end_year)

first_all_temperatures = all_temperatures[stations[0]]

first_all_temperatures中保存了站点观测到的所有气温。通过print(len(first_all_temperatures)、max(first_all_temperatures)、min(first_all_temperatures)))等方法,我们知道这里一共有141082条记录,最高气温27℃,最低-16℃。

# 列表搜索的性能

我们要检查某个温度是否存在于first_all_temperatures列表中。让我们通过下面的方法来粗略估计这个过程要花多少时间:

%timeit (-10.7 in first_all_temperatures)

我电脑上的结果如下:

313 μs ± 6.39 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

如果我们查询一个不在列表中的值:

%timeit (-100 in first_all_temperatures)

结果变成了:

2.87 ms ± 20.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

几乎比前面的结果高了一个数量级。

为什么第二次搜索的性能如此之差呢?因为这里的搜索中,in操作符需要从头开始顺序扫描整个列表。这意味着,在最差的情况下,也就是目标不存在于列表中的时候,我们需要遍历整个列表。对于小列表尚且影响不大。但是随着数据规模扩大,这种额外增加的开销就很可观了。

这里我们没有绝对的数字可供参考,但是从毫秒到微妙的变化肯定是不理想的,至少应该降低一个数量级。

# 使用集合搜索

让我们看看将列表转换为集合之后表现是否有所改善。

set_first_all_temperatures = set(first_all_temperatures)
print(len(set_first_all_temperatures))

%timeit (-10.7 in set_first_all_temperatures)
%timeit (-100 in set_first_all_temperatures)

结果如下:

62.1 ns ± 3.27 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
26.6 ns ± 0.115 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

比之前的版本快了几个数量级!为什么呢?原因有两个:其一是集合的大小,其二是复杂度。

复杂度我们晚点再说,先看体积。还记得吗,原始的列表中有14万个元素,但是在集合中,所有雷同的元素都缩减为了一个元素,对列表实现了去重,最终只剩下400个元素,差不多是350倍的差距。体积变小了,搜索速度也就变快了。

这里的心得是关注列表中的重复元素,使用集合去重之后可以在更小的数据集上执行搜索。但是Python中列表和集合还有一个深刻的区别。

# Python中的列表、集合、字典的复杂度

前面的性能提升很大程度上来源于数据结构体积的缩小。那么如果列表数据中没有重复值,转换成集合之后体积不就一样了吗?让我们用下面的代码模拟一下这种情况:

a_list_range = list(range(100000))
a_set_range = set(a_list_range)
%timeit 50000 in a_list_range
%timeit 50000 in a_set_range
%timeit 500000 in a_list_range
%timeit 500000 in a_set_range

我们在列表与集合中都保存了0到99999的数字,从中寻找50000和500000,记时结果如下:

621 µs ± 23.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
54.4 ns ± 1.86 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
1.24 ms ± 23.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
40.8 ns ± 1.44 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

集合的表现仍然好于列表。在Python(准确地说是CPython)中,集合是通过哈希实现的。搜索一个元素的代价就是哈希搜索的代价。哈希函数有许多有点,也有许多设计问题。当我们拿列表和集合对比的时候,我们一般可以假设集合的查询是常量复杂度,体积从10到1000万都一样。虽然不够准确,但是可以用这种直观的方式来理解集合查询为什么比列表查询要快。

集合一般是以类似字典的方式实现的,不过没有值而已。当你搜索字典中的键时,效率和搜索集合是一样的。不过,字典与集合也不是万能的。比如,当你要搜索一个区间的时候,列表可能效率更高。在有序列表中,你可以先找到区间中的低位元素,然后再遍历找到第一个大于区间高位的元素,结束搜索。而在集合或字典中,你需要为区间中的每个元素进行检索。所以,如果你知道要搜索的确切值,那么字典更快;如果要搜索的是区间,那么用二分法来搜索有序列表效果更好。

考虑到列表在Python代码中到处都是,一定有许多场合中可以替换为更加合适的数据结构。不过列表本身作为一种基础数据结构,也有许多不错的使用场景。这里的点在于激发你的思想,而不是简单的禁用列表。

提示:在大型列表中使用in操作符一定要谨慎。Python代码中有大量这样的用法,对于小型列表可能问题不大,但是列表规模大了之后问题就严重了。
从软件工程实践的角度出发,in操作符的使用,很可能从开发阶段不起眼的小问题,演变为生产环境中的大问题。程序员一般只用小数据集进行测试,将大数据集用于单元测试也不太实际。而真正的数据规模可能很大,一旦引入,甚至会把整个系统卡死。
更加系统性的方案是,时不时地用大数据集进行测试。这不用每次都做,也不是专门针对in操作符而做,而是提醒注意开发环境和生产环境之间由数据规模带来的性能差异。

顺便说下,对于大部分搜索操作,有一类比列表和集合更好的数据结构:树。但本篇我们主要看Python内置的数据结构,暂时不包含树。

完整地介绍如何选择合适的算法与数据结构超出了本书的范围,而且通常也是计算机课程中最难的一门课。所以这里的点不在于穷尽这个话题,而是让你了解Python中常见的选择。如果你觉得Python中现有的数据结构不满足需求,可能需要考虑别的数据结构,可以参考别的专门介绍数据结构与算法的书籍。

另一个有用的资源是Python自己的TimeComplexity (opens new window),包含了Python中许多数据结构上的各种操作的时间复杂度。

本篇到这里关注的都是时间上的表现,但这不是大数据集性能的唯一影响因素,接下来我们要看看另一个重要的因素,内存占用。

# 寻找过度的内存分配

内存消耗对于性能影响也很关键。有效的内存分配可以让更多进程并行运行于同一个机器上。

让我们回到熟悉的NOAA数据库,让我们看看如何减少数据对磁盘的占用。为此,我们从调研数据文件的内容开始。此处的目标是加载一部分文件,对特征分布做一些统计。

def download_all_data(stations, start_year, end_year):
    for station in stations:
        for year in range(start_year, end_year + 1):
            if not os.path.exists(TEMPLATE_FILE.format(station=station, year=year)):
                download_data(station, year)


def get_all_files(stations, start_year, end_year):
    all_files = collections.defaultdict(list)
    for station in stations:
        for year in range(start_year, end_year + 1):
            f = open(TEMPLATE_FILE.format(station=station, year=year), 'rb')
            content = list(f.read())
            all_files[station].append(content)
            f.close()
    return all_files


stations = ['01044099999']
start_year = 2005
end_year = 2021
download_all_data(stations, start_year, end_year)
all_files = get_all_files(stations, start_year, end_year)

这里的all_files是一个字典,其中每个条目都包含了某一个站点的全部内容。让我们来研究一下其内存开销。

# 查看Python内存预测中的坑

Python在sys模块中提供了一个函数,getsizeof,返回的是一个对象所占据的内存空间。我们可以通过下面的代码来理解字典对象占据的内存空间:

print(sys.getsizeof(all_files))
print(sys.getsizeof(all_files.values()))
print(sys.getsizeof(list(all_files.values())))

结果分别是:240、40、60。

这里getsizeof返回的结果未必如你所想象,因为磁盘上的文件是以MB为单位的,这里返回的大小在1KB以下不免让人疑惑。这里getsizeof返回的是容器大小(第一个是字典,第二个是迭代器,第三个是列表),不包含其中的内容。所以我们要考虑的两种不同的东西对内存的占用:容器中的内容,还有容器本身。

注意,这里getsizeof的实现本身是没有问题的,只是和用户的预期不太一致,一般用户期待其能够囊括对象中引用到的全部内容。如果你阅读官方文档,你会找到一种递归式的实现指令来解决这个问题。对我们来说,这个小插曲正好是深入理解CPython内存分配的不错的起点。

接下来获取一些关于站点数据的信息:

station_content = all_files[stations[0]]
print(len(station_content))
print(sys.getsizeof(station_content))

返回的结果分别是17和248,字典中只有一个条目,这个条目对应的是一个列表,列表中有17个对象。列表本身是248字节,但是不包含其中的内容。接下来看看列表中第一个对象的大小。

print(len(station_content[0]))
print(sys.getsizeof(station_content[0]))
print(type(station_content[0]))

这里的长度是1303981,对应文件的大小。getsizeof的结果是10431904,差不多是前者的8倍,为什么呢?每个元素是一个指针,指向一个字符,每个指针是8字节。这里看起来很不妙,数据结构很大,我们没有很好的处理。接下来再看看单个的字符:

print(sys.getsizeof(station_content[0][0]))
print(type(station_content[0][0]))

这在体积上是巨大的。输出是28,类型是int。每个字符,即便只需要1个字节,这里都花费了28字节来表示。因此我们的列表花费了10431904加上28x1303981,总计46943372,大约是原始文件的36倍大。幸运的是,我们还有办法优化,CPython在内存分配方面还是挺聪明的。

CPython可以用一种更加精妙的方式来分配对象,远不像我们这种算法那样原始。让我们来计算内部内容的大小,不过不再是遍历矩阵中所有的数字,我们要确保不会重复计算。在Python中,如果一个对象被重复使用,那么这个对象会有相同的id。如果我们多次看见相同的id,我们应该只记录一次内存分配。

single_file_data = station_content[0]
all_ids = set()
for entry in single_file_data:
    all_ids.add(id(entry))
print(len(all_ids))

前面的代码获取了我们所有数字的唯一标识符。在CPython中,这会在内存分配的时候发生。CPython很明智地看到了相同的字符串内容会被反复使用,ASCII字符是由0到127之间的数字表示的,而上面的代码返回的结果是46。

所以,CPython在内存分配方面还是挺聪明的。这里的内存消耗就是列表的开销(10431904),加上46个不同的字符,几乎可以忽略。虽然Python在内存分配方面表现不错,但是不要指望每次都有最佳效果,因为这取决于数据的模式。

Python中的对象缓存和重用
Python会尽可能重用对象,但我们对这个预期要慎重。一方面这个特性取决于具体的实现,CPython和其它Python版本的行为是不一样的。另一方面,即便CPython也不保证每个版本的行为都一样。最后,即便是同一个版本,其工作细节也未必是一目了然。

这里我们使用数字列表的形式来展示文件内容,那么其它的表示方式如何呢?

# 其它展示形式的内存占用

接下来我们看看用别的方式来表示文件中的内容,有时候可能更好,有时候可能更差。这里主要是理解每种机制背后的代价。相比于使用整数来表示每个字符,我们可以使用长度为1的字符串:

single_file_str_list = [chr(i) for i in single_file_data]

这种方式比之前的还要糟糕,可以看看每个长度为1的字符串要占据多少空间:

print(sys.getsizeof(single_file_str_list[0]))

返回的是50,而之前用整数表示的时候还只需要28,我们不会采用这种方案。

图2.3

Python在处理许多小对象的时候开销是蛮大的。为什么整数需要28字节,而字符串需要50字节呢?每个Python对象需要24字节的开销,另外根据类型不同,还要加上额外的开销。如图2.3所示,字符串的开销比字节数组要更大。

我们还可以用另一种更加直观的方式来表示文件内容:相比于一个个字符, 我们使用一个字符串来容纳整个文件。

single_file_str = ''.join(single_file_str_list)
print(sys.getsizeof(single_file_str))

这里的大小是1304030,也就是文件大小加上字符串对象的开销。尽管这种方法简单又直观,但我们还是会沿用容器和字节序列,因为这些方法还有提升的空间。

# 使用数组作为一种紧凑表示以取代列表

这里我们来看看另一种元素容器,数组,是如何提升内存使用效率的。让我们回顾以下get_all_files函数的实现:

def get_all_files_clean(stations, start_year, end_year):
    all_files = collections.defaultdict(list)
    for station in stations:
        for year in range(start_year, end_year + 1):
            f = open(TEMPLATE_FILE.format(station=station, year=year), 'rb')
            content = f.read()
            all_files[station].append(content)
            f.close()
    return all_files

这里原先是content = list(f.read (opens new window)()),将read函数返回的内容转换为列表。现在不再转成列表,而是直接返回字节数组,让我们看看对象的体积:

print(type(single_file_data))
print(sys.getsizeof(single_file_data))

这里的类型是bytes,包含数据在内的大小是1304014。

数组是固定大小的,只能容纳相同类型的数据,因此对数据表示可以更加紧凑。

列表中的内存占用
当你分配一个列表,Python会将来可能添加的内容分配额外的空间,因此列表实际占据的空间比你想象的大。这也让数据插入更方便,不用每次都另外分配内存,除非之前分配的空间耗尽。当然这里的代价就是内存损耗,这种损耗也不大,除非你有许许多多小列表。

许多和数组管理相关的代码都在array模块中。不过在本篇之外我们就不会用array模块了,我们会使用NumPy,全面优于前者。不过这里的点在于理解和消除对象的内存损耗。

在这个阶段,你应该洞察Python对象在内存使用中的代价和陷阱。最后,我们会理解如何计算Python对象的内存使用。

# 将我们的所学系统化,估算Python对象的内存消耗

现在你已经对内存分配有了基本的理解,掌握了底层的原则,我们会通过代码将之前的知识融合进一个功能函数,对内存的消耗给出一个较好的评估。

我们要汇聚本篇中所有的知识点,在下面的函数中计算不仅包含对象的大小,还有容器带来的开销。如果你查阅下面的代码,你会发现ID跟踪、容器计数、字符串和数组管理。

针对通用对象计算体积是名副其实的雷区(原则上说,只用Python的方法是不太可能完成的)。下面的代码会尽可能避免重复计算对象、容器和迭代器。

from array import array
from collections.abc import Iterable, Mapping
from sys import getsizeof
from types import GeneratorType


def compute_allocation(obj) -> int:
    my_ids = set([id(obj)])
    to_compute = [obj]
    allocation_size = 0
    container_allocation = 0
    while len(to_compute) > 0:
        obj_to_check = to_compute.pop()
        allocation_size += getsizeof(obj_to_check)
        if type(obj_to_check) == str:
            continue
        if type(obj_to_check) == array:
            continue
        elif isinstance(obj_to_check, GeneratorType):
            continue
        elif isinstance(obj_to_check, Mapping):
            container_allocation += getsizeof(obj_to_check)
            for ikey, ivalue in obj_to_check.items():
                if id(ikey) not in my_ids:
                    my_ids.add(id(ikey))
                    to_compute.append(id(ikey))
                if id(ivalue) not in my_ids:
                    my_ids.add(ivalue)
                    to_compute.append(id(ivalue))
        elif isinstance(obj_to_check, Iterable):
            container_allocation += getsizeof(obj_to_check)
            for inner in obj_to_check:
                if id(inner) not in my_ids:
                    my_ids.add(id(inner))
                    to_compute.append(inner)
    return allocation_size, allocation_size - container_allocation

这里我们使用了迭代式的方法来计算内存分配。这种方法有利于递归式的实现,但是由于Python对递归实现和尾递归优化支持不理想,我们这里使用了迭代式的实现。

通过C或Rust的第三方库来计算对象的大小,更多是依赖于Python实现来以某种形式提供一些信息。对于这些库,要查阅文档获取更多细节。

警告,你可以使用一些Python的内存度量库。我曾经用过一些这类工具,效果一般,毕竟Python中的内存估计有许多坑。如果你用的话,一定要小心。
还有一些更加底层的方式来检查Python中的内存分配,我们会在后续谈到NumPy的时候再说。本篇中我们只看Python本身,不涉及外部库。

# 心得:估算Python对象的内存使用

总结一下,估算内存对象的大小远比想象的要困难。sys.getsizeof不会汇报整个对象的大小,因此需要额外的工作来准确计算对象的大小。一般情况下,这个问题不好解决:有些用底层语言编写的库可能不会报告它们分配的内存大小。

精炼的内存分配有许多好处。其一是在内存受限的情况下,允许更多进程并行运行。其二是为许多内存算法提供了工作的空间,不再像其它算法那样需要缓慢的磁盘访问了。

# 为大数据流水线应用延迟求值和生成器

现在我们要将注意力转向Python 3中广泛引入的延迟求值语法。它会任何计算推迟到实际需要使用结果数据的时候再进行,而不是在这之前进行。这对于处理大规模数据极为有用,因为有时候这些计算要花很多时间,甚至不可能完成。如果你使用生成器,你就已经用上了延迟求值。Python 3远比Python 2更“懒”,因为range、map和zip这些函数都延迟化了。这种方式可以让你处理更多的数据,使用更少的内存,更简单地创建数据流水线。

# 使用生成器取代标准函数

让我们回顾本篇第一节中的代码:

def get_file_temperatures(file_name):
    with open(file_name, "rt") as f:
        reader = csv.reader(f)
        header = next(reader)
        for row in reader:
            station = row[header.index("STATION")]
            # date = datetime.datetime.fromisoformat(row[header.index('DATE')])
            tmp = row[header.index("TMP")]
            temperature, status = tmp.split(",")
            if status != "1":
                continue
            temperature = int(temperature) / 10
            yield temperature

get_file_temperatures就是一个生成器(注意最后的yield)。让我们运行这个生成器。

temperatures = get_file_temperatures(TEMPLATE_FILE.format(station="01044099999", year=2021))
print(type(temperatures))
print(sys.getsizeof(temperatures))

这里返回的类型是generator,结构的大小是112。平时这个生成器不会做什么,只有你开始遍历的时候才会开始按需执行:

for temperature in temperatures:
    print(temperature)

这种方法有许多好处。首先也是最大的好处,不用一次性为所有的温度数据分配内存,因为它们是一个个处理的。而列表的话就需要内存来同时维护所有的温度数据。这一点非常重要,尤其是函数返回的数据结构有很多元素的情况下,直接关系到我们是否有足够的内存来执行代码。

第二,有时候我们不需要获取全部的结果,提前计算的话会把时间浪费在无效的计算上。假设,你要写一个函数来看是否存在低于0℃的情况,你就不需要所有的结果,只要出现一个0℃以下的值就能停了。

想要触发计算也很简单:

temperatures = list(temperatures)

这样你就失去了生成器的优势,但是有时候这也是有用的。比如,在计算时间和内存消耗可行的情况下,如果需要多次访问结果,及时求值的方式更加合理。

注意:Python 2和Python 3最大的区别之一是许多内置的工具进行了延迟化改造,这里说到的zip、map、filter等在Python 2中的行为是完全不同的。

生成器可以用来减少内存开销,甚至是计算时间。所以当你写代码返回数据序列的时候,问问自己能否将其转换成生成器。

# 总结

  • 检测性能瓶颈要比直觉想象的要难。性能度量是寻找性能缺陷的第一步。直觉一般是错的,实证的方法才是定位性能问题的可靠方法。
  • Python内部的度量工具挺有用,就是理解起来有点困难。类似SnakeViz的可视化工具可能帮我们更好地理解度量信息。
  • Python内部的度量系统在帮我们定位瓶颈的时候还是有局限。line_profiler这样的工具更加精确,当然代价是运行耗时较长。
  • 尽管CPU性能是我们做性能优化时候第一步要考虑的,内存消耗也是同等重要,有许多间接的好处。比如,内存优化很差的应用,如果能够优化为完全内存算法,将产生可观的时间收益。
  • Python提供的基础数据结构如果用的不好也会有性能影响。比如,在未排序的列表中搜索元素代价是很大的。我们要考虑Python基础数据结构上的操作复杂度。这些数据结构在Python程序中无处不在,优化的效果立竿见影。
  • 如果对计算复杂度有基础的理解,对编写高效代码是十分关键的。要时不时检查Python版本的变化,有时候底层的实现变了,算法的性能也会改变。
  • 延迟计算技术让我们使用更少的内存,甚至规避一大部分计算。
  • 本篇中所有的内容的适用性都很广泛,而且是其它所有技术的前置步骤。