Skip to content

Commit 6a31dd5

Browse files
authored
Use chatId from URL rather than from payload for chats (microsoft#700)
### Motivation and Context The verify access to a chat, we use HandleRequest() with the chatId provided. Currently, we get this from the payload, which can differ from the chatId from the URL, which opens us to a security problem where a user could inject an arbitrary chatId in the payload, which doesn't match what's in the URL. ### Description - Use chatId from URL and only from URL - Add integrations test to validate this ### Contribution Checklist - [ ] The code builds clean without any errors or warnings - [ ] The PR follows the [Contribution Guidelines](https://github.com/microsoft/chat-copilot/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/chat-copilot/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [ ] All unit tests pass, and I have added new tests where possible - [ ] I didn't break anyone 😄
1 parent 8e8e284 commit 6a31dd5

File tree

5 files changed

+72
-53
lines changed

5 files changed

+72
-53
lines changed

integration-tests/ChatTests.cs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
3+
using System.Collections.Generic;
4+
using System.Net.Http;
5+
using System.Net.Http.Json;
6+
using System.Text.Json;
7+
using CopilotChat.WebApi.Models.Request;
8+
using CopilotChat.WebApi.Models.Response;
9+
using Xunit;
10+
using static CopilotChat.WebApi.Models.Storage.CopilotChatMessage;
11+
12+
namespace ChatCopilotIntegrationTests;
13+
14+
public class ChatTests : ChatCopilotIntegrationTest
15+
{
16+
[Fact]
17+
public async void ChatMessagePostSucceedsWithValidInput()
18+
{
19+
await this.SetUpAuth();
20+
21+
// Create chat session
22+
var createChatParams = new CreateChatParameters() { Title = nameof(ChatMessagePostSucceedsWithValidInput) };
23+
HttpResponseMessage response = await this._httpClient.PostAsJsonAsync("chats", createChatParams);
24+
response.EnsureSuccessStatusCode();
25+
26+
var contentStream = await response.Content.ReadAsStreamAsync();
27+
var createChatResponse = await JsonSerializer.DeserializeAsync<CreateChatResponse>(contentStream, new JsonSerializerOptions { PropertyNameCaseInsensitive = true });
28+
Assert.NotNull(createChatResponse);
29+
30+
// Ask something to the bot
31+
var ask = new Ask
32+
{
33+
Input = "Who is Satya Nadella?",
34+
Variables = new KeyValuePair<string, string>[] { new("MessageType", ChatMessageType.Message.ToString()) }
35+
};
36+
response = await this._httpClient.PostAsJsonAsync($"chats/{createChatResponse.ChatSession.Id}/messages", ask);
37+
response.EnsureSuccessStatusCode();
38+
39+
contentStream = await response.Content.ReadAsStreamAsync();
40+
var askResult = await JsonSerializer.DeserializeAsync<AskResult>(contentStream, new JsonSerializerOptions { PropertyNameCaseInsensitive = true });
41+
Assert.NotNull(askResult);
42+
Assert.False(string.IsNullOrEmpty(askResult.Value));
43+
44+
45+
// Clean up
46+
response = await this._httpClient.DeleteAsync($"chats/{createChatResponse.ChatSession.Id}");
47+
response.EnsureSuccessStatusCode();
48+
}
49+
}
50+

webapi/Controllers/ChatController.cs

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ public async Task<IActionResult> ChatAsync(
9999
[FromServices] IKernel kernel,
100100
[FromServices] IHubContext<MessageRelayHub> messageRelayHubContext,
101101
[FromServices] CopilotChatPlanner planner,
102-
[FromServices] AskConverter askConverter,
103102
[FromServices] ChatSessionRepository chatSessionRepository,
104103
[FromServices] ChatParticipantRepository chatParticipantRepository,
105104
[FromServices] IAuthInfo authInfo,
@@ -108,7 +107,7 @@ public async Task<IActionResult> ChatAsync(
108107
{
109108
this._logger.LogDebug("Chat message received.");
110109

111-
return await this.HandleRequest(ChatFunctionName, kernel, messageRelayHubContext, planner, askConverter, chatSessionRepository, chatParticipantRepository, authInfo, ask, chatId.ToString());
110+
return await this.HandleRequest(ChatFunctionName, kernel, messageRelayHubContext, planner, chatSessionRepository, chatParticipantRepository, authInfo, ask, chatId.ToString());
112111
}
113112

114113
/// <summary>
@@ -135,7 +134,6 @@ public async Task<IActionResult> ProcessPlanAsync(
135134
[FromServices] IKernel kernel,
136135
[FromServices] IHubContext<MessageRelayHub> messageRelayHubContext,
137136
[FromServices] CopilotChatPlanner planner,
138-
[FromServices] AskConverter askConverter,
139137
[FromServices] ChatSessionRepository chatSessionRepository,
140138
[FromServices] ChatParticipantRepository chatParticipantRepository,
141139
[FromServices] IAuthInfo authInfo,
@@ -144,7 +142,7 @@ public async Task<IActionResult> ProcessPlanAsync(
144142
{
145143
this._logger.LogDebug("plan request received.");
146144

147-
return await this.HandleRequest(ProcessPlanFunctionName, kernel, messageRelayHubContext, planner, askConverter, chatSessionRepository, chatParticipantRepository, authInfo, ask, chatId.ToString());
145+
return await this.HandleRequest(ProcessPlanFunctionName, kernel, messageRelayHubContext, planner, chatSessionRepository, chatParticipantRepository, authInfo, ask, chatId.ToString());
148146
}
149147

150148
/// <summary>
@@ -166,15 +164,14 @@ private async Task<IActionResult> HandleRequest(
166164
IKernel kernel,
167165
IHubContext<MessageRelayHub> messageRelayHubContext,
168166
CopilotChatPlanner planner,
169-
AskConverter askConverter,
170167
ChatSessionRepository chatSessionRepository,
171168
ChatParticipantRepository chatParticipantRepository,
172169
IAuthInfo authInfo,
173170
Ask ask,
174171
string chatId)
175172
{
176173
// Put ask's variables in the context we will use.
177-
var contextVariables = askConverter.GetContextVariables(ask);
174+
var contextVariables = GetContextVariables(ask, authInfo, chatId);
178175

179176
// Verify that the chat exists and that the user has access to it.
180177
ChatSession? chat = null;
@@ -415,6 +412,25 @@ await planner.Kernel.ImportOpenAIPluginFunctionsAsync(
415412
return;
416413
}
417414

415+
private static ContextVariables GetContextVariables(Ask ask, IAuthInfo authInfo, string chatId)
416+
{
417+
const string UserIdKey = "userId";
418+
const string UserNameKey = "userName";
419+
const string ChatIdKey = "chatId";
420+
421+
var contextVariables = new ContextVariables(ask.Input);
422+
foreach (var variable in ask.Variables)
423+
{
424+
contextVariables.Set(variable.Key, variable.Value);
425+
}
426+
427+
contextVariables.Set(UserIdKey, authInfo.UserId);
428+
contextVariables.Set(UserNameKey, authInfo.Name);
429+
contextVariables.Set(ChatIdKey, chatId);
430+
431+
return contextVariables;
432+
}
433+
418434
/// <summary>
419435
/// Dispose of the object.
420436
/// </summary>

webapi/Extensions/ServiceExtensions.cs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,6 @@ internal static void AddOptions<TOptions>(this IServiceCollection services, ICon
8181
.PostConfigure(TrimStringProperties);
8282
}
8383

84-
internal static IServiceCollection AddUtilities(this IServiceCollection services)
85-
{
86-
return services.AddScoped<AskConverter>();
87-
}
88-
8984
internal static IServiceCollection AddPlugins(this IServiceCollection services, IConfiguration configuration)
9085
{
9186
var plugins = configuration.GetSection("Plugins").Get<List<Plugin>>() ?? new List<Plugin>();

webapi/Program.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ public static async Task Main(string[] args)
4545
.AddOptions(builder.Configuration)
4646
.AddPersistentChatStore()
4747
.AddPlugins(builder.Configuration)
48-
.AddUtilities()
4948
.AddChatCopilotAuthentication(builder.Configuration)
5049
.AddChatCopilotAuthorization();
5150

webapi/Utilities/AskConverter.cs

Lines changed: 0 additions & 41 deletions
This file was deleted.

0 commit comments

Comments
 (0)