Python装饰器与描述符在ORM中的实现
2026/6/16 3:51:55 网站建设 项目流程

Python装饰器与描述符在ORM中的实现

ORM框架大量使用描述符来实现字段的声明式定义。理解描述符在ORM中的实际用法,能直接指导你写出更简洁的数据访问层。

最基础的字段描述符:

class Field:
def __init__(self, name, column_type):
self.name = name
self.column_type = column_type
self.internal_name = f"_{name}"

def __get__(self, obj, objtype=None):
if obj is None:
return self
return getattr(obj, self.internal_name, None)

def __set__(self, obj, value):
self._validate(value)
object.__setattr__(obj, self.internal_name, value)

def _validate(self, value):
if not isinstance(value, self.column_type):
raise TypeError(
f"Expected {self.column_type.__name__}, got {type(value).__name__}"
)

class IntegerField(Field):
def __init__(self, name):
super().__init__(name, int)

class StringField(Field):
def __init__(self, name):
super().__init__(name, str)

class User:
id = IntegerField("id")
name = StringField("name")

def __init__(self, id, name):
self.id = id
self.name = name

u = User(1, "Alice")
print(u.id) # 1
print(u.name) # "Alice"

但这个实现有个问题:field_name参数重复了。在类定义中既写了变量名(id = IntegerField(...)),又写了构造参数("id")。

自动获取字段名需要元类的配合:

import inspect

class ModelMeta(type):
def __new__(mcls, name, bases, namespace):
annotations = namespace.get('__annotations__', {})
for attr_name, attr_type in annotations.items():
if attr_type == int:
field = IntegerField(attr_name)
elif attr_type == str:
field = StringField(attr_name)
else:
continue
namespace[attr_name] = field
return super().__new__(mcls, name, bases, namespace)

class Base(metaclass=ModelMeta):
pass

class User(Base):
id: int
name: str

def __init__(self, id, name):
self.id = id
self.name = name

使用类型注解替代重复的字段名声明,避免了DRY原则的违反。

真实情况更复杂,因为ORM字段需要存储很多元数据:表名、列名、是否为主键、是否有索引、默认值等。以SQLAlchemy风格的完整实现:

class Field:
_registry = {}

def __init__(self, column_type, primary_key=False, default=None, nullable=True, index=False):
self.column_type = column_type
self.primary_key = primary_key
self.default = default
self.nullable = nullable
self.index = index
self.name = None
self.model_class = None

def contribute_to_class(self, model_class, name):
self.name = name
self.model_class = model_class
Field._registry.setdefault(model_class, {})[name] = self

def __get__(self, obj, objtype=None):
if obj is None:
return self
return obj.__dict__.get(self.name, self.default)

def __set__(self, obj, value):
if value is None and not self.nullable:
raise ValueError(f"{self.name} cannot be null")
obj.__dict__[self.name] = value

class ModelMeta(type):
def __new__(mcls, name, bases, namespace):
if name == 'Model':
return super().__new__(mcls, name, bases, namespace)

fields = {}
for attr_name, attr_value in list(namespace.items()):
if isinstance(attr_value, Field):
fields[attr_name] = namespace.pop(attr_name)

cls = super().__new__(mcls, name, bases, namespace)
cls._fields = fields
cls._table_name = name.lower()

for field_name, field in fields.items():
field.contribute_to_class(cls, field_name)

return cls

class Model(metaclass=ModelMeta):
def __init__(self, **kwargs):
for field_name, field in self._fields.items():
if field_name in kwargs:
setattr(self, field_name, kwargs[field_name])
elif field.default is not None:
setattr(self, field_name, field.default() if callable(field.default) else field.default)
elif not field.nullable:
raise ValueError(f"{field_name} is required")

def save(self):
columns = []
values = []
placeholders = []
for name, field in self._fields.items():
value = getattr(self, name)
columns.append(name)
values.append(value)
placeholders.append('?')

sql = f"INSERT INTO {self._table_name} ({', '.join(columns)}) VALUES ({', '.join(placeholders)})"
# execute_sql(sql, values)
print(f"SQL: {sql}")
print(f"Values: {values}")

class IntegerField(Field):
def __init__(self, **kwargs):
super().__init__(int, **kwargs)

class StringField(Field):
def __init__(self, max_length=255, **kwargs):
super().__init__(str, **kwargs)
self.max_length = max_length

def __set__(self, obj, value):
if value is not None and len(str(value)) > self.max_length:
raise ValueError(f"{self.name} exceeds max_length of {self.max_length}")
super().__set__(obj, value)

class User(Model):
id = IntegerField(primary_key=True)
name = StringField(max_length=100)
email = StringField(nullable=True)

user = User(id=1, name="Alice", email="alice@example.com")
user.save()

描述符的__set__方法拦截了属性赋值,在OR-M中同时承担类型验证、长度检查和其他约束校验的职责。Field的contribute_to_class方法在元类创建类时调用,确保字段知道它属于哪个类和叫什么名字。

懒加载(Lazy Loading)是ORM的另一个关键特性,也通过描述符实现:

class LazyField(Field):
def __get__(self, obj, objtype=None):
if obj is None:
return self
value = obj.__dict__.get(self.name)
if value is None and not self.nullable and self.name != 'id':
self._load(obj)
return obj.__dict__.get(self.name)

def _load(self, obj):
# 从数据库加载关联数据
print(f"懒加载 {self.name} for {obj}")
obj.__dict__[self.name] = self._fetch_from_db(obj)

def _fetch_from_db(self, obj):
# 模拟数据库查询
return f"loaded_{self.name}_{obj.id}"

class Profile(Model):
user_id = IntegerField()
bio = LazyField(str, nullable=True)

p = Profile(user_id=1)
print(p.bio) # 首次访问触发_DB查询

外键关系同样由描述符管理:

class ForeignKey(Field):
def __init__(self, to_model, **kwargs):
super().__init__(int, **kwargs)
self.to_model = to_model
self._related_cache_attr = None

def __get__(self, obj, objtype=None):
if obj is None:
return self

fk_value = obj.__dict__.get(self.name)
if fk_value is None:
return None

cache_key = f"_cached_{self.name}"
if cache_key in obj.__dict__:
return obj.__dict__[cache_key]

related_obj = self._fetch_related(fk_value)
obj.__dict__[cache_key] = related_obj
return related_obj

def __set__(self, obj, value):
if isinstance(value, Model):
obj.__dict__[self.name] = value.id
obj.__dict__[f"_cached_{self.name}"] = value
else:
obj.__dict__[self.name] = value

def _fetch_related(self, pk):
print(f"查询关联表: {self.to_model._table_name} WHERE id = {pk}")
# return self.to_model.objects.get(pk=pk)
return MockModel(pk)

class MockModel:
def __init__(self, pk):
self.id = pk

class Post(Model):
id = IntegerField(primary_key=True)
title = StringField(max_length=200)
author_id = ForeignKey(User)

@property
def author(self):
return self.author_id

post = Post(id=1, title="Hello")
post.author_id = User(id=1, name="Bob")
print(post.author.name) # "Bob"

Foreignkey的__get__实现了关联对象的自动加载和缓存。首次访问时检查__dict__中是否有缓存,没有则发起数据库查询,查询结果缓存起来,后续访问直接返回缓存。__set__接受对象或主键值,统一转换成内部存储格式。

N+1查询问题是ORM最有名的性能坑,根源也在描述符的懒加载策略:

class Author(Model):
id = IntegerField(primary_key=True)
name = StringField(max_length=100)

class Book(Model):
id = IntegerField(primary_key=True)
title = StringField(max_length=200)
author_id = ForeignKey(Author)

# N+1问题
books = [Book(id=i, title=f"Book {i}") for i in range(100)]
for book in books:
book.author_id = i % 10
print(book.author.name) # 每个book单独查询一次author

这个循环执行了1次查询获取books + 100次查询获取authors。加上预加载(eager loading)可以解决:

class PrefetchQuerySet:
def __init__(self, model):
self.model = model

def select_related(self, *fields):
self._prefetch_fields = fields
return self

def __iter__(self):
for obj in self._fetch_all():
for field_name in getattr(self, '_prefetch_fields', []):
self._prefetch_relation(obj, field_name)
yield obj

def _prefetch_relation(self, obj, field_name):
field = self.model._fields.get(field_name)
if isinstance(field, ForeignKey):
# 批量查询所有关联对象
fk_values = [getattr(o, field_name) for o in self._cache]
related_objects = self._batch_fetch(field.to_model, fk_values)
for obj in self._cache:
fk_value = getattr(obj, field_name)
related = next((r for r in related_objects if r.id == fk_value), None)
obj.__dict__[f"_cached_{field_name}"] = related

def _batch_fetch(self, model, ids):
ids = list(set(ids))
print(f"批量查询: {model._table_name} WHERE id IN ({', '.join(str(i) for i in ids)})")
return [MockModel(i) for i in ids]

def _fetch_all(self):
self._cache = [Book(id=i, title=f"Book {i}") for i in range(100)]
return self._cache

books_qs = PrefetchQuerySet(Book).select_related('author_id')
for book in books_qs:
print(book.author.name) # 只执行1次批量查询

select_related使用IN查询一次性获取所有关联对象,填回每个对象的缓存。查询次数从1+N降为1+1。

描述符在ORM中的另一个重要场景是属性变更追踪。当字段值被修改时,ORM需要知道哪些字段被改了,以便生成UPDATE语句:

class TrackedField(IntegerField):
def __set__(self, obj, value):
original = obj.__dict__.get(self.name)
super().__set__(obj, value)
if original is not None and original != value:
self._mark_dirty(obj)

def _mark_dirty(self, obj):
if '_dirty_fields' not in obj.__dict__:
obj.__dict__['_dirty_fields'] = set()
obj.__dict__['_dirty_fields'].add(self.name)

user = User(id=1, name="Alice")
user.name = "Bob"
print(user.__dict__.get('_dirty_fields')) # {'name'}

Dirty tracking是ORM生成增量更新语句的基础。save方法检查_dirty_fields集合,只生成被修改字段的update语句,减少不必要的数据库写入。

描述符与装饰器的结合使用在SQLAlchemy 2.0风格的映射中很常见:

def mapped_column(*, primary_key=False, nullable=True, default=None):
def decorator(func):
field = Field(int if primary_key else str)
field.primary_key = primary_key
field.nullable = nullable
field.default = default
func._field = field
return func
return decorator

class DeclarativeMeta(type):
def __new__(mcls, name, bases, namespace):
cls = super().__new__(mcls, name, bases, namespace)
cls._fields = {}
for attr_name in dir(cls):
attr = getattr(cls, attr_name, None)
if hasattr(attr, '_field'):
field = attr._field
field.name = attr_name
field.model_class = cls
cls._fields[attr_name] = field
setattr(cls, attr_name, field)
return cls

class BaseModel(metaclass=DeclarativeMeta):
pass

class Product(BaseModel):
@mapped_column(primary_key=True)
def id(self): pass

@mapped_column(nullable=False)
def name(self): pass

装饰器在类体中被执行,创建一个Field对象并挂载到函数对象上。DeclarativeMeta在创建类时扫描所有带_field的函数,将Field对象注册为类的描述符属性。这个模式结合了装饰器的声明式简洁性和描述符的属性拦截能力。

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询