Skip to content

Commit bb02245

Browse files
Undinyopox
authored andcommitted
intellij-rust#8224: fix type inference if dependencies have names as stdlib crates
1 parent c7103d6 commit bb02245

File tree

6 files changed

+99
-51
lines changed

6 files changed

+99
-51
lines changed

src/main/kotlin/org/rust/cargo/project/workspace/CargoWorkspace.kt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import com.intellij.openapi.util.UserDataHolderBase
99
import com.intellij.openapi.util.UserDataHolderEx
1010
import com.intellij.openapi.vfs.VirtualFile
1111
import com.intellij.openapi.vfs.VirtualFileManager
12+
import com.intellij.util.ThreeState
1213
import org.jetbrains.annotations.TestOnly
1314
import org.rust.cargo.CfgOptions
1415
import org.rust.cargo.project.model.CargoProjectsService
@@ -48,7 +49,14 @@ interface CargoWorkspace {
4849
val featureGraph: FeatureGraph
4950

5051
fun findPackageById(id: PackageId): Package? = packages.find { it.id == id }
51-
fun findPackageByName(name: String): Package? = packages.find { it.name == name || it.normName == name }
52+
fun findPackageByName(name: String, isStd: ThreeState = ThreeState.UNSURE): Package? = packages.find {
53+
if (it.name != name && it.normName != name) return@find false
54+
when (isStd) {
55+
ThreeState.YES -> it.origin == STDLIB
56+
ThreeState.NO -> it.origin == WORKSPACE || it.origin == DEPENDENCY
57+
ThreeState.UNSURE -> true
58+
}
59+
}
5260

5361
fun findTargetByCrateRoot(root: VirtualFile): Target?
5462
fun isCrateRoot(root: VirtualFile) = findTargetByCrateRoot(root) != null

src/main/kotlin/org/rust/lang/core/resolve/KnownItems.kt

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import com.intellij.openapi.util.Key
1010
import com.intellij.psi.util.CachedValue
1111
import com.intellij.psi.util.CachedValueProvider
1212
import com.intellij.psi.util.CachedValuesManager
13+
import com.intellij.util.ThreeState
1314
import org.rust.cargo.project.model.CargoProject
1415
import org.rust.cargo.project.workspace.CargoWorkspace
1516
import org.rust.cargo.util.AutoInjectedCrates.CORE
@@ -51,8 +52,8 @@ class KnownItems(
5152
fun findLangItemRaw(langAttribute: String, crateName: String) =
5253
lookup.findLangItem(langAttribute, crateName)
5354

54-
fun findItemRaw(path: String): RsNamedElement? =
55-
lookup.findItem(path)
55+
fun findItemRaw(path: String, isStd: Boolean): RsNamedElement? =
56+
lookup.findItem(path, isStd)
5657

5758
/**
5859
* Find some known item by its "lang" attribute
@@ -66,8 +67,8 @@ class KnownItems(
6667
crateName: String = CORE
6768
): T? = findLangItemRaw(langAttribute, crateName) as? T
6869

69-
inline fun <reified T : RsNamedElement> findItem(path: String): T? =
70-
findItemRaw(path) as? T
70+
inline fun <reified T : RsNamedElement> findItem(path: String, isStd: Boolean = true): T? =
71+
findItemRaw(path, isStd) as? T
7172

7273
val Vec: RsStructOrEnumItemElement? get() = findItem("alloc::vec::Vec")
7374
val String: RsStructOrEnumItemElement? get() = findItem("alloc::string::String")
@@ -145,12 +146,12 @@ class KnownItems(
145146

146147
interface KnownItemsLookup {
147148
fun findLangItem(langAttribute: String, crateName: String): RsNamedElement?
148-
fun findItem(path: String): RsNamedElement?
149+
fun findItem(path: String, isStd: Boolean): RsNamedElement?
149150
}
150151

151152
private object DummyKnownItemsLookup : KnownItemsLookup {
152153
override fun findLangItem(langAttribute: String, crateName: String): RsNamedElement? = null
153-
override fun findItem(path: String): RsNamedElement? = null
154+
override fun findItem(path: String, isStd: Boolean): RsNamedElement? = null
154155
}
155156

156157
private class RealKnownItemsLookup(
@@ -167,9 +168,9 @@ private class RealKnownItemsLookup(
167168
}.orElse(null)
168169
}
169170

170-
override fun findItem(path: String): RsNamedElement? {
171+
override fun findItem(path: String, isStd: Boolean): RsNamedElement? {
171172
return resolvedItems.getOrPut(path) {
172-
Optional.ofNullable(resolveStringPath(path, workspace, project)?.first)
173+
Optional.ofNullable(resolveStringPath(path, workspace, project, ThreeState.fromBoolean(isStd))?.first)
173174
}.orElse(null)
174175
}
175176
}
@@ -189,11 +190,11 @@ enum class KnownDerivableTrait(
189190
PartialOrd(KnownItems::PartialOrd, arrayOf(PartialEq)),
190191
Ord(KnownItems::Ord, arrayOf(PartialOrd, Eq, PartialEq)),
191192

192-
Serialize({ it.findItem("serde::Serialize") }, isStd = false),
193-
Deserialize({ it.findItem("serde::Deserialize") }, isStd = false),
193+
Serialize({ it.findItem("serde::Serialize", isStd = false) }, isStd = false),
194+
Deserialize({ it.findItem("serde::Deserialize", isStd = false) }, isStd = false),
194195

195196
// TODO Fail also derives `Display`. Ignore it for now
196-
Fail({ it.findItem("failure::Fail") }, arrayOf(Debug), isStd = false),
197+
Fail({ it.findItem("failure::Fail", isStd = false) }, arrayOf(Debug), isStd = false),
197198
;
198199

199200
fun findTrait(items: KnownItems): RsTraitItem? = resolver(items)

src/main/kotlin/org/rust/lang/core/resolve/NameResolution.kt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import com.intellij.psi.util.CachedValue
2222
import com.intellij.psi.util.CachedValueProvider
2323
import com.intellij.psi.util.CachedValuesManager
2424
import com.intellij.psi.util.PsiTreeUtil
25+
import com.intellij.util.ThreeState
2526
import org.rust.cargo.project.workspace.CargoWorkspace
2627
import org.rust.cargo.project.workspace.PackageOrigin
2728
import org.rust.cargo.util.AutoInjectedCrates.CORE
@@ -759,9 +760,14 @@ fun processLocalVariables(place: RsElement, originalProcessor: (RsPatBinding) ->
759760
/**
760761
* Resolves an absolute path.
761762
*/
762-
fun resolveStringPath(path: String, workspace: CargoWorkspace, project: Project): Pair<RsNamedElement, CargoWorkspace.Package>? {
763+
fun resolveStringPath(
764+
path: String,
765+
workspace: CargoWorkspace,
766+
project: Project,
767+
isStd: ThreeState = ThreeState.UNSURE
768+
): Pair<RsNamedElement, CargoWorkspace.Package>? {
763769
val (pkgName, crateRelativePath) = splitAbsolutePath(path) ?: return null
764-
val pkg = workspace.findPackageByName(pkgName) ?: run {
770+
val pkg = workspace.findPackageByName(pkgName, isStd) ?: run {
765771
return if (isUnitTestMode) {
766772
// Allows to set a fake path for some item in tests via
767773
// lang attribute, e.g. `#[lang = "std::iter::Iterator"]`

src/test/kotlin/org/rust/RustProjectDescriptors.kt

Lines changed: 53 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,39 @@ open class RustProjectDescriptorBase : LightProjectDescriptor() {
9595
CargoWorkspaceData(packages, emptyMap(), emptyMap(), contentRoot), CfgOptions.DEFAULT)
9696
}
9797

98+
protected open fun externalPackage(
99+
contentRoot: String,
100+
source: String?,
101+
name: String,
102+
targetName: String = name,
103+
version: String = "0.0.1",
104+
origin: PackageOrigin = PackageOrigin.DEPENDENCY,
105+
libKind: LibKind = LibKind.LIB,
106+
procMacroArtifact: CargoWorkspaceData.ProcMacroArtifact? = null,
107+
): Package {
108+
return Package(
109+
id = "$name $version",
110+
contentRootUrl = contentRoot,
111+
name = name,
112+
version = version,
113+
targets = listOf(
114+
// don't use `FileUtil.join` here because it uses `File.separator`
115+
// which is system dependent although all other code uses `/` as separator
116+
Target(source?.let { "$contentRoot/$it" } ?: "", targetName,
117+
TargetKind.Lib(libKind), Edition.EDITION_2015, doctest = true, requiredFeatures = emptyList())
118+
),
119+
source = source,
120+
origin = origin,
121+
edition = Edition.EDITION_2015,
122+
features = emptyMap(),
123+
enabledFeatures = emptySet(),
124+
cfgOptions = CfgOptions.EMPTY,
125+
env = emptyMap(),
126+
outDirUrl = null,
127+
procMacroArtifact = procMacroArtifact
128+
)
129+
}
130+
98131
protected fun testCargoPackage(contentRoot: String, name: String = "test-package") = Package(
99132
id = "$name 0.0.1",
100133
contentRootUrl = contentRoot,
@@ -292,38 +325,6 @@ open class WithCustomStdlibRustProjectDescriptor(
292325
}
293326

294327
object WithDependencyRustProjectDescriptor : RustProjectDescriptorBase() {
295-
private fun externalPackage(
296-
contentRoot: String,
297-
source: String?,
298-
name: String,
299-
targetName: String = name,
300-
version: String = "0.0.1",
301-
origin: PackageOrigin = PackageOrigin.DEPENDENCY,
302-
libKind: LibKind = LibKind.LIB,
303-
procMacroArtifact: CargoWorkspaceData.ProcMacroArtifact? = null,
304-
): Package {
305-
return Package(
306-
id = "$name $version",
307-
contentRootUrl = contentRoot,
308-
name = name,
309-
version = version,
310-
targets = listOf(
311-
// don't use `FileUtil.join` here because it uses `File.separator`
312-
// which is system dependent although all other code uses `/` as separator
313-
Target(source?.let { "$contentRoot/$it" } ?: "", targetName,
314-
TargetKind.Lib(libKind), Edition.EDITION_2015, doctest = true, requiredFeatures = emptyList())
315-
),
316-
source = source,
317-
origin = origin,
318-
edition = Edition.EDITION_2015,
319-
features = emptyMap(),
320-
enabledFeatures = emptySet(),
321-
cfgOptions = CfgOptions.EMPTY,
322-
env = emptyMap(),
323-
outDirUrl = null,
324-
procMacroArtifact = procMacroArtifact
325-
)
326-
}
327328

328329
override fun setUp(fixture: CodeInsightTestFixture) {
329330
val root = fixture.findFileInTempDir(".")!!
@@ -395,6 +396,26 @@ object WithDependencyRustProjectDescriptor : RustProjectDescriptorBase() {
395396
}
396397
}
397398

399+
private class WithStdlibLikeDependencyRustProjectDescriptor : RustProjectDescriptorBase() {
400+
override fun testCargoProject(module: Module, contentRoot: String): CargoWorkspace {
401+
val packages = listOf(
402+
testCargoPackage(contentRoot),
403+
externalPackage("$contentRoot/core", "lib.rs", "core"),
404+
externalPackage("$contentRoot/alloc", "lib.rs", "alloc"),
405+
externalPackage("$contentRoot/std", "lib.rs", "std")
406+
)
407+
return CargoWorkspace.deserialize(
408+
Paths.get("${Urls.newFromIdea(contentRoot).path}/workspace/Cargo.toml"),
409+
CargoWorkspaceData(packages, emptyMap(), emptyMap(), contentRoot), CfgOptions.DEFAULT)
410+
}
411+
}
412+
413+
/**
414+
* Provides `core`, `alloc` and `std` workspace dependencies.
415+
* It's supposed to be used to check how the plugin works with dependencies that have the same name as stdlib packages
416+
*/
417+
object WithStdlibAndStdlibLikeDependencyRustProjectDescriptor : WithRustup(WithStdlibLikeDependencyRustProjectDescriptor())
418+
398419
private fun RsToolchainBase.getRustcInfo(): RustcInfo? {
399420
val rustc = rustc()
400421
val sysroot = rustc.getSysroot(Paths.get(".")) ?: return null

src/test/kotlin/org/rust/lang/core/resolve/RsStdlibResolveTest.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,12 @@ class RsStdlibResolveTest : RsResolveTestBase() {
410410
}
411411
""")
412412

413-
fun `test resolve derive traits`() {
413+
fun `test resolve derive traits`() = resolveDeriveTraits()
414+
415+
@ProjectDescriptor(WithStdlibAndStdlibLikeDependencyRustProjectDescriptor::class)
416+
fun `test resolve derive traits with stdlib-like dependencies`() = resolveDeriveTraits()
417+
418+
private fun resolveDeriveTraits() {
414419
val traitToPath = mapOf(
415420
"std::marker::Clone" to "clone.rs",
416421
"std::marker::Copy" to "marker.rs",

src/test/kotlin/org/rust/lang/core/type/RsStdlibExpressionTypeInferenceTest.kt

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55

66
package org.rust.lang.core.type
77

8-
import org.rust.ExpandMacros
9-
import org.rust.MockEdition
10-
import org.rust.ProjectDescriptor
11-
import org.rust.WithStdlibRustProjectDescriptor
8+
import org.rust.*
129
import org.rust.cargo.project.workspace.CargoWorkspace
1310
import org.rust.lang.core.macros.MacroExpansionScope
1411
import org.rust.lang.core.psi.ext.*
@@ -126,6 +123,16 @@ class RsStdlibExpressionTypeInferenceTest : RsTypificationTestBase() {
126123
}
127124
""")
128125

126+
@ProjectDescriptor(WithStdlibAndStdlibLikeDependencyRustProjectDescriptor::class)
127+
fun `test vec! with stdlib-like dependencies`() = stubOnlyTypeInfer("""
128+
//- main.rs
129+
fn main() {
130+
let x = vec!(1, 2u16, 4, 8);
131+
x;
132+
//^ Vec<u16> | Vec<u16, Global>
133+
}
134+
""")
135+
129136
fun `test vec! no_std`() = stubOnlyTypeInfer("""
130137
//- main.rs
131138
#![no_std]

0 commit comments

Comments
 (0)