1
0
mirror of https://github.com/checktheroads/hyperglass synced 2024-05-11 05:55:08 +00:00

Fix HyperglassMultiModel root class inheritance

This commit is contained in:
thatmattlove
2021-09-17 12:09:12 -07:00
parent 66e69db17d
commit ed58c3622b
2 changed files with 14 additions and 16 deletions

View File

@@ -315,10 +315,10 @@ class NativeDirective(Directive):
DirectiveT = t.Union[NativeDirective, Directive] DirectiveT = t.Union[NativeDirective, Directive]
class Directives(HyperglassMultiModel[DirectiveT]): class Directives(HyperglassMultiModel[Directive]):
"""Collection of directives.""" """Collection of directives."""
def __init__(self, *items: t.Dict[str, t.Any]) -> None: def __init__(self, *items: t.Union[DirectiveT, t.Dict[str, t.Any]]) -> None:
"""Initialize base class and validate objects.""" """Initialize base class and validate objects."""
super().__init__(*items, model=Directive, accessor="id") super().__init__(*items, model=Directive, accessor="id")

View File

@@ -128,13 +128,15 @@ class HyperglassMultiModel(GenericModel, t.Generic[MultiModelT]):
extra = "forbid" extra = "forbid"
validate_assignment = True validate_assignment = True
def __init__(self, *items: t.Dict[str, t.Any], model: MultiModelT, accessor: str) -> None: def __init__(
self, *items: t.Union[MultiModelT, t.Dict[str, t.Any]], model: MultiModelT, accessor: str
) -> None:
"""Validate items.""" """Validate items."""
items = self._valid_items(*items, model=model, accessor=accessor)
super().__init__(__root__=items)
self._count = len(items)
self._accessor = accessor self._accessor = accessor
self._model = model self._model = model
valid = self._valid_items(*items)
super().__init__(__root__=valid)
self._count = len(self.__root__)
def __iter__(self) -> t.Iterator[MultiModelT]: def __iter__(self) -> t.Iterator[MultiModelT]:
"""Iterate items.""" """Iterate items."""
@@ -179,31 +181,27 @@ class HyperglassMultiModel(GenericModel, t.Generic[MultiModelT]):
"""Access item count.""" """Access item count."""
return self._count return self._count
@staticmethod
def _valid_items( def _valid_items(
*to_validate: t.List[t.Union[MultiModelT, t.Dict[str, t.Any]]], self, *to_validate: t.List[t.Union[MultiModelT, t.Dict[str, t.Any]]]
model: MultiModelT,
accessor: str
) -> t.List[MultiModelT]: ) -> t.List[MultiModelT]:
items = [ items = [
item item
for item in to_validate for item in to_validate
if any( if any(
( (
(isinstance(item, dict) and accessor in item), (isinstance(item, self.model) and hasattr(item, self.accessor)),
(isinstance(item, model) and hasattr(item, accessor)), (isinstance(item, t.Dict) and self.accessor in item),
), ),
) )
] ]
for index, item in enumerate(items): for index, item in enumerate(items):
if isinstance(item, dict): if isinstance(item, t.Dict):
items[index] = model(**item) items[index] = self.model(**item)
return items return items
def add(self, *items, unique_by: t.Optional[str] = None) -> None: def add(self, *items, unique_by: t.Optional[str] = None) -> None:
"""Add an item to the model.""" """Add an item to the model."""
to_add = self._valid_items(*items, model=self.model, accessor=self.accessor) to_add = self._valid_items(*items)
if unique_by is not None: if unique_by is not None:
unique_by_values = { unique_by_values = {
getattr(obj, unique_by) for obj in (*self, *to_add) if hasattr(obj, unique_by) getattr(obj, unique_by) for obj in (*self, *to_add) if hasattr(obj, unique_by)