CRUD pipeline
Source location
src/sqlmodel_ext/mixins/table.py — TableBaseMixin and UUIDTableBaseMixin
This chapter explains how save() / get() / update() work internally. For full method signatures see CRUD methods reference; for typical usage see the How-to guides.
TableBaseMixin basics
class TableBaseMixin(AsyncAttrs):
_has_table_mixin: ClassVar[bool] = True # Lets the metaclass identify "this is a table class"
id: int | None = Field(default=None, primary_key=True)
created_at: datetime = Field(default_factory=now)
updated_at: datetime = Field(
sa_type=DateTime,
sa_column_kwargs={'default': now, 'onupdate': now},
default_factory=now,
)Inheriting AsyncAttrs enables await obj.awaitable_attrs.some_relation syntax on model objects, providing additional async safety.
_has_table_mixin = True is a marker that lets the metaclass automatically add table=True in __new__.
save() implementation
save() is the most core method, containing optimistic lock retry logic:
async def save(self, session, ..., optimistic_retry_count=0):
cls = type(self)
instance = self
retries_remaining = optimistic_retry_count
current_data = None
while True:
session.add(instance)
try:
await session.commit()
break # Success, exit
except StaleDataError as e: # Version conflict!
await session.rollback()
if retries_remaining <= 0:
raise OptimisticLockError(
message=f"optimistic lock conflict",
model_class=cls.__name__,
record_id=str(instance.id),
expected_version=instance.version,
original_error=e,
) from e
retries_remaining -= 1
# Save current modifications (excluding metadata fields)
if current_data is None:
current_data = self.model_dump(
exclude={'id', 'version', 'created_at', 'updated_at'}
)
# Get the latest record from the database
fresh = await cls.get(session, cls.id == self.id)
if fresh is None:
raise OptimisticLockError("record has been deleted") from e
# Re-apply my changes to the latest record
for key, value in current_data.items():
if hasattr(fresh, key):
setattr(fresh, key, value)
instance = fresh
# After commit, use sa_inspect to safely read ID (avoiding MissingGreenlet)
_insp = inspect(instance)
_instance_id = _insp.identity[0] if _insp.identity else None
result = await cls.get(session, cls.id == _instance_id, load=load)
return resultsession.add() behavior
session.add() does not execute SQL. SQLAlchemy automatically decides during commit() or flush():
- Object is new →
INSERT - Object is already in Session and has changes →
UPDATE
Why must you use the return value?
Object expiration
session.commit() expires all objects in the Session. The original user object's attributes become "expired", triggering implicit queries on access. save() returns a fresh object loaded via cls.get() — this also passes through the Redis cache (if CachedTableBaseMixin is enabled).
update() implementation
async def update(self, session, other, extra_data=None,
exclude_unset=True, exclude=None, ...):
update_data = other.model_dump(exclude_unset=exclude_unset, exclude=exclude)
instance.sqlmodel_update(update_data, update=extra_data)
session.add(instance)
await session.commit()PATCH semantics
The key is exclude_unset=True: only explicitly set fields are updated; unset fields retain their original values. That's PATCH semantics — distinct from PUT (full replacement).
get() implementation
This is the longest method (~300 lines), handling multiple scenarios in layers. For the full signature see reference/crud-methods.
Layer 1: Basic query
statement = select(cls)
if condition is not None:
statement = statement.where(condition)Layer 2: Pagination + sorting
if table_view:
order_column = cls.created_at if table_view.order == "created_at" else cls.updated_at
order_by = [desc(order_column) if table_view.desc else asc(order_column)]
statement = statement.order_by(*order_by).offset(table_view.offset).limit(table_view.limit)Layer 3: Time filtering
@classmethod
def _build_time_filters(cls, created_before_datetime, created_after_datetime, ...):
filters = []
if created_after_datetime is not None:
filters.append(col(cls.created_at) >= created_after_datetime)
if created_before_datetime is not None:
filters.append(col(cls.created_at) < created_before_datetime)
...
return filtersLayer 4: Relation preloading
if load:
load_list = load if isinstance(load, list) else [load]
load_chains = cls._build_load_chains(load_list)
for chain in load_chains:
loader = selectinload(chain[0])
for rel in chain[1:]:
loader = loader.selectinload(rel)
statement = statement.options(loader)_build_load_chains automatically detects relation dependencies and builds nested loading chains. For example, load=[User.profile, Profile.avatar] → selectinload(User.profile).selectinload(Profile.avatar).
Layer 5: Polymorphic queries
if is_jti:
polymorphic_cls = with_polymorphic(cls, '*')
statement = select(polymorphic_cls) # Auto-JOINs all sub-tables
if is_sti:
descendant_identities = [m.polymorphic_identity for m in mapper.self_and_descendants]
statement = statement.where(poly_on.in_(descendant_identities))JTI uses with_polymorphic to auto-JOIN sub-tables. STI requires manually adding a WHERE _polymorphic_name IN (...) filter — SQLAlchemy/SQLModel doesn't add this discriminator filter automatically; sqlmodel-ext patches it in.
Layer 6: fetch_mode determines return value
result = await session.exec(statement)
if fetch_mode == "first": return result.first()
elif fetch_mode == "one": return result.one()
elif fetch_mode == "all": return list(result.all())rel() and cond() — type-safe helpers
def rel(relationship: object) -> QueryableAttribute[Any]:
"""Cast Relationship field to QueryableAttribute, fixing basedpyright inference"""
if not isinstance(relationship, QueryableAttribute):
raise AttributeError(...)
return relationship
def cond(expr: ColumnElement[bool] | bool) -> ColumnElement[bool]:
"""Narrow column comparison expression to ColumnElement[bool], fixing & | operator type errors"""
return cast(ColumnElement[bool], expr)These two functions are similar to SQLModel's col() — they perform type assertions/casts at runtime to satisfy static type checkers (basedpyright).
get_one() implementation
@classmethod
async def get_one(cls, session, id, *, load=None, with_for_update=False):
return await cls.get(
session, col(cls.id) == id,
fetch_mode='one', load=load, with_for_update=with_for_update,
)Essentially a shortcut for get(fetch_mode='one'). UUIDTableBaseMixin provides a more precisely typed override (accepting only uuid.UUID).
get_exist_one() FastAPI integration
@classmethod
async def get_exist_one(cls, session, id, load=None):
instance = await cls.get(session, col(cls.id) == id, load=load)
if not instance:
if _HAS_FASTAPI:
raise _FastAPIHTTPException(status_code=404, detail="Not found")
raise RecordNotFoundError("Not found")
return instanceAdaptive exception
At module import time, it checks whether FastAPI is installed. If so, it raises HTTPException(404); otherwise, it raises RecordNotFoundError. This avoids making FastAPI a hard dependency.
sanitize_integrity_error() implementation
@staticmethod
def sanitize_integrity_error(e: IntegrityError, default_message: str = "...") -> str:
orig = e.orig
# SQLSTATE 23514 (check_violation): PostgreSQL trigger's RAISE EXCEPTION
if orig is not None and getattr(orig, 'sqlstate', None) == '23514':
error_msg = str(orig)
if '\n' in error_msg:
error_msg = error_msg.split('\n')[0] # Take the first line
if error_msg.startswith('ERROR:'):
error_msg = error_msg[6:].strip()
return error_msg
return default_messagePostgreSQL triggers can produce business-semantic error messages via RAISE EXCEPTION ... USING ERRCODE = 'check_violation', which are safe to display to users. Other constraint errors (FK, unique, etc.) might leak table structure, so the default message is returned.
FOR UPDATE tracking
In the get() method, when with_for_update=True, the locked instance's id() is recorded in session.info:
SESSION_FOR_UPDATE_KEY = '_for_update_locked'
# In get():
if with_for_update:
locked: set[int] = session.info.setdefault(SESSION_FOR_UPDATE_KEY, set())
locked.add(id(instance))This is used by the @requires_for_update decorator for runtime checking.
count() implementation
@classmethod
async def count(cls, session, condition=None, ...):
statement = select(func.count()).select_from(cls)
if condition is not None:
statement = statement.where(condition)
result = await session.scalar(statement)
return result or 0Uses database-level COUNT(*) rather than Python's len().
get_with_count() implementation
@classmethod
async def get_with_count(cls, session, condition=None, *, table_view=None, ...):
total_count = await cls.count(session, condition, ...)
items = await cls.get(session, condition, fetch_mode="all", table_view=table_view, ...)
return ListResponse(count=total_count, items=items)Essentially a combination of count() + get(fetch_mode="all"). The order doesn't affect the result but reading count() first then get() is more intuitive.