Skip to content

Commit e49738b

Browse files
authored
[mlir][lsp] Enable registering dialects based on URI. (#141331)
Previously the dialects registered were fixed per LSP binary. This works as long as all the dialects of interest from the different projects across which one uses the LSP, are disjoint. This expands this to support cases where there are dialects that overlap in dialect name but usage of these are separate wrt projects. The alternative is multiple binaries and switching LSP used in editor per project (there is some extra complexity in hosted instances). This handles a simple (I believe common case) where one can determine based on path and have single binary - the cost of dynamically doing so based on path would be either keeping different registries to return or repopulating dialect & extension maps.
1 parent f6c2ec2 commit e49738b

File tree

8 files changed

+143
-21
lines changed

8 files changed

+143
-21
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//===- MlirLspRegistryFunction.h - LSP registry functions -------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Registry function types for MLIR LSP.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_TOOLS_MLIR_LSP_SERVER_MLIRLSPREGISTRYFUNCTION_H
14+
#define MLIR_TOOLS_MLIR_LSP_SERVER_MLIRLSPREGISTRYFUNCTION_H
15+
16+
namespace llvm {
17+
template <typename Fn>
18+
class function_ref;
19+
} // namespace llvm
20+
21+
namespace mlir {
22+
class DialectRegistry;
23+
namespace lsp {
24+
class URIForFile;
25+
using DialectRegistryFn =
26+
llvm::function_ref<DialectRegistry &(const URIForFile &uri)>;
27+
} // namespace lsp
28+
} // namespace mlir
29+
30+
#endif // MLIR_TOOLS_MLIR_LSP_SERVER_MLIRLSPREGISTRYFUNCTION_H

mlir/include/mlir/Tools/mlir-lsp-server/MlirLspServerMain.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,27 @@
1212

1313
#ifndef MLIR_TOOLS_MLIR_LSP_SERVER_MLIRLSPSERVERMAIN_H
1414
#define MLIR_TOOLS_MLIR_LSP_SERVER_MLIRLSPSERVERMAIN_H
15+
#include "mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h"
1516

1617
namespace llvm {
1718
struct LogicalResult;
1819
} // namespace llvm
1920

2021
namespace mlir {
21-
class DialectRegistry;
2222

2323
/// Implementation for tools like `mlir-lsp-server`.
2424
/// - registry should contain all the dialects that can be parsed in source IR
25-
/// passed to the server.
25+
/// passed to the server.
2626
llvm::LogicalResult MlirLspServerMain(int argc, char **argv,
2727
DialectRegistry &registry);
2828

29+
/// Implementation for tools like `mlir-lsp-server`.
30+
/// - registry should contain all the dialects that can be parsed in source IR
31+
/// passed to the server and may register different dialects depending on the
32+
/// input URI.
33+
llvm::LogicalResult MlirLspServerMain(int argc, char **argv,
34+
lsp::DialectRegistryFn registry_fn);
35+
2936
} // namespace mlir
3037

3138
#endif // MLIR_TOOLS_MLIR_LSP_SERVER_MLIRLSPSERVERMAIN_H

mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -997,7 +997,7 @@ namespace {
997997
class MLIRTextFile {
998998
public:
999999
MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1000-
int64_t version, DialectRegistry &registry,
1000+
int64_t version, lsp::DialectRegistryFn registry_fn,
10011001
std::vector<lsp::Diagnostic> &diagnostics);
10021002

10031003
/// Return the current version of this text file.
@@ -1046,9 +1046,9 @@ class MLIRTextFile {
10461046
} // namespace
10471047

10481048
MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1049-
int64_t version, DialectRegistry &registry,
1049+
int64_t version, lsp::DialectRegistryFn registry_fn,
10501050
std::vector<lsp::Diagnostic> &diagnostics)
1051-
: context(registry, MLIRContext::Threading::DISABLED),
1051+
: context(registry_fn(uri), MLIRContext::Threading::DISABLED),
10521052
contents(fileContents.str()), version(version) {
10531053
context.allowUnregisteredDialects();
10541054

@@ -1263,11 +1263,11 @@ MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
12631263
//===----------------------------------------------------------------------===//
12641264

12651265
struct lsp::MLIRServer::Impl {
1266-
Impl(DialectRegistry &registry) : registry(registry) {}
1266+
Impl(lsp::DialectRegistryFn registry_fn) : registry_fn(registry_fn) {}
12671267

1268-
/// The registry containing dialects that can be recognized in parsed .mlir
1269-
/// files.
1270-
DialectRegistry &registry;
1268+
/// The registry factory for containing dialects that can be recognized in
1269+
/// parsed .mlir files.
1270+
lsp::DialectRegistryFn registry_fn;
12711271

12721272
/// The files held by the server, mapped by their URI file name.
12731273
llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
@@ -1277,15 +1277,15 @@ struct lsp::MLIRServer::Impl {
12771277
// MLIRServer
12781278
//===----------------------------------------------------------------------===//
12791279

1280-
lsp::MLIRServer::MLIRServer(DialectRegistry &registry)
1281-
: impl(std::make_unique<Impl>(registry)) {}
1280+
lsp::MLIRServer::MLIRServer(lsp::DialectRegistryFn registry_fn)
1281+
: impl(std::make_unique<Impl>(registry_fn)) {}
12821282
lsp::MLIRServer::~MLIRServer() = default;
12831283

12841284
void lsp::MLIRServer::addOrUpdateDocument(
12851285
const URIForFile &uri, StringRef contents, int64_t version,
12861286
std::vector<Diagnostic> &diagnostics) {
12871287
impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
1288-
uri, contents, version, impl->registry, diagnostics);
1288+
uri, contents, version, impl->registry_fn, diagnostics);
12891289
}
12901290

12911291
std::optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
@@ -1348,7 +1348,7 @@ void lsp::MLIRServer::getCodeActions(const URIForFile &uri, const Range &pos,
13481348

13491349
llvm::Expected<lsp::MLIRConvertBytecodeResult>
13501350
lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) {
1351-
MLIRContext tempContext(impl->registry);
1351+
MLIRContext tempContext(impl->registry_fn(uri));
13521352
tempContext.allowUnregisteredDialects();
13531353

13541354
// Collect any errors during parsing.

mlir/lib/Tools/mlir-lsp-server/MLIRServer.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define LIB_MLIR_TOOLS_MLIRLSPSERVER_SERVER_H_
1111

1212
#include "mlir/Support/LLVM.h"
13+
#include "mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h"
1314
#include "llvm/Support/Error.h"
1415
#include <memory>
1516
#include <optional>
@@ -28,15 +29,14 @@ struct Location;
2829
struct MLIRConvertBytecodeResult;
2930
struct Position;
3031
struct Range;
31-
class URIForFile;
3232

3333
/// This class implements all of the MLIR related functionality necessary for a
3434
/// language server. This class allows for keeping the MLIR specific logic
3535
/// separate from the logic that involves LSP server/client communication.
3636
class MLIRServer {
3737
public:
38-
/// Construct a new server with the given dialect regitstry.
39-
MLIRServer(DialectRegistry &registry);
38+
/// Construct a new server with the given dialect registry function.
39+
MLIRServer(DialectRegistryFn registry_fn);
4040
~MLIRServer();
4141

4242
/// Add or update the document, with the provided `version`, at the given URI.

mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ using namespace mlir;
1919
using namespace mlir::lsp;
2020

2121
LogicalResult mlir::MlirLspServerMain(int argc, char **argv,
22-
DialectRegistry &registry) {
22+
DialectRegistryFn registry_fn) {
2323
llvm::cl::opt<JSONStreamStyle> inputStyle{
2424
"input-style",
2525
llvm::cl::desc("Input JSON stream encoding"),
@@ -72,6 +72,15 @@ LogicalResult mlir::MlirLspServerMain(int argc, char **argv,
7272
URIForFile::registerSupportedScheme("mlir.bytecode-mlir");
7373

7474
// Configure the servers and start the main language server.
75-
MLIRServer server(registry);
75+
MLIRServer server(registry_fn);
7676
return runMlirLSPServer(server, transport);
7777
}
78+
79+
llvm::LogicalResult mlir::MlirLspServerMain(int argc, char **argv,
80+
DialectRegistry &registry) {
81+
auto registry_fn =
82+
[&registry](const lsp::URIForFile &uri) -> DialectRegistry & {
83+
return registry;
84+
};
85+
return MlirLspServerMain(argc, argv, registry_fn);
86+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: not mlir-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s
2+
{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"mlir","capabilities":{},"trace":"off"}}
3+
// -----
4+
// Just regular parse, successful.
5+
{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{
6+
"uri":"test:///foo-regular-registration.mlir",
7+
"languageId":"mlir",
8+
"version":1,
9+
"text":"func.func @fail_with_empty_registry() { return }"
10+
}}}
11+
// CHECK: "method": "textDocument/publishDiagnostics",
12+
// CHECK: "diagnostics": []
13+
// -----
14+
// Just regular parse, successful.
15+
{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{
16+
"uri":"test:///foo-disable-lsp-registration.mlir",
17+
"languageId":"mlir",
18+
"version":1,
19+
"text":"func.func @fail_with_empty_registry() { return }"
20+
}}}
21+
// CHECK: "method": "textDocument/publishDiagnostics",
22+
// CHECK: "message": "Dialect `func' not found for custom op 'func.func'
23+

mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "mlir/IR/Dialect.h"
109
#include "mlir/IR/MLIRContext.h"
1110
#include "mlir/InitAllDialects.h"
1211
#include "mlir/InitAllExtensions.h"
12+
#include "mlir/Tools/lsp-server-support/Protocol.h"
1313
#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h"
1414

1515
using namespace mlir;
@@ -23,7 +23,7 @@ void registerTestTransformDialectExtension(DialectRegistry &);
2323
#endif
2424

2525
int main(int argc, char **argv) {
26-
DialectRegistry registry;
26+
DialectRegistry registry, empty;
2727
registerAllDialects(registry);
2828
registerAllExtensions(registry);
2929

@@ -32,5 +32,18 @@ int main(int argc, char **argv) {
3232
::test::registerTestTransformDialectExtension(registry);
3333
::test::registerTestDynDialect(registry);
3434
#endif
35-
return failed(MlirLspServerMain(argc, argv, registry));
35+
36+
// Returns the registry, except in testing mode when the URI contains
37+
// "-disable-lsp-registration". Testing for/example of registering dialects
38+
// based on URI.
39+
auto registryFn = [&registry,
40+
&empty](const lsp::URIForFile &uri) -> DialectRegistry & {
41+
(void)empty;
42+
#ifdef MLIR_INCLUDE_TESTS
43+
if (uri.uri().contains("-disable-lsp-registration"))
44+
return empty;
45+
#endif
46+
return registry;
47+
};
48+
return failed(MlirLspServerMain(argc, argv, registryFn));
3649
}

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8904,11 +8904,51 @@ cc_binary(
89048904
name = "mlir-lsp-server",
89058905
srcs = ["tools/mlir-lsp-server/mlir-lsp-server.cpp"],
89068906
includes = ["include"],
8907+
local_defines = ["MLIR_INCLUDE_TESTS"],
89078908
deps = [
89088909
":AllExtensions",
89098910
":AllPassesAndDialects",
89108911
":IR",
89118912
":MlirLspServerLib",
8913+
":MlirLspServerSupportLib",
8914+
"//mlir/test:TestAffine",
8915+
"//mlir/test:TestAnalysis",
8916+
"//mlir/test:TestArith",
8917+
"//mlir/test:TestArmNeon",
8918+
"//mlir/test:TestArmSME",
8919+
"//mlir/test:TestBufferization",
8920+
"//mlir/test:TestControlFlow",
8921+
"//mlir/test:TestConvertToSPIRV",
8922+
"//mlir/test:TestDLTI",
8923+
"//mlir/test:TestDialect",
8924+
"//mlir/test:TestFunc",
8925+
"//mlir/test:TestFuncToLLVM",
8926+
"//mlir/test:TestGPU",
8927+
"//mlir/test:TestIR",
8928+
"//mlir/test:TestLLVM",
8929+
"//mlir/test:TestLinalg",
8930+
"//mlir/test:TestLoopLikeInterface",
8931+
"//mlir/test:TestMath",
8932+
"//mlir/test:TestMathToVCIX",
8933+
"//mlir/test:TestMemRef",
8934+
"//mlir/test:TestMesh",
8935+
"//mlir/test:TestNVGPU",
8936+
"//mlir/test:TestPDLL",
8937+
"//mlir/test:TestPass",
8938+
"//mlir/test:TestReducer",
8939+
"//mlir/test:TestRewrite",
8940+
"//mlir/test:TestSCF",
8941+
"//mlir/test:TestSPIRV",
8942+
"//mlir/test:TestShapeDialect",
8943+
"//mlir/test:TestTensor",
8944+
"//mlir/test:TestTestDynDialect",
8945+
"//mlir/test:TestTilingInterface",
8946+
"//mlir/test:TestTosaDialect",
8947+
"//mlir/test:TestTransformDialect",
8948+
"//mlir/test:TestTransforms",
8949+
"//mlir/test:TestVector",
8950+
"//mlir/test:TestVectorToSPIRV",
8951+
"//mlir/test:TestXeGPU",
89128952
],
89138953
)
89148954

0 commit comments

Comments
 (0)