diff --git a/server/implementation/src/main/java/io/smallrye/graphql/bootstrap/FederationDataFetcher.java b/server/implementation/src/main/java/io/smallrye/graphql/bootstrap/FederationDataFetcher.java index 90d93c02d..7595af381 100644 --- a/server/implementation/src/main/java/io/smallrye/graphql/bootstrap/FederationDataFetcher.java +++ b/server/implementation/src/main/java/io/smallrye/graphql/bootstrap/FederationDataFetcher.java @@ -1,15 +1,17 @@ package io.smallrye.graphql.bootstrap; -import static java.util.stream.Collectors.toList; import static java.util.stream.Collectors.toSet; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; import com.apollographql.federation.graphqljava._Entity; +import graphql.execution.Async; import graphql.schema.DataFetcher; import graphql.schema.DataFetchingEnvironment; import graphql.schema.DelegatingDataFetchingEnvironment; @@ -20,7 +22,7 @@ import graphql.schema.GraphQLObjectType; import graphql.schema.GraphQLOutputType; -class FederationDataFetcher implements DataFetcher> { +class FederationDataFetcher implements DataFetcher>> { private final GraphQLObjectType queryType; private final GraphQLCodeRegistry codeRegistry; @@ -31,10 +33,11 @@ public FederationDataFetcher(GraphQLObjectType queryType, GraphQLCodeRegistry co } @Override - public List get(DataFetchingEnvironment environment) throws Exception { - return environment.>> getArgument(_Entity.argumentName).stream() - .map(representations -> fetchEntities(environment, representations)) - .collect(toList()); + public CompletableFuture> get(DataFetchingEnvironment environment) throws Exception { + return sequence(environment.>> getArgument(_Entity.argumentName).stream() + .map(representations -> fetchEntities(environment, representations)).map(Async::toCompletableFuture) + .collect(Collectors.toList())); + } private Object fetchEntities(DataFetchingEnvironment env, Map representations) { @@ -90,4 +93,11 @@ public T getArgumentOrDefault(String name, T defaultValue) { throw new RuntimeException("can't fetch data from " + field, e); } } + + static CompletableFuture> sequence(List> com) { + return CompletableFuture.allOf(com.toArray(new CompletableFuture[0])) + .thenApply(v -> com.stream() + .map(CompletableFuture::join) + .collect(Collectors.toList())); + } }