Skip to content
Snippets Groups Projects
Verified Commit 4a692419 authored by Nik | Klampfradler's avatar Nik | Klampfradler
Browse files

[Refctor] Remove global state from FieldType

parent 268f1dec
No related branches found
No related tags found
No related merge requests found
......@@ -33,15 +33,14 @@ class FieldType:
def get_data_type(cls) -> type:
return cls.data_type
@classmethod
def get_converter(cls) -> Callable[[Any], Any]:
converters_pre = cls.get_args().get("converter_pre", [])
def get_converter(self) -> Callable[[Any], Any]:
converters_pre = self.get_args().get("converter_pre", [])
if isinstance(converters_pre, str):
converters_pre = [converters_pre]
converters_post = cls.get_args().get("converter_post", [])
converters_post = self.get_args().get("converter_post", [])
if isinstance(converters_post, str):
converters_post = [converters_post]
converters = cls.converter
converters = self.converter
if converters is None:
converters = []
elif isinstance(converters, str):
......@@ -62,30 +61,20 @@ class FieldType:
def get_alternative(cls) -> Optional[str]:
return cls.alternative
@classmethod
def get_args(cls) -> dict:
return cls.args or {}
def get_args(self) -> dict:
return self.args or {}
@classmethod
def get_column_name(cls) -> str:
return cls.name
def get_column_name(self) -> str:
return self.name
@classmethod
def get_db_field(cls):
if cls.get_args().get("db_field"):
return cls.get_args()["db_field"]
return cls.db_field
@classmethod
def is_model_valid(cls, model: Model) -> bool:
if not cls.models:
return True
return model in cls.models
def get_db_field(self):
if self.get_args().get("db_field"):
return self.get_args()["db_field"]
return self.db_field
@classmethod
def prepare(cls, school_term: SchoolTerm, base_path: str):
cls.school_term = school_term
cls.base_path = os.path.realpath(base_path)
def __init__(self, school_term: SchoolTerm, base_path: str):
self.school_term = school_term
self.base_path = os.path.realpath(base_path)
class MatchFieldType(FieldType):
......@@ -93,9 +82,8 @@ class MatchFieldType(FieldType):
priority: int = 1
@classmethod
def get_priority(cls):
return cls.get_args().get("priority", "") or cls.priority
def get_priority(self):
return self.get_args().get("priority", "") or self.priority
class DirectMappingFieldType(FieldType):
......@@ -119,9 +107,8 @@ class RegExFieldType(ProcessFieldType):
data_type = str
reg_ex: str = ""
@classmethod
def get_reg_ex(cls):
return cls.reg_ex or cls.get_args().get("reg_ex", "")
def get_reg_ex(self):
return self.reg_ex or self.get_args().get("reg_ex", "")
def process(self, instance: Model, value):
match = re.fullmatch(self.get_reg_ex(), value)
......@@ -198,7 +185,6 @@ field_type_registry = FieldTypeRegistry()
class UniqueReferenceFieldType(MatchFieldType):
name = "unique_reference"
verbose_name = _("Unique reference")
models = [Person, Group]
db_field = "import_ref_csv"
priority = 10
......@@ -207,7 +193,6 @@ class UniqueReferenceFieldType(MatchFieldType):
class NameFieldType(DirectMappingFieldType):
name = "name"
verbose_name = _("Name")
models = [Group]
db_field = "name"
alternative = "short_name"
......@@ -216,7 +201,6 @@ class NameFieldType(DirectMappingFieldType):
class FirstNameFieldType(DirectMappingFieldType):
name = "first_name"
verbose_name = _("First name")
models = [Person]
db_field = "first_name"
......@@ -224,7 +208,6 @@ class FirstNameFieldType(DirectMappingFieldType):
class LastNameFieldType(DirectMappingFieldType):
name = "last_name"
verbose_name = _("Last name")
models = [Person]
db_field = "last_name"
......@@ -232,7 +215,6 @@ class LastNameFieldType(DirectMappingFieldType):
class AdditionalNameFieldType(DirectMappingFieldType):
name = "additional_name"
verbose_name = _("Additional name")
models = [Person]
db_field = "additional_name"
......@@ -240,7 +222,6 @@ class AdditionalNameFieldType(DirectMappingFieldType):
class ShortNameFieldType(MatchFieldType):
name = "short_name"
verbose_name = _("Short name")
models = [Person, Group]
priority = 8
db_field = "short_name"
alternative = "name"
......@@ -250,19 +231,17 @@ class ShortNameFieldType(MatchFieldType):
class EmailFieldType(MatchFieldType):
name = "email"
verbose_name = _("Email")
models = [Person]
db_field = "email"
priority = 12
@classmethod
def get_converter(cls) -> Optional[Callable]:
if "email_domain" in cls.get_args():
def get_converter(self) -> Optional[Callable]:
if "email_domain" in self.get_args():
def add_domain_to_email(value: str) -> str:
if "@" in value:
return value
else:
return f"{value}@{cls.get_args()['email_domain']}"
return f"{value}@{self.get_args()['email_domain']}"
return add_domain_to_email
return super().get_converter()
......@@ -272,7 +251,6 @@ class EmailFieldType(MatchFieldType):
class DateOfBirthFieldType(DirectMappingFieldType):
name = "date_of_birth"
verbose_name = _("Date of birth")
models = [Person]
db_field = "date_of_birth"
converter = "parse_date"
......@@ -281,7 +259,6 @@ class DateOfBirthFieldType(DirectMappingFieldType):
class SexFieldType(DirectMappingFieldType):
name = "sex"
verbose_name = _("Sex")
models = [Person]
db_field = "sex"
converter = "parse_sex"
......@@ -290,7 +267,6 @@ class SexFieldType(DirectMappingFieldType):
class StreetFieldType(DirectMappingFieldType):
name = "street"
verbose_name = _("Street")
models = [Person]
db_field = "street"
......@@ -298,7 +274,6 @@ class StreetFieldType(DirectMappingFieldType):
class HouseNumberFieldType(DirectMappingFieldType):
name = "housenumber"
verbose_name = _("Housenumber")
models = [Person]
db_field = "housenumber"
......@@ -306,7 +281,6 @@ class HouseNumberFieldType(DirectMappingFieldType):
class StreetAndHouseNumberFieldType(RegExFieldType):
name = "street_housenumber"
verbose_name = _("Street and housenumber")
models = [Person]
reg_ex = r"^(?P<street>[\w\s]{3,})\s+(?P<housenumber>\d+\s*[a-zA-Z]*)$"
......@@ -314,7 +288,6 @@ class StreetAndHouseNumberFieldType(RegExFieldType):
class PostalCodeFieldType(DirectMappingFieldType):
name = "postal_code"
verbose_name = _("Postal code")
models = [Person]
db_field = "postal_code"
......@@ -322,7 +295,6 @@ class PostalCodeFieldType(DirectMappingFieldType):
class PlaceFieldType(DirectMappingFieldType):
name = "place"
verbose_name = _("Place")
models = [Person]
db_field = "place"
......@@ -330,7 +302,6 @@ class PlaceFieldType(DirectMappingFieldType):
class PhoneNumberFieldType(DirectMappingFieldType):
name = "phone_number"
verbose_name = _("Phone number")
models = [Person]
db_field = "phone_number"
converter = "parse_phone_number"
......@@ -339,7 +310,6 @@ class PhoneNumberFieldType(DirectMappingFieldType):
class MobileNumberFieldType(DirectMappingFieldType):
name = "mobile_number"
verbose_name = _("Mobile number")
models = [Person]
db_field = "mobile_number"
converter = "parse_phone_number"
......@@ -348,7 +318,6 @@ class MobileNumberFieldType(DirectMappingFieldType):
class IgnoreFieldType(FieldType):
name = "ignore"
verbose_name = _("Ignore data in this field")
models = [Person, Group]
@classmethod
def get_column_name(cls) -> str:
......@@ -359,7 +328,6 @@ class IgnoreFieldType(FieldType):
class DepartmentsFieldType(ProcessFieldType):
name = "departments"
verbose_name = _("Comma-seperated list of departments")
models = [Person]
converter = "parse_comma_separated_data"
def process(self, instance: Model, value):
......@@ -402,7 +370,6 @@ class DepartmentsFieldType(ProcessFieldType):
class GroupSubjectByShortNameFieldType(ProcessFieldType):
name = "group_subject_short_name"
verbose_name = _("Short name of the subject")
models = [Group]
def process(self, instance: Model, value):
with_chronos = apps.is_installed("aleksis.apps.chronos")
......@@ -417,13 +384,13 @@ class GroupSubjectByShortNameFieldType(ProcessFieldType):
class ClassRangeFieldType(ProcessFieldType):
name = "class_range"
verbose_name = _("Class range (e. g. 7a-d)")
models = [Group]
@classmethod
def prepare(cls, school_term: SchoolTerm):
"""Prefetch class groups."""
cls.classes_per_short_name = get_classes_per_short_name(school_term)
cls.classes_per_grade = get_classes_per_grade(cls.classes_per_short_name.keys())
def __init__(self, school_term: SchoolTerm, base_path: str):
# Prefetch class groups
self.classes_per_short_name = get_classes_per_short_name(school_term)
self.classes_per_grade = get_classes_per_grade(self.classes_per_short_name.keys())
super().__init__(school_term, base_path)
def process(self, instance: Model, value):
classes = parse_class_range(
......@@ -438,7 +405,6 @@ class ClassRangeFieldType(ProcessFieldType):
class PrimaryGroupByShortNameFieldType(ProcessFieldType):
name = "primary_group_short_name"
verbose_name = _("Short name of the person's primary group")
models = [Person]
def process(self, instance: Model, value):
group, __ = Group.objects.get_or_create(
......@@ -453,7 +419,6 @@ class PrimaryGroupByShortNameFieldType(ProcessFieldType):
class PrimaryGroupOwnerByShortNameFieldType(ProcessFieldType):
name = "primary_group_owner_short_name"
verbose_name = _("Short name of an owner of the person's primary group")
models = [Person]
def process(self, instance: Model, value):
if instance.primary_group:
......@@ -465,7 +430,6 @@ class PrimaryGroupOwnerByShortNameFieldType(ProcessFieldType):
class GroupOwnerByShortNameFieldType(MultipleValuesFieldType):
name = "group_owner_short_name"
verbose_name = _("Short name of a single group owner")
models = [Group]
def process(self, instance: Model, values: Sequence):
group_owners = bulk_get_or_create(
......@@ -483,8 +447,6 @@ class GroupMembershipByShortNameFieldType(MultipleValuesFieldType):
name = "group_membership_short_name"
verbose_name = _("Short name of the group the person is a member of")
models = [Person]
def process(self, instance: Model, values: Sequence):
groups = Group.objects.filter(short_name__in=values, school_term=self.school_term)
instance.member_of.add(*groups)
......@@ -494,7 +456,6 @@ class GroupMembershipByShortNameFieldType(MultipleValuesFieldType):
class ChildByUniqueReference(ProcessFieldType):
name = "child_by_unique_reference"
verbose_name = _("Child by unique reference (from students import)")
models = [Person]
def process(self, instance: Model, value):
child = Person.objects.get(import_ref_csv=value)
......
......@@ -49,18 +49,18 @@ def import_csv(
cols_for_multiple_fields = {}
converters = {}
match_field_types = []
field_types = {}
for field in template.fields.all():
# Get field type and prepare for import
field_type = field.field_type_class
field_type.prepare(school_term, temp_dir)
if issubclass(field_type, MatchFieldType):
field_type = field.field_type_class(school_term, temp_dir)
if isinstance(field_type, MatchFieldType):
# Field is used to match existing instances
match_field_types.append(field_type)
# Get column name/header
column_name = field_type.get_column_name()
cols.append(column_name)
if issubclass(field_type, MultipleValuesFieldType):
if isinstance(field_type, MultipleValuesFieldType):
# Mark column as containing multiple target fields
cols_for_multiple_fields.setdefault(field_type, [])
cols_for_multiple_fields[field_type].append(column_name)
......@@ -70,6 +70,8 @@ def import_csv(
if field_type.get_converter():
converters[column_name] = field_type.get_converter()
field_types[column_name] = field_type
# Order matching fields by priority
match_field_types = sorted(match_field_types, key=lambda x: x.priority)
......@@ -140,9 +142,9 @@ def import_csv(
# Build dict with all fields that should be directly updated
update_dict = {}
for key, value in row.items():
if key in field_type_registry.field_types:
field_type = field_type_registry.get_from_name(key)
if issubclass(field_type, DirectMappingFieldType):
if key in field_types:
field_type = field_types[key]
if isinstance(field_type, DirectMappingFieldType):
update_dict[field_type.get_db_field()] = value
# Set alternatives for some fields
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment