Skip to content

Commit

Permalink
Fix for federation data fetcher correct support namespaces (#2172)
Browse files Browse the repository at this point in the history
* Add group support for federation

* add cache and recursive search for matching field definition when using namespace

* Fix field selecting for namespace and non namespaces types

* Added tests for federation namespaces

* Fix federation namespace test

---------

Co-authored-by: Roman Lovakov <rlovakov@nota.tech>
  • Loading branch information
RoMiRoSSaN and Roman Lovakov authored Sep 4, 2024
1 parent c9f4649 commit 60ff734
Show file tree
Hide file tree
Showing 11 changed files with 467 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
Expand All @@ -29,6 +30,7 @@
import graphql.schema.GraphQLNamedSchemaElement;
import graphql.schema.GraphQLNonNull;
import graphql.schema.GraphQLObjectType;
import graphql.schema.GraphQLSchemaElement;
import graphql.schema.GraphQLType;
import io.smallrye.graphql.spi.config.Config;

Expand All @@ -37,6 +39,7 @@ public class FederationDataFetcher implements DataFetcher<CompletableFuture<List
public static final String TYPENAME = "__typename";
private final GraphQLObjectType queryType;
private final GraphQLCodeRegistry codeRegistry;
private final HashMap<TypeAndArgumentNames, TypeFieldWrapper> cache = new HashMap<>();

public FederationDataFetcher(GraphQLObjectType queryType, GraphQLCodeRegistry codeRegistry) {
this.queryType = queryType;
Expand All @@ -55,17 +58,13 @@ public CompletableFuture<List<Object>> get(DataFetchingEnvironment environment)
var repsWithPositionPerType = representations.stream().collect(Collectors.groupingBy(r -> r.typeAndArgumentNames));
//then we search for the field definition to resolve the objects
var fieldDefinitions = repsWithPositionPerType.keySet().stream()
.collect(Collectors.toMap(Function.identity(), typeAndArgumentNames -> {
var batchDefinition = findBatchFieldDefinition(typeAndArgumentNames);
if (batchDefinition == null) {
return findFieldDefinition(typeAndArgumentNames);
} else {
return batchDefinition;
}
}));
.collect(Collectors.toMap(Function.identity(), typeAndArgumentNames -> cache.computeIfAbsent(
typeAndArgumentNames, type -> Objects.requireNonNullElseGet(
findBatchFieldDefinition(type),
() -> findFieldDefinition(type)))));
return sequence(repsWithPositionPerType.entrySet().stream().map(e -> {
var fieldDefinition = fieldDefinitions.get(e.getKey());
if (getGraphqlTypeFromField(fieldDefinition) instanceof GraphQLList) {
if (getGraphqlTypeFromField(fieldDefinition.getField()) instanceof GraphQLList) {
//use batch loader if available
return executeList(fieldDefinition, environment, e.getValue());
} else {
Expand All @@ -79,35 +78,67 @@ public CompletableFuture<List<Object>> get(DataFetchingEnvironment environment)
.sorted(Comparator.comparingInt(r -> r.position)).map(r -> r.Result).collect(Collectors.toList()));

}
Map<TypeAndArgumentNames, GraphQLFieldDefinition> cache = new HashMap<>();
return sequence(representations.stream()
.map(rep -> fetchEntities(environment, rep,
cache.computeIfAbsent(rep.typeAndArgumentNames, this::findFieldDefinition)))
.collect(Collectors.toList())).thenApply(l -> l.stream().map(r -> r.Result).collect(Collectors.toList()));
}

private GraphQLFieldDefinition findBatchFieldDefinition(TypeAndArgumentNames typeAndArgumentNames) {
private TypeFieldWrapper findRecursiveFieldDefinition(TypeAndArgumentNames typeAndArgumentNames,
GraphQLFieldDefinition field, BiFunction<GraphQLFieldDefinition, String, Boolean> matchesReturnType) {
if (field.getType() instanceof GraphQLObjectType) {
for (GraphQLSchemaElement child : field.getType().getChildren()) {
if (child instanceof GraphQLFieldDefinition) {
GraphQLFieldDefinition definition = (GraphQLFieldDefinition) child;
if (matchesReturnType.apply(definition, typeAndArgumentNames.type)
&& matchesArguments(typeAndArgumentNames, definition)) {
return new TypeFieldWrapper((GraphQLObjectType) field.getType(), definition);
} else if (definition.getType() instanceof GraphQLObjectType) {
return findRecursiveFieldDefinition(typeAndArgumentNames, definition, matchesReturnType);
}
}
}
}
return null;
}

private TypeFieldWrapper findBatchFieldDefinition(TypeAndArgumentNames typeAndArgumentNames) {
for (GraphQLFieldDefinition field : queryType.getFields()) {
if (matchesReturnTypeList(field, typeAndArgumentNames.type) && matchesArguments(typeAndArgumentNames, field)) {
return field;
return new TypeFieldWrapper(queryType, field);
}
}
for (GraphQLFieldDefinition field : queryType.getFields()) {
TypeFieldWrapper typeFieldWrapper = findRecursiveFieldDefinition(typeAndArgumentNames, field,
this::matchesReturnTypeList);
if (typeFieldWrapper != null) {
return typeFieldWrapper;
}
}
return null;
}

private GraphQLFieldDefinition findFieldDefinition(TypeAndArgumentNames typeAndArgumentNames) {
private TypeFieldWrapper findFieldDefinition(TypeAndArgumentNames typeAndArgumentNames) {
for (GraphQLFieldDefinition field : queryType.getFields()) {
if (matchesReturnType(field, typeAndArgumentNames.type) && matchesArguments(typeAndArgumentNames, field)) {
return field;
return new TypeFieldWrapper(queryType, field);
}
}
for (GraphQLFieldDefinition field : queryType.getFields()) {
TypeFieldWrapper typeFieldWrapper = findRecursiveFieldDefinition(typeAndArgumentNames, field,
this::matchesReturnType);
if (typeFieldWrapper != null) {
return typeFieldWrapper;
}
}

throw new RuntimeException(
"no query found for " + typeAndArgumentNames.type + " by " + typeAndArgumentNames.argumentNames);
}

private CompletableFuture<ResultObject> fetchEntities(DataFetchingEnvironment env, Representation representation,
GraphQLFieldDefinition field) {
return execute(field, env, representation);
TypeFieldWrapper wrapper) {
return execute(wrapper, env, representation);
}

private boolean matchesReturnType(GraphQLFieldDefinition field, String typename) {
Expand Down Expand Up @@ -140,9 +171,9 @@ private boolean matchesArguments(TypeAndArgumentNames typeAndArgumentNames, Grap
return argumentNames.equals(typeAndArgumentNames.argumentNames);
}

private CompletableFuture<List<ResultObject>> executeList(GraphQLFieldDefinition field, DataFetchingEnvironment env,
private CompletableFuture<List<ResultObject>> executeList(TypeFieldWrapper wrapper, DataFetchingEnvironment env,
List<Representation> representations) {
DataFetcher<?> dataFetcher = codeRegistry.getDataFetcher(queryType, field);
DataFetcher<?> dataFetcher = codeRegistry.getDataFetcher(wrapper.getType(), wrapper.getField());
Map<String, List<Object>> arguments = new HashMap<>();
representations.forEach(r -> {
r.arguments.forEach((argumentName, argumentValue) -> {
Expand Down Expand Up @@ -183,7 +214,7 @@ public <T> T getArgumentOrDefault(String name, T defaultValue) {
resultList = (List<Object>) results;
} else {
throw new IllegalStateException(
"Result of batchDataFetcher for Field " + field.getName() + " needs to be a list"
"Result of batchDataFetcher for Field " + wrapper.getField().getName() + " needs to be a list"
+ results.toString());
}

Expand All @@ -197,13 +228,13 @@ public <T> T getArgumentOrDefault(String name, T defaultValue) {
.collect(Collectors.toList());
});
} catch (Exception e) {
throw new RuntimeException("can't fetch data from " + field, e);
throw new RuntimeException("can't fetch data from " + wrapper.getField(), e);
}
}

private CompletableFuture<ResultObject> execute(GraphQLFieldDefinition field, DataFetchingEnvironment env,
private CompletableFuture<ResultObject> execute(TypeFieldWrapper wrapper, DataFetchingEnvironment env,
Representation representation) {
DataFetcher<?> dataFetcher = codeRegistry.getDataFetcher(queryType, field);
DataFetcher<?> dataFetcher = codeRegistry.getDataFetcher(wrapper.getType(), wrapper.getField());
DataFetchingEnvironment argsEnv = new DelegatingDataFetchingEnvironment(env) {
@Override
public Map<String, Object> getArguments() {
Expand All @@ -230,7 +261,7 @@ public <T> T getArgumentOrDefault(String name, T defaultValue) {
return Async.toCompletableFuture(dataFetcher.get(argsEnv))
.thenApply(o -> new ResultObject(o, representation.position));
} catch (Exception e) {
throw new RuntimeException("can't fetch data from " + field, e);
throw new RuntimeException("can't fetch data from " + wrapper.getField(), e);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package io.smallrye.graphql.bootstrap;

import graphql.schema.GraphQLFieldDefinition;
import graphql.schema.GraphQLObjectType;

class TypeFieldWrapper {
private final GraphQLObjectType type;
private final GraphQLFieldDefinition field;

public TypeFieldWrapper(GraphQLObjectType type, GraphQLFieldDefinition field) {
this.type = type;
this.field = field;
}

public GraphQLObjectType getType() {
return type;
}

public GraphQLFieldDefinition getField() {
return field;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package io.smallrye.graphql.execution;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

import java.io.IOException;
import java.io.InputStream;
import java.util.function.BiFunction;
import java.util.stream.Stream;

import jakarta.json.JsonObject;
import jakarta.json.JsonString;
import jakarta.json.JsonValue;

import org.jboss.jandex.IndexView;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import graphql.schema.GraphQLSchema;
import io.smallrye.graphql.api.Directive;
import io.smallrye.graphql.api.federation.External;
import io.smallrye.graphql.api.federation.Key;
import io.smallrye.graphql.bootstrap.Bootstrap;
import io.smallrye.graphql.schema.SchemaBuilder;
import io.smallrye.graphql.schema.model.Schema;
import io.smallrye.graphql.spi.config.Config;
import io.smallrye.graphql.test.namespace.NamedNamespaceModel;
import io.smallrye.graphql.test.namespace.NamedNamespaceTestApi;
import io.smallrye.graphql.test.namespace.NamedNamespaceWIthGroupingKeyModel;
import io.smallrye.graphql.test.namespace.NamedNamespaceWithGroupingKeyTestApi;
import io.smallrye.graphql.test.namespace.SourceNamespaceModel;
import io.smallrye.graphql.test.namespace.SourceNamespaceTestApi;
import io.smallrye.graphql.test.namespace.UnamedModel;
import io.smallrye.graphql.test.namespace.UnnamedTestApi;

/**
* Test for Federated namespaces
*/
public class FederatedNamespaceTest {
private static final TestConfig config = (TestConfig) Config.get();
private static ExecutionService executionService;

@AfterAll
static void afterAll() {
config.reset();
config.federationEnabled = false;
System.setProperty("smallrye.graphql.federation.enabled", "false");
}

@BeforeAll
static void beforeAll() {
config.federationEnabled = true;
System.setProperty("smallrye.graphql.federation.enabled", "true");

IndexView index = buildIndex(Directive.class, Key.class, External.class, Key.Keys.class,
NamedNamespaceModel.class, NamedNamespaceTestApi.class,
NamedNamespaceWIthGroupingKeyModel.class, NamedNamespaceWithGroupingKeyTestApi.class,
SourceNamespaceModel.class, SourceNamespaceTestApi.class,
SourceNamespaceTestApi.First.class, SourceNamespaceTestApi.Second.class,
UnamedModel.class, UnnamedTestApi.class);

GraphQLSchema graphQLSchema = createGraphQLSchema(index);
Schema schema = SchemaBuilder.build(index);
executionService = new ExecutionService(graphQLSchema, schema);
}

private static IndexView buildIndex(Class<?>... classes) {
org.jboss.jandex.Indexer indexer = new org.jboss.jandex.Indexer();
Stream.of(classes).forEach(cls -> index(indexer, cls));
return indexer.complete();
}

private static InputStream getResourceStream(Class<?> type) {
String name = type.getName().replace(".", "/") + ".class";
return Thread.currentThread().getContextClassLoader().getResourceAsStream(name);
}

private static void index(org.jboss.jandex.Indexer indexer, Class<?> cls) {
try {
indexer.index(getResourceStream(cls));
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private static GraphQLSchema createGraphQLSchema(IndexView index) {
Schema schema = SchemaBuilder.build(index);
assertNotNull(schema, "Schema should not be null");
GraphQLSchema graphQLSchema = Bootstrap.bootstrap(schema, true);
assertNotNull(graphQLSchema, "GraphQLSchema should not be null");
return graphQLSchema;
}

private static JsonObject executeAndGetResult(String graphQL) {
JsonObjectResponseWriter jsonObjectResponseWriter = new JsonObjectResponseWriter(graphQL);
jsonObjectResponseWriter.logInput();
executionService.executeSync(jsonObjectResponseWriter.getInput(), jsonObjectResponseWriter);
jsonObjectResponseWriter.logOutput();
return jsonObjectResponseWriter.getOutput();
}

private void test(String type, String id) {
JsonObject jsonObject = executeAndGetResult(TEST_QUERY.apply(type, id));
assertNotNull(jsonObject);

JsonValue jsonValue = jsonObject.getJsonObject("data")
.getJsonArray("_entities")
.getJsonObject(0)
.get("value");
String value = ((JsonString) jsonValue).getString();
assertEquals(value, id);
}

@Test
public void findEntityWithoutNamespace() {
test(UnamedModel.class.getSimpleName(), "unnamed_id");
}

@Test
public void findEntityWithNameNamespace() {
test(NamedNamespaceModel.class.getSimpleName(), "named_id");
}

@Test
public void findEntityWithSourceNamespace() {
test(SourceNamespaceModel.class.getSimpleName(), "source_id");
}

@Test
public void findEntityWithWithGroupedKeyAndNamespace() {
String id = "grouped_key";

JsonObject jsonObject = executeAndGetResult(GROUPED_KEY_QUERY.apply(
NamedNamespaceWIthGroupingKeyModel.class.getSimpleName(),
id));
assertNotNull(jsonObject);

JsonValue jsonValue = jsonObject.getJsonObject("data")
.getJsonArray("_entities")
.getJsonObject(0)
.get("value");
String value = ((JsonString) jsonValue).getString();
assertEquals(value, id);

jsonValue = jsonObject.getJsonObject("data")
.getJsonArray("_entities")
.getJsonObject(0)
.get("anotherId");
String anotherId = ((JsonString) jsonValue).getString();
assertEquals(anotherId, "otherKey_" + id);
}

private static final BiFunction<String, String, String> GROUPED_KEY_QUERY = (type, id) -> "query {\n" +
"_entities(\n" +
" representations: { id: \"" + id + "\", anotherId : \"otherKey_" + id + "\", __typename: \"" + type + "\" }\n" +
") {\n" +
" __typename\n" +
" ... on " + type + " {\n" +
" id\n" +
" anotherId\n" +
" value\n" +
" }\n" +
" }\n" +
"}";

private static final BiFunction<String, String, String> TEST_QUERY = (type, id) -> "query {\n" +
"_entities(\n" +
" representations: { id: \"" + id + "\", __typename: \"" + type + "\" }\n" +
") {\n" +
" __typename\n" +
" ... on " + type + " {\n" +
" id\n" +
" value\n" +
" }\n" +
" }\n" +
"}";
}
Loading

0 comments on commit 60ff734

Please sign in to comment.