Source code for upgrade_marshmallow.upgrade
"""Upgrade Assembly"""
import ast
import typing as T
from collections import defaultdict
import astor
ARG = T.TypeVar('ARG') # TODO: bind str
VALUE = T.TypeVar('VALUE') # TODO: bind Any
PATH = T.TypeVar('PATH') # TODO: bind str
RENAME_ARGUMENTS: dict = {
'default': 'dump_default',
'missing': 'load_default',
}
_FIELD_ARGUMENTS: list = [
# 'default',
# 'missing',
'data_key',
'attribute',
'validate',
'required',
'allow_none',
'load_only',
'dump_only',
'error_messages',
'metadata'
]
__fields_args = _FIELD_ARGUMENTS.copy()
__fields_args.extend(RENAME_ARGUMENTS.keys())
__fields_args.extend(RENAME_ARGUMENTS.values())
field_arguments = set(__fields_args)
MAPPING = {}
TREE = defaultdict(dict)
[docs]def replace_as_metadata_kw(
file: PATH,
*,
indent: int = 4,
):
"""
Args:
- file (str): source code file path
- indent (int): indent of source code
"""
tree = astor.parse_file(file)
find_rename_fields = FindRenameFields()
find_rename_ma = FindRenameMarshmallow()
find_imported_into_field = FindImportedIntoField()
# Replace metadata -------------------------------------------------------
rpl_metadata = ReplaceAsMetadataKW()
rpl_metadata_id = id(rpl_metadata)
MAPPING[id(find_rename_fields)] = rpl_metadata_id
MAPPING[id(find_rename_ma)] = rpl_metadata_id
MAPPING[id(find_imported_into_field)] = rpl_metadata_id
TREE[rpl_metadata_id]['fields'] = 'fields'
TREE[rpl_metadata_id]['marshmallow'] = 'marshmallow'
TREE[rpl_metadata_id]['imported_fields'] = set()
find_rename_fields.visit(tree)
find_rename_ma.visit(tree)
find_imported_into_field.visit(tree)
tree = rpl_metadata.visit(tree)
# Replace default->dump_default, missing->load_default -------------------
rpl_default_missing = ReplaceDefaultAndMissing()
rpl_default_missing_id = id(rpl_default_missing)
MAPPING[id(find_rename_fields)] = rpl_default_missing_id
MAPPING[id(find_rename_ma)] = rpl_default_missing_id
MAPPING[id(find_imported_into_field)] = rpl_default_missing_id
TREE[rpl_default_missing_id]['fields'] = 'fields'
TREE[rpl_default_missing_id]['marshmallow'] = 'marshmallow'
TREE[rpl_default_missing_id]['imported_fields'] = set()
find_rename_fields.visit(tree)
find_rename_ma.visit(tree)
find_imported_into_field.visit(tree)
new_tree = rpl_default_missing.visit(tree)
ast.fix_missing_locations(new_tree)
# TODO: formatter + comment issue
# - formatter: yapf + --style='{based_on_style: pep8, indent_width: 2}'
# - comment: no solution yet
codes = astor.to_source(new_tree, indent_with=' ' * indent)
return codes
[docs]class FindRenameFields(ast.NodeTransformer):
def visit_ImportFrom(self, node: ast.ImportFrom) -> T.Any:
node_transformer_id = MAPPING[id(self)]
if node.module == 'marshmallow':
for alias in node.names:
if alias.name == 'fields':
# case: `from marshmallow import fields as ma_fields`
TREE[node_transformer_id]['fields'] = alias.asname or alias.name
return node
[docs]class FindRenameMarshmallow(ast.NodeTransformer):
def visit_Import(self, node: ast.Import) -> T.Any:
node_transformer_id = MAPPING[id(self)]
for alias in node.names:
if alias.name == 'marshmallow':
# case: `import marshmallow as ma`
TREE[node_transformer_id]['marshmallow'] = alias.asname or alias.name
return node
[docs]class FindImportedIntoField(ast.NodeTransformer):
def visit_ImportFrom(self, node: ast.ImportFrom) -> T.Any:
node_transformer_id = MAPPING[id(self)]
if node.module.endswith('marshmallow.fields'):
for alias in node.names:
TREE[node_transformer_id]['imported_fields'].add(alias.asname or alias.name)
return node
def is_marshmallow_feild_call_expression(node: ast.Call, rpl_obj_id: int) -> bool:
# case:
# from marshmallow import fields
# class FooSchema(Schema):
# foo = fields.String(title='foo', description='foo')
# ^
if isinstance(node.func, ast.Attribute):
attr: ast.Attribute = node.func
# case: `marshmallow.fields.String()`
# Call(func=Attribute(value=Attribute(value=Name(id='marshmallow', ctx=), attr='fields', ctx=), attr='String', ctx=), args=[], keywords=[
# ^ ^ ^ ^ ^ ^
# node node.func node.func.value node.func.value.value node.func.value.attr node.func.attr
#
if isinstance(attr.value, ast.Attribute):
if isinstance(attr.value.value, ast.Name):
if attr.value.value.id == TREE[rpl_obj_id]['marshmallow']:
if attr.value.attr == 'fields':
return True
# foo = fields.String(title='foo', description='foo')
# ^
if isinstance(attr.value, ast.Name):
if node.func.value.id == TREE[rpl_obj_id]['fields']:
return True
# Just others ast.Call, but not rely to marshmallow.fields
else:
pass
# case: `from marshmallow.fields import String`
elif (isinstance(node.func, ast.Name)
and node.func.id in TREE[rpl_obj_id]['imported_fields']):
return True
# Just others ast.Call, but not rely to marshmallow.fields
else:
pass
return False
[docs]class ReplaceAsMetadataKW(ast.NodeTransformer):
def visit_Call(self, node: ast.Call) -> T.Any:
if not is_marshmallow_feild_call_expression(node, id(self)):
return node
kws = node.keywords[:]
node.keywords.clear()
metadata: T.List[T.Tuple[ARG, VALUE]] = []
for kw_obj in kws:
# metadata
if kw_obj.arg not in field_arguments:
metadata.append((kw_obj.arg, kw_obj.value))
# the sig.parameters
else:
node.keywords.append(kw_obj)
for kw_obj in node.keywords:
# TODO: fields.String(required=True, title='name', metadata={'description': '...'})
if kw_obj.arg == 'metadata':
raise RuntimeError('not handle ths condition!')
if metadata:
node.keywords.append(ast.keyword(
arg='metadata',
value=ast.Dict(
keys=[ast.Constant(value=arg, kind=None)
for arg, _ in metadata],
values= [value for _, value in metadata]
)
))
return node
[docs]class ReplaceDefaultAndMissing(ast.NodeTransformer):
def visit_Call(self, node: ast.Call) -> T.Any:
if not is_marshmallow_feild_call_expression(node, id(self)):
return node
kws = node.keywords[:]
node.keywords.clear()
# rename = []
for kw_obj in kws:
if kw_obj.arg in RENAME_ARGUMENTS:
# rename.append()
node.keywords.append(ast.keyword(
arg=RENAME_ARGUMENTS[kw_obj.arg],
value=kw_obj.value
))
else:
node.keywords.append(kw_obj)
return node