Skip to content

Commit

Permalink
handle vector
Browse files Browse the repository at this point in the history
  • Loading branch information
adshidtadka committed Oct 26, 2023
1 parent 0cbd3c0 commit d81ff07
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions src/sqlacodegen/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.exc import CompileError
from sqlalchemy.sql.elements import TextClause
from pgvector.sqlalchemy import Vector

from .models import (
ColumnAttribute,
Expand Down Expand Up @@ -213,7 +214,9 @@ def collect_imports_for_model(self, model: Model) -> None:
self.collect_imports_for_constraint(index)

def collect_imports_for_column(self, column: Column[Any]) -> None:
self.add_import(column.type)
if not is_vector(column):
# NOTE(adshidtadka): from sqlalchemy.sql.sqltypes import NullType が追加されるのを避ける
self.add_import(column.type)

if isinstance(column.type, ARRAY):
self.add_import(column.type.item_type.__class__)
Expand Down Expand Up @@ -437,7 +440,10 @@ def render_column(
# Render the column type if there are no foreign keys on it or any of them
# points back to itself
if not dedicated_fks or any(fk.column is column for fk in dedicated_fks):
args.append(self.render_column_type(column.type))
if is_vector(column):
args.append("Vector")
else:
args.append(self.render_column_type(column.type))

for fk in dedicated_fks:
args.append(self.render_constraint(fk))
Expand Down Expand Up @@ -1230,8 +1236,12 @@ def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
column_python_type = f"{python_type_module}.{python_type_name}"
self.add_module_import(python_type_module)
except NotImplementedError:
self.add_literal_import("typing", "Any")
column_python_type = "Any"
if is_vector(column):
self.add_literal_import("pgvector.sqlalchemy", "Vector")
column_python_type = "Vector"
else:
self.add_literal_import("typing", "Any")
column_python_type = "Any"

if column.nullable:
self.add_literal_import("typing", "Optional")
Expand Down Expand Up @@ -1661,3 +1671,6 @@ def render_relationship_args(self, arguments: str) -> list[str]:
rendered_args.append("sa_relationship_kwargs={'uselist': False}")

return rendered_args

def is_vector(column: Column):
return column.type.__class__.__name__ == "NullType" and "vector" in column.key

This comment has been minimized.

Copy link
@hirayu

hirayu Oct 26, 2023

Member

descriptionの方が良いかも

0 comments on commit d81ff07

Please sign in to comment.