【设计模式】Python仓储模式:从入门到实战
这篇是设计模式系列的学习笔记,这次来聊聊仓储模式
Python仓储模式:让数据访问不再混乱
前言
这篇是设计模式系列的学习笔记,这次来聊聊仓储模式(Repository Pattern)。
上一篇讲了工厂模式和依赖注入,最后提到了分层架构里的 Repository 层。这篇就专门把数据访问这块的设计讲透。
说实话,很多人写 FastAPI 项目的时候,习惯在路由函数里直接写 SQLAlchemy 查询,一个接口几十行,查询、过滤、分页全揉在一起。项目小的时候还好,一旦业务复杂起来,代码就变成一锅粥了——同样的查询逻辑写了好几遍,改个字段名得改十几个地方,测试更是没法写。
仓储模式就是解决这个问题的。它把数据访问逻辑封装起来,让业务层不用关心数据是怎么存的、怎么取的。听起来是不是有点像工厂模式?没错,设计模式之间经常是互相配合的,仓储模式和工厂模式、依赖注入一起用,能让代码架构非常清晰。
这篇文章会从"为什么需要"讲起,然后一步步实现一个完整的仓储层,包括泛型仓储、工作单元模式这些进阶内容。内容比较多,但都是实战中会用到的,耐心看完肯定有收获。
🏠个人主页:山沐与山
文章目录
- 一、仓储模式是什么
- 二、不用仓储模式会怎样
- 三、仓储模式的基本实现
- 四、泛型仓储:减少重复代码
- 五、工作单元模式配合仓储
- 六、FastAPI完整实战
- 七、仓储模式 vs 直接用ORM
- 八、测试中的巨大优势
- 九、常见问题与最佳实践
- 十、总结
一、仓储模式是什么
1.1 从一个比喻说起
想象一下图书馆。你想借一本书,需要知道这本书具体放在哪个架子的第几层吗?不需要,你只需要告诉图书管理员书名,管理员会帮你找到并拿给你。还书的时候也一样,你不用操心书该放回哪里,交给管理员就行。
仓储模式里的 Repository 就是这个"图书管理员"。你的业务代码只需要说"我要 ID 为 1 的用户"或者"帮我保存这个订单",至于数据是存在 MySQL 还是 PostgreSQL,用的是什么 ORM,业务代码完全不需要知道。
1.2 正式定义
仓储模式(Repository Pattern)是一种将数据访问逻辑与业务逻辑分离的设计模式。它提供一个类似"集合"的接口来访问领域对象,让业务层可以像操作内存中的集合一样操作数据,而不需要关心底层的数据存储细节。
用更直白的话说:Repository 是数据层的抽象,它把"怎么存数据"这件事藏起来,只暴露"存什么、取什么"的接口。
1.3 仓储模式的核心思想
仓储模式有几个核心理念:
第一,隔离数据访问细节。业务层不直接和数据库打交道,而是通过 Repository 这个中间层。这样数据库换了(比如从 MySQL 换到 PostgreSQL),或者 ORM 换了(比如从 SQLAlchemy 换到 Tortoise),业务层的代码不用改。
第二,统一数据访问接口。不管底层是关系型数据库、NoSQL、文件系统还是远程 API,Repository 对外提供的接口是一致的。这让业务代码变得简洁,也让测试变得容易(可以用内存实现替换真实数据库)。
第三,集中管理查询逻辑。所有和某个实体相关的查询都放在对应的 Repository 里。想找"获取活跃用户"的逻辑?去 UserRepository 里找就行,不用在几十个文件里翻。
1.4 仓储模式在分层架构中的位置
在典型的分层架构中,Repository 处于数据访问层,向上对接服务层(Service),向下对接具体的数据存储:
┌─────────────────────────────────────────────────────┐
│ Presentation Layer │
│ (Controllers / Routers) │
│ 处理HTTP请求,调用Service,返回响应 │
└───────────────────────┬─────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────┐
│ Service Layer │
│ (Business Logic) │
│ 业务逻辑,调用Repository获取数据 │
└───────────────────────┬─────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────┐
│ Repository Layer │
│ (Data Access Logic) │
│ 封装数据访问,提供类集合的操作接口 │
└───────────────────────┬─────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────┐
│ Data Storage Layer │
│ (Database / Cache / External API) │
│ 实际的数据存储 │
└─────────────────────────────────────────────────────┘
每一层只和相邻的层打交道,这样修改任何一层都不会影响其他层(只要接口不变)。
二、不用仓储模式会怎样
在讲怎么实现之前,先看看不用仓储模式会遇到什么问题。这样你才能理解为什么要引入这个模式。
2.1 典型的"面条代码"
很多 FastAPI 项目一开始是这样写的,所有逻辑都塞在路由函数里:
from fastapi import FastAPI, Depends, HTTPException
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, func
from database import get_db
from models import User, Order, Product
app = FastAPI()
@app.get("/users/{user_id}")
def get_user(user_id: int, db: Session = Depends(get_db)):
# 直接在路由里写查询逻辑
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
return user
@app.get("/users")
def list_users(
skip: int = 0,
limit: int = 10,
is_active: bool = None,
keyword: str = None,
db: Session = Depends(get_db)
):
# 查询逻辑越来越复杂
query = db.query(User)
if is_active is not None:
query = query.filter(User.is_active == is_active)
if keyword:
query = query.filter(
or_(
User.username.ilike(f"%{keyword}%"),
User.email.ilike(f"%{keyword}%")
)
)
total = query.count()
users = query.offset(skip).limit(limit).all()
return {"total": total, "items": users}
@app.get("/users/{user_id}/orders")
def get_user_orders(
user_id: int,
status: str = None,
start_date: str = None,
end_date: str = None,
db: Session = Depends(get_db)
):
# 又是一大段查询逻辑
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
query = db.query(Order).filter(Order.user_id == user_id)
if status:
query = query.filter(Order.status == status)
if start_date:
query = query.filter(Order.created_at >= start_date)
if end_date:
query = query.filter(Order.created_at <= end_date)
return query.all()
@app.post("/orders")
def create_order(
user_id: int,
product_id: int,
quantity: int,
db: Session = Depends(get_db)
):
# 更多的数据库操作...
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
product = db.query(Product).filter(Product.id == product_id).first()
if not product:
raise HTTPException(status_code=404, detail="商品不存在")
if product.stock < quantity:
raise HTTPException(status_code=400, detail="库存不足")
# 创建订单
order = Order(
user_id=user_id,
product_id=product_id,
quantity=quantity,
total_price=product.price * quantity
)
db.add(order)
# 扣减库存
product.stock -= quantity
db.commit()
db.refresh(order)
return order
看起来能跑,但问题其实很多。
2.2 问题分析
问题一:代码重复严重
"根据 ID 查用户"这个逻辑在 get_user、get_user_orders、create_order 里都写了一遍。如果以后查询条件变了(比如要加上软删除过滤),得改好几个地方。累不累?
# 这段代码重复了三次
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
问题二:业务逻辑和数据访问逻辑混在一起
路由函数本来应该只负责"接收请求、返回响应",现在还要处理复杂的数据库查询。一个函数做了太多事情,违反单一职责原则。
更糟糕的是,业务规则(比如"库存不足不能下单")和数据操作(查询、更新)交织在一起,代码的意图变得不清晰,后来维护的人很难理解这段代码到底在干什么。
问题三:难以测试
想测试 create_order 的业务逻辑,必须准备一个真实的数据库,还得提前插入用户和商品数据。测试又慢又脆弱,稍微改点东西测试就挂了。
如果能把数据访问抽象出来,测试时用个假的实现(返回固定数据),就能快速验证业务逻辑是否正确。
问题四:难以复用
假设有另一个地方也需要"获取用户的订单列表",怎么办?把那一大段代码复制过去?还是提取成函数?提取成函数放在哪里?
没有统一的规范,代码组织会越来越乱,每个人都有自己的写法,最后就是一团糟。
问题五:切换数据源困难
现在用的是 SQLAlchemy 查 PostgreSQL,哪天要加个 Redis 缓存,或者某些数据要从外部 API 获取,改动会非常大。因为数据库操作散落在各处,没有统一的抽象层。
2.3 我们需要什么
总结一下,我们需要:
| 需求 | 说明 |
|---|---|
| 统一的地方 | 放所有和某个实体相关的数据访问逻辑 |
| 抽象的接口 | 让业务层不依赖具体的数据库实现 |
| 易于测试 | 可以用假实现替换真实数据库 |
| 可复用 | 查询逻辑写一次到处用 |
这就是仓储模式要解决的问题。
三、仓储模式的基本实现
理解了问题,现在来看解决方案。
3.1 定义仓储接口
首先,定义一个抽象的接口,规定 Repository 应该提供哪些操作。Python 里可以用 Protocol 或 ABC 来定义接口。
Protocol 是 Python 3.8+ 引入的,它实现的是"结构化子类型"(structural subtyping),也叫鸭子类型的静态版本。只要一个类实现了 Protocol 定义的方法,就被认为是该 Protocol 的实现,不需要显式继承。这比 ABC 更灵活。
from typing import Protocol, TypeVar, Generic, Optional, List
# 定义泛型类型变量
T = TypeVar('T') # 实体类型
ID = TypeVar('ID') # ID类型
class IRepository(Protocol[T, ID]):
"""
仓储接口
这是一个泛型协议,T 是实体类型,ID 是主键类型。
所有具体的 Repository 都应该实现这个接口定义的方法。
"""
def get_by_id(self, id: ID) -> Optional[T]:
"""根据ID获取单个实体,不存在返回 None"""
...
def get_all(self) -> List[T]:
"""获取所有实体"""
...
def add(self, entity: T) -> T:
"""添加一个实体"""
...
def update(self, entity: T) -> T:
"""更新一个实体"""
...
def delete(self, id: ID) -> bool:
"""删除一个实体,返回是否成功"""
...
这个接口定义了最基本的 CRUD 操作(Create、Read、Update、Delete)。注意,这里没有任何和具体数据库相关的代码,它只是一个"契约",规定实现类必须提供这些方法。
3.2 实现具体的仓储
有了接口,接下来实现具体的 Repository。这里以用户仓储为例,使用 SQLAlchemy 作为 ORM:
from typing import Optional, List
from sqlalchemy.orm import Session
from models import User
class UserRepository:
"""
用户仓储的 SQLAlchemy 实现
这个类负责所有和 User 实体相关的数据库操作。
业务层通过这个类来访问用户数据,不需要知道底层用的是什么数据库。
"""
def __init__(self, db: Session):
"""
构造函数
Args:
db: SQLAlchemy 的数据库会话,通过依赖注入传入
"""
self.db = db
def get_by_id(self, user_id: int) -> Optional[User]:
"""
根据ID获取用户
返回 Optional 类型意味着可能返回 None,
调用方需要处理用户不存在的情况。
"""
return self.db.query(User).filter(User.id == user_id).first()
def get_by_username(self, username: str) -> Optional[User]:
"""根据用户名获取用户"""
return self.db.query(User).filter(User.username == username).first()
def get_by_email(self, email: str) -> Optional[User]:
"""根据邮箱获取用户"""
return self.db.query(User).filter(User.email == email).first()
def get_all(
self,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None
) -> List[User]:
"""
获取用户列表
支持分页和过滤。在实际项目中,这个方法的参数可能会更多,
比如排序字段、排序方向、多种过滤条件等。
"""
query = self.db.query(User)
# 条件过滤
if is_active is not None:
query = query.filter(User.is_active == is_active)
# 分页
return query.offset(skip).limit(limit).all()
def count(self, is_active: Optional[bool] = None) -> int:
"""统计用户数量"""
query = self.db.query(User)
if is_active is not None:
query = query.filter(User.is_active == is_active)
return query.count()
def add(self, user: User) -> User:
"""
添加用户
注意这里只是把对象加到 session 里,
真正写入数据库是在 commit 时。
"""
self.db.add(user)
self.db.flush() # flush 会生成 ID,但不提交事务
self.db.refresh(user) # 刷新对象,获取数据库生成的值
return user
def update(self, user: User) -> User:
"""更新用户"""
self.db.flush()
self.db.refresh(user)
return user
def delete(self, user_id: int) -> bool:
"""删除用户"""
user = self.get_by_id(user_id)
if user:
self.db.delete(user)
return True
return False
def exists_by_username(self, username: str) -> bool:
"""检查用户名是否存在"""
return self.db.query(
self.db.query(User).filter(User.username == username).exists()
).scalar()
def exists_by_email(self, email: str) -> bool:
"""检查邮箱是否存在"""
return self.db.query(
self.db.query(User).filter(User.email == email).exists()
).scalar()
这个实现有几点值得注意:
构造函数接收 Session:Repository 不自己创建数据库连接,而是从外部传入。这是依赖注入的体现,让 Repository 可以在不同的上下文中复用(比如测试时传入一个内存数据库的 Session)。
方法命名清晰:get_by_id、get_by_username 这样的命名一眼就能看出方法的作用。Repository 的方法名应该反映业务语义,而不是技术细节。
分离了 flush 和 commit:Repository 只负责数据操作,不负责事务控制(commit)。事务的边界应该由更上层(Service 层或工作单元)来决定,这样可以把多个操作放在一个事务里。
3.3 在服务层使用仓储
有了 Repository,服务层的代码就变得简洁多了:
from typing import Optional, List
from fastapi import HTTPException, status
from repositories.user_repository import UserRepository
from models import User
from schemas import UserCreate, UserUpdate, UserResponse
class UserService:
"""
用户服务
服务层负责业务逻辑,它通过 Repository 来访问数据,
而不直接操作数据库。
"""
def __init__(self, user_repo: UserRepository):
"""通过依赖注入接收 Repository"""
self.user_repo = user_repo
def get_user(self, user_id: int) -> UserResponse:
"""获取用户信息"""
user = self.user_repo.get_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)
return UserResponse.model_validate(user)
def get_users(
self,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None
) -> dict:
"""获取用户列表(带分页)"""
users = self.user_repo.get_all(skip=skip, limit=limit, is_active=is_active)
total = self.user_repo.count(is_active=is_active)
return {
"total": total,
"items": [UserResponse.model_validate(u) for u in users]
}
def create_user(self, user_data: UserCreate) -> UserResponse:
"""创建用户"""
# 业务规则:检查用户名是否已存在
if self.user_repo.exists_by_username(user_data.username):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="用户名已存在"
)
# 业务规则:检查邮箱是否已存在
if self.user_repo.exists_by_email(user_data.email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="邮箱已被注册"
)
# 创建用户实体
user = User(
username=user_data.username,
email=user_data.email,
hashed_password=self._hash_password(user_data.password)
)
# 通过 Repository 保存
user = self.user_repo.add(user)
return UserResponse.model_validate(user)
def _hash_password(self, password: str) -> str:
"""密码哈希"""
from passlib.context import CryptContext
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
return pwd_context.hash(password)
对比一下之前的"面条代码",现在的服务层:
| 改进点 | 说明 |
|---|---|
| 职责清晰 | 只处理业务逻辑,不写 SQL 查询 |
| 代码简洁 | 数据访问都委托给 Repository |
| 易于测试 | 可以 mock Repository 来测试业务逻辑 |
| 可读性强 | 方法名反映业务意图,读代码像读文档 |
3.4 配置依赖注入
最后,用 FastAPI 的依赖注入把这些串起来:
from fastapi import FastAPI, Depends
from sqlalchemy.orm import Session
from database import get_db
from repositories.user_repository import UserRepository
from services.user_service import UserService
app = FastAPI()
def get_user_repository(db: Session = Depends(get_db)) -> UserRepository:
"""创建用户仓储实例"""
return UserRepository(db)
def get_user_service(
user_repo: UserRepository = Depends(get_user_repository)
) -> UserService:
"""创建用户服务实例"""
return UserService(user_repo)
@app.get("/users/{user_id}")
def get_user(
user_id: int,
service: UserService = Depends(get_user_service)
):
"""路由函数变得非常简洁"""
return service.get_user(user_id)
@app.get("/users")
def list_users(
skip: int = 0,
limit: int = 100,
is_active: bool = None,
service: UserService = Depends(get_user_service)
):
return service.get_users(skip=skip, limit=limit, is_active=is_active)
依赖链是这样的:Router → Service → Repository → Database Session。每一层只依赖下一层的抽象,不依赖具体实现,这就是依赖倒置原则的体现。
四、泛型仓储:减少重复代码
上面的 UserRepository 挺好的,但如果项目里有 User、Order、Product、Category 等十几个实体,每个都写一遍基础的 CRUD 方法,太重复了。
泛型仓储(Generic Repository)可以解决这个问题——把通用的 CRUD 逻辑提取到基类,具体的 Repository 只需要继承基类,添加特有的方法。
4.1 实现泛型基类
from typing import TypeVar, Generic, Optional, List, Type, Any
from sqlalchemy.orm import Session
from sqlalchemy import inspect
ModelType = TypeVar("ModelType") # ORM 模型类型
class BaseRepository(Generic[ModelType]):
"""
泛型仓储基类
这个类实现了通用的 CRUD 操作,具体的 Repository 可以继承它,
自动获得这些基础功能,只需要添加特有的查询方法。
"""
def __init__(self, model: Type[ModelType], db: Session):
"""
Args:
model: ORM 模型类,比如 User、Order
db: 数据库会话
"""
self.model = model
self.db = db
def get_by_id(self, id: Any) -> Optional[ModelType]:
"""根据主键获取实体"""
# 自动检测模型的主键字段
pk_columns = inspect(self.model).primary_key
if len(pk_columns) == 1:
pk_name = pk_columns[0].name
return self.db.query(self.model).filter(
getattr(self.model, pk_name) == id
).first()
else:
# 复合主键的情况
filters = [
getattr(self.model, col.name) == id[col.name]
for col in pk_columns
]
return self.db.query(self.model).filter(*filters).first()
def get_all(
self,
skip: int = 0,
limit: int = 100,
**filters
) -> List[ModelType]:
"""
获取实体列表,支持分页和简单过滤
filters 参数允许传入任意字段的过滤条件,比如:
get_all(is_active=True, category_id=1)
"""
query = self.db.query(self.model)
# 应用过滤条件
for field, value in filters.items():
if value is not None and hasattr(self.model, field):
query = query.filter(getattr(self.model, field) == value)
return query.offset(skip).limit(limit).all()
def count(self, **filters) -> int:
"""统计实体数量"""
query = self.db.query(self.model)
for field, value in filters.items():
if value is not None and hasattr(self.model, field):
query = query.filter(getattr(self.model, field) == value)
return query.count()
def exists(self, id: Any) -> bool:
"""检查实体是否存在"""
return self.get_by_id(id) is not None
def add(self, entity: ModelType) -> ModelType:
"""添加实体"""
self.db.add(entity)
self.db.flush()
self.db.refresh(entity)
return entity
def add_many(self, entities: List[ModelType]) -> List[ModelType]:
"""批量添加实体"""
self.db.add_all(entities)
self.db.flush()
for entity in entities:
self.db.refresh(entity)
return entities
def update(self, entity: ModelType) -> ModelType:
"""更新实体"""
self.db.flush()
self.db.refresh(entity)
return entity
def delete(self, id: Any) -> bool:
"""删除实体"""
entity = self.get_by_id(id)
if entity:
self.db.delete(entity)
return True
return False
def delete_many(self, ids: List[Any]) -> int:
"""批量删除实体,返回实际删除的数量"""
pk_columns = inspect(self.model).primary_key
pk_name = pk_columns[0].name
count = self.db.query(self.model).filter(
getattr(self.model, pk_name).in_(ids)
).delete(synchronize_session=False)
return count
4.2 继承基类创建具体仓储
现在,创建具体的 Repository 变得非常简单:
from typing import Optional, List
from sqlalchemy.orm import Session
from sqlalchemy import or_
from models import User, Order, Product
from repositories.base import BaseRepository
class UserRepository(BaseRepository[User]):
"""
用户仓储
继承自 BaseRepository[User],自动获得所有基础 CRUD 方法。
这里只需要添加 User 特有的查询方法。
"""
def __init__(self, db: Session):
super().__init__(User, db)
# ===== User 特有的方法 =====
def get_by_username(self, username: str) -> Optional[User]:
"""根据用户名查询"""
return self.db.query(User).filter(User.username == username).first()
def get_by_email(self, email: str) -> Optional[User]:
"""根据邮箱查询"""
return self.db.query(User).filter(User.email == email).first()
def search(
self,
keyword: str,
skip: int = 0,
limit: int = 100
) -> List[User]:
"""搜索用户(在用户名和邮箱中模糊匹配)"""
return self.db.query(User).filter(
or_(
User.username.ilike(f"%{keyword}%"),
User.email.ilike(f"%{keyword}%")
)
).offset(skip).limit(limit).all()
def exists_by_username(self, username: str) -> bool:
"""检查用户名是否存在"""
return self.db.query(
self.db.query(User).filter(User.username == username).exists()
).scalar()
def exists_by_email(self, email: str) -> bool:
"""检查邮箱是否存在"""
return self.db.query(
self.db.query(User).filter(User.email == email).exists()
).scalar()
class OrderRepository(BaseRepository[Order]):
"""订单仓储"""
def __init__(self, db: Session):
super().__init__(Order, db)
def get_by_user(
self,
user_id: int,
status: Optional[str] = None,
skip: int = 0,
limit: int = 100
) -> List[Order]:
"""获取用户的订单"""
query = self.db.query(Order).filter(Order.user_id == user_id)
if status:
query = query.filter(Order.status == status)
return query.order_by(Order.created_at.desc()).offset(skip).limit(limit).all()
def get_recent_orders(self, days: int = 7) -> List[Order]:
"""获取最近几天的订单"""
from datetime import datetime, timedelta
since = datetime.now() - timedelta(days=days)
return self.db.query(Order).filter(Order.created_at >= since).all()
def calculate_user_total(self, user_id: int) -> float:
"""计算用户的累计消费金额"""
from sqlalchemy import func
result = self.db.query(func.sum(Order.total_price)).filter(
Order.user_id == user_id,
Order.status == "completed"
).scalar()
return result or 0.0
class ProductRepository(BaseRepository[Product]):
"""商品仓储"""
def __init__(self, db: Session):
super().__init__(Product, db)
def get_by_category(
self,
category_id: int,
in_stock: bool = True
) -> List[Product]:
"""获取某分类下的商品"""
query = self.db.query(Product).filter(Product.category_id == category_id)
if in_stock:
query = query.filter(Product.stock > 0)
return query.all()
def search_by_name(self, keyword: str) -> List[Product]:
"""搜索商品"""
return self.db.query(Product).filter(
Product.name.ilike(f"%{keyword}%")
).all()
def decrease_stock(self, product_id: int, quantity: int) -> bool:
"""
扣减库存
使用乐观锁防止超卖:只有 stock >= quantity 时才会更新成功
"""
result = self.db.query(Product).filter(
Product.id == product_id,
Product.stock >= quantity
).update(
{Product.stock: Product.stock - quantity},
synchronize_session=False
)
return result > 0
看到了吗?每个具体的 Repository 只需要几行代码就定义好了,因为基础方法都继承自 BaseRepository。特有的业务查询方法单独定义,代码量大大减少,而且结构清晰。
4.3 泛型仓储的优缺点
| 优点 | 缺点 |
|---|---|
| 减少重复代码 | 复杂查询仍需单独实现 |
| 统一的 CRUD 接口 | 增加一层抽象 |
新增实体 Repository 很快 |
泛型可能让新手困惑 |
| 便于统一添加日志、审计等功能 | 过度抽象可能适得其反 |
五、工作单元模式配合仓储
前面的 Repository 实现里,我们只用了 flush() 而没有 commit()。这是故意的——事务控制应该由更上层来负责。工作单元模式(Unit of Work)就是专门管理事务的。
5.1 什么是工作单元
工作单元的核心思想是:把一组相关的数据库操作放在一起,要么全部成功,要么全部失败。
比如创建订单这个业务,需要:
- 创建订单记录
- 更新商品库存
- 创建支付记录
这三个操作必须是原子的——要么都成功,要么都回滚。如果订单创建成功但库存更新失败,数据就不一致了。
工作单元负责:
- 追踪一个业务操作涉及的所有数据变更
- 在合适的时机统一提交(
commit)或回滚(rollback) - 确保事务的完整性
5.2 实现工作单元
from typing import Optional
from sqlalchemy.orm import Session
from repositories.user_repository import UserRepository
from repositories.order_repository import OrderRepository
from repositories.product_repository import ProductRepository
class UnitOfWork:
"""
工作单元
管理多个 Repository 的操作,确保这些操作在同一个事务中执行。
典型用法:
with UnitOfWork(db) as uow:
user = uow.users.get_by_id(1)
order = Order(user_id=user.id, ...)
uow.orders.add(order)
uow.commit()
"""
def __init__(self, session: Session):
self._session = session
# 懒加载 Repository
self._users: Optional[UserRepository] = None
self._orders: Optional[OrderRepository] = None
self._products: Optional[ProductRepository] = None
@property
def users(self) -> UserRepository:
"""用户仓储(懒加载)"""
if self._users is None:
self._users = UserRepository(self._session)
return self._users
@property
def orders(self) -> OrderRepository:
"""订单仓储"""
if self._orders is None:
self._orders = OrderRepository(self._session)
return self._orders
@property
def products(self) -> ProductRepository:
"""商品仓储"""
if self._products is None:
self._products = ProductRepository(self._session)
return self._products
def __enter__(self):
"""进入上下文管理器"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""退出上下文管理器,异常时自动回滚"""
if exc_type is not None:
self.rollback()
def commit(self):
"""提交事务"""
try:
self._session.commit()
except Exception:
self.rollback()
raise
def rollback(self):
"""回滚事务"""
self._session.rollback()
def flush(self):
"""刷新会话(不提交)"""
self._session.flush()
5.3 在服务层使用工作单元
from fastapi import HTTPException, status
from unit_of_work import UnitOfWork
from models import Order
from schemas import OrderCreate
class OrderService:
"""订单服务"""
def __init__(self, uow: UnitOfWork):
self.uow = uow
def create_order(self, order_data: OrderCreate) -> Order:
"""
创建订单
涉及多个表(订单、库存),必须在一个事务中完成
"""
# 1. 检查用户是否存在
user = self.uow.users.get_by_id(order_data.user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)
# 2. 检查商品并计算总价
total_price = 0
order_items = []
for item in order_data.items:
product = self.uow.products.get_by_id(item.product_id)
if not product:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"商品 {item.product_id} 不存在"
)
if product.stock < item.quantity:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"商品 {product.name} 库存不足"
)
subtotal = product.price * item.quantity
total_price += subtotal
order_items.append({
'product': product,
'quantity': item.quantity,
'price': product.price
})
# 3. 创建订单
order = Order(
user_id=user.id,
total_price=total_price,
status="pending"
)
self.uow.orders.add(order)
self.uow.flush() # 获取订单 ID
# 4. 扣减库存
for item_data in order_items:
product = item_data['product']
quantity = item_data['quantity']
success = self.uow.products.decrease_stock(product.id, quantity)
if not success:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"商品 {product.name} 库存不足"
)
# 5. 提交事务
self.uow.commit()
return order
看到没?UnitOfWork 把多个 Repository 组织在一起,服务层通过 uow.users、uow.orders、uow.products 访问不同的仓储,最后统一 commit()。任何一步失败,整个事务都会回滚。
5.4 配置依赖注入
from fastapi import FastAPI, Depends
from sqlalchemy.orm import Session
from database import get_db
from unit_of_work import UnitOfWork
from services.order_service import OrderService
app = FastAPI()
def get_uow(db: Session = Depends(get_db)) -> UnitOfWork:
"""每个请求一个工作单元"""
return UnitOfWork(db)
def get_order_service(uow: UnitOfWork = Depends(get_uow)) -> OrderService:
"""创建订单服务"""
return OrderService(uow)
@app.post("/orders")
def create_order(
order_data: OrderCreate,
service: OrderService = Depends(get_order_service)
):
return service.create_order(order_data)
六、FastAPI完整实战
把前面讲的内容整合起来,看看一个完整的项目结构:
6.1 项目结构
myapp/
├── main.py # 应用入口
├── config.py # 配置
├── database.py # 数据库连接
├── unit_of_work.py # 工作单元
├── dependencies.py # 依赖注入配置
│
├── models/ # ORM 模型
│ ├── __init__.py
│ ├── user.py
│ ├── order.py
│ └── product.py
│
├── schemas/ # Pydantic 模型
│ ├── __init__.py
│ ├── user.py
│ ├── order.py
│ └── product.py
│
├── repositories/ # 仓储层
│ ├── __init__.py
│ ├── base.py # 泛型基类
│ ├── user_repository.py
│ ├── order_repository.py
│ └── product_repository.py
│
├── services/ # 服务层
│ ├── __init__.py
│ ├── user_service.py
│ ├── order_service.py
│ └── product_service.py
│
└── routers/ # 路由层
├── __init__.py
├── user_router.py
├── order_router.py
└── product_router.py
6.2 依赖注入配置
# dependencies.py
from fastapi import Depends
from sqlalchemy.orm import Session
from database import get_db
from unit_of_work import UnitOfWork
from services.user_service import UserService
from services.order_service import OrderService
from services.product_service import ProductService
def get_uow(db: Session = Depends(get_db)) -> UnitOfWork:
"""获取工作单元"""
return UnitOfWork(db)
def get_user_service(uow: UnitOfWork = Depends(get_uow)) -> UserService:
"""获取用户服务"""
return UserService(uow)
def get_order_service(uow: UnitOfWork = Depends(get_uow)) -> OrderService:
"""获取订单服务"""
return OrderService(uow)
def get_product_service(uow: UnitOfWork = Depends(get_uow)) -> ProductService:
"""获取商品服务"""
return ProductService(uow)
6.3 路由示例
# routers/user_router.py
from typing import Optional
from fastapi import APIRouter, Depends, status
from dependencies import get_user_service
from services.user_service import UserService
from schemas.user import UserCreate, UserUpdate, UserResponse, UserListResponse
router = APIRouter(prefix="/users", tags=["用户管理"])
@router.get("/", response_model=UserListResponse)
def list_users(
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
service: UserService = Depends(get_user_service)
):
"""获取用户列表"""
return service.get_users(skip=skip, limit=limit, is_active=is_active)
@router.get("/{user_id}", response_model=UserResponse)
def get_user(
user_id: int,
service: UserService = Depends(get_user_service)
):
"""获取用户详情"""
return service.get_user(user_id)
@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
def create_user(
user_data: UserCreate,
service: UserService = Depends(get_user_service)
):
"""创建用户"""
return service.create_user(user_data)
@router.put("/{user_id}", response_model=UserResponse)
def update_user(
user_id: int,
user_data: UserUpdate,
service: UserService = Depends(get_user_service)
):
"""更新用户"""
return service.update_user(user_id, user_data)
@router.delete("/{user_id}")
def delete_user(
user_id: int,
service: UserService = Depends(get_user_service)
):
"""删除用户"""
return service.delete_user(user_id)
6.4 主应用
# main.py
from fastapi import FastAPI
from routers import user_router, order_router, product_router
app = FastAPI(
title="电商API",
description="使用仓储模式的 FastAPI 示例",
version="1.0.0"
)
app.include_router(user_router.router)
app.include_router(order_router.router)
app.include_router(product_router.router)
@app.get("/")
def root():
return {"message": "Welcome to the API"}
@app.get("/health")
def health():
return {"status": "healthy"}
七、仓储模式 vs 直接用ORM
看到这里,你可能会想:这搞得这么复杂,直接用 SQLAlchemy 不是更简单?
这是个好问题。仓储模式不是银弹,它有成本也有收益,需要根据项目情况权衡。
7.1 直接使用 ORM 的场景
适合的情况:
- 项目较小,CRUD 为主,业务逻辑简单
- 团队对 ORM 熟悉,开发速度优先
- 不太可能更换数据库或 ORM
- 不需要严格的测试覆盖
# 直接用 ORM,简单直接
@app.get("/users/{user_id}")
def get_user(user_id: int, db: Session = Depends(get_db)):
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404)
return user
这样写没什么问题。如果你的项目就是几个简单接口,引入仓储模式反而是过度设计。
7.2 使用仓储模式的场景
适合的情况:
- 项目较大,业务逻辑复杂
- 需要良好的测试覆盖
- 可能更换底层存储
- 团队协作,需要清晰的代码边界
- 追求代码的可维护性和可扩展性
7.3 对比总结
| 方面 | 直接用ORM | 仓储模式 |
|---|---|---|
| 代码量 | 少 | 多 |
| 学习成本 | 低 | 中等 |
| 开发速度(初期) | 快 | 慢一点 |
| 维护成本(后期) | 高 | 低 |
| 可测试性 | 差 | 好 |
| 灵活性 | 差 | 好 |
| 适合项目规模 | 小型 | 中大型 |
7.4 我的建议
| 项目类型 | 建议 |
|---|---|
| 小项目(几个简单接口) | 直接用 ORM |
| 中型项目 | 至少把数据访问逻辑提取到单独模块 |
| 大型项目 | 完整分层架构 + 工作单元 |
| 需要高测试覆盖 | 仓储模式 |
八、测试中的巨大优势
仓储模式最大的优势之一就是方便测试。
8.1 不使用仓储模式的测试困境
# 业务代码直接操作数据库
class UserService:
def __init__(self, db: Session):
self.db = db
def create_user(self, username: str, email: str, password: str):
if self.db.query(User).filter(User.username == username).first():
raise ValueError("用户名已存在")
user = User(username=username, email=email, password=password)
self.db.add(user)
self.db.commit()
return user
# 测试代码
def test_create_user():
# 必须准备真实数据库
db = TestSessionLocal()
service = UserService(db)
# 清理可能存在的测试数据
db.query(User).filter(User.username == "testuser").delete()
db.commit()
# 执行测试
user = service.create_user("testuser", "test@example.com", "password")
assert user.username == "testuser"
# 清理
db.delete(user)
db.commit()
db.close()
问题很多:需要真实数据库、测试数据要准备和清理、测试速度慢、测试之间可能互相影响。
8.2 使用仓储模式的测试
创建一个内存实现的"假仓储":
from typing import Optional, List, Dict
from models import User
class FakeUserRepository:
"""假的用户仓储,用于测试"""
def __init__(self):
self._users: Dict[int, User] = {}
self._next_id = 1
def get_by_id(self, user_id: int) -> Optional[User]:
return self._users.get(user_id)
def get_by_username(self, username: str) -> Optional[User]:
for user in self._users.values():
if user.username == username:
return user
return None
def add(self, user: User) -> User:
user.id = self._next_id
self._next_id += 1
self._users[user.id] = user
return user
def exists_by_username(self, username: str) -> bool:
return self.get_by_username(username) is not None
def clear(self):
"""清空所有数据"""
self._users.clear()
self._next_id = 1
测试变得简单快速:
import pytest
from services.user_service import UserService
class FakeUnitOfWork:
"""假的工作单元"""
def __init__(self):
self.users = FakeUserRepository()
self.committed = False
def commit(self):
self.committed = True
def rollback(self):
pass
class TestUserService:
@pytest.fixture
def uow(self):
return FakeUnitOfWork()
@pytest.fixture
def service(self, uow):
return UserService(uow)
def test_create_user_success(self, service, uow):
"""测试成功创建用户"""
user_data = UserCreate(
username="testuser",
email="test@example.com",
password="password123"
)
result = service.create_user(user_data)
assert result.username == "testuser"
assert uow.committed
def test_create_user_duplicate_username(self, service, uow):
"""测试用户名重复"""
# 先添加一个用户
existing_user = User(
username="existinguser",
email="existing@example.com",
hashed_password="xxx"
)
uow.users.add(existing_user)
# 尝试创建同名用户
user_data = UserCreate(
username="existinguser",
email="new@example.com",
password="password123"
)
with pytest.raises(HTTPException) as exc_info:
service.create_user(user_data)
assert exc_info.value.status_code == 400
assert "用户名已存在" in exc_info.value.detail
8.3 FastAPI 集成测试
import pytest
from fastapi.testclient import TestClient
from main import app
from dependencies import get_uow
class TestUserAPI:
@pytest.fixture
def client(self):
fake_uow = FakeUnitOfWork()
# 预设测试数据
fake_uow.users.add(User(
username="existinguser",
email="existing@example.com",
hashed_password="xxx"
))
# 覆盖依赖
app.dependency_overrides[get_uow] = lambda: fake_uow
client = TestClient(app)
yield client
app.dependency_overrides.clear()
def test_create_user(self, client):
response = client.post("/users/", json={
"username": "newuser",
"email": "new@example.com",
"password": "password123"
})
assert response.status_code == 201
assert response.json()["username"] == "newuser"
8.4 测试的好处总结
| 好处 | 说明 |
|---|---|
| 快速 | 内存运行,毫秒级完成 |
| 隔离 | 每个测试有自己的假仓储 |
| 稳定 | 不受网络、数据库状态影响 |
| 简单 | 准备测试数据很方便 |
| 专注 | 单独测试业务逻辑 |
九、常见问题与最佳实践
9.1 Repository 应该返回什么?
返回 ORM 模型还是 DTO?
建议:Repository 返回 ORM 模型,转换成 DTO 是 Service 层的事。
# Repository 返回 ORM 对象
class UserRepository:
def get_by_id(self, user_id: int) -> Optional[User]:
return self.db.query(User).filter(User.id == user_id).first()
# Service 层做转换
class UserService:
def get_user(self, user_id: int) -> UserResponse:
user = self.repo.get_by_id(user_id) # User(ORM)
return UserResponse.model_validate(user) # UserResponse(DTO)
9.2 复杂查询放在哪里?
原则:和数据获取相关的放 Repository,和业务规则相关的放 Service。
# Repository:复杂的数据查询
class OrderRepository:
def get_monthly_stats(self, year: int, month: int) -> dict:
"""获取月度订单统计(数据查询)"""
from sqlalchemy import func
result = self.db.query(
func.count(Order.id).label('count'),
func.sum(Order.total_price).label('total')
).filter(
func.extract('year', Order.created_at) == year,
func.extract('month', Order.created_at) == month,
Order.status == 'completed'
).first()
return {
'count': result.count or 0,
'total': float(result.total or 0)
}
# Service:基于数据的业务计算
class ReportService:
def generate_monthly_report(self, year: int, month: int) -> dict:
"""生成月度报告(业务逻辑)"""
stats = self.uow.orders.get_monthly_stats(year, month)
# 业务计算
avg_order_value = (
stats['total'] / stats['count']
if stats['count'] > 0 else 0
)
return {
'order_count': stats['count'],
'total_revenue': stats['total'],
'avg_order_value': avg_order_value
}
9.3 如何处理 N+1 问题?
使用 joinedload 或 selectinload 预加载关联数据:
from sqlalchemy.orm import joinedload, selectinload
class OrderRepository:
def get_with_items(self, order_id: int) -> Optional[Order]:
"""获取订单及其订单项(预加载)"""
return self.db.query(Order).options(
selectinload(Order.items).selectinload(OrderItem.product)
).filter(Order.id == order_id).first()
9.4 仓储应该有多细粒度?
一个实体一个仓储是常见做法。但紧密相关的实体可以共用:
# 聚合根模式:订单项通过订单仓储访问
class OrderRepository:
def add_item(self, order_id: int, item: OrderItem): ...
def remove_item(self, order_id: int, item_id: int): ...
def get_items(self, order_id: int) -> List[OrderItem]: ...
9.5 性能优化建议
批量操作:避免循环中一个个操作数据库
# 不好:N次数据库操作
for user_id in user_ids:
user = repo.get_by_id(user_id)
# 好:1次数据库操作
users = repo.get_by_ids(user_ids)
分页:始终使用分页
# 好
users = repo.get_all(skip=0, limit=100)
# 不好
users = repo.get_all() # 可能有几万条
十、总结
这篇文章把仓储模式的各个方面都讲到了,总结一下关键点:
仓储模式是什么
它是数据访问层的抽象,把"怎么存数据"的细节藏起来,对外提供类似"集合"的接口。Repository 就像图书馆管理员,业务代码只需要说"我要这本书",不用管书放在哪个架子上。
为什么要用仓储模式
主要是为了解耦、可测试、可维护。数据访问逻辑集中管理,业务层不依赖具体的数据库实现,测试时可以用假实现替换真数据库。
怎么实现
从最基础的 Repository 类开始,然后引入泛型基类减少重复代码,再配合工作单元模式管理事务。
什么时候用
小项目直接用 ORM 就够了,中大型项目、需要高测试覆盖的项目建议使用。
关键要点总结
| 概念 | 说明 | 适用场景 |
|---|---|---|
Repository |
数据访问抽象 | 封装查询逻辑 |
| 泛型仓储 | 减少重复的 CRUD 代码 | 多实体项目 |
| 工作单元 | 事务管理 | 多表操作 |
| 假仓储 | 测试用的内存实现 | 单元测试 |
仓储模式 + 工厂模式 + 依赖注入 + 单例,这几个模式组合起来,形成了完整的分层架构。
下一篇打算讲观察者模式,这个在事件驱动的场景(比如发消息通知、触发异步任务)非常有用。在 FastAPI 里配合背景任务或消息队列使用,能让系统架构更加松耦合。
热门专栏推荐
- Agent小册
- 服务器部署
- Java基础合集
- Python基础合集
- Go基础合集
- 大数据合集
- 前端小册
- 数据库合集
- Redis 合集
- Spring 全家桶
- 微服务全家桶
- 数据结构与算法合集
- 设计模式小册
- 消息队列合集
等等等还有许多优秀的合集在主页等着大家的光顾,感谢大家的支持
文章到这里就结束了,如果有什么疑问的地方请指出,诸佬们一起来评论区一起讨论😊
希望能和诸佬们一起努力,今后我们一起观看感谢您的阅读🙏
如果帮助到您不妨3连支持一下,创造不易您们的支持是我的动力🌟
更多推荐



所有评论(0)