Skip to content

Commit

Permalink
[PYG-180] 🐖 Cleanup (#288)
Browse files Browse the repository at this point in the history
* docs; added missing docstrings

* docs: added missing docstrings

* docs: added missing docstrings
  • Loading branch information
doctrino authored Aug 18, 2024
1 parent 7548eee commit 182f079
Showing 1 changed file with 59 additions and 4 deletions.
63 changes: 59 additions & 4 deletions cognite/pygen/_core/models/data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,24 +68,29 @@ def __eq__(self, other: object):

@staticmethod
def to_base_name(view: dm.View) -> str:
"""Creates Python compatible base name from a view."""
return view.external_id.replace(" ", "_")

@classmethod
def to_base_name_with_version(cls, view: dm.View) -> str:
"""Creates Python compatible base name from a view with version."""
return f"{cls.to_base_name(view)}v{to_pascal(view.version)}".replace(" ", "_")

@classmethod
def to_base_name_with_space(cls, view: dm.View) -> str:
"""Creates Python compatible base name from a view with space."""
return f"{cls.to_base_name(view)}s{to_pascal(view.space)}".replace(" ", "_")

@classmethod
def to_base_name_with_space_and_version(cls, view: dm.View) -> str:
"""Creates Python compatible base name from a view with space and version."""
return f"{cls.to_base_name(view)}v{to_pascal(view.version)}s{to_pascal(view.space)}".replace(" ", "_")

@classmethod
def from_view(
cls, view: dm.View, base_name: str, used_for: Literal["node", "edge"], data_class: pygen_config.DataClassNaming
) -> DataClass:
"""Create a DataClass from a view."""
class_name = create_name(base_name, data_class.name)
if is_reserved_word(class_name, "data class", view.as_id()):
class_name = f"{class_name}_"
Expand Down Expand Up @@ -143,6 +148,11 @@ def update_fields(
has_default_instance_space: bool,
config: pygen_config.PygenConfig,
) -> None:
"""Update the fields of the data class.
This needs to be called after the data class has been created to update all fields with dependencies
on other data classes.
"""
for prop_name, prop in properties.items():
field_ = Field.from_property(
prop_name,
Expand All @@ -167,24 +177,28 @@ def update_fields(
self.initialization.add("fields")

def update_implements_interface_and_writable(self, parents: list[DataClass], is_interface: bool):
"""Update the implements, is_interface and is_writable attributes of the data class."""
self.is_interface = is_interface
self.implements.extend(parents)
self.is_writable = self.is_writable or self.is_all_fields_of_type(OneToManyConnectionField)
self.initialization.add("parents")

def update_direct_children(self, children: list[DataClass]):
"""Update the direct children of the data class."""
self.direct_children.extend(children)
self.initialization.add("children")

@property
def read_base_class(self) -> str:
"""Parent read classes."""
if self.implements:
return ", ".join(f"{interface.read_name}" for interface in self.implements)
else:
return "DomainModel"

@property
def write_base_class(self) -> str:
"""Parent write classes."""
if self.implements:
return ", ".join(f"{interface.write_name}" for interface in self.implements)
else:
Expand All @@ -197,50 +211,62 @@ def query_cls_name(self) -> str:

@property
def view_id_str(self) -> str:
"""The view id as a string."""
return f'dm.ViewId("{self.view_id.space}", "{self.view_id.external_id}", "{self.view_id.version}")'

@property
def has_filtering_fields(self) -> bool:
"""Check if the data class has any fields that support filtering."""
return any(field_.support_filtering for field_ in self.fields_of_type(PrimitiveField))

@property
def filtering_fields(self) -> Iterable[PrimitiveField]:
"""Return all fields that support filtering"""
return (field_ for field_ in self.fields_of_type(PrimitiveField) if field_.support_filtering)

@property
def filtering_import(self) -> str:
"""Import the filtering classes used in the data class."""
return "\n ".join(
f"{cls_name}," for cls_name in sorted(set(field_.filtering_cls for field_ in self.filtering_fields))
)

@property
def typed_read_bases_classes(self) -> str:
"""The parent read classes for the typed data class."""
if self.implements:
return ", ".join(f"{interface.read_name}" for interface in self.implements)
else:
return "TypedEdge" if isinstance(self, EdgeDataClass) else "TypedNode"

@property
def typed_write_bases_classes(self) -> str:
"""The parent write classes for the typed data class."""
if self.implements:
return ", ".join(f"{interface.read_name}Apply" for interface in self.implements)
else:
return "TypedEdgeApply" if isinstance(self, EdgeDataClass) else "TypedNodeApply"

@property
def text_field_names(self) -> str:
"""The name of the text fields Literal."""
return f"{self.read_name}TextFields"

@property
def field_names(self) -> str:
"""The name of the fields Literal."""
return f"{self.read_name}Fields"

@property
def properties_dict_name(self) -> str:
"""The name of the properties dictionary."""
return f"_{self.read_name.upper()}_PROPERTIES_BY_FIELD"

@property
def pydantic_field(self) -> Literal["Field", "pydantic.Field"]:
"""The name of the pydantic field to use.
This is in case we need to use pydantic.Field from pydantic instead of Field.
"""
if any(
name == "Field" for name in [self.read_name, self.write_name, self.read_list_name, self.write_list_name]
) or any(
Expand All @@ -259,6 +285,7 @@ def pydantic_field(self) -> Literal["Field", "pydantic.Field"]:

@property
def init_import(self) -> str:
"""The data class __init__ imports of this data class"""
import_classes = [self.read_name, self.graphql_name]
if self.is_writable:
import_classes.append(self.write_name)
Expand All @@ -275,6 +302,7 @@ def __iter__(self) -> Iterator[Field]:
return iter(self.fields)

def non_parent_fields(self, fields: Iterator[Field] | None = None) -> Iterator[Field]:
"""Return all fields that are not inherited from a parent."""
parent_fields = {field.prop_name for parent in self.implements for field in parent}
return (field for field in fields or self if field.prop_name not in parent_fields)

Expand All @@ -297,50 +325,55 @@ def fields_of_type(self, field_type: tuple[type[Field], ...]) -> Iterator[tuple[
def fields_of_type(
self, field_type: type[T_Field] | tuple[type[Field], ...]
) -> Iterator[T_Field] | Iterator[tuple[Field]]:
"""Return all fields of a specific type."""
return (field_ for field_ in self if isinstance(field_, field_type)) # type: ignore[return-value]

def has_field_of_type(self, type_: type[Field] | tuple[type[Field], ...]) -> bool:
"""Check if the data class has any fields of a specific type."""
return any(isinstance(field_, type_) for field_ in self)

def is_all_fields_of_type(self, type_: type[Field] | tuple[type[Field], ...]) -> bool:
"""Check if all fields are of a specific type."""
return all(isinstance(field_, type_) for field_ in self)

def primitive_fields_of_type(
self, type_: type[dm.PropertyType] | tuple[type[dm.PropertyType], ...]
) -> Iterable[BasePrimitiveField]:
"""Return all primitive fields of a specific type."""
return (
field_
for field_ in self.fields_of_type(BasePrimitiveField) # type: ignore[type-abstract]
if isinstance(field_.type_, type_)
)

def has_primitive_field_of_type(self, type_: type[dm.PropertyType] | tuple[type[dm.PropertyType], ...]) -> bool:
"""Check if the data class has any fields of a specific type."""
return any(self.primitive_fields_of_type(type_))

def is_all_primitive_fields_of_type(self, type_: type[dm.PropertyType] | tuple[type[dm.PropertyType], ...]) -> bool:
"""Check if all fields are of a specific type."""
return all(self.primitive_fields_of_type(type_))

def has_timeseries_fields(self) -> bool:
"""Check if the data class has any time series fields."""
return any(isinstance(field_, CDFExternalField) and field_.is_time_series for field_ in self)

def timeseries_fields(self) -> Iterable[CDFExternalField]:
"""Return all time series fields."""
return (field_ for field_ in self if isinstance(field_, CDFExternalField) and field_.is_time_series)

@property
def _field_type_hints(self) -> Iterable[str]:
return (hint for field_ in self.fields for hint in (field_.as_read_type_hint(), field_.as_write_type_hint()))

@property
def use_optional_type(self) -> bool:
return any("Optional" in hint for hint in self._field_type_hints)

@property
def use_pydantic_field(self) -> bool:
pydantic_field = self.pydantic_field
return any(pydantic_field in hint for hint in self._field_type_hints)

@property
def dependencies(self) -> list[DataClass]:
"""Return a list of all data class dependencies (through fields)."""
unique: dict[dm.ViewId, DataClass] = {}
for field_ in self.fields:
if isinstance(field_, BaseConnectionField):
Expand Down Expand Up @@ -382,6 +415,7 @@ def dependencies_with_edge_destinations(self) -> list[DataClass]:

@property
def has_dependencies(self) -> bool:
"""Check if the data class has any dependencies."""
return bool(self.dependencies)

@property
Expand Down Expand Up @@ -410,6 +444,8 @@ def container_fields(self) -> Iterable[Field]:
)

def container_fields_sorted(self, include: Literal["all", "only-self"] | DataClass = "all") -> list[Field]:
"""Return all container fields sorted by type."""

def key(x: Field) -> int:
return {True: 1, False: 0}[x.is_nullable] if isinstance(x, BasePrimitiveField) else 1

Expand Down Expand Up @@ -462,6 +498,7 @@ def one_to_one_direct_relations_with_source(self) -> Iterable[OneToOneConnection

@property
def one_to_one_reverse_direct_relation(self) -> Iterable[OneToOneConnectionField]:
"""All one to one reverse direct relations."""
return (
field_
for field_ in self.fields_of_type(OneToOneConnectionField)
Expand All @@ -479,6 +516,7 @@ def one_to_many_direct_relations_with_source(self) -> Iterable[OneToManyConnecti

@property
def one_to_many_reverse_direct_relations(self) -> Iterable[OneToManyConnectionField]:
"""All one to many reverse direct relations."""
return (
field_
for field_ in self.fields_of_type(OneToManyConnectionField)
Expand All @@ -487,26 +525,31 @@ def one_to_many_reverse_direct_relations(self) -> Iterable[OneToManyConnectionFi

@property
def has_one_to_one_direct_relations_with_source(self) -> bool:
"""Check if the data class has any one to one direct relations."""
return any(self.one_to_one_direct_relations_with_source)

@property
def primitive_fields_literal(self) -> str:
"""Return a literal with all primitive fields."""
return ", ".join(
f'"{field_.prop_name}"' for field_ in self if isinstance(field_, (PrimitiveField, CDFExternalField))
)

@property
def text_fields_literals(self) -> str:
"""Return a literal with all text fields."""
return ", ".join(
f'"{field_.name}"' for field_ in self.primitive_fields_of_type((dm.Text, dm.CDFExternalIdReference))
)

@property
def fields_literals(self) -> str:
"""Return a literal with all fields."""
return ", ".join(f'"{field_.name}"' for field_ in self if isinstance(field_, BasePrimitiveField))

@property
def container_field_variables(self) -> str:
"""Return a string with all container fields as variables."""
return ", ".join(
f"{field_.name}={field_.name}"
for field_ in self
Expand All @@ -516,14 +559,17 @@ def container_field_variables(self) -> str:

@property
def filter_name(self) -> str:
"""The name of the filter class."""
return f"_create_{self.variable}_filter"

@property
def is_edge_class(self) -> bool:
"""Check if the data class is an edge class."""
return False

@property
def connections_docs_write(self) -> str:
"""Return a string with all connections that are write fields."""
connections = [f for f in self.fields_of_type(BaseConnectionField) if f.destination_class and f.is_write_field] # type: ignore[type-abstract]
if len(connections) == 0:
raise ValueError("No connections found")
Expand All @@ -534,6 +580,7 @@ def connections_docs_write(self) -> str:

@property
def connections_docs(self) -> str:
"""Return a string with all connections."""
connections = [f for f in self.fields_of_type(BaseConnectionField) if f.destination_class] # type: ignore[type-abstract]
if len(connections) == 0:
raise ValueError("No connections found")
Expand All @@ -544,17 +591,20 @@ def connections_docs(self) -> str:

@property
def import_pydantic_field(self) -> str:
"""Import the pydantic field used in the data class."""
if self.pydantic_field == "Field":
return "from pydantic import Field"
else:
return "import pydantic"

@property
def has_any_field_model_prefix(self) -> bool:
"""Check if any field has a model prefix."""
return any(field_.name.startswith("model") for field_ in self)

@property
def has_edges(self) -> bool:
"""Check if the data class has any edges."""
return any(isinstance(field_, BaseConnectionField) and field_.is_edge for field_ in self)

@property
Expand All @@ -573,6 +623,7 @@ class NodeDataClass(DataClass):

@property
def typed_properties_name(self) -> str:
"""The name of the typed properties class."""
if self.has_edge_class:
return f"{self.read_name.removesuffix('Node')}Properties"
else:
Expand All @@ -588,17 +639,20 @@ class EdgeDataClass(DataClass):

@property
def typed_properties_name(self) -> str:
"""The name of the typed properties class."""
if self.has_node_class:
return f"{self.read_name.removesuffix('Edge')}Properties"
else:
return f"{self.read_name}Properties"

@property
def is_edge_class(self) -> bool:
"""Check if the data class is an edge class."""
return True

@property
def end_node_field(self) -> EndNodeField:
"""The end node field of the edge class."""
if self._end_node_field:
return self._end_node_field
raise ValueError("EdgeDataClass has not been initialized.")
Expand All @@ -612,6 +666,7 @@ def update_fields(
has_default_instance_space: bool,
config: pygen_config.PygenConfig,
):
"""Update the fields of the data class."""
# Find all node views that have an edge with properties in this view
# and get the node class it is pointing to.
edge_classes: dict[tuple[str, dm.DirectRelationReference, str], EdgeClass] = {}
Expand Down

0 comments on commit 182f079

Please sign in to comment.