Skip to content

Add a null-check to java enum safeValueOf #5904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 28, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ internal object JavaClassNames {
val Map: ClassName = ClassName.get("java.util", "Map")
val MapOfStringToObject = ParameterizedTypeName.get(Map, String, Object)
val JavaOptional = ClassName.get("java.util", "Optional")
val Objects = ClassName.get("java.util", "Objects")

val ObjectBuilderKt = ClassName.get(apolloApiPackageName, "ObjectBuilderKt")
val ObjectMap = ClassName.get(apolloApiPackageName, "ObjectMap")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package com.apollographql.apollo3.compiler.codegen.java.helpers

import com.apollographql.apollo3.compiler.GeneratedMethod
import com.apollographql.apollo3.compiler.GeneratedMethod.*
import com.apollographql.apollo3.compiler.internal.applyIf
import com.apollographql.apollo3.compiler.GeneratedMethod.EQUALS_HASH_CODE
import com.apollographql.apollo3.compiler.GeneratedMethod.TO_STRING
import com.apollographql.apollo3.compiler.codegen.Identifier.__h
import com.apollographql.apollo3.compiler.codegen.java.JavaClassNames
import com.apollographql.apollo3.compiler.codegen.java.L
import com.apollographql.apollo3.compiler.codegen.java.joinToCode
import com.apollographql.apollo3.compiler.internal.applyIf
import com.squareup.javapoet.ClassName
import com.squareup.javapoet.CodeBlock
import com.squareup.javapoet.FieldSpec
Expand All @@ -29,8 +30,8 @@ import javax.lang.model.element.Modifier
internal fun TypeSpec.Builder.makeClassFromParameters(
generateMethods: List<GeneratedMethod>,
parameters: List<ParameterSpec>,
className: ClassName
): TypeSpec.Builder {
className: ClassName,
): TypeSpec.Builder {
addMethod(
MethodSpec.constructorBuilder()
.addModifiers(Modifier.PUBLIC)
Expand All @@ -55,7 +56,7 @@ internal fun TypeSpec.Builder.makeClassFromParameters(

internal fun TypeSpec.Builder.addGeneratedMethods(
className: ClassName,
generateMethods: List<GeneratedMethod> = listOf(EQUALS_HASH_CODE, TO_STRING)
generateMethods: List<GeneratedMethod> = listOf(EQUALS_HASH_CODE, TO_STRING),
): TypeSpec.Builder {
return applyIf(generateMethods.contains(EQUALS_HASH_CODE)) { withEqualsImplementation(className) }
.applyIf(generateMethods.contains(EQUALS_HASH_CODE)) { withHashCodeImplementation() }
Expand All @@ -68,8 +69,8 @@ internal fun TypeSpec.Builder.addGeneratedMethods(
internal fun TypeSpec.Builder.makeClassFromProperties(
generateMethods: List<GeneratedMethod>,
fields: List<FieldSpec>,
className: ClassName
): TypeSpec.Builder {
className: ClassName,
): TypeSpec.Builder {
addMethod(
MethodSpec.constructorBuilder()
.addModifiers(Modifier.PUBLIC)
Expand All @@ -94,42 +95,45 @@ internal fun TypeSpec.Builder.makeClassFromProperties(

internal fun TypeSpec.Builder.withToStringImplementation(className: ClassName): TypeSpec.Builder {
fun printFieldCode(fieldIndex: Int, fieldName: String) =
CodeBlock.builder()
.let { if (fieldIndex > 0) it.add(" + \", \"\n") else it.add("\n") }
.indent()
.add("+ \$S + \$L", "$fieldName=", fieldName)
.unindent()
.build()
CodeBlock.builder()
.let { if (fieldIndex > 0) it.add(" + \", \"\n") else it.add("\n") }
.indent()
.add("+ \$S + \$L", "$fieldName=", fieldName)
.unindent()
.build()

fun methodCode() =
CodeBlock.builder()
.beginControlFlow("if (\$L == null)", MEMOIZED_TO_STRING_VAR)
.add("\$L = \$S", "\$toString", "${className.simpleName()}{")
.add(fieldSpecs
.filter { !it.hasModifier(Modifier.STATIC) }
.filter { !it.hasModifier(Modifier.TRANSIENT) }
.map { it.name }
.mapIndexed(::printFieldCode)
.fold(CodeBlock.builder(), CodeBlock.Builder::add)
.build())
.add(CodeBlock.builder()
.indent()
.add("\n+ \$S;\n", "}")
.unindent()
.build())
.endControlFlow()
.addStatement("return \$L", MEMOIZED_TO_STRING_VAR)
.build()
CodeBlock.builder()
.beginControlFlow("if (\$L == null)", MEMOIZED_TO_STRING_VAR)
.add("\$L = \$S", "\$toString", "${className.simpleName()}{")
.add(fieldSpecs
.filter { !it.hasModifier(Modifier.STATIC) }
.filter { !it.hasModifier(Modifier.TRANSIENT) }
.map { it.name }
.mapIndexed(::printFieldCode)
.fold(CodeBlock.builder(), CodeBlock.Builder::add)
.build())
.add(CodeBlock.builder()
.indent()
.add("\n+ \$S;\n", "}")
.unindent()
.build()
)
.endControlFlow()
.addStatement("return \$L", MEMOIZED_TO_STRING_VAR)
.build()

return addField(FieldSpec.builder(JavaClassNames.String, MEMOIZED_TO_STRING_VAR, Modifier.PRIVATE, Modifier.VOLATILE,
Modifier.TRANSIENT)
.build())
.addMethod(MethodSpec.methodBuilder("toString")
.addAnnotation(JavaClassNames.Override)
.addModifiers(Modifier.PUBLIC)
.returns(JavaClassNames.String)
.addCode(methodCode())
.build())
return addField(
FieldSpec.builder(JavaClassNames.String, MEMOIZED_TO_STRING_VAR, Modifier.PRIVATE, Modifier.VOLATILE, Modifier.TRANSIENT).build()
)
.addMethod(
MethodSpec.methodBuilder("toString")
.addAnnotation(JavaClassNames.Override)
.addModifiers(Modifier.PUBLIC)
.returns(JavaClassNames.String)
.addCode(methodCode())
.build()
)
}

private fun List<FieldSpec>.equalsCode(): CodeBlock = filter { !it.hasModifier(Modifier.STATIC) }
Expand All @@ -138,92 +142,114 @@ private fun List<FieldSpec>.equalsCode(): CodeBlock = filter { !it.hasModifier(M
.joinToCode("\n &&")

private fun FieldSpec.equalsCode() =
CodeBlock.builder()
.let {
if (type.isPrimitive) {
if (type == TypeName.DOUBLE) {
it.add("Double.doubleToLongBits(this.\$L) == Double.doubleToLongBits(that.\$L)",
name, name)
} else {
it.add("this.\$L == that.\$L", name, name)
}
CodeBlock.builder()
.let {
if (type.isPrimitive) {
if (type == TypeName.DOUBLE) {
it.add("Double.doubleToLongBits(this.\$L) == Double.doubleToLongBits(that.\$L)", name, name)
} else {
it.add("((this.\$L == null) ? (that.\$L == null) : this.\$L.equals(that.\$L))", name, name, name, name)
it.add("this.\$L == that.\$L", name, name)
}
} else {
it.add("((this.\$L == null) ? (that.\$L == null) : this.\$L.equals(that.\$L))", name, name, name, name)
}
.build()
}
.build()

internal fun TypeSpec.Builder.withEqualsImplementation(className: ClassName): TypeSpec.Builder {
val hasSuperClass = build().superclass != ClassName.OBJECT
fun methodCode(typeJavaClass: ClassName) =
CodeBlock.builder()
.beginControlFlow("if (o == this)")
.addStatement("return true")
.endControlFlow()
.beginControlFlow("if (o instanceof \$T)", typeJavaClass)
.apply {
if (fieldSpecs.isEmpty()) {
CodeBlock.builder()
.beginControlFlow("if (o == this)")
.addStatement("return true")
.endControlFlow()
.beginControlFlow("if (o instanceof \$T)", typeJavaClass)
.apply {
if (fieldSpecs.isEmpty()) {
if (hasSuperClass) {
add("return super.equals(o);\n")
} else {
add("return true;\n")
}
} else {
addStatement("\$T that = (\$T) o", typeJavaClass, typeJavaClass)
if (hasSuperClass) {
add("return super.equals(o) && $L;\n", fieldSpecs.equalsCode())
} else {
addStatement("\$T that = (\$T) o", typeJavaClass, typeJavaClass)
add("return $L;\n", if (fieldSpecs.isEmpty()) "true" else fieldSpecs.equalsCode())
add("return $L;\n", fieldSpecs.equalsCode())
}
}
.endControlFlow()
.addStatement("return false")
.build()
}
.endControlFlow()
.addStatement("return false")
.build()

return addMethod(MethodSpec.methodBuilder("equals")
.addAnnotation(JavaClassNames.Override)
.addModifiers(Modifier.PUBLIC)
.returns(TypeName.BOOLEAN)
.addParameter(ParameterSpec.builder(TypeName.OBJECT, "o").build())
.addCode(methodCode(className))
.build())
.addAnnotation(JavaClassNames.Override)
.addModifiers(Modifier.PUBLIC)
.returns(TypeName.BOOLEAN)
.addParameter(ParameterSpec.builder(TypeName.OBJECT, "o").build())
.addCode(methodCode(className))
.build()
)
}

internal fun TypeSpec.Builder.withHashCodeImplementation(): TypeSpec.Builder {
val hasSuperClass = build().superclass != ClassName.OBJECT
fun hashFieldCode(field: FieldSpec) =
CodeBlock.builder()
.addStatement("$__h *= 1000003")
.let {
if (field.type.isPrimitive) {
when (field.type.withoutAnnotations()) {
TypeName.DOUBLE -> it.addStatement("$__h ^= Double.valueOf(\$L).hashCode()", field.name)
TypeName.BOOLEAN -> it.addStatement("$__h ^= Boolean.valueOf(\$L).hashCode()", field.name)
else -> it.addStatement("$__h ^= \$L", field.name)
}
} else {
it.addStatement("$__h ^= (\$L == null) ? 0 : \$L.hashCode()", field.name, field.name)
CodeBlock.builder()
.addStatement("$__h *= 1000003")
.let {
if (field.type.isPrimitive) {
when (field.type.withoutAnnotations()) {
TypeName.DOUBLE -> it.addStatement("$__h ^= Double.valueOf(\$L).hashCode()", field.name)
TypeName.BOOLEAN -> it.addStatement("$__h ^= Boolean.valueOf(\$L).hashCode()", field.name)
else -> it.addStatement("$__h ^= \$L", field.name)
}
} else {
it.addStatement("$__h ^= (\$L == null) ? 0 : \$L.hashCode()", field.name, field.name)
}
.build()
}
.build()

fun methodCode() =
CodeBlock.builder()
.beginControlFlow("if (!\$L)", MEMOIZED_HASH_CODE_FLAG_VAR)
.addStatement("int $__h = 1")
.add(fieldSpecs
.filter { !it.hasModifier(Modifier.STATIC) }
.filter { !it.hasModifier(Modifier.TRANSIENT) }
.map(::hashFieldCode)
.fold(CodeBlock.builder(), CodeBlock.Builder::add)
.build())
.addStatement("\$L = $__h", MEMOIZED_HASH_CODE_VAR)
.addStatement("\$L = true", MEMOIZED_HASH_CODE_FLAG_VAR)
.endControlFlow()
.addStatement("return \$L", MEMOIZED_HASH_CODE_VAR)
.build()
CodeBlock.builder()
.beginControlFlow("if (!\$L)", MEMOIZED_HASH_CODE_FLAG_VAR)
.addStatement(
if (hasSuperClass) {
"int $__h = super.hashCode()"
} else {
"int $__h = 1"
}
)
.add(
fieldSpecs
.filter { !it.hasModifier(Modifier.STATIC) }
.filter { !it.hasModifier(Modifier.TRANSIENT) }
.map(::hashFieldCode)
.fold(CodeBlock.builder(), CodeBlock.Builder::add)
.build()
)
.addStatement("\$L = $__h", MEMOIZED_HASH_CODE_VAR)
.addStatement("\$L = true", MEMOIZED_HASH_CODE_FLAG_VAR)
.endControlFlow()
.addStatement("return \$L", MEMOIZED_HASH_CODE_VAR)
.build()

return addField(FieldSpec.builder(TypeName.INT, MEMOIZED_HASH_CODE_VAR, Modifier.PRIVATE, Modifier.VOLATILE,
Modifier.TRANSIENT).build())
.addField(FieldSpec.builder(TypeName.BOOLEAN, MEMOIZED_HASH_CODE_FLAG_VAR, Modifier.PRIVATE,
Modifier.VOLATILE, Modifier.TRANSIENT).build())
.addMethod(MethodSpec.methodBuilder("hashCode")
.addAnnotation(JavaClassNames.Override)
.addModifiers(Modifier.PUBLIC)
.returns(TypeName.INT)
.addCode(methodCode())
.build())
return addField(
FieldSpec.builder(TypeName.INT, MEMOIZED_HASH_CODE_VAR, Modifier.PRIVATE, Modifier.VOLATILE, Modifier.TRANSIENT).build()
)
.addField(
FieldSpec.builder(TypeName.BOOLEAN, MEMOIZED_HASH_CODE_FLAG_VAR, Modifier.PRIVATE, Modifier.VOLATILE, Modifier.TRANSIENT).build()
)
.addMethod(
MethodSpec.methodBuilder("hashCode")
.addAnnotation(JavaClassNames.Override)
.addModifiers(Modifier.PUBLIC)
.returns(TypeName.INT)
.addCode(methodCode())
.build()
)
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ internal class EnumAsClassBuilder(
)
.addMethod(
MethodSpec.constructorBuilder()
.addModifiers(Modifier.PUBLIC)
.addModifiers(Modifier.PRIVATE)
.addParameter(ParameterSpec.builder(JavaClassNames.String, rawValue).build())
.addCode("this.$rawValue = $rawValue;\n")
.build()
Expand All @@ -86,7 +86,7 @@ internal class EnumAsClassBuilder(
.returns(selfClassName)
.addCode(
CodeBlock.builder()
.beginControlFlow("switch($rawValue)")
.beginControlFlow("switch ($T.requireNonNull($rawValue))", JavaClassNames.Objects)
.apply {
values.forEach {
add("case $S: return $T.$L;\n", it.name, selfClassName, it.targetName.escapeTypeReservedWord()
Expand All @@ -113,7 +113,7 @@ internal class EnumAsClassBuilder(
.addJavadoc(L, "An enum value that wasn't known at compile time.\n")
.addMethod(
MethodSpec.constructorBuilder()
.addModifiers(Modifier.PUBLIC)
.addModifiers(Modifier.PRIVATE)
.addParameter(ParameterSpec.builder(JavaClassNames.String, rawValue).build())
.addCode("super($rawValue);\n")
.build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ internal class EnumAsEnumBuilder(
.returns(selfClassName)
.addCode(
CodeBlock.builder()
.beginControlFlow("switch ($rawValue)")
.beginControlFlow("switch ($T.requireNonNull($rawValue))", JavaClassNames.Objects)
.apply {
values.forEach {
add("case $S: return $T.$L;\n", it.name, selfClassName, it.targetName.escapeTypeReservedWord()
Expand Down
Loading
Loading