Skip to content

Commit 206240a

Browse files
authored
Smarter spec suggestions in protocols and implementations (#549)
in defprotocol first arg is always t in defimpl use inference to guess type
1 parent be0af9d commit 206240a

File tree

3 files changed

+154
-29
lines changed

3 files changed

+154
-29
lines changed

apps/language_server/lib/language_server/providers/code_lens/type_spec.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ defmodule ElixirLS.LanguageServer.Providers.CodeLens.TypeSpec do
1818
resp =
1919
for {_, line, {mod, fun, arity}, contract, is_macro} <- Server.suggest_contracts(uri),
2020
SourceFile.function_def_on_line?(text, line, fun),
21-
spec = ContractTranslator.translate_contract(fun, contract, is_macro) do
21+
spec = ContractTranslator.translate_contract(fun, contract, is_macro, mod) do
2222
CodeLens.build_code_lens(
2323
line,
2424
"@spec #{spec}",

apps/language_server/lib/language_server/providers/code_lens/type_spec/contract_translator.ex

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ defmodule ElixirLS.LanguageServer.Providers.CodeLens.TypeSpec.ContractTranslator
33
alias Erl2ex.Convert.{Context, ErlForms}
44
alias Erl2ex.Pipeline.{Parse, ModuleData, ExSpec}
55

6-
def translate_contract(fun, contract, is_macro) do
6+
def translate_contract(fun, contract, is_macro, mod) do
77
# FIXME: Private module
88
{[%ExSpec{specs: [spec]} | _], _} =
99
"-spec foo#{contract}."
@@ -21,6 +21,7 @@ defmodule ElixirLS.LanguageServer.Providers.CodeLens.TypeSpec.ContractTranslator
2121
spec
2222
|> Macro.postwalk(&tweak_specs/1)
2323
|> drop_macro_env(is_macro)
24+
|> improve_defprotocol_spec(mod, fun)
2425
|> Macro.to_string()
2526
|> String.replace("()", "")
2627
|> Code.format_string!(line_length: :infinity)
@@ -122,9 +123,7 @@ defmodule ElixirLS.LanguageServer.Providers.CodeLens.TypeSpec.ContractTranslator
122123
end
123124

124125
defp translate_map(struct_type, fields) do
125-
struct_type_spec_exists =
126-
ElixirSense.Core.Normalized.Typespec.get_types(struct_type)
127-
|> Enum.any?(&match?({kind, {:t, _, []}} when kind in [:type, :opaque], &1))
126+
struct_type_spec_exists = struct_type_spec_exists?(struct_type)
128127

129128
if struct_type_spec_exists do
130129
# struct_type.t/0 public/opaque type exists, assume it's a struct
@@ -136,4 +135,80 @@ defmodule ElixirLS.LanguageServer.Providers.CodeLens.TypeSpec.ContractTranslator
136135
{:%, [], [struct_type, map]}
137136
end
138137
end
138+
139+
defp struct_type_spec_exists?(struct_type) do
140+
ElixirSense.Core.Normalized.Typespec.get_types(struct_type)
141+
|> Enum.any?(&match?({kind, {:t, _, []}} when kind in [:type, :opaque], &1))
142+
end
143+
144+
defp improve_defprotocol_spec(ast, mod, fun) do
145+
cond do
146+
Code.ensure_loaded?(mod) and function_exported?(mod, :__protocol__, 1) ->
147+
# defprotocol
148+
# defs in defprotocol do not have when and have at least 1 arg
149+
{:"::", [], [{:foo, [], [{:any, [], []} | rest]}, {:any, [], []}]} = ast
150+
# first arg in defprotocol defs is always of type t
151+
{:"::", [], [{:foo, [], [{:t, [], []} | rest]}, {:any, [], []}]}
152+
153+
Code.ensure_loaded?(mod) and function_exported?(mod, :__impl__, 1) ->
154+
# defimpl
155+
implementation_of = mod.__impl__(:protocol)
156+
157+
{:"::", [], [{:foo, [], args}, res]} = ast
158+
arity = length(args)
159+
160+
if {fun, arity} in implementation_of.__protocol__(:functions) do
161+
# protocol fun
162+
implemented_for_type =
163+
case mod.__impl__(:for) do
164+
Any ->
165+
{:any, [], []}
166+
167+
Atom ->
168+
{:atom, [], []}
169+
170+
Integer ->
171+
{:integer, [], []}
172+
173+
Float ->
174+
{:float, [], []}
175+
176+
BitString ->
177+
{:binary, [], []}
178+
179+
Map ->
180+
{:map, [], []}
181+
182+
List ->
183+
{:list, [], []}
184+
185+
Function ->
186+
{:function, [], []}
187+
188+
Port ->
189+
{:port, [], []}
190+
191+
PID ->
192+
{:pid, [], []}
193+
194+
Tuple ->
195+
{:tuple, [], []}
196+
197+
Reference ->
198+
{:reference, [], []}
199+
200+
struct_type ->
201+
translate_map(struct_type, [])
202+
end
203+
204+
{:"::", [], [{:foo, [], [implemented_for_type | tl(args)]}, res]}
205+
else
206+
# non protocol fun/macro
207+
ast
208+
end
209+
210+
true ->
211+
ast
212+
end
213+
end
139214
end

apps/language_server/test/providers/code_lens/type_spec/contract_translator_test.exs

Lines changed: 74 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,131 +4,181 @@ defmodule ElixirLS.LanguageServer.Providers.CodeLens.TypeSpec.ContractTranslator
44

55
test "translate struct when struct.t type exists" do
66
contract = '() -> \#{\'__struct__\':=\'Elixir.DateTime\'}'
7-
assert "foo :: DateTime.t()" == ContractTranslator.translate_contract(:foo, contract, false)
7+
8+
assert "foo :: DateTime.t()" ==
9+
ContractTranslator.translate_contract(:foo, contract, false, Atom)
810
end
911

1012
test "don't translate struct when struct.t type does not exist" do
1113
contract = '() -> \#{\'__struct__\':=\'Elixir.SomeOtherStruct\'}'
1214

1315
assert "foo :: %SomeOtherStruct{}" ==
14-
ContractTranslator.translate_contract(:foo, contract, false)
16+
ContractTranslator.translate_contract(:foo, contract, false, Atom)
1517
end
1618

1719
test "struct" do
1820
contract = '() -> \#{\'__struct__\':=atom(), atom()=>any()}'
19-
assert "foo :: struct" == ContractTranslator.translate_contract(:foo, contract, false)
21+
assert "foo :: struct" == ContractTranslator.translate_contract(:foo, contract, false, Atom)
2022
end
2123

2224
test "drop macro env argument" do
2325
contract = '(any(), integer()) -> integer()'
2426

2527
assert "foo(any, integer) :: integer" ==
26-
ContractTranslator.translate_contract(:foo, contract, false)
28+
ContractTranslator.translate_contract(:foo, contract, false, Atom)
2729

2830
assert "foo(integer) :: integer" ==
29-
ContractTranslator.translate_contract(:foo, contract, true)
31+
ContractTranslator.translate_contract(:foo, contract, true, Atom)
3032
end
3133

3234
test "atom :ok" do
3335
contract = '(any()) -> ok'
34-
assert "foo(any) :: :ok" == ContractTranslator.translate_contract(:foo, contract, false)
36+
assert "foo(any) :: :ok" == ContractTranslator.translate_contract(:foo, contract, false, Atom)
3537
end
3638

3739
test "atom true" do
3840
contract = '(any()) -> true'
39-
assert "foo(any) :: true" == ContractTranslator.translate_contract(:foo, contract, false)
41+
42+
assert "foo(any) :: true" ==
43+
ContractTranslator.translate_contract(:foo, contract, false, Atom)
4044
end
4145

4246
test "atom _ substitution" do
4347
contract = '(_) -> false'
44-
assert "foo(any) :: false" == ContractTranslator.translate_contract(:foo, contract, false)
48+
49+
assert "foo(any) :: false" ==
50+
ContractTranslator.translate_contract(:foo, contract, false, Atom)
4551
end
4652

4753
test "do not drop when substitutions" do
4854
contract = '(X) -> atom() when X :: any()'
4955

5056
assert "foo(x) :: atom when x: any" ==
51-
ContractTranslator.translate_contract(:foo, contract, false)
57+
ContractTranslator.translate_contract(:foo, contract, false, Atom)
5258
end
5359

5460
test "keyword" do
5561
contract = '(any()) -> list({atom(), any()})'
56-
assert "foo(any) :: keyword" == ContractTranslator.translate_contract(:foo, contract, false)
62+
63+
assert "foo(any) :: keyword" ==
64+
ContractTranslator.translate_contract(:foo, contract, false, Atom)
5765

5866
contract = '(any()) -> list({atom(), _})'
59-
assert "foo(any) :: keyword" == ContractTranslator.translate_contract(:foo, contract, false)
67+
68+
assert "foo(any) :: keyword" ==
69+
ContractTranslator.translate_contract(:foo, contract, false, Atom)
6070
end
6171

6272
test "keyword(t)" do
6373
contract = '(any()) -> list({atom(), integer()})'
6474

6575
assert "foo(any) :: keyword(integer)" ==
66-
ContractTranslator.translate_contract(:foo, contract, false)
76+
ContractTranslator.translate_contract(:foo, contract, false, Atom)
6777
end
6878

6979
test "[type]" do
7080
contract = '(any()) -> list(atom())'
71-
assert "foo(any) :: [atom]" == ContractTranslator.translate_contract(:foo, contract, false)
81+
82+
assert "foo(any) :: [atom]" ==
83+
ContractTranslator.translate_contract(:foo, contract, false, Atom)
7284
end
7385

7486
test "list" do
7587
contract = '(any()) -> list(any())'
76-
assert "foo(any) :: list" == ContractTranslator.translate_contract(:foo, contract, false)
88+
89+
assert "foo(any) :: list" ==
90+
ContractTranslator.translate_contract(:foo, contract, false, Atom)
7791
end
7892

7993
test "empty list" do
8094
contract = '(any()) -> []'
81-
assert "foo(any) :: []" == ContractTranslator.translate_contract(:foo, contract, false)
95+
assert "foo(any) :: []" == ContractTranslator.translate_contract(:foo, contract, false, Atom)
8296
end
8397

8498
test "[...]" do
8599
contract = '(any()) -> nonempty_list(any())'
86-
assert "foo(any) :: [...]" == ContractTranslator.translate_contract(:foo, contract, false)
100+
101+
assert "foo(any) :: [...]" ==
102+
ContractTranslator.translate_contract(:foo, contract, false, Atom)
87103

88104
contract = '(any()) -> nonempty_list(_)'
89-
assert "foo(any) :: [...]" == ContractTranslator.translate_contract(:foo, contract, false)
105+
106+
assert "foo(any) :: [...]" ==
107+
ContractTranslator.translate_contract(:foo, contract, false, Atom)
90108
end
91109

92110
test "[type, ...]" do
93111
contract = '(any()) -> nonempty_list(atom())'
94112

95113
assert "foo(any) :: [atom, ...]" ==
96-
ContractTranslator.translate_contract(:foo, contract, false)
114+
ContractTranslator.translate_contract(:foo, contract, false, Atom)
97115
end
98116

99117
test "undoes conversion of :_ to any inside bitstring" do
100118
contract = '(any()) -> <<_:2, _:_*3>>'
101119

102120
assert "foo(any) :: <<_::2, _::_*3>>" ==
103-
ContractTranslator.translate_contract(:foo, contract, false)
121+
ContractTranslator.translate_contract(:foo, contract, false, Atom)
104122
end
105123

106124
test "function" do
107125
contract = '(any()) -> fun((...) -> ok)'
108126

109127
assert "foo(any) :: (... -> :ok)" ==
110-
ContractTranslator.translate_contract(:foo, contract, false)
128+
ContractTranslator.translate_contract(:foo, contract, false, Atom)
111129
end
112130

113131
test "fun" do
114132
contract = '(any()) -> fun((...) -> any())'
115-
assert "foo(any) :: fun" == ContractTranslator.translate_contract(:foo, contract, false)
133+
assert "foo(any) :: fun" == ContractTranslator.translate_contract(:foo, contract, false, Atom)
116134
end
117135

118136
test "empty map" do
119137
contract = '(any()) -> \#{}'
120-
assert "foo(any) :: %{}" == ContractTranslator.translate_contract(:foo, contract, false)
138+
assert "foo(any) :: %{}" == ContractTranslator.translate_contract(:foo, contract, false, Atom)
121139
end
122140

123141
test "map" do
124142
contract = '(any()) -> \#{any()=>any()}'
125-
assert "foo(any) :: map" == ContractTranslator.translate_contract(:foo, contract, false)
143+
assert "foo(any) :: map" == ContractTranslator.translate_contract(:foo, contract, false, Atom)
126144
end
127145

128146
test "map with fields" do
129147
contract = '(any()) -> \#{integer()=>any(), 1:=atom(), abc:=4}'
130148

131149
assert "foo(any) :: %{optional(integer) => any, 1 => atom, :abc => 4}" ==
132-
ContractTranslator.translate_contract(:foo, contract, false)
150+
ContractTranslator.translate_contract(:foo, contract, false, Atom)
151+
end
152+
153+
test "defprotocol type t" do
154+
contract = '(any()) -> any()'
155+
156+
assert "foo(t) :: any" ==
157+
ContractTranslator.translate_contract(:foo, contract, false, Enumerable)
158+
159+
contract = '(any(), any()) -> any()'
160+
161+
assert "foo(t, any) :: any" ==
162+
ContractTranslator.translate_contract(:foo, contract, false, Enumerable)
163+
164+
contract = '(any()) -> any()'
165+
assert "foo(any) :: any" == ContractTranslator.translate_contract(:foo, contract, false, Atom)
166+
167+
contract = '(any(), any()) -> any()'
168+
169+
assert "foo(any, any) :: any" ==
170+
ContractTranslator.translate_contract(:foo, contract, false, Atom)
171+
end
172+
173+
test "defimpl first arg" do
174+
contract = '(any()) -> any()'
175+
176+
assert "count(list) :: any" ==
177+
ContractTranslator.translate_contract(:count, contract, false, Enumerable.List)
178+
179+
contract = '(any()) -> any()'
180+
181+
assert "count(Date.Range.t()) :: any" ==
182+
ContractTranslator.translate_contract(:count, contract, false, Enumerable.Date.Range)
133183
end
134184
end

0 commit comments

Comments
 (0)