Skip to content

Commit c87ee29

Browse files
janedbalondrejmirtes
authored andcommitted
Improve COALESCE inference for MySQL
1 parent 2004f84 commit c87ee29

File tree

3 files changed

+270
-18
lines changed

3 files changed

+270
-18
lines changed

src/Type/Doctrine/Query/QueryResultTypeWalker.php

+82-6
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,7 @@ public function walkJoin($join): string
978978
*/
979979
public function walkCoalesceExpression($coalesceExpression): string
980980
{
981+
$rawTypes = [];
981982
$expressionTypes = [];
982983
$allTypesContainNull = true;
983984

@@ -987,22 +988,67 @@ public function walkCoalesceExpression($coalesceExpression): string
987988
continue;
988989
}
989990

990-
$type = $this->unmarshalType($expression->dispatch($this));
991-
$allTypesContainNull = $allTypesContainNull && $this->canBeNull($type);
991+
$rawType = $this->unmarshalType($expression->dispatch($this));
992+
$rawTypes[] = $rawType;
993+
994+
$allTypesContainNull = $allTypesContainNull && $this->canBeNull($rawType);
992995

993996
// Some drivers manipulate the types, lets avoid false positives by generalizing constant types
994997
// e.g. sqlsrv: "COALESCE returns the data type of value with the highest precedence"
995998
// e.g. mysql: COALESCE(1, 'foo') === '1' (undocumented? https://gist.github.com/jrunning/4535434)
996-
$expressionTypes[] = $this->generalizeConstantType($type, false);
999+
$expressionTypes[] = $this->generalizeConstantType($rawType, false);
9971000
}
9981001

999-
$type = TypeCombinator::union(...$expressionTypes);
1002+
$generalizedUnion = TypeCombinator::union(...$expressionTypes);
10001003

10011004
if (!$allTypesContainNull) {
1002-
$type = TypeCombinator::removeNull($type);
1005+
$generalizedUnion = TypeCombinator::removeNull($generalizedUnion);
10031006
}
10041007

1005-
return $this->marshalType($type);
1008+
if ($this->driverType === DriverDetector::MYSQLI || $this->driverType === DriverDetector::PDO_MYSQL) {
1009+
return $this->marshalType(
1010+
$this->inferCoalesceForMySql($rawTypes, $generalizedUnion)
1011+
);
1012+
}
1013+
1014+
return $this->marshalType($generalizedUnion);
1015+
}
1016+
1017+
/**
1018+
* @param list<Type> $rawTypes
1019+
*/
1020+
private function inferCoalesceForMySql(array $rawTypes, Type $originalResult): Type
1021+
{
1022+
$containsString = false;
1023+
$containsFloat = false;
1024+
$allIsNumericExcludingLiteralString = true;
1025+
1026+
foreach ($rawTypes as $rawType) {
1027+
$rawTypeNoNull = TypeCombinator::removeNull($rawType);
1028+
$isLiteralString = $rawTypeNoNull instanceof DqlConstantStringType && $rawTypeNoNull->getOriginLiteralType() === AST\Literal::STRING;
1029+
1030+
if (!$this->containsOnlyNumericTypes($rawTypeNoNull) || $isLiteralString) {
1031+
$allIsNumericExcludingLiteralString = false;
1032+
}
1033+
1034+
if ($rawTypeNoNull->isString()->yes()) {
1035+
$containsString = true;
1036+
}
1037+
1038+
if (!$rawTypeNoNull->isFloat()->yes()) {
1039+
continue;
1040+
}
1041+
1042+
$containsFloat = true;
1043+
}
1044+
1045+
if ($containsFloat && $allIsNumericExcludingLiteralString) {
1046+
return $this->simpleFloatify($originalResult);
1047+
} elseif ($containsString) {
1048+
return $this->simpleStringify($originalResult);
1049+
}
1050+
1051+
return $originalResult;
10061052
}
10071053

10081054
/**
@@ -2107,4 +2153,34 @@ private function isSupportedDriver(): bool
21072153
], true);
21082154
}
21092155

2156+
private function simpleStringify(Type $type): Type
2157+
{
2158+
return TypeTraverser::map($type, static function (Type $type, callable $traverse): Type {
2159+
if ($type instanceof UnionType || $type instanceof IntersectionType) {
2160+
return $traverse($type);
2161+
}
2162+
2163+
if ($type instanceof IntegerType || $type instanceof FloatType || $type instanceof BooleanType) {
2164+
return $type->toString();
2165+
}
2166+
2167+
return $traverse($type);
2168+
});
2169+
}
2170+
2171+
private function simpleFloatify(Type $type): Type
2172+
{
2173+
return TypeTraverser::map($type, static function (Type $type, callable $traverse): Type {
2174+
if ($type instanceof UnionType || $type instanceof IntersectionType) {
2175+
return $traverse($type);
2176+
}
2177+
2178+
if ($type instanceof IntegerType || $type instanceof BooleanType || $type instanceof StringType) {
2179+
return $type->toFloat();
2180+
}
2181+
2182+
return $traverse($type);
2183+
});
2184+
}
2185+
21102186
}

0 commit comments

Comments
 (0)