@@ -978,6 +978,7 @@ public function walkJoin($join): string
978
978
*/
979
979
public function walkCoalesceExpression ($ coalesceExpression ): string
980
980
{
981
+ $ rawTypes = [];
981
982
$ expressionTypes = [];
982
983
$ allTypesContainNull = true ;
983
984
@@ -987,22 +988,67 @@ public function walkCoalesceExpression($coalesceExpression): string
987
988
continue ;
988
989
}
989
990
990
- $ type = $ this ->unmarshalType ($ expression ->dispatch ($ this ));
991
- $ allTypesContainNull = $ allTypesContainNull && $ this ->canBeNull ($ type );
991
+ $ rawType = $ this ->unmarshalType ($ expression ->dispatch ($ this ));
992
+ $ rawTypes [] = $ rawType ;
993
+
994
+ $ allTypesContainNull = $ allTypesContainNull && $ this ->canBeNull ($ rawType );
992
995
993
996
// Some drivers manipulate the types, lets avoid false positives by generalizing constant types
994
997
// e.g. sqlsrv: "COALESCE returns the data type of value with the highest precedence"
995
998
// e.g. mysql: COALESCE(1, 'foo') === '1' (undocumented? https://gist.github.com/jrunning/4535434)
996
- $ expressionTypes [] = $ this ->generalizeConstantType ($ type , false );
999
+ $ expressionTypes [] = $ this ->generalizeConstantType ($ rawType , false );
997
1000
}
998
1001
999
- $ type = TypeCombinator::union (...$ expressionTypes );
1002
+ $ generalizedUnion = TypeCombinator::union (...$ expressionTypes );
1000
1003
1001
1004
if (!$ allTypesContainNull ) {
1002
- $ type = TypeCombinator::removeNull ($ type );
1005
+ $ generalizedUnion = TypeCombinator::removeNull ($ generalizedUnion );
1003
1006
}
1004
1007
1005
- return $ this ->marshalType ($ type );
1008
+ if ($ this ->driverType === DriverDetector::MYSQLI || $ this ->driverType === DriverDetector::PDO_MYSQL ) {
1009
+ return $ this ->marshalType (
1010
+ $ this ->inferCoalesceForMySql ($ rawTypes , $ generalizedUnion )
1011
+ );
1012
+ }
1013
+
1014
+ return $ this ->marshalType ($ generalizedUnion );
1015
+ }
1016
+
1017
+ /**
1018
+ * @param list<Type> $rawTypes
1019
+ */
1020
+ private function inferCoalesceForMySql (array $ rawTypes , Type $ originalResult ): Type
1021
+ {
1022
+ $ containsString = false ;
1023
+ $ containsFloat = false ;
1024
+ $ allIsNumericExcludingLiteralString = true ;
1025
+
1026
+ foreach ($ rawTypes as $ rawType ) {
1027
+ $ rawTypeNoNull = TypeCombinator::removeNull ($ rawType );
1028
+ $ isLiteralString = $ rawTypeNoNull instanceof DqlConstantStringType && $ rawTypeNoNull ->getOriginLiteralType () === AST \Literal::STRING ;
1029
+
1030
+ if (!$ this ->containsOnlyNumericTypes ($ rawTypeNoNull ) || $ isLiteralString ) {
1031
+ $ allIsNumericExcludingLiteralString = false ;
1032
+ }
1033
+
1034
+ if ($ rawTypeNoNull ->isString ()->yes ()) {
1035
+ $ containsString = true ;
1036
+ }
1037
+
1038
+ if (!$ rawTypeNoNull ->isFloat ()->yes ()) {
1039
+ continue ;
1040
+ }
1041
+
1042
+ $ containsFloat = true ;
1043
+ }
1044
+
1045
+ if ($ containsFloat && $ allIsNumericExcludingLiteralString ) {
1046
+ return $ this ->simpleFloatify ($ originalResult );
1047
+ } elseif ($ containsString ) {
1048
+ return $ this ->simpleStringify ($ originalResult );
1049
+ }
1050
+
1051
+ return $ originalResult ;
1006
1052
}
1007
1053
1008
1054
/**
@@ -2107,4 +2153,34 @@ private function isSupportedDriver(): bool
2107
2153
], true );
2108
2154
}
2109
2155
2156
+ private function simpleStringify (Type $ type ): Type
2157
+ {
2158
+ return TypeTraverser::map ($ type , static function (Type $ type , callable $ traverse ): Type {
2159
+ if ($ type instanceof UnionType || $ type instanceof IntersectionType) {
2160
+ return $ traverse ($ type );
2161
+ }
2162
+
2163
+ if ($ type instanceof IntegerType || $ type instanceof FloatType || $ type instanceof BooleanType) {
2164
+ return $ type ->toString ();
2165
+ }
2166
+
2167
+ return $ traverse ($ type );
2168
+ });
2169
+ }
2170
+
2171
+ private function simpleFloatify (Type $ type ): Type
2172
+ {
2173
+ return TypeTraverser::map ($ type , static function (Type $ type , callable $ traverse ): Type {
2174
+ if ($ type instanceof UnionType || $ type instanceof IntersectionType) {
2175
+ return $ traverse ($ type );
2176
+ }
2177
+
2178
+ if ($ type instanceof IntegerType || $ type instanceof BooleanType || $ type instanceof StringType) {
2179
+ return $ type ->toFloat ();
2180
+ }
2181
+
2182
+ return $ traverse ($ type );
2183
+ });
2184
+ }
2185
+
2110
2186
}
0 commit comments