@
johnsona 边看文档边写,或者自己抽象一层。我打包的 base 参考
```
class BaseModel(db.Model):
__abstract__ = True
create_time = db.Column(db.DateTime, nullable=False, index=True, comment='创建时间')
update_time = db.Column(db.DateTime, nullable=False, index=True, comment='修改时间')
# 另外一种方式处理自动更新
# created_at = db.Column(db.DateTime, default=sqlalchemy.func.now(), nullable=False)
# updated_at = db.Column(db.DateTime, default=sqlalchemy.func.now(), onupdate=sqlalchemy.func.now(), nullable=False)
@
classmethod def qry(cls) -> BaseQuery:
return db.session.query(cls)
@
classmethod def filter(cls, *args) -> Query:
return cls.qry().filter(*args)
@
classmethod def count(cls, rsql=None, sort=None, ignore_fields=None):
_q = cls.qry().filter(to_db_query_sql(cls, rsql, ignore_fields))
if sort is not None and sort != '':
_q = _q.order_by(to_db_order(cls, sort))
return _q.count()
@
classmethod def count_by(cls, **kwargs):
return cls.qry().filter_by(**kwargs).count()
@
classmethod def find(cls, rsql=None, sort=None, ignore_fields=None):
_q = cls.qry().filter(to_db_query_sql(cls, rsql, ignore_fields))
if sort is not None and sort != '':
_q = _q.order_by(to_db_order(cls, sort))
return _q.all()
@
classmethod def page_find(cls, page_num, page_size, rsql=None, sort=None, and_sql=None, ignore_fields=None):
if rsql is None or rsql.strip() == '':
rsql = and_sql
elif and_sql is not None:
rsql = "(%s) and (%s)" % (rsql, and_sql)
_q = cls.qry().filter(to_db_query_sql(cls, rsql, ignore_fields))
if sort is not None and sort != '':
_q = _q.order_by(to_db_order(cls, sort))
ret = _q.paginate(page_num, page_size)
return ret
@
classmethod def get(cls, _id):
# print(cls)
# print(__class__)
# print(__class__.__dict__)
# print(__class__.__bases__)
# print(__class__.__subclasses__())
return cls.qry().get(_id)
@
classmethod def delete(cls, _id):
cls.qry().filter(
cls.id == _id).delete()
db.session.commit()
def save(self):
try:
db.session.add(self)
db.session.flush()
db.session.commit()
except BaseException as e:
_log.error(e)
db.session.rollback()
@
classmethod def add(cls, obj, save=True):
item = None
_type = type(obj)
if _type is dict:
item = cls.from_dict(obj)
elif _type is cls:
item = obj
if item is not None:
if hasattr(item, 'check') and callable(getattr(item, 'check')):
getattr(item, 'check')()
if save:
item.save()
return item
def update_from_dict(self, obj, valid_fields=None):
columns = [m.key for m in self.__table__.columns]
valid_keys = [k for k in obj.keys() if k in columns]
item = self.from_dict(obj)
for k in valid_keys:
if valid_fields is not None and k not in valid_fields:
continue
setattr(self, k, getattr(item, k))
@
classmethod def update(cls, obj, save=True, valid_fields=None):
item = None
_type = type(obj)
if _type is dict:
pk = inspect(cls).primary_key[0].name
if pk not in obj.keys():
return None
item = cls.get(obj[pk])
item.update_from_dict(obj, valid_fields)
elif _type is cls:
item = obj
if item is not None:
if hasattr(item, 'check') and callable(getattr(item, 'check')):
getattr(item, 'check')()
if save:
item.save()
return item
@
classmethod def find_one_by(cls, **kwargs):
return cls.qry().filter_by(**kwargs).one_or_none()
@
classmethod def find_first_by(cls, **kwargs):
return cls.qry().filter_by(**kwargs).first()
@
classmethod def find_all_by(cls, **kwargs) -> list:
return cls.qry().filter_by(**kwargs).all()
@
classmethod def from_dict(cls, obj):
if type(obj) not in [dict, RowProxy]:
return None
_obj = {}
columns = inspect(cls).columns
for column in columns:
if
column.name not in obj.keys():
continue
obj_value = obj[
column.name]
_type = type(column.type)
value = obj_value
if value is None:
pass
elif _type is Date:
value = to_local_datetime(obj_value).date()
elif _type is DateTime:
value = to_local_datetime(obj_value)
elif _type is Boolean:
if type(value) is str:
value = value.lower() == 'true'
else:
value = bool(obj_value)
elif _type is Integer:
if type(obj_value) is int:
value = obj_value
elif type(obj_value) is str:
if not obj_value.isdigit():
value = None
else:
value = int(obj_value)
elif _type is Float:
value = float(obj_value)
elif _type is int:
value = int(obj_value)
elif _type:
value = str(obj_value)
_obj[
column.name] = value
# columns = [m.key for m in cls.__table__.columns]
# _obj = {k: v for k, v in obj.items() if k in columns}
return cls(**_obj)
def to_dict(self, rel=True, ignore=None):
"""Returns the model properties as a dict
:rtype: dict
"""
result = {}
if ignore is None:
ignore = []
columns = set([attr for attr in dir(self) if not attr.startswith('_')
and attr not in ignore
and attr not in ['metadata', 'query', 'query_class']
and not callable(getattr(self, attr))
])
if not rel:
columns = [m.key for m in self.__table__.columns]
for attr in columns:
value = getattr(self, attr)
if isinstance(value, list):
result[attr] = list(map(lambda x: x.to_dict() if hasattr(x, "to_dict") else x, value))
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
result[attr] = dict(map(
lambda item: (item[0], item[1].to_dict())
if hasattr(item[1], "to_dict") else item,
value.items()
))
else:
result[attr] = value
return result
def to_json(self):
def extended_encoder(x):
if isinstance(x, datetime.datetime):
return x.strftime("%Y-%m-%d %H:%M:%S")
if isinstance(x,
datetime.date):
return x.strftime("%Y-%m-%d")
if isinstance(x, UUID):
return str(x)
json_ignore = []
if hasattr(self, '_json_ignore'):
json_ignore = self._json_ignore
return json.dumps(self.to_dict(ignore=json_ignore), default=extended_encoder)
@
staticmethod def result_to_dict(result, keys):
new_result = []
for row in result:
row = dict(zip([
attr.name for attr in keys], row))
new_result.append(row)
return new_result
```