Skip to content

Commit

Permalink
feat: implement 'convert to named lambda parameters' code action
Browse files Browse the repository at this point in the history
  • Loading branch information
KacperFKorban committed Aug 13, 2024
1 parent 057e7bc commit 0eabc2e
Show file tree
Hide file tree
Showing 11 changed files with 550 additions and 1 deletion.
14 changes: 14 additions & 0 deletions metals/src/main/scala/scala/meta/internal/metals/Compilers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,20 @@ class Compilers(
}
}.getOrElse(Future.successful(Nil.asJava))

def convertToNamedLambdaParameters(
position: TextDocumentPositionParams,
token: CancelToken,
): Future[ju.List[TextEdit]] = {
withPCAndAdjustLsp(position) { (pc, pos, adjust) =>
pc.convertToNamedLambdaParameters(
CompilerOffsetParamsUtils.fromPos(pos, token)
).asScala
.map { edits =>
adjust.adjustTextEdits(edits)
}
}
}.getOrElse(Future.successful(Nil.asJava))

def implementAbstractMembers(
params: TextDocumentPositionParams,
token: CancelToken,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,20 @@ object ServerCommands {
|""".stripMargin,
)

final case class ConvertToNamedLambdaParametersRequest(
position: TextDocumentPositionParams
)
val ConvertToNamedLambdaParameters =
new ParametrizedCommand[ConvertToNamedLambdaParametersRequest](
"convert-to-named-lambda-parameters",
"Convert wildcard lambda parameters to named parameters",
"""|Whenever a user chooses code action to convert to named lambda parameters, this command is later run to
|rewrite the lambda to use named parameters.
|""".stripMargin,
"""|Object with [TextDocumentPositionParams](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocumentPositionParams) of the target lambda
|""".stripMargin,
)

val GotoLog = new Command(
"goto-log",
"Check logs",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ final class CodeActionProvider(
new MillifyDependencyCodeAction(buffers),
new MillifyScalaCliDependencyCodeAction(buffers),
new ConvertCommentCodeAction(buffers),
new ConvertToNamedLambdaParameters(trees, compilers, languageClient),
)

def codeActions(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package scala.meta.internal.metals.codeactions

import scala.concurrent.ExecutionContext
import scala.concurrent.Future

import scala.meta.Term
import scala.meta.internal.metals.Compilers
import scala.meta.internal.metals.MetalsEnrichments._
import scala.meta.internal.metals.ServerCommands
import scala.meta.internal.metals.clients.language.MetalsLanguageClient
import scala.meta.internal.metals.codeactions.CodeAction
import scala.meta.internal.metals.codeactions.CodeActionBuilder
import scala.meta.internal.metals.logging
import scala.meta.internal.parsing.Trees
import scala.meta.pc.CancelToken

import org.eclipse.{lsp4j => l}

/**
* Code action to convert a wildcard lambda to a lambda with named parameters
* e.g.
*
* List(1, 2).map(<<_>> + 1) => List(1, 2).map(i => i + 1)
*/
class ConvertToNamedLambdaParameters(
trees: Trees,
compilers: Compilers,
languageClient: MetalsLanguageClient,
) extends CodeAction {

override val kind: String = l.CodeActionKind.RefactorRewrite

override type CommandData =
ServerCommands.ConvertToNamedLambdaParametersRequest

override def command: Option[ActionCommand] = Some(
ServerCommands.ConvertToNamedLambdaParameters
)

override def handleCommand(
data: ServerCommands.ConvertToNamedLambdaParametersRequest,
token: CancelToken,
)(implicit ec: ExecutionContext): Future[Unit] = {
val uri = data.position.getTextDocument().getUri()
for {
edits <- compilers.convertToNamedLambdaParameters(
data.position,
token,
)
_ = logging.logErrorWhen(
edits.isEmpty(),
s"Could not convert lambda at position ${data.position} to named lambda",
)
workspaceEdit = new l.WorkspaceEdit(Map(uri -> edits).asJava)
_ <- languageClient
.applyEdit(new l.ApplyWorkspaceEditParams(workspaceEdit))
.asScala
} yield ()
}

override def contribute(
params: l.CodeActionParams,
token: CancelToken,
)(implicit ec: ExecutionContext): Future[Seq[l.CodeAction]] = {
val path = params.getTextDocument().getUri().toAbsolutePath
val range = params.getRange()
val maybeLambda =
trees.findLastEnclosingAt[Term.AnonymousFunction](path, range.getStart())
maybeLambda
.map { lambda =>
val position = new l.TextDocumentPositionParams(
params.getTextDocument(),
new l.Position(lambda.pos.startLine, lambda.pos.startColumn),
)
val command =
ServerCommands.ConvertToNamedLambdaParameters.toLsp(
ServerCommands.ConvertToNamedLambdaParametersRequest(position)
)
val codeAction = CodeActionBuilder.build(
title = ConvertToNamedLambdaParameters.title,
kind = kind,
command = Some(command),
)
Future.successful(Seq(codeAction))
}
.getOrElse(Future.successful(Nil))
}

}

object ConvertToNamedLambdaParameters {
def title: String = "Convert to named lambda parameters"
}
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,15 @@ public CompletableFuture<List<TextEdit>> inlineValue(OffsetParams params) {
public abstract CompletableFuture<List<TextEdit>> convertToNamedArguments(OffsetParams params,
List<Integer> argIndices);

/**
* Return the text edits for converting a wildcard lambda to a named lambda.
*/
public CompletableFuture<List<TextEdit>> convertToNamedLambdaParameters(OffsetParams params) {
return CompletableFuture.supplyAsync(() -> {
throw new DisplayableException("Convert to named lambda parameters is not available in this version of Scala");
});
};

/**
* The text contents of the given file changed.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package scala.meta.internal.mtags

/**
* Helpers for generating variable names based on the desired types.
*/
object TermNameInference {

/** Single character names for types. (`Int` => `i`, `i1`, `i2`, ...) */
def singleLetterNameStream(typeName: String): LazyList[String] = {
val typeName1 = sanitizeInput(typeName)
val firstCharStr = typeName1.headOption.getOrElse('x').toLower.toString
numberedStreamFromName(firstCharStr)
}

/** Names only from upper case letters (`OnDemandSymbolIndex` => `odsi`, `odsi1`, `odsi2`, ...) */
def shortNameStream(typeName: String): LazyList[String] = {
val typeName1 = sanitizeInput(typeName)
val upperCases = typeName1.filter(_.isUpper).map(_.toLower)
val name = if (upperCases.isEmpty) typeName1 else upperCases
numberedStreamFromName(name)
}

/** Names from lower case letters (`OnDemandSymbolIndex` => `onDemandSymbolIndex`, `onDemandSymbolIndex1`, ...) */
def fullNameStream(typeName: String): LazyList[String] = {
val typeName1 = sanitizeInput(typeName)
val withFirstLower =
typeName1.headOption.map(_.toLower).getOrElse('x') + typeName1.drop(1)
numberedStreamFromName(withFirstLower)
}

/** A lazy list of names: a, b, ..., z, aa, ab, ..., az, ba, bb, ... */
def saneNamesStream: LazyList[String] = {
val letters = ('a' to 'z').map(_.toString)
def computeNext(acc: String): String = {
if (acc.last == 'z')
computeNext(acc.init) + letters.head
else
acc.init + letters(letters.indexOf(acc.last) + 1)
}
def loop(acc: String): LazyList[String] =
acc #:: loop(computeNext(acc))
loop("a")
}

private def sanitizeInput(typeName: String): String =
typeName.filter(_.isLetterOrDigit)

private def numberedStreamFromName(name: String): LazyList[String] = {
val rest = LazyList.from(1).map(name + _)
name #:: rest
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package scala.meta.internal.pc

import java.nio.file.Paths

import scala.meta.internal.mtags.MtagsEnrichments.*
import scala.meta.internal.mtags.TermNameInference.*
import scala.meta.pc.OffsetParams

import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.core.Contexts.Context
import dotty.tools.dotc.core.Flags
import dotty.tools.dotc.interactive.Interactive
import dotty.tools.dotc.interactive.InteractiveDriver
import dotty.tools.dotc.util.SourceFile
import dotty.tools.dotc.util.SourcePosition
import org.eclipse.lsp4j as l

/**
* Facilitates the code action that converts a wildcard lambda to a lambda with named parameters
* e.g.
*
* List(1, 2).map(<<_>> + 1) => List(1, 2).map(i => i + 1)
*/
final class ConvertToNamedLambdaParametersProvider(
driver: InteractiveDriver,
params: OffsetParams
):
import ConvertToNamedLambdaParametersProvider._

def convertToNamedLambdaParameters: Either[String, List[l.TextEdit]] = {
val uri = params.uri
val filePath = Paths.get(uri)
driver.run(
uri,
SourceFile.virtual(filePath.toString, params.text),
)
val unit = driver.latestRun
given newctx: Context = driver.currentCtx.fresh.setCompilationUnit(unit)
val pos = driver.sourcePosition(params)
val trees = driver.openedTrees(uri)
val treeList = Interactive.pathTo(trees, pos)
// Extractor for a lambda function (needs context, so has to be defined here)
val LambdaExtractor = Lambda(using newctx)
// select the most inner wildcard lambda
val firstLambda = treeList.collectFirst {
case LambdaExtractor(params, rhsFn) if params.forall(isWildcardParam) =>
params -> rhsFn
}

firstLambda match {
case Some((params, lambda)) =>
// avoid names that are either defined or referenced in the lambda
val namesToAvoid = allDefAndRefNamesInTree(lambda)
// compute parameter names based on the type of the parameter
val computedParamNames: List[String] =
params.foldLeft(List.empty[String]) { (acc, param) =>
val name = singleLetterNameStream(param.tpe.typeSymbol.name.toString())
.find(n => !namesToAvoid.contains(n) && !acc.contains(n))
acc ++ name.toList
}
if computedParamNames.size == params.size then
val paramReferenceEdits = params.zip(computedParamNames).flatMap { (param, paramName) =>
val paramReferencePosition = findParamReferencePosition(param, lambda)
paramReferencePosition.toList.map { pos =>
val position = pos.toLsp
val range = new l.Range(
position.getStart(),
position.getEnd()
)
new l.TextEdit(range, paramName)
}
}
val paramNamesStr = computedParamNames.mkString(", ")
val paramDefsStr =
if params.size == 1 then paramNamesStr
else s"($paramNamesStr)"
val defRange = new l.Range(
lambda.sourcePos.toLsp.getStart(),
lambda.sourcePos.toLsp.getStart()
)
val paramDefinitionEdits = List(
new l.TextEdit(defRange, s"$paramDefsStr => ")
)
Right(paramDefinitionEdits ++ paramReferenceEdits)
else
Right(Nil)
case _ =>
Right(Nil)
}
}

end ConvertToNamedLambdaParametersProvider

object ConvertToNamedLambdaParametersProvider:
class Lambda(using Context):
def unapply(tree: tpd.Block): Option[(List[tpd.ValDef], tpd.Tree)] = tree match {
case tpd.Block((ddef @ tpd.DefDef(_, tpd.ValDefs(params) :: Nil, _, body: tpd.Tree)) :: Nil, tpd.Closure(_, meth, _))
if ddef.symbol == meth.symbol =>
params match {
case List(param) =>
// lambdas with multiple wildcard parameters are represented as a single parameter function and a block with wildcard valdefs
Some(multipleUnderscoresFromBody(param, body))
case _ => Some(params -> body)
}
case _ => None
}
end Lambda

private def multipleUnderscoresFromBody(param: tpd.ValDef, body: tpd.Tree)(using Context): (List[tpd.ValDef], tpd.Tree) = body match {
case tpd.Block(defs, expr) if param.symbol.is(Flags.Synthetic) =>
val wildcardParamDefs = defs.collect {
case valdef: tpd.ValDef if isWildcardParam(valdef) => valdef
}
if wildcardParamDefs.size == defs.size then wildcardParamDefs -> expr
else List(param) -> body
case _ => List(param) -> body
}

def isWildcardParam(param: tpd.ValDef)(using Context): Boolean =
param.name.toString.startsWith("_$") && param.symbol.is(Flags.Synthetic)

def findParamReferencePosition(param: tpd.ValDef, lambda: tpd.Tree)(using Context): Option[SourcePosition] =
var pos: Option[SourcePosition] = None
object FindParamReference extends tpd.TreeTraverser:
override def traverse(tree: tpd.Tree)(using Context): Unit =
tree match
case ident @ tpd.Ident(_) if ident.symbol == param.symbol =>
pos = Some(tree.sourcePos)
case _ =>
traverseChildren(tree)
FindParamReference.traverse(lambda)
pos
end findParamReferencePosition

def allDefAndRefNamesInTree(tree: tpd.Tree)(using Context): List[String] =
object FindDefinitionsAndRefs extends tpd.TreeAccumulator[List[String]]:
override def apply(x: List[String], tree: tpd.Tree)(using Context): List[String] =
tree match
case tpd.DefDef(name, _, _, _) =>
super.foldOver(x :+ name.toString, tree)
case tpd.ValDef(name, _, _) =>
super.foldOver(x :+ name.toString, tree)
case tpd.Ident(name) =>
super.foldOver(x :+ name.toString, tree)
case _ =>
super.foldOver(x, tree)
FindDefinitionsAndRefs.foldOver(Nil, tree)
end allDefAndRefNamesInTree

end ConvertToNamedLambdaParametersProvider
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,25 @@ case class ScalaPresentationCompiler(
case Right(edits: List[l.TextEdit]) => edits.asJava
}
end convertToNamedArguments

override def convertToNamedLambdaParameters(
params: OffsetParams
): ju.concurrent.CompletableFuture[ju.List[l.TextEdit]] =
val empty: Either[String, List[l.TextEdit]] = Right(List())
(compilerAccess
.withInterruptableCompiler(Some(params))(empty, params.token) { pc =>
new ConvertToNamedLambdaParametersProvider(
pc.compiler(),
params
).convertToNamedLambdaParameters
})
.thenApplyAsync {
case Left(error: String) => throw new DisplayableException(error)
case Right(edits: List[l.TextEdit]) => edits.asJava
}
end convertToNamedLambdaParameters


override def selectionRange(
params: ju.List[OffsetParams]
): CompletableFuture[ju.List[l.SelectionRange]] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import scala.meta.internal.metals.TextEdits
import munit.Location
import munit.TestOptions
import org.eclipse.{lsp4j => l}
import tests.BaseCodeActionSuite

class BaseExtractMethodSuite extends BaseCodeActionSuite {
def checkEdit(
Expand Down
Loading

0 comments on commit 0eabc2e

Please sign in to comment.