Skip to content

Commit

Permalink
Instanceof pattern: add type parameters if type casts use type param
Browse files Browse the repository at this point in the history
Correction
  • Loading branch information
BoykoAlex committed Sep 13, 2024
1 parent abb0338 commit 625b89c
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ private static class InstanceOfPatternReplacements {
private final Map<J.InstanceOf, Set<Cursor>> contextScopes = new HashMap<>();
private final Map<J.TypeCast, J.InstanceOf> replacements = new HashMap<>();
private final Map<J.InstanceOf, J.VariableDeclarations.NamedVariable> variablesToDelete = new HashMap<>();
private final Map<J.VariableDeclarations.NamedVariable, J.TypeCast> variableTypeCast = new HashMap<>();

public void registerInstanceOf(J.InstanceOf instanceOf, Set<J> contexts) {
Expression expression = instanceOf.getExpression();
Expand Down Expand Up @@ -195,6 +196,7 @@ public void registerTypeCast(J.TypeCast typeCast, Cursor cursor) {
if (parent.getValue() instanceof J.VariableDeclarations.NamedVariable
&& !variablesToDelete.containsKey(instanceOf)) {
variablesToDelete.put(instanceOf, parent.getValue());
variableTypeCast.put(parent.getValue(), typeCast);
} else {
replacements.put(typeCast, instanceOf);
}
Expand All @@ -215,6 +217,19 @@ public J.InstanceOf processInstanceOf(J.InstanceOf instanceOf, Cursor cursor) {
if (!contextScopes.containsKey(instanceOf)) {
return instanceOf;
}

// Find a single replacement type tree based on tyep cast expressions matching the instanceof expression
TypeTree typeTree = computeReplacementTypeTree(instanceOf);

// Either type casts to different type of expressions have been found or no type casts at all.
// Nothing to do in this case. Leave the code as is to stay on the safe side
if (typeTree == null) {
variablesToDelete.remove(instanceOf);
variableTypeCast.remove(instanceOf);
replacements.entrySet().removeIf(e -> e.getValue() == instanceOf);
return instanceOf;
}

@Nullable JavaType type = ((TypedTree) instanceOf.getClazz()).getType();
String name = patternVariableName(instanceOf, cursor);
J.InstanceOf result = instanceOf.withPattern(new J.Identifier(
Expand All @@ -226,23 +241,25 @@ public J.InstanceOf processInstanceOf(J.InstanceOf instanceOf, Cursor cursor) {
type,
null));
JavaType.FullyQualified fqType = TypeUtils.asFullyQualified(type);
if (fqType != null && !fqType.getTypeParameters().isEmpty() && !(instanceOf.getClazz() instanceof J.ParameterizedType)) {
TypedTree oldTypeTree = (TypedTree) instanceOf.getClazz();

// Each type parameter is turned into a wildcard, i.e. `List` -> `List<?>` or `Map.Entry` -> `Map.Entry<?,?>`
List<Expression> wildcardsList = IntStream.range(0, fqType.getTypeParameters().size())
.mapToObj(i -> new J.Wildcard(randomId(), Space.EMPTY, Markers.EMPTY, null, null))
.collect(Collectors.toList());

J.ParameterizedType newTypeTree = new J.ParameterizedType(
randomId(),
oldTypeTree.getPrefix(),
Markers.EMPTY,
oldTypeTree.withPrefix(Space.EMPTY),
null,
oldTypeTree.getType()
).withTypeParameters(wildcardsList);
result = result.withClazz(newTypeTree);
// Check if type parameters (i.e. <?>) should be added. Check if type cast type has type parameters. If yes add the type parameters wildcards.
if (typeTree instanceof J.ParameterizedType && !(instanceOf.getClazz() instanceof J.ParameterizedType)) {
TypedTree oldTypeTree = (TypedTree) instanceOf.getClazz();

// Each type parameter is turned into a wildcard, i.e. `List` -> `List<?>` or `Map.Entry` -> `Map.Entry<?,?>`
List<Expression> wildcardsList = IntStream.range(0, fqType.getTypeParameters().size())
.mapToObj(i -> new J.Wildcard(randomId(), Space.EMPTY, Markers.EMPTY, null, null))
.collect(Collectors.toList());

J.ParameterizedType newTypeTree = new J.ParameterizedType(
randomId(),
oldTypeTree.getPrefix(),
Markers.EMPTY,
oldTypeTree.withPrefix(Space.EMPTY),
null,
oldTypeTree.getType()
).withTypeParameters(wildcardsList);
result = result.withClazz(newTypeTree);
}

// update entry in replacements to share the pattern variable name
Expand All @@ -254,6 +271,34 @@ public J.InstanceOf processInstanceOf(J.InstanceOf instanceOf, Cursor cursor) {
return result;
}

private TypeTree computeReplacementTypeTree(J.InstanceOf instanceOf) {
TypeTree currentType = null;
for (Map.Entry<J.TypeCast, J.InstanceOf> entry : replacements.entrySet()) {
if (entry.getValue() == instanceOf) {
TypeTree typeTree = entry.getKey().getClazz().getTree();
if (typeTree != null && typeTree.getType() != null) {
if (currentType == null) {
currentType = entry.getKey().getClazz().getTree();
} else if (!typeTree.getType().equals(currentType.getType())) {
return null;
}
}
}
}
J.VariableDeclarations.NamedVariable v = variablesToDelete.get(instanceOf);
if (v != null) {
TypeTree typeTree = variableTypeCast.get(v).getClazz().getTree();
if (typeTree != null && typeTree.getType() != null) {
if (currentType == null) {
currentType = variableTypeCast.get(v).getClazz().getTree();
} else if (!typeTree.getType().equals(currentType.getType())) {
return null;
}
}
}
return currentType;
}

private String patternVariableName(J.InstanceOf instanceOf, Cursor cursor) {
VariableNameStrategy strategy;
if (root instanceof J.If) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ void test(Object o) {
}

@Test
void genericsWithoutParameters() {
void genericsWithoutParameters_1() {
rewriteRun(
//language=java
java(
Expand Down Expand Up @@ -190,6 +190,75 @@ public static List<Map<String, Object>> applyRoutesType(Object routes) {
);
}

@Test
void genericsWithoutParameters_2() {
rewriteRun(
//language=java
java(
"""
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public class A {
@SuppressWarnings("unchecked")
public static List<Map<String, Object>> applyRoutesType(Object routes) {
if (routes instanceof List) {
List routesList = (List) routes;
if (routesList.isEmpty()) {
return Collections.emptyList();
}
}
return Collections.emptyList();
}
}
""",
"""
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public class A {
@SuppressWarnings("unchecked")
public static List<Map<String, Object>> applyRoutesType(Object routes) {
if (routes instanceof List routesList) {
if (routesList.isEmpty()) {
return Collections.emptyList();
}
}
return Collections.emptyList();
}
}
"""
)
);
}

@Test
void genericsWithoutParameters_3() {
rewriteRun(
//language=java
java(
"""
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public class A {
@SuppressWarnings("unchecked")
public static List<Map<String, Object>> applyRoutesType(Object routes) {
if (routes instanceof List) {
List<Object> routesList = (List<Object>) routes;
String.join(",", (List) routes);
}
return Collections.emptyList();
}
}
"""
)
);
}

@Test
void primitiveArray() {
rewriteRun(
Expand Down Expand Up @@ -302,7 +371,7 @@ void test(Object o) {
public class A {
void test(Object o) {
Map.Entry entry = null;
if (o instanceof Map.Entry<?,?> entry1) {
if (o instanceof Map.Entry entry1) {
entry = entry1;
}
System.out.println(entry);
Expand Down Expand Up @@ -975,6 +1044,52 @@ String test(Object o) {
)
);
}
@Test
void iterableParameter() {
rewriteRun(
//language=java
java(
"""
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class ApplicationSecurityGroupsParameterHelper {
static final String APPLICATION_SECURITY_GROUPS = "application-security-groups";
public Map<String, Object> transformGatewayParameters(Map<String, Object> parameters) {
Map<String, Object> environment = new HashMap<>();
Object applicationSecurityGroups = parameters.get(APPLICATION_SECURITY_GROUPS);
if (applicationSecurityGroups instanceof List) {
environment.put(APPLICATION_SECURITY_GROUPS, String.join(",", (List) applicationSecurityGroups));
}
return environment;
}
}
""",
"""
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class ApplicationSecurityGroupsParameterHelper {
static final String APPLICATION_SECURITY_GROUPS = "application-security-groups";
public Map<String, Object> transformGatewayParameters(Map<String, Object> parameters) {
Map<String, Object> environment = new HashMap<>();
Object applicationSecurityGroups = parameters.get(APPLICATION_SECURITY_GROUPS);
if (applicationSecurityGroups instanceof List list) {
environment.put(APPLICATION_SECURITY_GROUPS, String.join(",", list));
}
return environment;
}
}
"""
)
);
}
}

@Nested
Expand Down

0 comments on commit 625b89c

Please sign in to comment.