Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PYG-173, PYG-190, PYG-200] 🐻Edge with Properties Consistency #275

Merged
merged 18 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 10 additions & 14 deletions cognite/pygen/_core/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def create_edge_apis(
self._config,
)
for field in self.data_class.fields_of_type(fields.BaseConnectionField) # type: ignore[type-abstract]
if not field.is_direct_relation
if field.is_edge
]

@property
Expand All @@ -625,7 +625,7 @@ def generate_data_class_file(self, is_pydantic_v2: bool) -> str:
"""
unique_start_classes = []
unique_end_classes = []
grouped_edge_classes = []
grouped_edge_classes: dict[str, list[str]] = {}
if isinstance(self.data_class, NodeDataClass):
type_data = self._env.get_template("data_class_node.py.jinja")
elif isinstance(self.data_class, EdgeDataClass):
Expand All @@ -636,20 +636,17 @@ def generate_data_class_file(self, is_pydantic_v2: bool) -> str:
unique_end_classes = sorted(
_unique_data_classes([edge.end_class for edge in self.data_class.end_node_field.edge_classes])
)

grouped_edge_classes = [
(key, list(group))
for key, group in itertools.groupby(
sorted(self.data_class.end_node_field.edge_classes), key=lambda c: c.end_class
)
]
_grouped_edge_classes: dict[str, set[str]] = defaultdict(set)
for edge_class in self.data_class.end_node_field.edge_classes:
if "outwards" in edge_class.used_directions:
_grouped_edge_classes[edge_class.end_class.write_name].add(edge_class.start_class.write_name)
elif "inwards" in edge_class.used_directions:
_grouped_edge_classes[edge_class.start_class.write_name].add(edge_class.end_class.write_name)
for start_class, end_classes in sorted(_grouped_edge_classes.items(), key=lambda x: x[0]):
grouped_edge_classes[start_class] = sorted(end_classes)
else:
raise ValueError(f"Unknown data class {type(self.data_class)}")

def create_start_node_set(group: list[EdgeAPIClass]) -> str:
joined = ", ".join([g.start_class.write_name for g in group])
return f"{{{joined}}}"

if is_pydantic_v2 and self.data_class.has_any_field_model_prefix:
names = ", ".join(field.name for field in self.data_class.fields if field.name.startswith("name"))
warnings.warn(
Expand All @@ -671,7 +668,6 @@ def create_start_node_set(group: list[EdgeAPIClass]) -> str:
unique_start_classes=unique_start_classes,
unique_end_classes=unique_end_classes,
grouped_edge_classes=grouped_edge_classes,
create_start_node_set=create_start_node_set,
)
+ "\n"
)
Expand Down
16 changes: 4 additions & 12 deletions cognite/pygen/_core/models/api_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from cognite.pygen import config as pygen_config
from cognite.pygen.utils.text import create_name

from .data_classes import DataClass, EdgeDataClass
from .data_classes import DataClass, EdgeDataClass, NodeDataClass
from .fields import BaseConnectionField, CDFExternalField
from .filter_methods import FilterMethod, FilterParameter

Expand Down Expand Up @@ -159,21 +159,13 @@ def from_fields(
parent_attribute = create_name(field.name, api_class.client_attribute)

# This is always true for Edge Connection Fields
field_end_class = cast(DataClass, field.end_classes[0]) # type: ignore[index]
end_class = cast(NodeDataClass, field.destination_class)

edge_class: EdgeDataClass | None = None
end_class: DataClass
if isinstance(field_end_class, EdgeDataClass):
edge_class = field_end_class
try:
end_class = next(
c.end_class for c in edge_class.end_node_field.edge_classes if c.edge_type == field.edge_type
)
except StopIteration:
raise ValueError(f"Could not find end class {field_end_class.view_id}") from None
if field.edge_class:
edge_class = field.edge_class
filter_method = FilterMethod.from_fields(edge_class.fields, pygen_config.filtering, is_edge_class=True)
else: # NodeDataClass
end_class = field_end_class
filter_method = FilterMethod.from_fields([], pygen_config.filtering)

return cls(
Expand Down
99 changes: 53 additions & 46 deletions cognite/pygen/_core/models/data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
BaseConnectionField,
BasePrimitiveField,
CDFExternalField,
EdgeClasses,
EdgeClass,
EndNodeField,
Field,
OneToManyConnectionField,
Expand Down Expand Up @@ -338,35 +338,38 @@ def dependencies(self) -> list[DataClass]:
unique: dict[dm.ViewId, DataClass] = {}
for field_ in self.fields:
if isinstance(field_, BaseConnectionField):
for class_ in field_.end_classes or []:
if field_.edge_class:
unique[field_.edge_class.view_id] = field_.edge_class
elif field_.destination_class:
# This will overwrite any existing data class with the same view id
# however, this is not a problem as all data classes are uniquely identified by their view id
unique[class_.view_id] = class_
unique[field_.destination_class.view_id] = field_.destination_class
elif isinstance(field_, EndNodeField):
for class_ in field_.end_classes:
for class_ in field_.destination_classes:
unique[class_.view_id] = class_

return sorted(unique.values(), key=lambda x: x.write_name)

@property
def dependencies_with_edge_destinations(self) -> list[DataClass]:
"""Return a list of all dependencies which also includes the edge
destination if th dependency is a EdgeClass."""
destination if the dependency is a EdgeClass."""
unique: dict[dm.ViewId, DataClass] = {}
for field_ in self.fields:
if isinstance(field_, BaseConnectionField):
for class_ in field_.end_classes or []:
if field_.destination_class:
# This will overwrite any existing data class with the same view id
# however, this is not a problem as all data classes are uniquely identified by their view id
unique[class_.view_id] = class_
if isinstance(class_, EdgeDataClass):
for edge_class in class_.end_node_field.edge_classes:
if field_.edge_direction == "outwards":
unique[edge_class.end_class.view_id] = edge_class.end_class
else:
unique[edge_class.start_class.view_id] = edge_class.start_class
unique[field_.destination_class.view_id] = field_.destination_class
if field_.edge_class:
unique[field_.edge_class.view_id] = field_.edge_class
for edge_class in field_.edge_class.end_node_field.edge_classes:
if field_.edge_direction == "outwards":
unique[edge_class.end_class.view_id] = edge_class.end_class
else:
unique[edge_class.start_class.view_id] = edge_class.start_class
elif isinstance(field_, EndNodeField):
for class_ in field_.end_classes:
for class_ in field_.destination_classes:
unique[class_.view_id] = class_

return sorted(unique.values(), key=lambda x: x.read_name)
Expand Down Expand Up @@ -425,33 +428,33 @@ def has_container_fields(self) -> bool:
@property
def one_to_many_edges_without_properties(self) -> Iterable[OneToManyConnectionField]:
"""All MultiEdges without properties on the edge."""
return (field_ for field_ in self.fields_of_type(OneToManyConnectionField) if field_.is_no_property_edge)
return (field_ for field_ in self.fields_of_type(OneToManyConnectionField) if field_.is_edge_without_properties)

@property
def one_to_one_edge_without_properties(self) -> Iterable[OneToOneConnectionField]:
"""All MultiEdges without properties on the edge."""
return (field_ for field_ in self.fields_of_type(OneToOneConnectionField) if field_.is_no_property_edge)
return (field_ for field_ in self.fields_of_type(OneToOneConnectionField) if field_.is_edge_without_properties)

@property
def one_to_many_edges_with_properties(self) -> Iterable[OneToManyConnectionField]:
"""All MultiEdges with properties on the edge."""
return (field_ for field_ in self.fields_of_type(OneToManyConnectionField) if field_.is_property_edge)
return (field_ for field_ in self.fields_of_type(OneToManyConnectionField) if field_.is_edge_with_properties)

@property
def one_to_one_direct_relations_with_source(self) -> Iterable[OneToOneConnectionField]:
"""All direct relations."""
return (
field_
for field_ in self.fields_of_type(OneToOneConnectionField)
if field_.is_direct_relation and field_.end_classes
if field_.is_direct_relation and field_.destination_class
)

@property
def one_to_one_reverse_direct_relation(self) -> Iterable[OneToOneConnectionField]:
return (
field_
for field_ in self.fields_of_type(OneToOneConnectionField)
if field_.is_reverse_direct_relation and field_.end_classes
if field_.is_reverse_direct_relation and field_.destination_class
)

@property
Expand All @@ -460,15 +463,15 @@ def one_to_many_direct_relations_with_source(self) -> Iterable[OneToManyConnecti
return (
field_
for field_ in self.fields_of_type(OneToManyConnectionField)
if field_.is_direct_relation and field_.end_classes
if field_.is_direct_relation and field_.destination_class
)

@property
def one_to_many_reverse_direct_relations(self) -> Iterable[OneToManyConnectionField]:
return (
field_
for field_ in self.fields_of_type(OneToManyConnectionField)
if field_.is_reverse_direct_relation and field_.end_classes
if field_.is_reverse_direct_relation and field_.destination_class
)

@property
Expand Down Expand Up @@ -510,7 +513,7 @@ def is_edge_class(self) -> bool:

@property
def connections_docs_write(self) -> str:
connections = [f for f in self.fields_of_type(BaseConnectionField) if f.end_classes and f.is_write_field] # type: ignore[type-abstract]
connections = [f for f in self.fields_of_type(BaseConnectionField) if f.destination_class and f.is_write_field] # type: ignore[type-abstract]
if len(connections) == 0:
raise ValueError("No connections found")
elif len(connections) == 1:
Expand All @@ -520,7 +523,7 @@ def connections_docs_write(self) -> str:

@property
def connections_docs(self) -> str:
connections = [f for f in self.fields_of_type(BaseConnectionField) if f.end_classes] # type: ignore[type-abstract]
connections = [f for f in self.fields_of_type(BaseConnectionField) if f.destination_class] # type: ignore[type-abstract]
if len(connections) == 0:
raise ValueError("No connections found")
elif len(connections) == 1:
Expand Down Expand Up @@ -564,6 +567,7 @@ class EdgeDataClass(DataClass):
"""This represent data class used for views marked as used_for='edge'."""

has_node_class: bool
_end_node_field: EndNodeField | None = None

@property
def typed_properties_name(self) -> str:
Expand All @@ -578,10 +582,9 @@ def is_edge_class(self) -> bool:

@property
def end_node_field(self) -> EndNodeField:
try:
return next(field_ for field_ in self.fields if isinstance(field_, EndNodeField))
except StopIteration:
raise ValueError("EdgeDataClass has not been initialized.") from None
if self._end_node_field:
return self._end_node_field
raise ValueError("EdgeDataClass has not been initialized.")

def update_fields(
self,
Expand All @@ -593,29 +596,33 @@ def update_fields(
):
# Find all node views that have an edge with properties in this view
# and get the node class it is pointing to.
edge_classes = []
edge_classes: dict[tuple[str, dm.DirectRelationReference, str], EdgeClass] = {}
for view in views:
view_id = view.as_id()
if view_id not in node_class_by_view_id:
continue
start_class = node_class_by_view_id[view_id]
for _prop_name, prop in view.properties.items():
source_class = node_class_by_view_id[view_id]
for prop in view.properties.values():
if isinstance(prop, dm.EdgeConnection) and prop.edge_source == self.view_id:
end_class = node_class_by_view_id[prop.source]
start, end = (start_class, end_class) if prop.direction == "outwards" else (end_class, start_class)

new_edge_class = EdgeClasses(start, prop.type, end)
if new_edge_class not in edge_classes:
edge_classes.append(new_edge_class)

self.fields.append(
EndNodeField(
name="end_node",
doc_name="end node",
prop_name="end_node",
description="The end node of this edge.",
pydantic_field="Field",
edge_classes=edge_classes,
)
destination_class = node_class_by_view_id[prop.source]
start, end = (
(source_class, destination_class)
if prop.direction == "outwards"
else (destination_class, source_class)
)
identifier = start.read_name, prop.type, end.read_name
if edge_class := edge_classes.get(identifier):
edge_class.used_directions.add(prop.direction)
else:
edge_classes[identifier] = EdgeClass(start, prop.type, end, {prop.direction})

self._end_node_field = EndNodeField(
name="end_node",
doc_name="end node",
prop_name="end_node",
description="The end node of this edge.",
pydantic_field="Field",
edge_classes=list(edge_classes.values()),
)
self.fields.append(self._end_node_field)
super().update_fields(properties, node_class_by_view_id, edge_class_by_view_id, views, config)
4 changes: 2 additions & 2 deletions cognite/pygen/_core/models/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .cdf_reference import CDFExternalField, CDFExternalListField
from .connections import (
BaseConnectionField,
EdgeClasses,
EdgeClass,
EndNodeField,
OneToManyConnectionField,
OneToOneConnectionField,
Expand All @@ -21,7 +21,7 @@
"CDFExternalField",
"CDFExternalListField",
"EndNodeField",
"EdgeClasses",
"EdgeClass",
"T_Field",
"BaseConnectionField",
"OneToOneConnectionField",
Expand Down
Loading