From 889393587499680f0ff9abfef13bc97af77729b4 Mon Sep 17 00:00:00 2001 From: Tim te Beek Date: Sat, 19 Aug 2023 15:01:48 +0200 Subject: [PATCH] Simplify boolean expressions comparing with null and `isEmpty` (#3489) For https://github.com/openrewrite/rewrite-templating/issues/28 --- ...SimplifyBooleanExpressionVisitorTest.java} | 74 +++- .../SimplifyBooleanExpressionVisitor.java | 324 +++++++++++------- .../org/openrewrite/test/RewriteTest.java | 25 +- 3 files changed, 283 insertions(+), 140 deletions(-) rename rewrite-java-test/src/test/java/org/openrewrite/java/cleanup/{SimplifyBooleanExpressionTest.java => SimplifyBooleanExpressionVisitorTest.java} (64%) diff --git a/rewrite-java-test/src/test/java/org/openrewrite/java/cleanup/SimplifyBooleanExpressionTest.java b/rewrite-java-test/src/test/java/org/openrewrite/java/cleanup/SimplifyBooleanExpressionVisitorTest.java similarity index 64% rename from rewrite-java-test/src/test/java/org/openrewrite/java/cleanup/SimplifyBooleanExpressionTest.java rename to rewrite-java-test/src/test/java/org/openrewrite/java/cleanup/SimplifyBooleanExpressionVisitorTest.java index f0f0941d0fb..49c15245cb7 100644 --- a/rewrite-java-test/src/test/java/org/openrewrite/java/cleanup/SimplifyBooleanExpressionTest.java +++ b/rewrite-java-test/src/test/java/org/openrewrite/java/cleanup/SimplifyBooleanExpressionVisitorTest.java @@ -16,6 +16,8 @@ package org.openrewrite.java.cleanup; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; import org.openrewrite.DocumentExample; import org.openrewrite.Issue; import org.openrewrite.test.RecipeSpec; @@ -25,7 +27,7 @@ import static org.openrewrite.test.RewriteTest.toRecipe; @SuppressWarnings("ALL") -class SimplifyBooleanExpressionTest implements RewriteTest { +class SimplifyBooleanExpressionVisitorTest implements RewriteTest { @Override public void defaults(RecipeSpec spec) { @@ -346,4 +348,74 @@ public class A { ) ); } + + @ParameterizedTest + @Issue("https://github.com/openrewrite/rewrite-templating/issues/28") + // Mimic what would be inserted by a Refaster template using two nullable parameters, with the second one a literal + @CsvSource(delimiterString = "//", textBlock = """ + a == null || a.isEmpty() // a == null || a.isEmpty() + a == null || !a.isEmpty() // a == null || !a.isEmpty() + a != null && a.isEmpty() // a != null && a.isEmpty() + a != null && !a.isEmpty() // a != null && !a.isEmpty() + + "" == null || "".isEmpty() // true + "" == null || !"".isEmpty() // false + "" != null && "".isEmpty() // true + "" != null && !"".isEmpty() // false + + "b" == null || "b".isEmpty() // false + "b" == null || !"b".isEmpty() // true + "b" != null && "b".isEmpty() // false + "b" != null && !"b".isEmpty() // true + + a == null || a.isEmpty() || "" == null || "".isEmpty() // true + a == null || a.isEmpty() || "" == null || !"".isEmpty() // a == null || a.isEmpty() + a == null || a.isEmpty() || "" != null && "".isEmpty() // true + a == null || a.isEmpty() || "" != null && !"".isEmpty() // a == null || a.isEmpty() + a == null || a.isEmpty() && "" == null || "".isEmpty() // true + a == null || a.isEmpty() && "" == null || !"".isEmpty() // a == null + a == null || a.isEmpty() && "" != null && "".isEmpty() // a == null || a.isEmpty() + a == null || a.isEmpty() && "" != null && !"".isEmpty() // a == null + a == null || !a.isEmpty() || "" == null || "".isEmpty() // true + a == null || !a.isEmpty() || "" == null || !"".isEmpty() // a == null || !a.isEmpty() + a == null || !a.isEmpty() || "" != null && "".isEmpty() // true + a == null || !a.isEmpty() || "" != null && !"".isEmpty() // a == null || !a.isEmpty() + a == null || !a.isEmpty() && "" == null || "".isEmpty() // true + a == null || !a.isEmpty() && "" == null || !"".isEmpty() // a == null + a == null || !a.isEmpty() && "" != null && "".isEmpty() // a == null || !a.isEmpty() + a == null || !a.isEmpty() && "" != null && !"".isEmpty() // a == null + + a == null || a.isEmpty() || "b" == null || "b".isEmpty() // a == null || a.isEmpty() + a == null || a.isEmpty() || "b" == null || !"b".isEmpty() // true + a == null || a.isEmpty() || "b" != null && "b".isEmpty() // a == null || a.isEmpty() + a == null || a.isEmpty() || "b" != null && !"b".isEmpty() // true + a == null || a.isEmpty() && "b" == null || "b".isEmpty() // a == null + a == null || a.isEmpty() && "b" == null || !"b".isEmpty() // true + a == null || a.isEmpty() && "b" != null && "b".isEmpty() // a == null + a == null || a.isEmpty() && "b" != null && !"b".isEmpty() // a == null || a.isEmpty() + a == null || !a.isEmpty() || "b" == null || "b".isEmpty() // a == null || !a.isEmpty() + a == null || !a.isEmpty() || "b" == null || !"b".isEmpty() // true + a == null || !a.isEmpty() || "b" != null && "b".isEmpty() // a == null || !a.isEmpty() + a == null || !a.isEmpty() || "b" != null && !"b".isEmpty() // true + a == null || !a.isEmpty() && "b" == null || "b".isEmpty() // a == null + a == null || !a.isEmpty() && "b" == null || !"b".isEmpty() // true + a == null || !a.isEmpty() && "b" != null && "b".isEmpty() // a == null + a == null || !a.isEmpty() && "b" != null && !"b".isEmpty() // a == null || !a.isEmpty() + """) + void simplifyLiteralNull(String before, String after) { + //language=java + String template = """ + class A { + void foo(String a) { + boolean c = %s; + } + } + """; + String beforeJava = template.formatted(before); + if (before.equals(after)) { + rewriteRun(java(beforeJava)); + } else { + rewriteRun(java(beforeJava, template.formatted(after))); + } + } } diff --git a/rewrite-java/src/main/java/org/openrewrite/java/cleanup/SimplifyBooleanExpressionVisitor.java b/rewrite-java/src/main/java/org/openrewrite/java/cleanup/SimplifyBooleanExpressionVisitor.java index 81b200fa775..e5fe977895b 100644 --- a/rewrite-java/src/main/java/org/openrewrite/java/cleanup/SimplifyBooleanExpressionVisitor.java +++ b/rewrite-java/src/main/java/org/openrewrite/java/cleanup/SimplifyBooleanExpressionVisitor.java @@ -15,149 +15,221 @@ */ package org.openrewrite.java.cleanup; -import org.openrewrite.*; +import org.openrewrite.Cursor; +import org.openrewrite.ExecutionContext; +import org.openrewrite.Tree; import org.openrewrite.internal.lang.Nullable; import org.openrewrite.java.JavaIsoVisitor; import org.openrewrite.java.JavaVisitor; +import org.openrewrite.java.MethodMatcher; import org.openrewrite.java.UnwrapParentheses; import org.openrewrite.java.format.AutoFormatVisitor; -import org.openrewrite.java.tree.Expression; -import org.openrewrite.java.tree.J; -import org.openrewrite.java.tree.JavaSourceFile; -import org.openrewrite.java.tree.Space; +import org.openrewrite.java.tree.*; + +import java.util.Collections; import static java.util.Objects.requireNonNull; -public class SimplifyBooleanExpressionVisitor extends JavaVisitor { - private static final String MAYBE_AUTO_FORMAT_ME = "MAYBE_AUTO_FORMAT_ME"; +public class SimplifyBooleanExpressionVisitor extends JavaVisitor { + private static final String MAYBE_AUTO_FORMAT_ME = "MAYBE_AUTO_FORMAT_ME"; - @Override - public J visit(@Nullable Tree tree, ExecutionContext ctx) { - if (tree instanceof JavaSourceFile) { - JavaSourceFile cu = (JavaSourceFile) requireNonNull(super.visit(tree, ctx)); - if (tree != cu) { - // recursive simplification - cu = (JavaSourceFile) visitNonNull(cu, ctx); - } - return cu; - } - return super.visit(tree, ctx); + @Override + public J visit(@Nullable Tree tree, ExecutionContext ctx) { + if (tree instanceof JavaSourceFile) { + JavaSourceFile cu = (JavaSourceFile) requireNonNull(super.visit(tree, ctx)); + if (tree != cu) { + // recursive simplification + cu = (JavaSourceFile) visitNonNull(cu, ctx); } + return cu; + } + return super.visit(tree, ctx); + } - @Override - public J visitBinary(J.Binary binary, ExecutionContext ctx) { - J j = super.visitBinary(binary, ctx); - J.Binary asBinary = (J.Binary) j; - - if (asBinary.getOperator() == J.Binary.Type.And) { - if (isLiteralFalse(asBinary.getLeft())) { - maybeUnwrapParentheses(); - j = asBinary.getLeft(); - } else if (isLiteralFalse(asBinary.getRight())) { - maybeUnwrapParentheses(); - j = asBinary.getRight().withPrefix(asBinary.getRight().getPrefix().withWhitespace("")); - } else if (removeAllSpace(asBinary.getLeft()).printTrimmed(getCursor()) - .equals(removeAllSpace(asBinary.getRight()).printTrimmed(getCursor()))) { - maybeUnwrapParentheses(); - j = asBinary.getLeft(); - } - } else if (asBinary.getOperator() == J.Binary.Type.Or) { - if (isLiteralTrue(asBinary.getLeft())) { - maybeUnwrapParentheses(); - j = asBinary.getLeft(); - } else if (isLiteralTrue(asBinary.getRight())) { - maybeUnwrapParentheses(); - j = asBinary.getRight().withPrefix(asBinary.getRight().getPrefix().withWhitespace("")); - } else if (removeAllSpace(asBinary.getLeft()).printTrimmed(getCursor()) - .equals(removeAllSpace(asBinary.getRight()).printTrimmed(getCursor()))) { - maybeUnwrapParentheses(); - j = asBinary.getLeft(); - } - } else if (asBinary.getOperator() == J.Binary.Type.Equal) { - if (isLiteralTrue(asBinary.getLeft())) { - maybeUnwrapParentheses(); - j = asBinary.getRight().withPrefix(asBinary.getRight().getPrefix().withWhitespace("")); - } else if (isLiteralTrue(asBinary.getRight())) { - maybeUnwrapParentheses(); - j = asBinary.getLeft(); - } - } else if (asBinary.getOperator() == J.Binary.Type.NotEqual) { - if (isLiteralFalse(asBinary.getLeft())) { - maybeUnwrapParentheses(); - j = asBinary.getRight().withPrefix(asBinary.getRight().getPrefix().withWhitespace("")); - } else if (isLiteralFalse(asBinary.getRight())) { - maybeUnwrapParentheses(); - j = asBinary.getLeft(); - } - } - if (asBinary != j) { - getCursor().getParentTreeCursor().putMessage(MAYBE_AUTO_FORMAT_ME, ""); - } - return j; - } + @Override + public J visitBinary(J.Binary binary, ExecutionContext ctx) { + J j = super.visitBinary(binary, ctx); + J.Binary asBinary = (J.Binary) j; - @Override - public J postVisit(J tree, ExecutionContext ctx) { - J j = super.postVisit(tree, ctx); - if (getCursor().pollMessage(MAYBE_AUTO_FORMAT_ME) != null) { - j = new AutoFormatVisitor<>().visit(j, ctx, getCursor().getParentOrThrow()); - } - return j; + if (asBinary.getOperator() == J.Binary.Type.And) { + if (isLiteralFalse(asBinary.getLeft())) { + maybeUnwrapParentheses(); + j = asBinary.getLeft(); + } else if (isLiteralFalse(asBinary.getRight())) { + maybeUnwrapParentheses(); + j = asBinary.getRight().withPrefix(asBinary.getRight().getPrefix().withWhitespace("")); + } else if (isLiteralTrue(asBinary.getLeft())) { + maybeUnwrapParentheses(); + j = asBinary.getRight(); + } else if (isLiteralTrue(asBinary.getRight())) { + maybeUnwrapParentheses(); + j = asBinary.getLeft().withPrefix(asBinary.getLeft().getPrefix().withWhitespace("")); + } else if (removeAllSpace(asBinary.getLeft()).printTrimmed(getCursor()) + .equals(removeAllSpace(asBinary.getRight()).printTrimmed(getCursor()))) { + maybeUnwrapParentheses(); + j = asBinary.getLeft(); } - - @Override - public J visitUnary(J.Unary unary, ExecutionContext ctx) { - J j = super.visitUnary(unary, ctx); - J.Unary asUnary = (J.Unary) j; - - if (asUnary.getOperator() == J.Unary.Type.Not) { - if (isLiteralTrue(asUnary.getExpression())) { - maybeUnwrapParentheses(); - j = ((J.Literal) asUnary.getExpression()).withValue(false).withValueSource("false"); - } else if (isLiteralFalse(asUnary.getExpression())) { - maybeUnwrapParentheses(); - j = ((J.Literal) asUnary.getExpression()).withValue(true).withValueSource("true"); - } else if (asUnary.getExpression() instanceof J.Unary && ((J.Unary) asUnary.getExpression()).getOperator() == J.Unary.Type.Not) { - maybeUnwrapParentheses(); - j = ((J.Unary) asUnary.getExpression()).getExpression(); - } - } - if (asUnary != j) { - getCursor().getParentTreeCursor().putMessage(MAYBE_AUTO_FORMAT_ME, ""); - } - return j; + } else if (asBinary.getOperator() == J.Binary.Type.Or) { + if (isLiteralTrue(asBinary.getLeft())) { + maybeUnwrapParentheses(); + j = asBinary.getLeft(); + } else if (isLiteralTrue(asBinary.getRight())) { + maybeUnwrapParentheses(); + j = asBinary.getRight().withPrefix(asBinary.getRight().getPrefix().withWhitespace("")); + } else if (isLiteralFalse(asBinary.getLeft())) { + maybeUnwrapParentheses(); + j = asBinary.getRight(); + } else if (isLiteralFalse(asBinary.getRight())) { + maybeUnwrapParentheses(); + j = asBinary.getLeft().withPrefix(asBinary.getLeft().getPrefix().withWhitespace("")); + } else if (removeAllSpace(asBinary.getLeft()).printTrimmed(getCursor()) + .equals(removeAllSpace(asBinary.getRight()).printTrimmed(getCursor()))) { + maybeUnwrapParentheses(); + j = asBinary.getLeft(); } - - /** - * Specifically for removing immediately-enclosing parentheses on Identifiers and Literals. - * This queues a potential unwrap operation for the next visit. After unwrapping something, it's possible - * there are more Simplifications this recipe can identify and perform, which is why visitCompilationUnit - * checks for any changes to the entire Compilation Unit, and if so, queues up another SimplifyBooleanExpression - * recipe call. This convergence loop eventually reconciles any remaining Boolean Expression Simplifications - * the recipe can perform. - */ - private void maybeUnwrapParentheses() { - Cursor c = getCursor().getParentOrThrow().getParentTreeCursor(); - if (c.getValue() instanceof J.Parentheses) { - doAfterVisit(new UnwrapParentheses<>(c.getValue())); - } + } else if (asBinary.getOperator() == J.Binary.Type.Equal) { + if (isLiteralTrue(asBinary.getLeft())) { + maybeUnwrapParentheses(); + j = asBinary.getRight().withPrefix(asBinary.getRight().getPrefix().withWhitespace("")); + } else if (isLiteralTrue(asBinary.getRight())) { + maybeUnwrapParentheses(); + j = asBinary.getLeft().withPrefix(asBinary.getLeft().getPrefix().withWhitespace(" ")); + } else { + j = maybeReplaceCompareWithNull(asBinary, true); } - - private boolean isLiteralTrue(@Nullable Expression expression) { - return expression instanceof J.Literal && ((J.Literal) expression).getValue() == Boolean.valueOf(true); + } else if (asBinary.getOperator() == J.Binary.Type.NotEqual) { + if (isLiteralFalse(asBinary.getLeft())) { + maybeUnwrapParentheses(); + j = asBinary.getRight().withPrefix(asBinary.getRight().getPrefix().withWhitespace("")); + } else if (isLiteralFalse(asBinary.getRight())) { + maybeUnwrapParentheses(); + j = asBinary.getLeft().withPrefix(asBinary.getLeft().getPrefix().withWhitespace(" ")); + } else { + j = maybeReplaceCompareWithNull(asBinary, false); } + } + if (asBinary != j) { + getCursor().getParentTreeCursor().putMessage(MAYBE_AUTO_FORMAT_ME, ""); + } + return j; + } - private boolean isLiteralFalse(@Nullable Expression expression) { - return expression instanceof J.Literal && ((J.Literal) expression).getValue() == Boolean.valueOf(false); - } + @Override + public J postVisit(J tree, ExecutionContext ctx) { + J j = super.postVisit(tree, ctx); + if (getCursor().pollMessage(MAYBE_AUTO_FORMAT_ME) != null) { + j = new AutoFormatVisitor<>().visit(j, ctx, getCursor().getParentOrThrow()); + } + return j; + } + + @Override + public J visitUnary(J.Unary unary, ExecutionContext ctx) { + J j = super.visitUnary(unary, ctx); + J.Unary asUnary = (J.Unary) j; - private J removeAllSpace(J j) { - //noinspection ConstantConditions - return new JavaIsoVisitor() { - @Override - public Space visitSpace(Space space, Space.Location loc, Integer integer) { - return Space.EMPTY; - } - }.visit(j, 0); + if (asUnary.getOperator() == J.Unary.Type.Not) { + if (isLiteralTrue(asUnary.getExpression())) { + maybeUnwrapParentheses(); + j = ((J.Literal) asUnary.getExpression()).withValue(false).withValueSource("false"); + } else if (isLiteralFalse(asUnary.getExpression())) { + maybeUnwrapParentheses(); + j = ((J.Literal) asUnary.getExpression()).withValue(true).withValueSource("true"); + } else if (asUnary.getExpression() instanceof J.Unary && ((J.Unary) asUnary.getExpression()).getOperator() == J.Unary.Type.Not) { + maybeUnwrapParentheses(); + j = ((J.Unary) asUnary.getExpression()).getExpression(); } } + if (asUnary != j) { + getCursor().getParentTreeCursor().putMessage(MAYBE_AUTO_FORMAT_ME, ""); + } + return j; + } + + private final MethodMatcher isEmpty = new MethodMatcher("java.lang.String isEmpty()"); + + @Override + public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext executionContext) { + J j = super.visitMethodInvocation(method, executionContext); + J.MethodInvocation asMethod = (J.MethodInvocation) j; + Expression select = asMethod.getSelect(); + if (isEmpty.matches(asMethod) + && select instanceof J.Literal + && select.getType() == JavaType.Primitive.String) { + maybeUnwrapParentheses(); + return booleanLiteral(method, J.Literal.isLiteralValue(select, "")); + } + return j; + } + + /** + * Specifically for removing immediately-enclosing parentheses on Identifiers and Literals. + * This queues a potential unwrap operation for the next visit. After unwrapping something, it's possible + * there are more Simplifications this recipe can identify and perform, which is why visitCompilationUnit + * checks for any changes to the entire Compilation Unit, and if so, queues up another SimplifyBooleanExpression + * recipe call. This convergence loop eventually reconciles any remaining Boolean Expression Simplifications + * the recipe can perform. + */ + private void maybeUnwrapParentheses() { + Cursor c = getCursor().getParentOrThrow().getParentTreeCursor(); + if (c.getValue() instanceof J.Parentheses) { + doAfterVisit(new UnwrapParentheses<>(c.getValue())); + } + } + + private boolean isLiteralTrue(@Nullable Expression expression) { + return expression instanceof J.Literal && ((J.Literal) expression).getValue() == Boolean.valueOf(true); + } + + private boolean isLiteralFalse(@Nullable Expression expression) { + return expression instanceof J.Literal && ((J.Literal) expression).getValue() == Boolean.valueOf(false); + } + + private boolean isNullLiteral(Expression expression) { + return expression instanceof J.Literal && ((J.Literal) expression).getType() == JavaType.Primitive.Null; + } + + private boolean isNonNullLiteral(Expression expression) { + return expression instanceof J.Literal && ((J.Literal) expression).getType() != JavaType.Primitive.Null; + } + + private J maybeReplaceCompareWithNull(J.Binary asBinary, boolean valueIfEqual) { + Expression left = asBinary.getLeft(); + Expression right = asBinary.getRight(); + + boolean leftIsNull = isNullLiteral(left); + boolean rightIsNull = isNullLiteral(right); + if (leftIsNull && rightIsNull) { + maybeUnwrapParentheses(); + return booleanLiteral(asBinary, valueIfEqual); + } + boolean leftIsNonNullLiteral = isNonNullLiteral(left); + boolean rightIsNonNullLiteral = isNonNullLiteral(right); + if ((leftIsNull && rightIsNonNullLiteral) || (rightIsNull && leftIsNonNullLiteral)) { + maybeUnwrapParentheses(); + return booleanLiteral(asBinary, !valueIfEqual); + } + + return asBinary; + } + + private J.Literal booleanLiteral(J j, boolean value) { + return new J.Literal(Tree.randomId(), + j.getPrefix(), + j.getMarkers(), + value, + String.valueOf(value), + Collections.emptyList(), + JavaType.Primitive.Boolean); + } + + private J removeAllSpace(J j) { + //noinspection ConstantConditions + return new JavaIsoVisitor() { + @Override + public Space visitSpace(Space space, Space.Location loc, Integer integer) { + return Space.EMPTY; + } + }.visit(j, 0); + } +} diff --git a/rewrite-test/src/main/java/org/openrewrite/test/RewriteTest.java b/rewrite-test/src/main/java/org/openrewrite/test/RewriteTest.java index f107c16d250..26ff51becf9 100644 --- a/rewrite-test/src/main/java/org/openrewrite/test/RewriteTest.java +++ b/rewrite-test/src/main/java/org/openrewrite/test/RewriteTest.java @@ -552,19 +552,18 @@ default void rewriteRun(Consumer spec, SourceSpec... sourceSpecs) } static void assertContentEquals(SourceFile sourceFile, String expected, String actual, String errorMessagePrefix) { - try { - try (InMemoryDiffEntry diffEntry = new InMemoryDiffEntry( - sourceFile.getSourcePath(), - sourceFile.getSourcePath(), - null, - expected, - actual, - Collections.emptySet() - )) { - assertThat(actual) - .as(errorMessagePrefix + " \"%s\":\n%s", sourceFile.getSourcePath(), diffEntry.getDiff()) - .isEqualTo(expected); - } + try (InMemoryDiffEntry diffEntry = new InMemoryDiffEntry( + sourceFile.getSourcePath(), + sourceFile.getSourcePath(), + null, + expected, + actual, + Collections.emptySet() + )) { + assertThat(actual) + .as(errorMessagePrefix + " \"%s\":\n%s", sourceFile.getSourcePath(), diffEntry.getDiff()) + .isEqualTo(expected); + } catch (LinkageError e) { // in case JGit fails to load properly assertThat(actual)