Skip to content
Snippets Groups Projects
Verified Commit cc24052e authored by Nik | Klampfradler's avatar Nik | Klampfradler
Browse files

Rewrite import logic to not rely on uninstantiated field types

parent 4a692419
No related branches found
No related tags found
No related merge requests found
......@@ -13,11 +13,18 @@ Added
~~~~~
* ZIP files with multiple CSVs and accompanying photos can now be imported
* Field types can now provide values for arbitrary alternative DB fields
Fixed
~~~~~
* CSV files with non-UTF-8 charsets can now be imported
* Imports could expose undefined behaviour when hitting the same interpreter process
Changed
~~~~~~~
* Refactored import and field type code for better readability
`2.1`_ - 2022-01-17
-------------------
......
......@@ -31,6 +31,7 @@ pedasos_courses:
has_header_row: true
separator: "\t"
fields:
- ignore
- short_name
- class_range
- group_subject_short_name
......
......@@ -14,7 +14,7 @@ from aleksis.apps.csv_import.util.class_range_helpers import (
parse_class_range,
)
from aleksis.apps.csv_import.util.converters import converter_registry
from aleksis.apps.csv_import.util.import_helpers import bulk_get_or_create, with_prefix
from aleksis.apps.csv_import.util.import_helpers import with_prefix
from aleksis.core.models import Group, Person, SchoolTerm
from aleksis.core.util.core_helpers import get_site_preferences
......@@ -26,7 +26,7 @@ class FieldType:
data_type: type = str
db_field: str = ""
converter: Optional[Union[str, Sequence[str]]] = None
alternative: Optional[str] = None
alternative_db_fields: Optional[str] = None
args: Optional[dict] = None
@classmethod
......@@ -57,24 +57,27 @@ class FieldType:
return _converter_chain
@classmethod
def get_alternative(cls) -> Optional[str]:
return cls.alternative
def get_args(self) -> dict:
return self.args or {}
def get_column_name(self) -> str:
return self.name
def get_db_field(self):
def get_db_field(self) -> str:
if self.get_args().get("db_field"):
return self.get_args()["db_field"]
return self.db_field
def get_alternative_db_fields(self) -> list[str]:
if self.get_args().get("alternative_db_fields"):
return self.get_args()["alterntive_db_fields"]
return self.alternative_db_fields or []
def get_column_name(self) -> str:
"""Get column name for use in Pandas structures."""
return self.column_name
def __init__(self, school_term: SchoolTerm, base_path: str):
self.school_term = school_term
self.base_path = os.path.realpath(base_path)
self.column_name = f"col_{uuid4()}"
class MatchFieldType(FieldType):
......@@ -89,10 +92,6 @@ class MatchFieldType(FieldType):
class DirectMappingFieldType(FieldType):
"""Set value directly in DB."""
@classmethod
def get_column_name(cls):
return cls.get_db_field()
class ProcessFieldType(FieldType):
"""Field type with custom logic for importing."""
......@@ -118,25 +117,9 @@ class RegExFieldType(ProcessFieldType):
instance.save()
class MultipleValuesFieldType(ProcessFieldType):
"""Has multiple columns."""
def process(self, instance: Model, values: Sequence):
pass
@classmethod
def get_column_name(cls) -> str:
return f"{cls.name}_{uuid4()}"
class FieldTypeRegistry:
def __init__(self):
self.field_types = {}
self.allowed_field_types_for_models = {}
self.allowed_field_types_for_all_models = set()
self.alternatives = {}
self.match_field_types = []
self.process_field_types = []
def register(self, field_type: Type[FieldType]):
"""Add new `FieldType` to registry.
......@@ -147,21 +130,6 @@ class FieldTypeRegistry:
raise ValueError(f"The field type {field_type.name} is already registered.")
self.field_types[field_type.name] = field_type
if not field_type.models:
self.allowed_field_types_for_all_models.add(field_type)
else:
for model in field_type.models:
self.allowed_field_types_for_models.setdefault(model, []).append(field_type)
if field_type.get_alternative():
self.alternatives[field_type] = field_type.get_alternative()
if issubclass(field_type, MatchFieldType):
self.match_field_types.append((field_type.priority, field_type))
if issubclass(field_type, ProcessFieldType):
self.process_field_types.append(field_type)
return field_type
def get_from_name(self, name: str) -> FieldType:
......@@ -173,10 +141,6 @@ class FieldTypeRegistry:
"""Return choices in Django format."""
return [(f.name, f.verbose_name) for f in self.field_types.values()]
@property
def unique_references_by_priority(self) -> Sequence[FieldType]:
return sorted(self.match_field_types)
field_type_registry = FieldTypeRegistry()
......@@ -194,7 +158,7 @@ class NameFieldType(DirectMappingFieldType):
name = "name"
verbose_name = _("Name")
db_field = "name"
alternative = "short_name"
alternative_db_fields = ["short_name"]
@field_type_registry.register
......@@ -224,7 +188,7 @@ class ShortNameFieldType(MatchFieldType):
verbose_name = _("Short name")
priority = 8
db_field = "short_name"
alternative = "name"
alternative_db_fields = ["name", "first_name", "last_name"]
@field_type_registry.register
......@@ -319,10 +283,6 @@ class IgnoreFieldType(FieldType):
name = "ignore"
verbose_name = _("Ignore data in this field")
@classmethod
def get_column_name(cls) -> str:
return f"_ignore_{uuid4()}"
@field_type_registry.register
class DepartmentsFieldType(ProcessFieldType):
......@@ -427,29 +387,32 @@ class PrimaryGroupOwnerByShortNameFieldType(ProcessFieldType):
@field_type_registry.register
class GroupOwnerByShortNameFieldType(MultipleValuesFieldType):
class GroupOwnerByShortNameFieldType(ProcessFieldType):
name = "group_owner_short_name"
verbose_name = _("Short name of a single group owner")
def process(self, instance: Model, values: Sequence):
group_owners = bulk_get_or_create(
Person,
values,
attr="short_name",
default_attrs="last_name",
defaults={"first_name": "?"},
def process(self, instance: Model, short_name: str):
group_owner, __ = Person.objects.get_or_create(
short_name=short_name,
defaults={"first_name": "?", "last_name": short_name},
)
instance.owners.set(group_owners)
if self.get_args().get("clear", False):
instance.owners.set([group_owner])
else:
instance.owners.add(group_owner)
@field_type_registry.register
class GroupMembershipByShortNameFieldType(MultipleValuesFieldType):
class GroupMembershipByShortNameFieldType(ProcessFieldType):
name = "group_membership_short_name"
verbose_name = _("Short name of the group the person is a member of")
def process(self, instance: Model, values: Sequence):
groups = Group.objects.filter(short_name__in=values, school_term=self.school_term)
instance.member_of.add(*groups)
def process(self, instance: Model, short_name: str):
try:
group = Group.objects.get(short_name=short_name, school_term=self.school_term)
instance.member_of.add(group)
except Group.DoesNotExist:
pass
@field_type_registry.register
......
from typing import Optional, Sequence, Union
from django.db.models import Model
from typing import Optional
def with_prefix(prefix: Optional[str], value: str) -> str:
......@@ -14,49 +12,3 @@ def with_prefix(prefix: Optional[str], value: str) -> str:
return f"{prefix} {value}"
else:
return value
def bulk_get_or_create(
model: Model,
objs: Sequence,
attr: str,
default_attrs: Optional[Union[Sequence[str], str]] = None,
defaults: Optional[dict] = None,
) -> Sequence[Model]:
"""
Do get_or_create on a list of values.
:param model: Model on which get_or_create should be executed
:param objs: List of values
:param attr: Field of model which should be set
:param default_attrs: One or more extra fields of model which also should be set to the value
:param defaults: Extra fields of model which should be set to a specific value
:return: List of instances
"""
if not defaults:
defaults = {}
if not default_attrs:
default_attrs = []
if not isinstance(default_attrs, list):
default_attrs = [default_attrs]
attrs = default_attrs + [attr]
qs = model.objects.filter(**{f"{attr}__in": objs})
existing_values = qs.values_list(attr, flat=True)
instances = [x for x in qs]
for obj in objs:
if obj in existing_values:
continue
kwargs = defaults
for _attr in attrs:
kwargs[_attr] = obj
instance = model.objects.create(**kwargs)
instances.append(instance)
return instances
......@@ -18,8 +18,7 @@ from tqdm import tqdm
from aleksis.apps.csv_import.field_types import (
DirectMappingFieldType,
MatchFieldType,
MultipleValuesFieldType,
field_type_registry,
ProcessFieldType,
)
from aleksis.apps.csv_import.settings import FALSE_VALUES, TRUE_VALUES
from aleksis.core.models import Group, Person
......@@ -45,36 +44,21 @@ def import_csv(
# Dissect template definition
# These structures will be filled with information for columns
data_types = {}
cols = []
cols_for_multiple_fields = {}
converters = {}
match_field_types = []
field_types = {}
for field in template.fields.all():
# Get field type and prepare for import
field_type = field.field_type_class(school_term, temp_dir)
if isinstance(field_type, MatchFieldType):
# Field is used to match existing instances
match_field_types.append(field_type)
# Get column name/header
column_name = field_type.get_column_name()
cols.append(column_name)
if isinstance(field_type, MultipleValuesFieldType):
# Mark column as containing multiple target fields
cols_for_multiple_fields.setdefault(field_type, [])
cols_for_multiple_fields[field_type].append(column_name)
# Get data type and conversion rules, if any
# Get data type and conversion rules, if any,
# to be passed to Pandas
data_types[column_name] = field_type.get_data_type()
if field_type.get_converter():
converters[column_name] = field_type.get_converter()
field_types[column_name] = field_type
# Order matching fields by priority
match_field_types = sorted(match_field_types, key=lambda x: x.priority)
# Determine whether the data file is a plain CSV or an archive
if import_job.data_file.name.endswith(".zip"):
# Unpack to temporary directory
......@@ -108,7 +92,7 @@ def import_csv(
data = pandas.read_csv(
csv,
sep=template.parsed_separator,
names=cols,
names=data_types.keys(),
header=0 if template.has_header_row else None,
index_col=template.has_index_col,
dtype=data_types,
......@@ -147,40 +131,41 @@ def import_csv(
if isinstance(field_type, DirectMappingFieldType):
update_dict[field_type.get_db_field()] = value
# Set alternatives for some fields
for (
field_type_origin,
alternative_name,
) in field_type_registry.alternatives.items():
if (
model in field_type_origin.models
and field_type_origin.name not in row
and alternative_name in row
):
update_dict[field_type_origin.name] = row[alternative_name]
# Set group type for imported groups if defined in template globally
if template.group_type and model == Group:
update_dict["group_type"] = template.group_type
# Determine available fields for finding existing instances
get_dict = {}
match_field_found = False
for (
priority,
match_field_type,
) in match_field_types:
if match_field_found or match_field_type.name not in row:
update_dict[match_field_type.get_db_field()] = row[
match_field_type.name
]
elif match_field_type.name in row:
get_dict[match_field_type.get_db_field()] = row[match_field_type.name]
match_field_found = True
if not match_field_found:
match_field_priority = 0
for (column_name, field_type,) in sorted(
filter(lambda f: isinstance(f[1], MatchFieldType), field_types.items()),
key=lambda f: f[1].priority,
):
if match_field_priority and match_field_priority < field_type.priority:
# We found a match field, but with less important priority than
# those before, so we write data from the field
update_dict[field_type.get_db_field()] = row[column_name]
elif column_name in row:
get_dict[field_type.get_db_field()] = row[column_name]
match_field_priority = field_type.priority
if not match_field_priority:
raise ValueError(_("Missing unique reference or other matching fields."))
# Set alternatives for some fields
for column_name, field_type in field_types.items():
for alternative_db_field in field_type.get_alternative_db_fields():
origin_db_field = field_type.get_db_field()
if (
hasattr(model, alternative_db_field)
and alternative_db_field not in update_dict
):
if origin_db_field in update_dict:
update_dict[alternative_db_field] = update_dict[origin_db_field]
elif origin_db_field in get_dict:
update_dict[alternative_db_field] = get_dict[origin_db_field]
# Set school term globally if model is school term related
if hasattr(model, "school_term") and school_term:
get_dict["school_term"] = school_term
......@@ -189,37 +174,24 @@ def import_csv(
try:
get_dict["defaults"] = update_dict
instance, created = model.objects.update_or_create(**get_dict)
# Process fields spanning multiple target attributes
values_for_multiple_fields = {}
for field_type, cols_for_field_type in cols_for_multiple_fields.items():
values_for_multiple_fields[field_type] = []
for col in cols_for_field_type:
value = row[col]
values_for_multiple_fields[field_type].append(value)
field_type().process(instance, values_for_multiple_fields[field_type])
if created:
created_count += 1
# Process field types with custom logic
for process_field_type in field_type_registry.process_field_types:
if process_field_type.name in row:
try:
process_field_type().process(
instance, row[process_field_type.name]
)
except RuntimeError as e:
if recorder:
recorder.add_message(messages.ERROR, str(e))
else:
logging.error(str(e))
for column_name, field_type in filter(
lambda f: isinstance(f[1], ProcessFieldType), field_types.items()
):
try:
field_type.process(instance, row[column_name])
except RuntimeError as e:
if recorder:
recorder.add_message(messages.ERROR, str(e))
else:
logging.error(str(e))
# Add current instance to group if import defines a target group for persons
if template.group and isinstance(instance, Person):
instance.member_of.add(template.group)
if created:
created_count += 1
except (
ValueError,
ValidationError,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment