Skip to content

Commit 0f2ab3b

Browse files
authored
Code action: Expand catch all variant (#987)
* code action for expanding catch all with variants * make work with polyvariants * extend to work on options * changelog + fix
1 parent 89deb9c commit 0f2ab3b

File tree

5 files changed

+340
-48
lines changed

5 files changed

+340
-48
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
- Emit `%todo` instead of `failwith("TODO")` when we can (ReScript >= v11.1). https://github.com/rescript-lang/rescript-vscode/pull/981
3535
- Complete `%todo`. https://github.com/rescript-lang/rescript-vscode/pull/981
3636
- Add code action for extracting a locally defined module into its own file. https://github.com/rescript-lang/rescript-vscode/pull/983
37+
- Add code action for expanding catch-all patterns. https://github.com/rescript-lang/rescript-vscode/pull/987
3738

3839
## 1.50.0
3940

analysis/src/CompletionFrontEnd.ml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1005,7 +1005,9 @@ let completionWithParser1 ~currentFile ~debug ~offset ~path ~posCursor
10051005
typedCompletionExpr expr;
10061006
match expr.pexp_desc with
10071007
| Pexp_match (expr, cases)
1008-
when cases <> [] && locHasCursor expr.pexp_loc = false ->
1008+
when cases <> []
1009+
&& locHasCursor expr.pexp_loc = false
1010+
&& Option.is_none findThisExprLoc ->
10091011
if Debug.verbose () then
10101012
print_endline "[completionFrontend] Checking each case";
10111013
let ctxPath = exprToContextPath expr in

analysis/src/Xform.ml

+230-42
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,42 @@ let rangeOfLoc (loc : Location.t) =
1111
let end_ = loc |> Loc.end_ |> mkPosition in
1212
{Protocol.start; end_}
1313

14+
let extractTypeFromExpr expr ~debug ~path ~currentFile ~full ~pos =
15+
match
16+
expr.Parsetree.pexp_loc
17+
|> CompletionFrontEnd.findTypeOfExpressionAtLoc ~debug ~path ~currentFile
18+
~posCursor:(Pos.ofLexing expr.Parsetree.pexp_loc.loc_start)
19+
with
20+
| Some (completable, scope) -> (
21+
let env = SharedTypes.QueryEnv.fromFile full.SharedTypes.file in
22+
let completions =
23+
completable
24+
|> CompletionBackEnd.processCompletable ~debug ~full ~pos ~scope ~env
25+
~forHover:true
26+
in
27+
let rawOpens = Scope.getRawOpens scope in
28+
match completions with
29+
| {env} :: _ -> (
30+
let opens =
31+
CompletionBackEnd.getOpens ~debug ~rawOpens ~package:full.package ~env
32+
in
33+
match
34+
CompletionBackEnd.completionsGetCompletionType2 ~debug ~full ~rawOpens
35+
~opens ~pos completions
36+
with
37+
| Some (typ, _env) ->
38+
let extractedType =
39+
match typ with
40+
| ExtractedType t -> Some t
41+
| TypeExpr t ->
42+
TypeUtils.extractType t ~env ~package:full.package
43+
|> TypeUtils.getExtractedType
44+
in
45+
extractedType
46+
| None -> None)
47+
| _ -> None)
48+
| _ -> None
49+
1450
module IfThenElse = struct
1551
(* Convert if-then-else to switch *)
1652

@@ -324,6 +360,196 @@ module AddTypeAnnotation = struct
324360
| _ -> ()))
325361
end
326362

363+
module ExpandCatchAllForVariants = struct
364+
let mkIterator ~pos ~result =
365+
let expr (iterator : Ast_iterator.iterator) (e : Parsetree.expression) =
366+
(if e.pexp_loc |> Loc.hasPos ~pos then
367+
match e.pexp_desc with
368+
| Pexp_match (switchExpr, cases) -> (
369+
let catchAllCase =
370+
cases
371+
|> List.find_opt (fun (c : Parsetree.case) ->
372+
match c with
373+
| {pc_lhs = {ppat_desc = Ppat_any}} -> true
374+
| _ -> false)
375+
in
376+
match catchAllCase with
377+
| None -> ()
378+
| Some catchAllCase ->
379+
result := Some (switchExpr, catchAllCase, cases))
380+
| _ -> ());
381+
Ast_iterator.default_iterator.expr iterator e
382+
in
383+
{Ast_iterator.default_iterator with expr}
384+
385+
let xform ~path ~pos ~full ~structure ~currentFile ~codeActions ~debug =
386+
let result = ref None in
387+
let iterator = mkIterator ~pos ~result in
388+
iterator.structure iterator structure;
389+
match !result with
390+
| None -> ()
391+
| Some (switchExpr, catchAllCase, cases) -> (
392+
if Debug.verbose () then
393+
print_endline
394+
"[codeAction - ExpandCatchAllForVariants] Found target switch";
395+
let currentConstructorNames =
396+
cases
397+
|> List.filter_map (fun (c : Parsetree.case) ->
398+
match c with
399+
| {pc_lhs = {ppat_desc = Ppat_construct ({txt}, _)}} ->
400+
Some (Longident.last txt)
401+
| {pc_lhs = {ppat_desc = Ppat_variant (name, _)}} -> Some name
402+
| _ -> None)
403+
in
404+
match
405+
switchExpr
406+
|> extractTypeFromExpr ~debug ~path ~currentFile ~full
407+
~pos:(Pos.ofLexing switchExpr.pexp_loc.loc_end)
408+
with
409+
| Some (Tvariant {constructors}) ->
410+
let missingConstructors =
411+
constructors
412+
|> List.filter (fun (c : SharedTypes.Constructor.t) ->
413+
currentConstructorNames |> List.mem c.cname.txt = false)
414+
in
415+
if List.length missingConstructors > 0 then
416+
let newText =
417+
missingConstructors
418+
|> List.map (fun (c : SharedTypes.Constructor.t) ->
419+
c.cname.txt
420+
^
421+
match c.args with
422+
| Args [] -> ""
423+
| Args _ | InlineRecord _ -> "(_)")
424+
|> String.concat " | "
425+
in
426+
let range = rangeOfLoc catchAllCase.pc_lhs.ppat_loc in
427+
let codeAction =
428+
CodeActions.make ~title:"Expand catch-all" ~kind:RefactorRewrite
429+
~uri:path ~newText ~range
430+
in
431+
codeActions := codeAction :: !codeActions
432+
else ()
433+
| Some (Tpolyvariant {constructors}) ->
434+
let missingConstructors =
435+
constructors
436+
|> List.filter (fun (c : SharedTypes.polyVariantConstructor) ->
437+
currentConstructorNames |> List.mem c.name = false)
438+
in
439+
if List.length missingConstructors > 0 then
440+
let newText =
441+
missingConstructors
442+
|> List.map (fun (c : SharedTypes.polyVariantConstructor) ->
443+
Res_printer.polyVarIdentToString c.name
444+
^
445+
match c.args with
446+
| [] -> ""
447+
| _ -> "(_)")
448+
|> String.concat " | "
449+
in
450+
let range = rangeOfLoc catchAllCase.pc_lhs.ppat_loc in
451+
let codeAction =
452+
CodeActions.make ~title:"Expand catch-all" ~kind:RefactorRewrite
453+
~uri:path ~newText ~range
454+
in
455+
codeActions := codeAction :: !codeActions
456+
else ()
457+
| Some (Toption (env, innerType)) -> (
458+
if Debug.verbose () then
459+
print_endline
460+
"[codeAction - ExpandCatchAllForVariants] Found option type";
461+
let innerType =
462+
match innerType with
463+
| ExtractedType t -> Some t
464+
| TypeExpr t -> (
465+
match TypeUtils.extractType ~env ~package:full.package t with
466+
| None -> None
467+
| Some (t, _) -> Some t)
468+
in
469+
match innerType with
470+
| Some ((Tvariant _ | Tpolyvariant _) as variant) ->
471+
let currentConstructorNames =
472+
cases
473+
|> List.filter_map (fun (c : Parsetree.case) ->
474+
match c with
475+
| {
476+
pc_lhs =
477+
{
478+
ppat_desc =
479+
Ppat_construct
480+
( {txt = Lident "Some"},
481+
Some {ppat_desc = Ppat_construct ({txt}, _)} );
482+
};
483+
} ->
484+
Some (Longident.last txt)
485+
| {
486+
pc_lhs =
487+
{
488+
ppat_desc =
489+
Ppat_construct
490+
( {txt = Lident "Some"},
491+
Some {ppat_desc = Ppat_variant (name, _)} );
492+
};
493+
} ->
494+
Some name
495+
| _ -> None)
496+
in
497+
let hasNoneCase =
498+
cases
499+
|> List.exists (fun (c : Parsetree.case) ->
500+
match c.pc_lhs.ppat_desc with
501+
| Ppat_construct ({txt = Lident "None"}, _) -> true
502+
| _ -> false)
503+
in
504+
let missingConstructors =
505+
match variant with
506+
| Tvariant {constructors} ->
507+
constructors
508+
|> List.filter_map (fun (c : SharedTypes.Constructor.t) ->
509+
if currentConstructorNames |> List.mem c.cname.txt = false
510+
then
511+
Some
512+
( c.cname.txt,
513+
match c.args with
514+
| Args [] -> false
515+
| _ -> true )
516+
else None)
517+
| Tpolyvariant {constructors} ->
518+
constructors
519+
|> List.filter_map
520+
(fun (c : SharedTypes.polyVariantConstructor) ->
521+
if currentConstructorNames |> List.mem c.name = false then
522+
Some
523+
( Res_printer.polyVarIdentToString c.name,
524+
match c.args with
525+
| [] -> false
526+
| _ -> true )
527+
else None)
528+
| _ -> []
529+
in
530+
if List.length missingConstructors > 0 || not hasNoneCase then
531+
let newText =
532+
"Some("
533+
^ (missingConstructors
534+
|> List.map (fun (name, hasArgs) ->
535+
name ^ if hasArgs then "(_)" else "")
536+
|> String.concat " | ")
537+
^ ")"
538+
in
539+
let newText =
540+
if hasNoneCase then newText else newText ^ " | None"
541+
in
542+
let range = rangeOfLoc catchAllCase.pc_lhs.ppat_loc in
543+
let codeAction =
544+
CodeActions.make ~title:"Expand catch-all" ~kind:RefactorRewrite
545+
~uri:path ~newText ~range
546+
in
547+
codeActions := codeAction :: !codeActions
548+
else ()
549+
| _ -> ())
550+
| _ -> ())
551+
end
552+
327553
module ExhaustiveSwitch = struct
328554
(* Expand expression to be an exhaustive switch of the underlying value *)
329555
type posType = Single of Pos.t | Range of Pos.t * Pos.t
@@ -336,46 +562,6 @@ module ExhaustiveSwitch = struct
336562
}
337563
| Selection of {expr: Parsetree.expression}
338564

339-
module C = struct
340-
let extractTypeFromExpr expr ~debug ~path ~currentFile ~full ~pos =
341-
match
342-
expr.Parsetree.pexp_loc
343-
|> CompletionFrontEnd.findTypeOfExpressionAtLoc ~debug ~path
344-
~currentFile
345-
~posCursor:(Pos.ofLexing expr.Parsetree.pexp_loc.loc_start)
346-
with
347-
| Some (completable, scope) -> (
348-
let env = SharedTypes.QueryEnv.fromFile full.SharedTypes.file in
349-
let completions =
350-
completable
351-
|> CompletionBackEnd.processCompletable ~debug ~full ~pos ~scope ~env
352-
~forHover:true
353-
in
354-
let rawOpens = Scope.getRawOpens scope in
355-
match completions with
356-
| {env} :: _ -> (
357-
let opens =
358-
CompletionBackEnd.getOpens ~debug ~rawOpens ~package:full.package
359-
~env
360-
in
361-
match
362-
CompletionBackEnd.completionsGetCompletionType2 ~debug ~full
363-
~rawOpens ~opens ~pos completions
364-
with
365-
| Some (typ, _env) ->
366-
let extractedType =
367-
match typ with
368-
| ExtractedType t -> Some t
369-
| TypeExpr t ->
370-
TypeUtils.extractType t ~env ~package:full.package
371-
|> TypeUtils.getExtractedType
372-
in
373-
extractedType
374-
| None -> None)
375-
| _ -> None)
376-
| _ -> None
377-
end
378-
379565
let mkIteratorSingle ~pos ~result =
380566
let expr (iterator : Ast_iterator.iterator) (exp : Parsetree.expression) =
381567
(match exp.pexp_desc with
@@ -434,7 +620,7 @@ module ExhaustiveSwitch = struct
434620
| Some (Selection {expr}) -> (
435621
match
436622
expr
437-
|> C.extractTypeFromExpr ~debug ~path ~currentFile ~full
623+
|> extractTypeFromExpr ~debug ~path ~currentFile ~full
438624
~pos:(Pos.ofLexing expr.pexp_loc.loc_start)
439625
with
440626
| None -> ()
@@ -460,7 +646,7 @@ module ExhaustiveSwitch = struct
460646
| Some (Switch {switchExpr; completionExpr; pos}) -> (
461647
match
462648
completionExpr
463-
|> C.extractTypeFromExpr ~debug ~path ~currentFile ~full ~pos
649+
|> extractTypeFromExpr ~debug ~path ~currentFile ~full ~pos
464650
with
465651
| None -> ()
466652
| Some extractedType -> (
@@ -743,6 +929,8 @@ let extractCodeActions ~path ~startPos ~endPos ~currentFile ~debug =
743929
match Cmt.loadFullCmtFromPath ~path with
744930
| Some full ->
745931
AddTypeAnnotation.xform ~path ~pos ~full ~structure ~codeActions ~debug;
932+
ExpandCatchAllForVariants.xform ~path ~pos ~full ~structure ~codeActions
933+
~currentFile ~debug;
746934
ExhaustiveSwitch.xform ~printExpr ~path
747935
~pos:
748936
(if startPos = endPos then Single startPos

analysis/tests/src/Xform.res

+38-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
type kind = First | Second | Third
1+
type kind = First | Second | Third | Fourth(int)
22
type r = {name: string, age: int}
33

4-
let ret = _ => assert false
5-
let kind = assert false
4+
let ret = _ => assert(false)
5+
let kind = assert(false)
66

77
if kind == First {
88
// ^xfm
@@ -63,7 +63,7 @@ let bar = () => {
6363
}
6464
//^xfm
6565
}
66-
@res.partial Inner.foo(1)
66+
Inner.foo(1, ...)
6767
}
6868

6969
module ExtractableModule = {
@@ -72,4 +72,37 @@ module ExtractableModule = {
7272
// A comment here
7373
let doStuff = a => a + 1
7474
// ^xfm
75-
}
75+
}
76+
77+
let variant = First
78+
79+
let _x = switch variant {
80+
| First => "first"
81+
| _ => "other"
82+
// ^xfm
83+
}
84+
85+
let polyvariant: [#first | #second | #"illegal identifier" | #third(int)] = #first
86+
87+
let _y = switch polyvariant {
88+
| #first => "first"
89+
| _ => "other"
90+
// ^xfm
91+
}
92+
93+
let variantOpt = Some(variant)
94+
95+
let _x = switch variantOpt {
96+
| Some(First) => "first"
97+
| _ => "other"
98+
// ^xfm
99+
}
100+
101+
let polyvariantOpt = Some(polyvariant)
102+
103+
let _x = switch polyvariantOpt {
104+
| Some(#first) => "first"
105+
| None => "nothing"
106+
| _ => "other"
107+
// ^xfm
108+
}

0 commit comments

Comments
 (0)