diff --git a/parquet-generator/src/main/java/org/apache/parquet/filter2/IncrementallyUpdatedFilterPredicateGenerator.java b/parquet-generator/src/main/java/org/apache/parquet/filter2/IncrementallyUpdatedFilterPredicateGenerator.java index 7f66ce3821..f62ba56a89 100644 --- a/parquet-generator/src/main/java/org/apache/parquet/filter2/IncrementallyUpdatedFilterPredicateGenerator.java +++ b/parquet-generator/src/main/java/org/apache/parquet/filter2/IncrementallyUpdatedFilterPredicateGenerator.java @@ -281,51 +281,50 @@ private void addInequalityCase(TypeInfo info, String op, boolean expectMultipleR } private void addInNotInCase(TypeInfo info, boolean isEq, boolean expectMultipleResults) throws IOException { - add(" if (clazz.equals(" + info.className + ".class)) {\n" + " if (pred.getValues().contains(null)) {\n" - + " valueInspector = new ValueInspector() {\n" - + " @Override\n" - + " public void updateNull() {\n" - + " setResult(" - + isEq + ");\n" + " }\n" - + "\n" - + " @Override\n" - + " public void update(" - + info.primitiveName + " value) {\n" + " setResult(" - + !isEq + ");\n" + " }\n" - + " };\n" - + " } else {\n" - + " final Set<" + String nullResult = isEq ? "containsNull" : "!containsNull"; + add(" if (clazz.equals(" + info.className + ".class)) {\n" + " final Set<" + info.className + "> target = (Set<" + info.className + ">) pred.getValues();\n" - + " final PrimitiveComparator<" + + " final boolean containsNull = target.contains(null);\n" + + " final PrimitiveComparator<" + info.className + "> comparator = getComparator(columnPath);\n" + "\n" - + " valueInspector = new ValueInspector() {\n" - + " @Override\n" - + " public void updateNull() {\n" - + " setResult(" - + !isEq + ");\n" + " }\n" + + " valueInspector = new ValueInspector() {\n" + + " @Override\n" + + " public void updateNull() {\n"); + if (!expectMultipleResults) { + add(" setResult(" + nullResult + ");\n"); + } else { + add(" if (" + nullResult + ") { setResult(true); }\n"); + } + add(" }\n" + "\n" - + " @Override\n" - + " public void update(" + + " @Override\n" + + " public void update(" + info.primitiveName + " value) {\n"); if (expectMultipleResults) { - add(" if (isKnown()) return;\n"); + add(" if (isKnown()) return;\n"); } - add(" for (" + info.primitiveName + " i : target) {\n"); - - add(" if(" + compareEquality("value", "i", isEq) + ") {\n"); + add(" for (" + info.className + " i : target) {\n" + + " if (i == null) { continue; }\n" + + " " + + info.primitiveName + " targetValue = i;\n"); - add(" setResult(true);\n return;\n"); + add(" if(" + compareEquality("value", "targetValue", true) + ") {\n"); - add(" }\n"); + if (!expectMultipleResults || isEq) { + add(" setResult(" + isEq + ");\n"); + } + add(" return;\n"); add(" }\n"); - if (!expectMultipleResults) { - add(" setResult(false);\n"); - } + add(" }\n"); + if (!expectMultipleResults || !isEq) { + add(" setResult(" + !isEq + ");\n"); + } + add(" }\n"); - add(" };\n" + " }\n" + " }\n\n"); + add(" };\n" + " }\n\n"); } private void addUdpBegin() throws IOException { diff --git a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/TestRecordLevelFilters.java b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/TestRecordLevelFilters.java index 0c03f548b2..c69dce6917 100644 --- a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/TestRecordLevelFilters.java +++ b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/TestRecordLevelFilters.java @@ -44,7 +44,9 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Set; import java.util.stream.LongStream; import org.apache.parquet.example.data.Group; import org.apache.parquet.filter2.compat.FilterCompat; @@ -156,7 +158,11 @@ private static void assertFilter(List found, UserFilter f) { } private static void assertPredicate(FilterPredicate predicate, long... expectedIds) throws IOException { - List found = PhoneBookWriter.readFile(phonebookFile, FilterCompat.get(predicate)); + assertPredicate(phonebookFile, predicate, expectedIds); + } + + private static void assertPredicate(File file, FilterPredicate predicate, long... expectedIds) throws IOException { + List found = PhoneBookWriter.readFile(file, FilterCompat.get(predicate)); assertEquals(expectedIds.length, found.size()); for (int i = 0; i < expectedIds.length; i++) { @@ -175,6 +181,35 @@ public boolean keep(User u) { }); } + @Test + public void testNotInChecksAllValues() throws Exception { + File file = PhoneBookWriter.writeToFile(List.of( + new User(300, "lon-1", null, new Location(1.0, null)), + new User(301, "lon-2", null, new Location(2.0, null)), + new User(302, "lon-3", null, new Location(3.0, null)))); + DoubleColumn lon = doubleColumn("location.lon"); + + Set values = new LinkedHashSet<>(); + values.add(1.0); + values.add(3.0); + assertPredicate(file, notIn(lon, values), 301L); + } + + @Test + public void testInAndNotInCheckNullAndNonNullValues() throws Exception { + File file = PhoneBookWriter.writeToFile(List.of( + new User(300, "lon-null", null, new Location(null, null)), + new User(301, "lon-2", null, new Location(2.0, null)), + new User(302, "lon-3", null, new Location(3.0, null)))); + DoubleColumn lon = doubleColumn("location.lon"); + + Set values = new LinkedHashSet<>(); + values.add(null); + values.add(3.0); + assertPredicate(file, in(lon, values), 300L, 302L); + assertPredicate(file, notIn(lon, values), 301L); + } + @Test public void testAllFilter() throws Exception { BinaryColumn name = binaryColumn("name");