# 装饰器

# 装饰器

# 装饰器概述

装饰器的作用与意义,在于其可以通过自定义的函数或类,在不改变原函数的基础上,改变原函数的一些功能。

不改变原函数,可能是说原函数实现的功能不变。 但是当我们试着打印出 greet() 函数的一些元信息, 会发现,greet() 函数被装饰以后,它的元信息变了。元信息告诉我们“它不再是以前的那个 greet() 函数,而是被 wrapper() 函数取代了,为了解决这个问题,我们通常使用内置的装饰器@functools.wrap,它会帮助保留原函数的元信息(也就是将原函数的元信息,拷贝到对应的装饰器函数里)。

装饰器将额外增加的功能,封装在自己的装饰器函数或类中;如果你想要调用它,只需要在原函数的顶部,加上 @decorator 即可。显然,这样做可以让你的代码得到高度的抽象、分离与简化。

# 装饰器本质

  • 本质上装饰器只是一个语法糖-即@,其实质是在@的位置将被装饰器的类或函数作为参数传递给装饰器函数并调用装饰器函数,然后再用装饰器函数的返回值替换被装饰的类或函数。

    def dec(fun):
        print("IN DEC")
        print("func_name: ",fun.__name__)
        return "something from dec"
    
    @dec
    def f():
        print("IN F")
    
    # 第六行的@dec等价于在定义完函数之后执行下列语句
    # f = dec(f)
    
    print(f)
    
    
  • 实际上装饰器的存在前提是语言支持函数一等公民-即函数可以像变量一样被作为参数传递、返回等,这在Python中能实现的原因是Python一切皆对象的缘故,即无论是内建对象-int、list等,还是自定义类、还是函数,其本质上全是对象,函数是函数对象,因此可以做到函数作为一等公民这个特性。

# 装饰器优缺点

# 优点:

  1. 代码重用: 装饰器可以将一些通用的功能封装成可重用的装饰器函数,从而减少了代码的重复性。
  2. 代码简洁: 装饰器可以使代码更加简洁和易读,避免了在每个函数中重复相似的逻辑。
  3. 可扩展性: 可以方便地添加、移除或者修改装饰器,而不用修改原函数的代码。
  4. 分离关注点: 装饰器允许你将关注点从业务逻辑中分离出来,使得代码更易于维护和理解。
  5. 动态性: 装饰器可以在运行时动态地选择要应用的装饰器,从而提供了更大的灵活性。

# 缺点:

  1. 难以调试: 由于函数的行为可能会被多个装饰器修改,当出现问题时,定位错误可能相对困难。
  2. 理解难度: 对于初学者来说,装饰器的概念可能比较抽象,需要一定的时间来理解其工作原理。
  3. 顺序问题: 装饰器的顺序很重要,不同顺序可能导致不同的结果,这需要开发者深入理解各个装饰器的行为。
  4. 性能影响: 虽然通常情况下装饰器的性能影响较小,但是在某些性能要求较高的场景下,多层嵌套的装饰器可能会引起性能问题。

# 装饰器执行顺序

当一个函数被多个装饰器装饰时,这些装饰器会按照从上到下的顺序依次执行,最后返回被装饰后的函数对象。

例如,假设有三个装饰器 decorator1​、decorator2​ 和 decorator3​,并且一个函数 my_function​ 被这三个装饰器装饰,那么它们的执行顺序如下:

@decorator1
@decorator2
def my_function():
    pass
  1. Python 会解释器会把 my_function​ 传给 decorator2​,并将 decorator2​ 的返回值传给 decorator1​。
  2. 然后,decorator1​ 的返回值就是装饰后的 my_function​ 函数对象,即my_function = decorator1(decorator2(my_function))

即对函数的修改是从下到上的。

但是需要注意函数执行时是从上到下的-实际上就是函数的嵌套,修改结果最外面一层的首先被执行。

试验以及分析结果如下:

image

# decorator test
def AA(func):
    print('AA OUTER')
    def func_a(*args, **kwargs):
        print('AA INNER BEFORE')
        res = func(*args, **kwargs)
        print('AA INNER AFTER')
        return res
    return func_a

def BB(func):
    print('BB OUTER')
    def func_b(*args, **kwargs):
        print('BB INNER BEFORE')
        res = func(*args, **kwargs)
        print('BB INNER AFTER')
        return res
    return func_b

@BB
@AA
def f(x):
    print('F CALLED')
    return x * 10

## 完全等价于
# f = BB(AA(f))

# 在调用被装饰的函数之前输出-即即使注释掉下面对f的调用也会输出
# AA OUTER
# BB OUTER

print(f(1))

# BB INNER BEFORE
# AA INNER BEFORE
# F CALLED
# AA INNER AFTER
# BB INNER AFTER
# 10

# 装饰器示例

# 极简装饰器

def record_time(func):
    """自定义装饰函数的装饰器"""
  
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time()
        result = func(*args, **kwargs)
        print(f'{func.__name__}: {time() - start}秒')
        return result
  
    return wrapper

# 参数化装饰器

from functools import wraps
from time import time


def record(output):
	"""可以参数化的装饰器"""
	def decorate(func):
		@wraps(func)
		def wrapper(*args, **kwargs):
			start = time()
			result = func(*args, **kwargs)
			output(func.__name__, time() - start)
			return result
		return wrapper
	return decorate

# 装饰器类

from functools import wraps
from time import time


class Record():
    """通过定义类的方式定义装饰器"""
    """类有`__call__`魔术方法,该类对象就是可调用对象,可以当做装饰器来使用。"""

    def __init__(self, output):
        print("I m init")
        self.output = output

    def __call__(self, func):

        @wraps(func)
        def wrapper(*args, **kwargs):
            start = time()
            result = func(*args, **kwargs)
            self.output(func.__name__, time() - start)
            return result

        return wrapper


@Record(print)
def fun():
    print("测试一下")

# 对类的装饰器

def decorator(cls):
    print(cls.a) # 1
    cls.p = 3
    return cls

@decorator
class A:
    a = 1

print(A.p) # 3

# 另类的定义变量的方式

def eval_now(func):
    return func()

# some other code before...

@eval_now
def logger():
    # log format
    formatter = logging.Formatter(
        '[%(asctime)s] %(process)5d %(levelname) 8s - %(message)s',
        '%Y-%m-%d %H:%M:%S',
    )

    # stdout handler
    stdout_handler = logging.StreamHandler(sys.stdout)
    stdout_handler.setFormatter(formatter)
    stdout_handler.setLevel(logging.DEBUG)

    # stderr handler
    stderr_handler = logging.StreamHandler(sys.stderr)
    stderr_handler.setFormatter(formatter)
    stderr_handler.setLevel(logging.ERROR)

    # logger object
    logger = logging.Logger(__name__)
    logger.setLevel(logging.DEBUG)
    logger.addHandler(stdout_handler)
    logger.addHandler(stderr_handler)

    return logger

# again some other code after...

# 参数化装饰器-优雅写法

# 写法1

def log_slow_call(wrapped=None, seconds=10):
  
    if wrapped is None:
        def another_decorator(wrapped):
            return log_slow_call(wrapped, seconds)
        return another_decorator
  
    def proxy(*args, **kwargs):
        start_time = time.time()
        result = wrapped(*args, **kwargs)
        expired = time.time() - start_time
        if expired > seconds:
            logging.warning('call {} expires {} seconds'.format(wrapped.__name__, expired))
        return result
  
    return proxy


# Case A
# 函数执行超过5秒输出日志
@log_show_call(seconds=5)
def some_func():
    # ...

# Case B
# 默认函数执行超过10秒输出日志
@log_slow_call
def some_func():
    # ...

# 写法2

def log_slow_call(wrapped=None, seconds=10):
  
    if wrapped is None:
        return partial(log_slow_call, seconds=seconds)
  
    def proxy(*args, **kwargs):
        start_time = time.time()
        result = wrapped(*args, **kwargs)
        expired = time.time() - start_time
        if expired > seconds:
            logging.warning('call {} expires {} seconds'.format(wrapped.__name__, expired))
        return result
  
    return proxy


# Case A
# 函数执行超过5秒输出日志
@log_show_call(seconds=5)
def some_func():
    # ...

# Case B
# 默认函数执行超过10秒输出日志
@log_slow_call
def some_func():
    # ...

# 装饰器用法实例

# 身份认证

最常见的身份认证的应用。这个很容易理解,举个最常见的例子,你登录微信,需要输入用户名密码,然后点击确认,这样,服务器端便会查询你的用户名是否存在、是否和密码匹配等等。如果认证通过,你就可以顺利登录;如果不通过,就抛出异常并提示你登录失败。

再比如一些网站,你不登录也可以浏览内容,但如果你想要发布文章或留言,在点击发布时,服务器端便会查询你是否登录。如果没有登录,就不允许这项操作等等。

我们来看一个大概的代码示例:

import functools

def authenticate(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        request = args[0]
        if check_user_logged_in(request): # 如果用户处于登录状态
            return func(*args, **kwargs) # 执行函数post_comment() 
        else:
            raise Exception('Authentication failed')
    return wrapper
  
@authenticate
def post_comment(request, ...)
    ...
 

这段代码中,我们定义了装饰器 authenticate;而函数 post_comment(),则表示发表用户对某篇文章的评论。每次调用这个函数前,都会先检查用户是否处于登录状态,如果是登录状态,则允许这项操作;如果没有登录,则不允许。

# 日志记录

日志记录同样是很常见的一个案例。在实际工作中,如果你怀疑某些函数的耗时过长,导致整个系统的 latency(延迟)增加,所以想在线上测试某些函数的执行时间,那么,装饰器就是一种很常用的手段。

我们通常用下面的方法来表示:

import time
import functools

def log_execution_time(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        start = time.perf_counter()
        res = func(*args, **kwargs)
        end = time.perf_counter()
        print('{} took {} ms'.format(func.__name__, (end - start) * 1000))
        return res
    return wrapper
  
@log_execution_time
def calculate_similarity(items):
    ...

这里,装饰器 log_execution_time 记录某个函数的运行时间,并返回其执行结果。如果你想计算任何函数的执行时间,在这个函数上方加上@log_execution_time即可。

# 输入合理性检查

再来看今天要讲的第三个应用,输入合理性检查。

在大型公司的机器学习框架中,我们调用机器集群进行模型训练前,往往会用装饰器对其输入(往往是很长的 JSON 文件)进行合理性检查。这样就可以大大避免,输入不正确对机器造成的巨大开销。

它的写法往往是下面的格式:

import functools

def validation_check(input):
    @functools.wraps(func)
    def wrapper(*args, **kwargs): 
        ... # 检查输入是否合法
  
@validation_check
def neural_network_training(param1, param2, ...):
    ...

其实在工作中,很多情况下都会出现输入不合理的现象。因为我们调用的训练模型往往很复杂,输入的文件有成千上万行,很多时候确实也很难发现。

试想一下,如果没有输入的合理性检查,很容易出现“模型训练了好几个小时后,系统却报错说输入的一个参数不对,成果付之一炬”的现象。这样的“惨案”,大大减缓了开发效率,也对机器资源造成了巨大浪费。

# 缓存

最后,我们来看缓存方面的应用。关于缓存装饰器的用法,其实十分常见,这里我以 Python 内置的 LRU cache 为例来说明(如果你不了解 LRU cache,可以点击链接自行查阅)。

LRU cache,在 Python 中的表示形式是@lru_cache。@lru_cache会缓存进程中的函数参数和结果,当缓存满了以后,会删除 least recenly used 的数据。

正确使用缓存装饰器,往往能极大地提高程序运行效率。为什么呢?我举一个常见的例子来说明。

大型公司服务器端的代码中往往存在很多关于设备的检查,比如你使用的设备是安卓还是 iPhone,版本号是多少。这其中的一个原因,就是一些新的 feature,往往只在某些特定的手机系统或版本上才有(比如 Android v200+)。

这样一来,我们通常使用缓存装饰器,来包裹这些检查函数,避免其被反复调用,进而提高程序运行效率,比如写成下面这样:

@lru_cache
def check(param1, param2, ...) # 检查用户设备类型,版本号等等
    ...