diff --git a/crates/cairo-lang-language-server/src/ide/hover/render/definition.rs b/crates/cairo-lang-language-server/src/ide/hover/render/definition.rs index 6e7475b3f3b..c1c11fe1054 100644 --- a/crates/cairo-lang-language-server/src/ide/hover/render/definition.rs +++ b/crates/cairo-lang-language-server/src/ide/hover/render/definition.rs @@ -7,10 +7,10 @@ use cairo_lang_utils::Upcast; use tower_lsp::lsp_types::Hover; use crate::ide::hover::markdown_contents; +use crate::ide::hover::render::markdown::{fenced_code_block, RULE}; use crate::lang::db::AnalysisDatabase; use crate::lang::inspect::defs::{MemberDef, SymbolDef}; use crate::lang::lsp::ToLsp; -use crate::markdown::{fenced_code_block, RULE}; /// Get declaration and documentation "definition" of an item referred by the given identifier. #[tracing::instrument(level = "trace", skip_all)] diff --git a/crates/cairo-lang-language-server/src/ide/hover/render/legacy.rs b/crates/cairo-lang-language-server/src/ide/hover/render/legacy.rs index aea5f9cacb4..27608be8dc5 100644 --- a/crates/cairo-lang-language-server/src/ide/hover/render/legacy.rs +++ b/crates/cairo-lang-language-server/src/ide/hover/render/legacy.rs @@ -11,8 +11,8 @@ use cairo_lang_utils::Upcast; use tower_lsp::lsp_types::Hover; use crate::ide::hover::markdown_contents; +use crate::ide::hover::render::markdown::{fenced_code_block, RULE}; use crate::lang::db::{AnalysisDatabase, LsSemanticGroup}; -use crate::markdown::{fenced_code_block, RULE}; /// Legacy hover rendering backported from Cairo 2.6.3 codebase. /// diff --git a/crates/cairo-lang-language-server/src/markdown.rs b/crates/cairo-lang-language-server/src/ide/hover/render/markdown.rs similarity index 100% rename from crates/cairo-lang-language-server/src/markdown.rs rename to crates/cairo-lang-language-server/src/ide/hover/render/markdown.rs diff --git a/crates/cairo-lang-language-server/src/ide/hover/render/mod.rs b/crates/cairo-lang-language-server/src/ide/hover/render/mod.rs index 629eb0d709d..b9506f35925 100644 --- a/crates/cairo-lang-language-server/src/ide/hover/render/mod.rs +++ b/crates/cairo-lang-language-server/src/ide/hover/render/mod.rs @@ -3,3 +3,4 @@ pub use self::legacy::*; mod definition; mod legacy; +mod markdown; diff --git a/crates/cairo-lang-language-server/src/ide/navigation/goto_definition.rs b/crates/cairo-lang-language-server/src/ide/navigation/goto_definition.rs index 54a66517d39..2dfcb368b01 100644 --- a/crates/cairo-lang-language-server/src/ide/navigation/goto_definition.rs +++ b/crates/cairo-lang-language-server/src/ide/navigation/goto_definition.rs @@ -1,8 +1,11 @@ +use cairo_lang_filesystem::db::get_originating_location; +use cairo_lang_filesystem::ids::FileId; +use cairo_lang_filesystem::span::{TextPosition, TextSpan}; use cairo_lang_utils::Upcast; use tower_lsp::lsp_types::{GotoDefinitionParams, GotoDefinitionResponse, Location}; -use crate::get_definition_location; -use crate::lang::db::AnalysisDatabase; +use crate::lang::db::{AnalysisDatabase, LsSemanticGroup, LsSyntaxGroup}; +use crate::lang::inspect::defs::find_definition; use crate::lang::lsp::{LsProtoGroup, ToCairo, ToLsp}; /// Get the definition location of a symbol at a given text document position. @@ -23,3 +26,34 @@ pub fn goto_definition( let range = span.position_in_file(db.upcast(), found_file)?.to_lsp(); Some(GotoDefinitionResponse::Scalar(Location { uri: found_uri, range })) } + +/// Returns the file id and span of the definition of an expression from its position. +/// +/// # Arguments +/// +/// * `db` - Preloaded compilation database +/// * `uri` - Uri of the expression position +/// * `position` - Position of the expression +/// +/// # Returns +/// +/// The [FileId] and [TextSpan] of the expression definition if found. +fn get_definition_location( + db: &AnalysisDatabase, + file: FileId, + position: TextPosition, +) -> Option<(FileId, TextSpan)> { + let identifier = db.find_identifier_at_position(file, position)?; + + let syntax_db = db.upcast(); + let node = db.find_syntax_node_at_position(file, position)?; + let lookup_items = db.collect_lookup_items_stack(&node)?; + let (_, stable_ptr) = find_definition(db, &identifier, &lookup_items)?; + let node = stable_ptr.lookup(syntax_db); + let found_file = stable_ptr.file_id(syntax_db); + let span = node.span_without_trivia(syntax_db); + let width = span.width(); + let (file_id, mut span) = get_originating_location(db.upcast(), found_file, span.start_only()); + span.end = span.end.add_width(width); + Some((file_id, span)) +} diff --git a/crates/cairo-lang-language-server/src/lang/inspect/defs.rs b/crates/cairo-lang-language-server/src/lang/inspect/defs.rs index 750b9eb95af..8dc364c65c4 100644 --- a/crates/cairo-lang-language-server/src/lang/inspect/defs.rs +++ b/crates/cairo-lang-language-server/src/lang/inspect/defs.rs @@ -1,27 +1,33 @@ use std::iter; +use cairo_lang_defs::db::DefsGroup; use cairo_lang_defs::ids::{ - LanguageElementId, LookupItemId, MemberId, ModuleItemId, TopLevelLanguageElementId, TraitItemId, + FunctionTitleId, LanguageElementId, LookupItemId, MemberId, ModuleId, ModuleItemId, + SubmoduleLongId, TopLevelLanguageElementId, TraitItemId, }; +use cairo_lang_diagnostics::ToOption; use cairo_lang_doc::db::DocGroup; +use cairo_lang_parser::db::ParserGroup; use cairo_lang_semantic::db::SemanticGroup; use cairo_lang_semantic::expr::pattern::QueryPatternVariablesFromDb; use cairo_lang_semantic::items::function_with_body::SemanticExprLookup; +use cairo_lang_semantic::items::functions::GenericFunctionId; +use cairo_lang_semantic::items::imp::ImplLongId; use cairo_lang_semantic::lookup_item::LookupItemEx; use cairo_lang_semantic::resolve::{ResolvedConcreteItem, ResolvedGenericItem}; -use cairo_lang_semantic::{Binding, Mutability}; +use cairo_lang_semantic::{Binding, Expr, Mutability, TypeLongId}; use cairo_lang_syntax::node::ast::{Param, PatternIdentifier, PatternPtr, TerminalIdentifier}; +use cairo_lang_syntax::node::ids::SyntaxStablePtrId; use cairo_lang_syntax::node::kind::SyntaxKind; use cairo_lang_syntax::node::utils::is_grandparent_of_kind; -use cairo_lang_syntax::node::{SyntaxNode, Terminal, TypedSyntaxNode}; -use cairo_lang_utils::Upcast; +use cairo_lang_syntax::node::{ast, SyntaxNode, Terminal, TypedStablePtr, TypedSyntaxNode}; +use cairo_lang_utils::{Intern, LookupIntern, Upcast}; use itertools::Itertools; use smol_str::SmolStr; use tracing::error; -use crate::lang::db::{AnalysisDatabase, LsSemanticGroup}; +use crate::lang::db::{AnalysisDatabase, LsSemanticGroup, LsSyntaxGroup}; use crate::lang::inspect::defs::SymbolDef::Member; -use crate::{find_definition, ResolvedItem}; /// Keeps information about the symbol that is being searched for/inspected. /// @@ -40,6 +46,13 @@ pub struct MemberDef { pub structure: ItemDef, } +/// Either [`ResolvedGenericItem`], [`ResolvedConcreteItem`] or [`MemberId`]. +pub enum ResolvedItem { + Generic(ResolvedGenericItem), + Concrete(ResolvedConcreteItem), + Member(MemberId), +} + impl SymbolDef { /// Finds definition of the symbol referred by the given identifier. #[tracing::instrument(name = "SymbolDef::find", level = "trace", skip_all)] @@ -277,3 +290,186 @@ impl VariableDef { format!("{prefix}{mutability}{name}: {ty}") } } + +// TODO(mkaput): make private. +#[tracing::instrument(level = "trace", skip_all)] +pub fn find_definition( + db: &AnalysisDatabase, + identifier: &ast::TerminalIdentifier, + lookup_items: &[LookupItemId], +) -> Option<(ResolvedItem, SyntaxStablePtrId)> { + if let Some(parent) = identifier.as_syntax_node().parent() { + if parent.kind(db) == SyntaxKind::ItemModule { + let Some(containing_module_file_id) = db.find_module_file_containing_node(&parent) + else { + error!("`find_definition` failed: could not find module"); + return None; + }; + + let submodule_id = SubmoduleLongId( + containing_module_file_id, + ast::ItemModule::from_syntax_node(db, parent).stable_ptr(), + ) + .intern(db); + let item = ResolvedGenericItem::Module(ModuleId::Submodule(submodule_id)); + return Some(( + ResolvedItem::Generic(item.clone()), + resolved_generic_item_def(db, item)?, + )); + } + } + + if let Some(member_id) = try_extract_member(db, identifier, lookup_items) { + return Some((ResolvedItem::Member(member_id), member_id.untyped_stable_ptr(db))); + } + + for lookup_item_id in lookup_items.iter().copied() { + if let Some(item) = + db.lookup_resolved_generic_item_by_ptr(lookup_item_id, identifier.stable_ptr()) + { + return Some(( + ResolvedItem::Generic(item.clone()), + resolved_generic_item_def(db, item)?, + )); + } + + if let Some(item) = + db.lookup_resolved_concrete_item_by_ptr(lookup_item_id, identifier.stable_ptr()) + { + let stable_ptr = resolved_concrete_item_def(db.upcast(), item.clone())?; + return Some((ResolvedItem::Concrete(item), stable_ptr)); + } + } + + // Skip variable definition, otherwise we would get parent ModuleItem for variable. + if db.first_ancestor_of_kind(identifier.as_syntax_node(), SyntaxKind::StatementLet).is_none() { + let item = match lookup_items.first().copied()? { + LookupItemId::ModuleItem(item) => { + ResolvedGenericItem::from_module_item(db, item).to_option()? + } + LookupItemId::TraitItem(trait_item) => { + if let TraitItemId::Function(trait_fn) = trait_item { + ResolvedGenericItem::TraitFunction(trait_fn) + } else { + ResolvedGenericItem::Trait(trait_item.trait_id(db)) + } + } + LookupItemId::ImplItem(impl_item) => { + ResolvedGenericItem::Impl(impl_item.impl_def_id(db)) + } + }; + + Some((ResolvedItem::Generic(item.clone()), resolved_generic_item_def(db, item)?)) + } else { + None + } +} + +/// Extracts [`MemberId`] if the [`ast::TerminalIdentifier`] points to +/// right-hand side of access member expression e.g., to `xyz` in `self.xyz`. +fn try_extract_member( + db: &AnalysisDatabase, + identifier: &ast::TerminalIdentifier, + lookup_items: &[LookupItemId], +) -> Option { + let syntax_node = identifier.as_syntax_node(); + let binary_expr_syntax_node = + db.first_ancestor_of_kind(syntax_node.clone(), SyntaxKind::ExprBinary)?; + let binary_expr = ast::ExprBinary::from_syntax_node(db, binary_expr_syntax_node); + + let function_with_body = lookup_items.first()?.function_with_body()?; + + let expr_id = + db.lookup_expr_by_ptr(function_with_body, binary_expr.stable_ptr().into()).ok()?; + let semantic_expr = db.expr_semantic(function_with_body, expr_id); + + if let Expr::MemberAccess(expr_member_access) = semantic_expr { + let pointer_to_rhs = binary_expr.rhs(db).stable_ptr().untyped(); + + let mut current_node = syntax_node; + // Check if the terminal identifier points to a member, not a struct variable. + while pointer_to_rhs != current_node.stable_ptr() { + // If we found the node with the binary expression, then we are sure we won't find the + // node with the member. + if current_node.stable_ptr() == binary_expr.stable_ptr().untyped() { + return None; + } + current_node = current_node.parent().unwrap(); + } + + Some(expr_member_access.member) + } else { + None + } +} + +#[tracing::instrument(level = "trace", skip_all)] +fn resolved_concrete_item_def( + db: &AnalysisDatabase, + item: ResolvedConcreteItem, +) -> Option { + match item { + ResolvedConcreteItem::Type(ty) => { + if let TypeLongId::GenericParameter(param) = ty.lookup_intern(db) { + Some(param.untyped_stable_ptr(db.upcast())) + } else { + None + } + } + ResolvedConcreteItem::Impl(imp) => { + if let ImplLongId::GenericParameter(param) = imp.lookup_intern(db) { + Some(param.untyped_stable_ptr(db.upcast())) + } else { + None + } + } + _ => None, + } +} + +#[tracing::instrument(level = "trace", skip_all)] +fn resolved_generic_item_def( + db: &AnalysisDatabase, + item: ResolvedGenericItem, +) -> Option { + let defs_db = db.upcast(); + Some(match item { + ResolvedGenericItem::GenericConstant(item) => item.untyped_stable_ptr(defs_db), + ResolvedGenericItem::Module(module_id) => { + // Check if the module is an inline submodule. + if let ModuleId::Submodule(submodule_id) = module_id { + if let ast::MaybeModuleBody::Some(submodule_id) = + submodule_id.stable_ptr(defs_db).lookup(db.upcast()).body(db.upcast()) + { + // Inline module. + return Some(submodule_id.stable_ptr().untyped()); + } + } + let module_file = db.module_main_file(module_id).ok()?; + let file_syntax = db.file_module_syntax(module_file).ok()?; + file_syntax.as_syntax_node().stable_ptr() + } + ResolvedGenericItem::GenericFunction(item) => { + let title = match item { + GenericFunctionId::Free(id) => FunctionTitleId::Free(id), + GenericFunctionId::Extern(id) => FunctionTitleId::Extern(id), + GenericFunctionId::Impl(id) => { + // Note: Only the trait title is returned. + FunctionTitleId::Trait(id.function) + } + GenericFunctionId::Trait(id) => FunctionTitleId::Trait(id.trait_function(db)), + }; + title.untyped_stable_ptr(defs_db) + } + ResolvedGenericItem::GenericType(generic_type) => generic_type.untyped_stable_ptr(defs_db), + ResolvedGenericItem::GenericTypeAlias(type_alias) => type_alias.untyped_stable_ptr(defs_db), + ResolvedGenericItem::GenericImplAlias(impl_alias) => impl_alias.untyped_stable_ptr(defs_db), + ResolvedGenericItem::Variant(variant) => variant.id.stable_ptr(defs_db).untyped(), + ResolvedGenericItem::Trait(trt) => trt.stable_ptr(defs_db).untyped(), + ResolvedGenericItem::Impl(imp) => imp.stable_ptr(defs_db).untyped(), + ResolvedGenericItem::TraitFunction(trait_function) => { + trait_function.stable_ptr(defs_db).untyped() + } + ResolvedGenericItem::Variable(var) => var.untyped_stable_ptr(defs_db), + }) +} diff --git a/crates/cairo-lang-language-server/src/lib.rs b/crates/cairo-lang-language-server/src/lib.rs index 6c1aadf288f..dc91f7c0439 100644 --- a/crates/cairo-lang-language-server/src/lib.rs +++ b/crates/cairo-lang-language-server/src/lib.rs @@ -45,76 +45,57 @@ use std::path::PathBuf; use std::sync::Arc; use std::time::{Duration, SystemTime}; -use anyhow::{bail, Context}; +use anyhow::Context; use cairo_lang_compiler::db::validate_corelib; use cairo_lang_compiler::project::{setup_project, update_crate_roots_from_project_config}; use cairo_lang_defs::db::DefsGroup; -use cairo_lang_defs::ids::{ - FunctionTitleId, LanguageElementId, LookupItemId, MemberId, ModuleId, SubmoduleLongId, - TraitItemId, -}; -use cairo_lang_diagnostics::{Diagnostics, ToOption}; -use cairo_lang_filesystem::db::{ - get_originating_location, AsFilesGroupMut, FilesGroup, FilesGroupEx, PrivRawFileContentQuery, -}; +use cairo_lang_defs::ids::ModuleId; +use cairo_lang_diagnostics::Diagnostics; +use cairo_lang_filesystem::db::FilesGroup; use cairo_lang_filesystem::ids::{FileId, FileLongId}; -use cairo_lang_filesystem::span::{TextPosition, TextSpan}; use cairo_lang_lowering::db::LoweringGroup; use cairo_lang_lowering::diagnostic::LoweringDiagnostic; use cairo_lang_parser::db::ParserGroup; use cairo_lang_project::ProjectConfig; use cairo_lang_semantic::db::SemanticGroup; -use cairo_lang_semantic::items::function_with_body::SemanticExprLookup; -use cairo_lang_semantic::items::functions::GenericFunctionId; -use cairo_lang_semantic::items::imp::ImplLongId; -use cairo_lang_semantic::lookup_item::LookupItemEx; use cairo_lang_semantic::plugin::PluginSuite; -use cairo_lang_semantic::resolve::{ResolvedConcreteItem, ResolvedGenericItem}; -use cairo_lang_semantic::{Expr, SemanticDiagnostic, TypeLongId}; -use cairo_lang_syntax::node::ids::SyntaxStablePtrId; -use cairo_lang_syntax::node::kind::SyntaxKind; -use cairo_lang_syntax::node::{ast, TypedStablePtr, TypedSyntaxNode}; +use cairo_lang_semantic::SemanticDiagnostic; use cairo_lang_utils::ordered_hash_map::OrderedHashMap; use cairo_lang_utils::{Intern, LookupIntern, Upcast}; use itertools::Itertools; use salsa::{Cancelled, ParallelDatabase}; -use serde_json::Value; -use state::{FileDiagnostics, Owned, StateSnapshot}; +use state::{FileDiagnostics, StateSnapshot}; use tokio::sync::Semaphore; use tokio::task::spawn_blocking; use tower_lsp::jsonrpc::{Error as LSPError, Result as LSPResult}; use tower_lsp::lsp_types::request::Request; -use tower_lsp::lsp_types::*; +use tower_lsp::lsp_types::{TextDocumentPositionParams, Url}; use tower_lsp::{Client, ClientSocket, LanguageServer, LspService, Server}; use tracing::{debug, error, info, trace_span, warn, Instrument}; use crate::config::Config; -use crate::lang::db::{AnalysisDatabase, LsSemanticGroup, LsSyntaxGroup}; +use crate::lang::db::AnalysisDatabase; use crate::lang::diagnostics::lsp::map_cairo_diagnostics_to_lsp; use crate::lang::lsp::LsProtoGroup; -use crate::lsp::capabilities::server::{ - collect_dynamic_registrations, collect_server_capabilities, +use crate::lsp::ext::{ + CorelibVersionMismatch, ProvideVirtualFileRequest, ProvideVirtualFileResponse, }; -use crate::lsp::ext::CorelibVersionMismatch; use crate::project::scarb::update_crate_roots; use crate::project::unmanaged_core_crate::try_to_init_unmanaged_core; use crate::project::ProjectManifestPath; use crate::server::notifier::Notifier; use crate::state::State; use crate::toolchain::scarb::ScarbToolchain; -use crate::vfs::{ProvideVirtualFileRequest, ProvideVirtualFileResponse}; mod config; mod env_config; mod ide; mod lang; pub mod lsp; -mod markdown; mod project; mod server; mod state; mod toolchain; -mod vfs; /// Carries various customizations that can be applied to CairoLS. /// @@ -291,7 +272,7 @@ macro_rules! state_mut_async { impl Backend { fn build_service(tricks: Tricks) -> (LspService, ClientSocket) { LspService::build(|client| Self::new(client, tricks)) - .custom_method("vfs/provide", Self::vfs_provide) + .custom_method(lsp::ext::ProvideVirtualFile::METHOD, Self::vfs_provide) .custom_method(lsp::ext::ViewAnalyzedCrates::METHOD, Self::view_analyzed_crates) .custom_method(lsp::ext::ExpandMacro::METHOD, Self::expand_macro) .finish() @@ -806,461 +787,3 @@ impl Backend { self.refresh_diagnostics().await.ok(); } } - -enum ServerCommands { - Reload, -} - -impl TryFrom for ServerCommands { - type Error = anyhow::Error; - - fn try_from(value: String) -> anyhow::Result { - match value.as_str() { - "cairo.reload" => Ok(ServerCommands::Reload), - _ => bail!("Unrecognized command: {value}"), - } - } -} - -#[tower_lsp::async_trait] -impl LanguageServer for Backend { - #[tracing::instrument(level = "debug", skip_all)] - async fn initialize(&self, params: InitializeParams) -> LSPResult { - let client_capabilities = Owned::new(Arc::new(params.capabilities)); - let client_capabilities_snapshot = client_capabilities.snapshot(); - self.with_state_mut(move |state| { - state.client_capabilities = client_capabilities; - }) - .await; - - Ok(InitializeResult { - server_info: None, - capabilities: collect_server_capabilities(&client_capabilities_snapshot), - }) - } - - #[tracing::instrument(level = "debug", skip_all)] - async fn initialized(&self, _: InitializedParams) { - // Initialize the configuration. - self.reload_config().await; - - // Dynamically register capabilities. - let client_capabilities = self.state_snapshot().await.client_capabilities; - - let dynamic_registrations = collect_dynamic_registrations(&client_capabilities); - if !dynamic_registrations.is_empty() { - let result = self.client.register_capability(dynamic_registrations).await; - if let Err(err) = result { - warn!("failed to register dynamic capabilities: {err:#?}"); - } - } - } - - async fn shutdown(&self) -> LSPResult<()> { - Ok(()) - } - - #[tracing::instrument(level = "debug", skip_all)] - async fn did_change_configuration(&self, _: DidChangeConfigurationParams) { - self.reload_config().await; - } - - #[tracing::instrument(level = "debug", skip_all)] - async fn did_change_watched_files(&self, params: DidChangeWatchedFilesParams) { - // Invalidate changed cairo files. - self.with_state_mut(|state| { - for change in ¶ms.changes { - if is_cairo_file_path(&change.uri) { - let Some(file) = state.db.file_for_url(&change.uri) else { continue }; - PrivRawFileContentQuery - .in_db_mut(state.db.as_files_group_mut()) - .invalidate(&file); - } - } - }) - .await; - - // Reload workspace if a config file has changed. - for change in params.changes { - let changed_file_path = change.uri.to_file_path().unwrap_or_default(); - let changed_file_name = changed_file_path.file_name().unwrap_or_default(); - // TODO(pmagiera): react to Scarb.lock. Keep in mind Scarb does save Scarb.lock on each - // metadata call, so it is easy to fall in a loop here. - if ["Scarb.toml", "cairo_project.toml"].map(Some).contains(&changed_file_name.to_str()) - { - self.reload().await.ok(); - } - } - } - - #[tracing::instrument(level = "debug", skip_all, fields(command = params.command))] - async fn execute_command(&self, params: ExecuteCommandParams) -> LSPResult> { - let command = ServerCommands::try_from(params.command); - if let Ok(cmd) = command { - match cmd { - ServerCommands::Reload => { - self.reload().await?; - } - } - } - - match self.client.apply_edit(WorkspaceEdit::default()).await { - Ok(res) if res.applied => self.client.log_message(MessageType::INFO, "applied").await, - Ok(_) => self.client.log_message(MessageType::INFO, "rejected").await, - Err(err) => self.client.log_message(MessageType::ERROR, err).await, - } - - Ok(None) - } - - #[tracing::instrument(level = "debug", skip_all, fields(uri = %params.text_document.uri))] - async fn did_open(&self, params: DidOpenTextDocumentParams) { - let refresh = state_mut_async! {state, self, - let uri = params.text_document.uri; - - // Try to detect the crate for physical files. - // The crate for virtual files is already known. - if uri.scheme() == "file" { - let Ok(path) = uri.to_file_path() else { return false }; - self.detect_crate_for(&mut state.db, &state.config, path).await; - } - - let Some(file_id) = state.db.file_for_url(&uri) else { return false }; - state.open_files.insert(uri); - state.db.override_file_content(file_id, Some(params.text_document.text.into())); - - true - } - .await; - - if refresh { - self.refresh_diagnostics().await.ok(); - } - } - - #[tracing::instrument(level = "debug", skip_all, fields(uri = %params.text_document.uri))] - async fn did_change(&self, params: DidChangeTextDocumentParams) { - let text = if let Ok([TextDocumentContentChangeEvent { text, .. }]) = - TryInto::<[_; 1]>::try_into(params.content_changes) - { - text - } else { - error!("unexpected format of document change"); - return; - }; - let refresh = self - .with_state_mut(|state| { - let uri = params.text_document.uri; - let Some(file) = state.db.file_for_url(&uri) else { return false }; - state.db.override_file_content(file, Some(text.into())); - - true - }) - .await; - - if refresh { - self.refresh_diagnostics().await.ok(); - } - } - - #[tracing::instrument(level = "debug", skip_all, fields(uri = %params.text_document.uri))] - async fn did_save(&self, params: DidSaveTextDocumentParams) { - self.with_state_mut(|state| { - let Some(file) = state.db.file_for_url(¶ms.text_document.uri) else { return }; - PrivRawFileContentQuery.in_db_mut(state.db.as_files_group_mut()).invalidate(&file); - state.db.override_file_content(file, None); - }) - .await; - } - - #[tracing::instrument(level = "debug", skip_all, fields(uri = %params.text_document.uri))] - async fn did_close(&self, params: DidCloseTextDocumentParams) { - let refresh = self - .with_state_mut(|state| { - state.open_files.remove(¶ms.text_document.uri); - let Some(file) = state.db.file_for_url(¶ms.text_document.uri) else { - return false; - }; - state.db.override_file_content(file, None); - - true - }) - .await; - - if refresh { - self.refresh_diagnostics().await.ok(); - } - } - - #[tracing::instrument(level = "trace", skip_all)] - async fn completion(&self, params: CompletionParams) -> LSPResult> { - let db = self.db_snapshot().await; - self.catch_panics(move || ide::completion::complete(params, &db)).await - } - - #[tracing::instrument(level = "trace", skip_all)] - async fn semantic_tokens_full( - &self, - params: SemanticTokensParams, - ) -> LSPResult> { - let db = self.db_snapshot().await; - self.catch_panics(move || ide::semantic_highlighting::semantic_highlight_full(params, &db)) - .await - } - - #[tracing::instrument(level = "trace", skip_all)] - async fn formatting( - &self, - params: DocumentFormattingParams, - ) -> LSPResult>> { - let db = self.db_snapshot().await; - self.catch_panics(move || ide::formatter::format(params, &db)).await - } - - #[tracing::instrument(level = "trace", skip_all)] - async fn hover(&self, params: HoverParams) -> LSPResult> { - let db = self.db_snapshot().await; - self.catch_panics(move || ide::hover::hover(params, &db)).await - } - - #[tracing::instrument(level = "trace", skip_all)] - async fn goto_definition( - &self, - params: GotoDefinitionParams, - ) -> LSPResult> { - let db = self.db_snapshot().await; - self.catch_panics(move || ide::navigation::goto_definition::goto_definition(params, &db)) - .await - } - - #[tracing::instrument(level = "trace", skip_all)] - async fn code_action(&self, params: CodeActionParams) -> LSPResult> { - let db = self.db_snapshot().await; - self.catch_panics(move || ide::code_actions::code_actions(params, &db)).await - } -} - -/// Extracts [`MemberId`] if the [`ast::TerminalIdentifier`] points to -/// right-hand side of access member expression e.g., to `xyz` in `self.xyz`. -fn try_extract_member( - db: &AnalysisDatabase, - identifier: &ast::TerminalIdentifier, - lookup_items: &[LookupItemId], -) -> Option { - let syntax_node = identifier.as_syntax_node(); - let binary_expr_syntax_node = - db.first_ancestor_of_kind(syntax_node.clone(), SyntaxKind::ExprBinary)?; - let binary_expr = ast::ExprBinary::from_syntax_node(db, binary_expr_syntax_node); - - let function_with_body = lookup_items.first()?.function_with_body()?; - - let expr_id = - db.lookup_expr_by_ptr(function_with_body, binary_expr.stable_ptr().into()).ok()?; - let semantic_expr = db.expr_semantic(function_with_body, expr_id); - - if let Expr::MemberAccess(expr_member_access) = semantic_expr { - let pointer_to_rhs = binary_expr.rhs(db).stable_ptr().untyped(); - - let mut current_node = syntax_node; - // Check if the terminal identifier points to a member, not a struct variable. - while pointer_to_rhs != current_node.stable_ptr() { - // If we found the node with the binary expression, then we are sure we won't find the - // node with the member. - if current_node.stable_ptr() == binary_expr.stable_ptr().untyped() { - return None; - } - current_node = current_node.parent().unwrap(); - } - - Some(expr_member_access.member) - } else { - None - } -} - -/// Either [`ResolvedGenericItem`], [`ResolvedConcreteItem`] or [`MemberId`]. -enum ResolvedItem { - Generic(ResolvedGenericItem), - Concrete(ResolvedConcreteItem), - Member(MemberId), -} - -// TODO(mkaput): Move this to crate::lang::inspect::defs and make private. -#[tracing::instrument(level = "trace", skip_all)] -fn find_definition( - db: &AnalysisDatabase, - identifier: &ast::TerminalIdentifier, - lookup_items: &[LookupItemId], -) -> Option<(ResolvedItem, SyntaxStablePtrId)> { - if let Some(parent) = identifier.as_syntax_node().parent() { - if parent.kind(db) == SyntaxKind::ItemModule { - let Some(containing_module_file_id) = db.find_module_file_containing_node(&parent) - else { - error!("`find_definition` failed: could not find module"); - return None; - }; - - let submodule_id = SubmoduleLongId( - containing_module_file_id, - ast::ItemModule::from_syntax_node(db, parent).stable_ptr(), - ) - .intern(db); - let item = ResolvedGenericItem::Module(ModuleId::Submodule(submodule_id)); - return Some(( - ResolvedItem::Generic(item.clone()), - resolved_generic_item_def(db, item)?, - )); - } - } - - if let Some(member_id) = try_extract_member(db, identifier, lookup_items) { - return Some((ResolvedItem::Member(member_id), member_id.untyped_stable_ptr(db))); - } - - for lookup_item_id in lookup_items.iter().copied() { - if let Some(item) = - db.lookup_resolved_generic_item_by_ptr(lookup_item_id, identifier.stable_ptr()) - { - return Some(( - ResolvedItem::Generic(item.clone()), - resolved_generic_item_def(db, item)?, - )); - } - - if let Some(item) = - db.lookup_resolved_concrete_item_by_ptr(lookup_item_id, identifier.stable_ptr()) - { - let stable_ptr = resolved_concrete_item_def(db.upcast(), item.clone())?; - return Some((ResolvedItem::Concrete(item), stable_ptr)); - } - } - - // Skip variable definition, otherwise we would get parent ModuleItem for variable. - if db.first_ancestor_of_kind(identifier.as_syntax_node(), SyntaxKind::StatementLet).is_none() { - let item = match lookup_items.first().copied()? { - LookupItemId::ModuleItem(item) => { - ResolvedGenericItem::from_module_item(db, item).to_option()? - } - LookupItemId::TraitItem(trait_item) => { - if let TraitItemId::Function(trait_fn) = trait_item { - ResolvedGenericItem::TraitFunction(trait_fn) - } else { - ResolvedGenericItem::Trait(trait_item.trait_id(db)) - } - } - LookupItemId::ImplItem(impl_item) => { - ResolvedGenericItem::Impl(impl_item.impl_def_id(db)) - } - }; - - Some((ResolvedItem::Generic(item.clone()), resolved_generic_item_def(db, item)?)) - } else { - None - } -} - -#[tracing::instrument(level = "trace", skip_all)] -fn resolved_concrete_item_def( - db: &AnalysisDatabase, - item: ResolvedConcreteItem, -) -> Option { - match item { - ResolvedConcreteItem::Type(ty) => { - if let TypeLongId::GenericParameter(param) = ty.lookup_intern(db) { - Some(param.untyped_stable_ptr(db.upcast())) - } else { - None - } - } - ResolvedConcreteItem::Impl(imp) => { - if let ImplLongId::GenericParameter(param) = imp.lookup_intern(db) { - Some(param.untyped_stable_ptr(db.upcast())) - } else { - None - } - } - _ => None, - } -} - -#[tracing::instrument(level = "trace", skip_all)] -fn resolved_generic_item_def( - db: &AnalysisDatabase, - item: ResolvedGenericItem, -) -> Option { - let defs_db = db.upcast(); - Some(match item { - ResolvedGenericItem::GenericConstant(item) => item.untyped_stable_ptr(defs_db), - ResolvedGenericItem::Module(module_id) => { - // Check if the module is an inline submodule. - if let ModuleId::Submodule(submodule_id) = module_id { - if let ast::MaybeModuleBody::Some(submodule_id) = - submodule_id.stable_ptr(defs_db).lookup(db.upcast()).body(db.upcast()) - { - // Inline module. - return Some(submodule_id.stable_ptr().untyped()); - } - } - let module_file = db.module_main_file(module_id).ok()?; - let file_syntax = db.file_module_syntax(module_file).ok()?; - file_syntax.as_syntax_node().stable_ptr() - } - ResolvedGenericItem::GenericFunction(item) => { - let title = match item { - GenericFunctionId::Free(id) => FunctionTitleId::Free(id), - GenericFunctionId::Extern(id) => FunctionTitleId::Extern(id), - GenericFunctionId::Impl(id) => { - // Note: Only the trait title is returned. - FunctionTitleId::Trait(id.function) - } - GenericFunctionId::Trait(id) => FunctionTitleId::Trait(id.trait_function(db)), - }; - title.untyped_stable_ptr(defs_db) - } - ResolvedGenericItem::GenericType(generic_type) => generic_type.untyped_stable_ptr(defs_db), - ResolvedGenericItem::GenericTypeAlias(type_alias) => type_alias.untyped_stable_ptr(defs_db), - ResolvedGenericItem::GenericImplAlias(impl_alias) => impl_alias.untyped_stable_ptr(defs_db), - ResolvedGenericItem::Variant(variant) => variant.id.stable_ptr(defs_db).untyped(), - ResolvedGenericItem::Trait(trt) => trt.stable_ptr(defs_db).untyped(), - ResolvedGenericItem::Impl(imp) => imp.stable_ptr(defs_db).untyped(), - ResolvedGenericItem::TraitFunction(trait_function) => { - trait_function.stable_ptr(defs_db).untyped() - } - ResolvedGenericItem::Variable(var) => var.untyped_stable_ptr(defs_db), - }) -} - -fn is_cairo_file_path(file_path: &Url) -> bool { - file_path.path().ends_with(".cairo") -} - -/// Returns the file id and span of the definition of an expression from its position. -/// -/// # Arguments -/// -/// * `db` - Preloaded compilation database -/// * `uri` - Uri of the expression position -/// * `position` - Position of the expression -/// -/// # Returns -/// -/// The [FileId] and [TextSpan] of the expression definition if found. -fn get_definition_location( - db: &AnalysisDatabase, - file: FileId, - position: TextPosition, -) -> Option<(FileId, TextSpan)> { - let identifier = db.find_identifier_at_position(file, position)?; - - let syntax_db = db.upcast(); - let node = db.find_syntax_node_at_position(file, position)?; - let lookup_items = db.collect_lookup_items_stack(&node)?; - let (_, stable_ptr) = find_definition(db, &identifier, &lookup_items)?; - let node = stable_ptr.lookup(syntax_db); - let found_file = stable_ptr.file_id(syntax_db); - let span = node.span_without_trivia(syntax_db); - let width = span.width(); - let (file_id, mut span) = get_originating_location(db.upcast(), found_file, span.start_only()); - span.end = span.end.add_width(width); - Some((file_id, span)) -} diff --git a/crates/cairo-lang-language-server/src/lsp/controller.rs b/crates/cairo-lang-language-server/src/lsp/controller.rs new file mode 100644 index 00000000000..d4543c57057 --- /dev/null +++ b/crates/cairo-lang-language-server/src/lsp/controller.rs @@ -0,0 +1,259 @@ +use std::sync::Arc; + +use cairo_lang_filesystem::db::{AsFilesGroupMut, FilesGroupEx, PrivRawFileContentQuery}; +use serde_json::Value; +use tower_lsp::jsonrpc::Result as LSPResult; +use tower_lsp::lsp_types::{ + CodeActionParams, CodeActionResponse, CompletionParams, CompletionResponse, + DidChangeConfigurationParams, DidChangeTextDocumentParams, DidChangeWatchedFilesParams, + DidCloseTextDocumentParams, DidOpenTextDocumentParams, DidSaveTextDocumentParams, + DocumentFormattingParams, ExecuteCommandParams, GotoDefinitionParams, GotoDefinitionResponse, + Hover, HoverParams, InitializeParams, InitializeResult, InitializedParams, MessageType, + SemanticTokensParams, SemanticTokensResult, TextDocumentContentChangeEvent, TextEdit, Url, + WorkspaceEdit, +}; +use tower_lsp::LanguageServer; +use tracing::{error, warn}; + +use crate::lang::lsp::LsProtoGroup; +use crate::lsp::capabilities::server::{ + collect_dynamic_registrations, collect_server_capabilities, +}; +use crate::server::commands::ServerCommands; +use crate::state::Owned; +use crate::{ide, Backend}; + +/// TODO: Remove when we move to sync world. +/// This is macro because of lifetimes problems with `self`. +macro_rules! state_mut_async { + ($state:ident, $this:ident, $($f:tt)+) => { + async { + let mut state = $this.state_mutex.lock().await; + let $state = &mut *state; + + $($f)+ + } + }; +} + +#[tower_lsp::async_trait] +impl LanguageServer for Backend { + #[tracing::instrument(level = "debug", skip_all)] + async fn initialize(&self, params: InitializeParams) -> LSPResult { + let client_capabilities = Owned::new(Arc::new(params.capabilities)); + let client_capabilities_snapshot = client_capabilities.snapshot(); + self.with_state_mut(move |state| { + state.client_capabilities = client_capabilities; + }) + .await; + + Ok(InitializeResult { + server_info: None, + capabilities: collect_server_capabilities(&client_capabilities_snapshot), + }) + } + + #[tracing::instrument(level = "debug", skip_all)] + async fn initialized(&self, _: InitializedParams) { + // Initialize the configuration. + self.reload_config().await; + + // Dynamically register capabilities. + let client_capabilities = self.state_snapshot().await.client_capabilities; + + let dynamic_registrations = collect_dynamic_registrations(&client_capabilities); + if !dynamic_registrations.is_empty() { + let result = self.client.register_capability(dynamic_registrations).await; + if let Err(err) = result { + warn!("failed to register dynamic capabilities: {err:#?}"); + } + } + } + + async fn shutdown(&self) -> LSPResult<()> { + Ok(()) + } + + #[tracing::instrument(level = "debug", skip_all)] + async fn did_change_configuration(&self, _: DidChangeConfigurationParams) { + self.reload_config().await; + } + + #[tracing::instrument(level = "debug", skip_all)] + async fn did_change_watched_files(&self, params: DidChangeWatchedFilesParams) { + // Invalidate changed cairo files. + self.with_state_mut(|state| { + for change in ¶ms.changes { + if is_cairo_file_path(&change.uri) { + let Some(file) = state.db.file_for_url(&change.uri) else { continue }; + PrivRawFileContentQuery + .in_db_mut(state.db.as_files_group_mut()) + .invalidate(&file); + } + } + }) + .await; + + // Reload workspace if a config file has changed. + for change in params.changes { + let changed_file_path = change.uri.to_file_path().unwrap_or_default(); + let changed_file_name = changed_file_path.file_name().unwrap_or_default(); + // TODO(pmagiera): react to Scarb.lock. Keep in mind Scarb does save Scarb.lock on each + // metadata call, so it is easy to fall in a loop here. + if ["Scarb.toml", "cairo_project.toml"].map(Some).contains(&changed_file_name.to_str()) + { + self.reload().await.ok(); + } + } + } + + #[tracing::instrument(level = "debug", skip_all, fields(command = params.command))] + async fn execute_command(&self, params: ExecuteCommandParams) -> LSPResult> { + let command = ServerCommands::try_from(params.command); + if let Ok(cmd) = command { + match cmd { + ServerCommands::Reload => { + self.reload().await?; + } + } + } + + match self.client.apply_edit(WorkspaceEdit::default()).await { + Ok(res) if res.applied => self.client.log_message(MessageType::INFO, "applied").await, + Ok(_) => self.client.log_message(MessageType::INFO, "rejected").await, + Err(err) => self.client.log_message(MessageType::ERROR, err).await, + } + + Ok(None) + } + + #[tracing::instrument(level = "debug", skip_all, fields(uri = %params.text_document.uri))] + async fn did_open(&self, params: DidOpenTextDocumentParams) { + let refresh = state_mut_async! {state, self, + let uri = params.text_document.uri; + + // Try to detect the crate for physical files. + // The crate for virtual files is already known. + if uri.scheme() == "file" { + let Ok(path) = uri.to_file_path() else { return false }; + self.detect_crate_for(&mut state.db, &state.config, path).await; + } + + let Some(file_id) = state.db.file_for_url(&uri) else { return false }; + state.open_files.insert(uri); + state.db.override_file_content(file_id, Some(params.text_document.text.into())); + + true + } + .await; + + if refresh { + self.refresh_diagnostics().await.ok(); + } + } + + #[tracing::instrument(level = "debug", skip_all, fields(uri = %params.text_document.uri))] + async fn did_change(&self, params: DidChangeTextDocumentParams) { + let text = if let Ok([TextDocumentContentChangeEvent { text, .. }]) = + TryInto::<[_; 1]>::try_into(params.content_changes) + { + text + } else { + error!("unexpected format of document change"); + return; + }; + let refresh = self + .with_state_mut(|state| { + let uri = params.text_document.uri; + let Some(file) = state.db.file_for_url(&uri) else { return false }; + state.db.override_file_content(file, Some(text.into())); + + true + }) + .await; + + if refresh { + self.refresh_diagnostics().await.ok(); + } + } + + #[tracing::instrument(level = "debug", skip_all, fields(uri = %params.text_document.uri))] + async fn did_save(&self, params: DidSaveTextDocumentParams) { + self.with_state_mut(|state| { + let Some(file) = state.db.file_for_url(¶ms.text_document.uri) else { return }; + PrivRawFileContentQuery.in_db_mut(state.db.as_files_group_mut()).invalidate(&file); + state.db.override_file_content(file, None); + }) + .await; + } + + #[tracing::instrument(level = "debug", skip_all, fields(uri = %params.text_document.uri))] + async fn did_close(&self, params: DidCloseTextDocumentParams) { + let refresh = self + .with_state_mut(|state| { + state.open_files.remove(¶ms.text_document.uri); + let Some(file) = state.db.file_for_url(¶ms.text_document.uri) else { + return false; + }; + state.db.override_file_content(file, None); + + true + }) + .await; + + if refresh { + self.refresh_diagnostics().await.ok(); + } + } + + #[tracing::instrument(level = "trace", skip_all)] + async fn completion(&self, params: CompletionParams) -> LSPResult> { + let db = self.db_snapshot().await; + self.catch_panics(move || ide::completion::complete(params, &db)).await + } + + #[tracing::instrument(level = "trace", skip_all)] + async fn semantic_tokens_full( + &self, + params: SemanticTokensParams, + ) -> LSPResult> { + let db = self.db_snapshot().await; + self.catch_panics(move || ide::semantic_highlighting::semantic_highlight_full(params, &db)) + .await + } + + #[tracing::instrument(level = "trace", skip_all)] + async fn formatting( + &self, + params: DocumentFormattingParams, + ) -> LSPResult>> { + let db = self.db_snapshot().await; + self.catch_panics(move || ide::formatter::format(params, &db)).await + } + + #[tracing::instrument(level = "trace", skip_all)] + async fn hover(&self, params: HoverParams) -> LSPResult> { + let db = self.db_snapshot().await; + self.catch_panics(move || ide::hover::hover(params, &db)).await + } + + #[tracing::instrument(level = "trace", skip_all)] + async fn goto_definition( + &self, + params: GotoDefinitionParams, + ) -> LSPResult> { + let db = self.db_snapshot().await; + self.catch_panics(move || ide::navigation::goto_definition::goto_definition(params, &db)) + .await + } + + #[tracing::instrument(level = "trace", skip_all)] + async fn code_action(&self, params: CodeActionParams) -> LSPResult> { + let db = self.db_snapshot().await; + self.catch_panics(move || ide::code_actions::code_actions(params, &db)).await + } +} + +fn is_cairo_file_path(file_path: &Url) -> bool { + file_path.path().ends_with(".cairo") +} diff --git a/crates/cairo-lang-language-server/src/lsp/ext.rs b/crates/cairo-lang-language-server/src/lsp/ext.rs index c20ea799374..84366833d4f 100644 --- a/crates/cairo-lang-language-server/src/lsp/ext.rs +++ b/crates/cairo-lang-language-server/src/lsp/ext.rs @@ -1,8 +1,28 @@ //! CairoLS extensions to the Language Server Protocol. +use serde::{Deserialize, Serialize}; use tower_lsp::lsp_types::notification::Notification; use tower_lsp::lsp_types::request::Request; -use tower_lsp::lsp_types::TextDocumentPositionParams; +use tower_lsp::lsp_types::{TextDocumentPositionParams, Url}; + +/// Provides content of virtual file from the database. +pub struct ProvideVirtualFile; + +#[derive(Debug, Eq, PartialEq, Clone, Deserialize, Serialize)] +pub struct ProvideVirtualFileRequest { + pub uri: Url, +} + +#[derive(Debug, Eq, PartialEq, Clone, Deserialize, Serialize)] +pub struct ProvideVirtualFileResponse { + pub content: Option, +} + +impl Request for ProvideVirtualFile { + type Params = ProvideVirtualFileRequest; + type Result = ProvideVirtualFileResponse; + const METHOD: &'static str = "vfs/provide"; +} /// Collects information about all Cairo crates that are currently being analyzed. pub struct ViewAnalyzedCrates; @@ -13,7 +33,7 @@ impl Request for ViewAnalyzedCrates { const METHOD: &'static str = "cairo/viewAnalyzedCrates"; } -/// Provides string with code after macros expansion +/// Provides string with code after macros expansion. pub struct ExpandMacro; impl Request for ExpandMacro { diff --git a/crates/cairo-lang-language-server/src/lsp/mod.rs b/crates/cairo-lang-language-server/src/lsp/mod.rs index 3772adc66c1..c2997dde46b 100644 --- a/crates/cairo-lang-language-server/src/lsp/mod.rs +++ b/crates/cairo-lang-language-server/src/lsp/mod.rs @@ -1,2 +1,3 @@ pub(crate) mod capabilities; +mod controller; pub mod ext; diff --git a/crates/cairo-lang-language-server/src/server/commands.rs b/crates/cairo-lang-language-server/src/server/commands.rs new file mode 100644 index 00000000000..3597ed7aa8b --- /dev/null +++ b/crates/cairo-lang-language-server/src/server/commands.rs @@ -0,0 +1,16 @@ +use anyhow::bail; + +pub enum ServerCommands { + Reload, +} + +impl TryFrom for ServerCommands { + type Error = anyhow::Error; + + fn try_from(value: String) -> anyhow::Result { + match value.as_str() { + "cairo.reload" => Ok(ServerCommands::Reload), + _ => bail!("Unrecognized command: {value}"), + } + } +} diff --git a/crates/cairo-lang-language-server/src/server/mod.rs b/crates/cairo-lang-language-server/src/server/mod.rs index dd1a82c5b08..a5e45b9efc8 100644 --- a/crates/cairo-lang-language-server/src/server/mod.rs +++ b/crates/cairo-lang-language-server/src/server/mod.rs @@ -1 +1,2 @@ +pub mod commands; pub mod notifier; diff --git a/crates/cairo-lang-language-server/src/vfs.rs b/crates/cairo-lang-language-server/src/vfs.rs deleted file mode 100644 index 89c3bffffd1..00000000000 --- a/crates/cairo-lang-language-server/src/vfs.rs +++ /dev/null @@ -1,12 +0,0 @@ -use serde::{Deserialize, Serialize}; -use tower_lsp::lsp_types::Url; - -#[derive(Debug, Eq, PartialEq, Clone, Deserialize, Serialize)] -pub struct ProvideVirtualFileRequest { - pub uri: Url, -} - -#[derive(Debug, Eq, PartialEq, Clone, Deserialize, Serialize)] -pub struct ProvideVirtualFileResponse { - pub content: Option, -}