Skip to content

Commit

Permalink
[PYG-173, PYG-190, PYG-200] 🐻Edge with Properties Consistency (#275)
Browse files Browse the repository at this point in the history
* refactor: more efficient way of looking up end node field

* refactor; replace end_classes in connection field

* fix: type hints

* refactor: gen node data class

* refactor: fix gen of node api

* refactor; regen

* refactor: updated example

* refactor: generate correct edge with properties

* Ãrefactor: deterministc order

* Ãrefactor: regenæ

* refactro: removed unused

* build: changelog

* Ãrefactor; fix

* refactor: deleted faluty reverse direct relatinos

* tests: failing test

* fix: correct destination node

* refactor: regen

* build: changelog
  • Loading branch information
doctrino authored Aug 14, 2024
1 parent 00d9901 commit b33a652
Show file tree
Hide file tree
Showing 36 changed files with 367 additions and 868 deletions.
24 changes: 10 additions & 14 deletions cognite/pygen/_core/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def create_edge_apis(
self._config,
)
for field in self.data_class.fields_of_type(fields.BaseConnectionField) # type: ignore[type-abstract]
if not field.is_direct_relation
if field.is_edge
]

@property
Expand All @@ -625,7 +625,7 @@ def generate_data_class_file(self, is_pydantic_v2: bool) -> str:
"""
unique_start_classes = []
unique_end_classes = []
grouped_edge_classes = []
grouped_edge_classes: dict[str, list[str]] = {}
if isinstance(self.data_class, NodeDataClass):
type_data = self._env.get_template("data_class_node.py.jinja")
elif isinstance(self.data_class, EdgeDataClass):
Expand All @@ -636,20 +636,17 @@ def generate_data_class_file(self, is_pydantic_v2: bool) -> str:
unique_end_classes = sorted(
_unique_data_classes([edge.end_class for edge in self.data_class.end_node_field.edge_classes])
)

grouped_edge_classes = [
(key, list(group))
for key, group in itertools.groupby(
sorted(self.data_class.end_node_field.edge_classes), key=lambda c: c.end_class
)
]
_grouped_edge_classes: dict[str, set[str]] = defaultdict(set)
for edge_class in self.data_class.end_node_field.edge_classes:
if "outwards" in edge_class.used_directions:
_grouped_edge_classes[edge_class.end_class.write_name].add(edge_class.start_class.write_name)
elif "inwards" in edge_class.used_directions:
_grouped_edge_classes[edge_class.start_class.write_name].add(edge_class.end_class.write_name)
for start_class, end_classes in sorted(_grouped_edge_classes.items(), key=lambda x: x[0]):
grouped_edge_classes[start_class] = sorted(end_classes)
else:
raise ValueError(f"Unknown data class {type(self.data_class)}")

def create_start_node_set(group: list[EdgeAPIClass]) -> str:
joined = ", ".join([g.start_class.write_name for g in group])
return f"{{{joined}}}"

if is_pydantic_v2 and self.data_class.has_any_field_model_prefix:
names = ", ".join(field.name for field in self.data_class.fields if field.name.startswith("name"))
warnings.warn(
Expand All @@ -671,7 +668,6 @@ def create_start_node_set(group: list[EdgeAPIClass]) -> str:
unique_start_classes=unique_start_classes,
unique_end_classes=unique_end_classes,
grouped_edge_classes=grouped_edge_classes,
create_start_node_set=create_start_node_set,
)
+ "\n"
)
Expand Down
16 changes: 4 additions & 12 deletions cognite/pygen/_core/models/api_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from cognite.pygen import config as pygen_config
from cognite.pygen.utils.text import create_name

from .data_classes import DataClass, EdgeDataClass
from .data_classes import DataClass, EdgeDataClass, NodeDataClass
from .fields import BaseConnectionField, CDFExternalField
from .filter_methods import FilterMethod, FilterParameter

Expand Down Expand Up @@ -159,21 +159,13 @@ def from_fields(
parent_attribute = create_name(field.name, api_class.client_attribute)

# This is always true for Edge Connection Fields
field_end_class = cast(DataClass, field.end_classes[0]) # type: ignore[index]
end_class = cast(NodeDataClass, field.destination_class)

edge_class: EdgeDataClass | None = None
end_class: DataClass
if isinstance(field_end_class, EdgeDataClass):
edge_class = field_end_class
try:
end_class = next(
c.end_class for c in edge_class.end_node_field.edge_classes if c.edge_type == field.edge_type
)
except StopIteration:
raise ValueError(f"Could not find end class {field_end_class.view_id}") from None
if field.edge_class:
edge_class = field.edge_class
filter_method = FilterMethod.from_fields(edge_class.fields, pygen_config.filtering, is_edge_class=True)
else: # NodeDataClass
end_class = field_end_class
filter_method = FilterMethod.from_fields([], pygen_config.filtering)

return cls(
Expand Down
99 changes: 53 additions & 46 deletions cognite/pygen/_core/models/data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
BaseConnectionField,
BasePrimitiveField,
CDFExternalField,
EdgeClasses,
EdgeClass,
EndNodeField,
Field,
OneToManyConnectionField,
Expand Down Expand Up @@ -338,35 +338,38 @@ def dependencies(self) -> list[DataClass]:
unique: dict[dm.ViewId, DataClass] = {}
for field_ in self.fields:
if isinstance(field_, BaseConnectionField):
for class_ in field_.end_classes or []:
if field_.edge_class:
unique[field_.edge_class.view_id] = field_.edge_class
elif field_.destination_class:
# This will overwrite any existing data class with the same view id
# however, this is not a problem as all data classes are uniquely identified by their view id
unique[class_.view_id] = class_
unique[field_.destination_class.view_id] = field_.destination_class
elif isinstance(field_, EndNodeField):
for class_ in field_.end_classes:
for class_ in field_.destination_classes:
unique[class_.view_id] = class_

return sorted(unique.values(), key=lambda x: x.write_name)

@property
def dependencies_with_edge_destinations(self) -> list[DataClass]:
"""Return a list of all dependencies which also includes the edge
destination if th dependency is a EdgeClass."""
destination if the dependency is a EdgeClass."""
unique: dict[dm.ViewId, DataClass] = {}
for field_ in self.fields:
if isinstance(field_, BaseConnectionField):
for class_ in field_.end_classes or []:
if field_.destination_class:
# This will overwrite any existing data class with the same view id
# however, this is not a problem as all data classes are uniquely identified by their view id
unique[class_.view_id] = class_
if isinstance(class_, EdgeDataClass):
for edge_class in class_.end_node_field.edge_classes:
if field_.edge_direction == "outwards":
unique[edge_class.end_class.view_id] = edge_class.end_class
else:
unique[edge_class.start_class.view_id] = edge_class.start_class
unique[field_.destination_class.view_id] = field_.destination_class
if field_.edge_class:
unique[field_.edge_class.view_id] = field_.edge_class
for edge_class in field_.edge_class.end_node_field.edge_classes:
if field_.edge_direction == "outwards":
unique[edge_class.end_class.view_id] = edge_class.end_class
else:
unique[edge_class.start_class.view_id] = edge_class.start_class
elif isinstance(field_, EndNodeField):
for class_ in field_.end_classes:
for class_ in field_.destination_classes:
unique[class_.view_id] = class_

return sorted(unique.values(), key=lambda x: x.read_name)
Expand Down Expand Up @@ -425,33 +428,33 @@ def has_container_fields(self) -> bool:
@property
def one_to_many_edges_without_properties(self) -> Iterable[OneToManyConnectionField]:
"""All MultiEdges without properties on the edge."""
return (field_ for field_ in self.fields_of_type(OneToManyConnectionField) if field_.is_no_property_edge)
return (field_ for field_ in self.fields_of_type(OneToManyConnectionField) if field_.is_edge_without_properties)

@property
def one_to_one_edge_without_properties(self) -> Iterable[OneToOneConnectionField]:
"""All MultiEdges without properties on the edge."""
return (field_ for field_ in self.fields_of_type(OneToOneConnectionField) if field_.is_no_property_edge)
return (field_ for field_ in self.fields_of_type(OneToOneConnectionField) if field_.is_edge_without_properties)

@property
def one_to_many_edges_with_properties(self) -> Iterable[OneToManyConnectionField]:
"""All MultiEdges with properties on the edge."""
return (field_ for field_ in self.fields_of_type(OneToManyConnectionField) if field_.is_property_edge)
return (field_ for field_ in self.fields_of_type(OneToManyConnectionField) if field_.is_edge_with_properties)

@property
def one_to_one_direct_relations_with_source(self) -> Iterable[OneToOneConnectionField]:
"""All direct relations."""
return (
field_
for field_ in self.fields_of_type(OneToOneConnectionField)
if field_.is_direct_relation and field_.end_classes
if field_.is_direct_relation and field_.destination_class
)

@property
def one_to_one_reverse_direct_relation(self) -> Iterable[OneToOneConnectionField]:
return (
field_
for field_ in self.fields_of_type(OneToOneConnectionField)
if field_.is_reverse_direct_relation and field_.end_classes
if field_.is_reverse_direct_relation and field_.destination_class
)

@property
Expand All @@ -460,15 +463,15 @@ def one_to_many_direct_relations_with_source(self) -> Iterable[OneToManyConnecti
return (
field_
for field_ in self.fields_of_type(OneToManyConnectionField)
if field_.is_direct_relation and field_.end_classes
if field_.is_direct_relation and field_.destination_class
)

@property
def one_to_many_reverse_direct_relations(self) -> Iterable[OneToManyConnectionField]:
return (
field_
for field_ in self.fields_of_type(OneToManyConnectionField)
if field_.is_reverse_direct_relation and field_.end_classes
if field_.is_reverse_direct_relation and field_.destination_class
)

@property
Expand Down Expand Up @@ -510,7 +513,7 @@ def is_edge_class(self) -> bool:

@property
def connections_docs_write(self) -> str:
connections = [f for f in self.fields_of_type(BaseConnectionField) if f.end_classes and f.is_write_field] # type: ignore[type-abstract]
connections = [f for f in self.fields_of_type(BaseConnectionField) if f.destination_class and f.is_write_field] # type: ignore[type-abstract]
if len(connections) == 0:
raise ValueError("No connections found")
elif len(connections) == 1:
Expand All @@ -520,7 +523,7 @@ def connections_docs_write(self) -> str:

@property
def connections_docs(self) -> str:
connections = [f for f in self.fields_of_type(BaseConnectionField) if f.end_classes] # type: ignore[type-abstract]
connections = [f for f in self.fields_of_type(BaseConnectionField) if f.destination_class] # type: ignore[type-abstract]
if len(connections) == 0:
raise ValueError("No connections found")
elif len(connections) == 1:
Expand Down Expand Up @@ -564,6 +567,7 @@ class EdgeDataClass(DataClass):
"""This represent data class used for views marked as used_for='edge'."""

has_node_class: bool
_end_node_field: EndNodeField | None = None

@property
def typed_properties_name(self) -> str:
Expand All @@ -578,10 +582,9 @@ def is_edge_class(self) -> bool:

@property
def end_node_field(self) -> EndNodeField:
try:
return next(field_ for field_ in self.fields if isinstance(field_, EndNodeField))
except StopIteration:
raise ValueError("EdgeDataClass has not been initialized.") from None
if self._end_node_field:
return self._end_node_field
raise ValueError("EdgeDataClass has not been initialized.")

def update_fields(
self,
Expand All @@ -593,29 +596,33 @@ def update_fields(
):
# Find all node views that have an edge with properties in this view
# and get the node class it is pointing to.
edge_classes = []
edge_classes: dict[tuple[str, dm.DirectRelationReference, str], EdgeClass] = {}
for view in views:
view_id = view.as_id()
if view_id not in node_class_by_view_id:
continue
start_class = node_class_by_view_id[view_id]
for _prop_name, prop in view.properties.items():
source_class = node_class_by_view_id[view_id]
for prop in view.properties.values():
if isinstance(prop, dm.EdgeConnection) and prop.edge_source == self.view_id:
end_class = node_class_by_view_id[prop.source]
start, end = (start_class, end_class) if prop.direction == "outwards" else (end_class, start_class)

new_edge_class = EdgeClasses(start, prop.type, end)
if new_edge_class not in edge_classes:
edge_classes.append(new_edge_class)

self.fields.append(
EndNodeField(
name="end_node",
doc_name="end node",
prop_name="end_node",
description="The end node of this edge.",
pydantic_field="Field",
edge_classes=edge_classes,
)
destination_class = node_class_by_view_id[prop.source]
start, end = (
(source_class, destination_class)
if prop.direction == "outwards"
else (destination_class, source_class)
)
identifier = start.read_name, prop.type, end.read_name
if edge_class := edge_classes.get(identifier):
edge_class.used_directions.add(prop.direction)
else:
edge_classes[identifier] = EdgeClass(start, prop.type, end, {prop.direction})

self._end_node_field = EndNodeField(
name="end_node",
doc_name="end node",
prop_name="end_node",
description="The end node of this edge.",
pydantic_field="Field",
edge_classes=list(edge_classes.values()),
)
self.fields.append(self._end_node_field)
super().update_fields(properties, node_class_by_view_id, edge_class_by_view_id, views, config)
4 changes: 2 additions & 2 deletions cognite/pygen/_core/models/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .cdf_reference import CDFExternalField, CDFExternalListField
from .connections import (
BaseConnectionField,
EdgeClasses,
EdgeClass,
EndNodeField,
OneToManyConnectionField,
OneToOneConnectionField,
Expand All @@ -21,7 +21,7 @@
"CDFExternalField",
"CDFExternalListField",
"EndNodeField",
"EdgeClasses",
"EdgeClass",
"T_Field",
"BaseConnectionField",
"OneToOneConnectionField",
Expand Down
Loading

0 comments on commit b33a652

Please sign in to comment.