Skip to content

Support IAuthorizationRequirementData in the implementation-first approach #8303

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,15 @@ namespace HotChocolate.AspNetCore.Authorization;

internal sealed class AuthorizationPolicyCache
{
private readonly ConcurrentDictionary<string, AuthorizationPolicy> _cache = new();
private readonly ConcurrentDictionary<AuthorizeDirective, AuthorizationPolicy> _cache = new();
Copy link
Preview

Copilot AI May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using AuthorizeDirective as the dictionary key may lead to cache misses if two different instances have equivalent values. Consider overriding Equals and GetHashCode on AuthorizeDirective or reverting to a value-based cache key to ensure consistent behavior.

Copilot uses AI. Check for mistakes.

Copy link
Collaborator Author

@sunghwan2789 sunghwan2789 May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is okay because the instance is at least bound to a field or type, and has reference-equality across requests.

This is needed for a directive that have no policy, no roles, but metadata.

AuthorizationRequirementDataTests.Multiple_NotAuthorized is a test...


public AuthorizationPolicy? LookupPolicy(AuthorizeDirective directive)
{
var cacheKey = directive.GetPolicyCacheKey();

return _cache.GetValueOrDefault(cacheKey);
return _cache.GetValueOrDefault(directive);
}

public void CachePolicy(AuthorizeDirective directive, AuthorizationPolicy policy)
{
var cacheKey = directive.GetPolicyCacheKey();

_cache.TryAdd(cacheKey, policy);
_cache.TryAdd(directive, policy);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,10 @@ private async ValueTask<AuthorizeResult> AuthorizeAsync(

if (authorizationPolicy is null)
{
authorizationPolicy = await BuildAuthorizationPolicy(directive.Policy, directive.Roles);
authorizationPolicy = await BuildAuthorizationPolicy(
directive.Policy,
directive.Roles,
directive.Metadata);

if (_canCachePolicies)
{
Expand All @@ -165,7 +168,8 @@ private async ValueTask<AuthorizeResult> AuthorizeAsync(

private async Task<AuthorizationPolicy> BuildAuthorizationPolicy(
string? policyName,
IReadOnlyList<string>? roles)
IReadOnlyList<string>? roles,
IReadOnlyList<object>? metadata)
{
var policyBuilder = new AuthorizationPolicyBuilder();

Expand Down Expand Up @@ -194,6 +198,21 @@ private async Task<AuthorizationPolicy> BuildAuthorizationPolicy(
policyBuilder = policyBuilder.RequireRole(roles);
}

var requirementData = metadata?.OfType<IAuthorizationRequirementData>()?.ToList() ?? [];
if (requirementData.Count > 0)
{
var reqPolicy = new AuthorizationPolicyBuilder();
foreach (var rd in requirementData)
{
foreach (var r in rd.GetRequirements())
{
reqPolicy.AddRequirements(r);
}
}

policyBuilder = policyBuilder.Combine(reqPolicy.Build());
}

return policyBuilder.Build();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
using System.Net;
using System.Reflection;
using System.Security.Claims;
using HotChocolate.AspNetCore.Tests.Utilities;
using HotChocolate.Execution.Configuration;
using HotChocolate.Types;
using HotChocolate.Types.Descriptors;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection;

namespace HotChocolate.AspNetCore.Authorization;

public class AuthorizationRequirementDataTests(TestServerFactory serverFactory) : ServerTestBase(serverFactory)
{
[Fact]
public async Task Authorized()
{
// arrange
var server = CreateTestServer(
builder =>
{
builder.Services.AddTransient<IAuthorizationHandler, CustomAuthorizationHandler>();
builder
.AddQueryType<Query>()
.AddAuthorization();
},
context =>
{
var identity = new ClaimsIdentity("testauth");
identity.AddClaim(new Claim(
"foo",
"bar"));
context.User = new ClaimsPrincipal(identity);
});

// act
var result = await server.PostAsync(new ClientQueryRequest { Query = "{ foo }" });

// assert
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
Assert.Null(result.Errors);
}

[Fact]
public async Task NotAuthorized()
{
// arrange
var server = CreateTestServer(
builder =>
{
builder.Services.AddTransient<IAuthorizationHandler, CustomAuthorizationHandler>();
builder
.AddQueryType<Query>()
.AddAuthorization();
},
context =>
{
var identity = new ClaimsIdentity("testauth");
identity.AddClaim(new Claim(
"foo",
"foo"));
context.User = new ClaimsPrincipal(identity);
});

// act
var result = await server.PostAsync(new ClientQueryRequest { Query = "{ foo }" });

// assert
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
result.MatchSnapshot();
}

[Fact]
public async Task Multiple_Authorized()
{
// arrange
var server = CreateTestServer(
builder =>
{
builder.Services.AddTransient<IAuthorizationHandler, CustomAuthorizationHandler>();
builder
.AddQueryType<Query>()
.AddAuthorization();
},
context =>
{
var identity = new ClaimsIdentity("testauth");
identity.AddClaim(new Claim(
"foo",
"bar"));
identity.AddClaim(new Claim(
"bar",
"baz"));
context.User = new ClaimsPrincipal(identity);
});

// act
var result = await server.PostAsync(new ClientQueryRequest { Query = "{ fooMultiple }" });

// assert
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
Assert.Null(result.Errors);
}

[Fact]
public async Task Multiple_NotAuthorized()
{
// arrange
var server = CreateTestServer(
builder =>
{
builder.Services.AddTransient<IAuthorizationHandler, CustomAuthorizationHandler>();
builder
.AddQueryType<Query>()
.AddAuthorization();
},
context =>
{
var identity = new ClaimsIdentity("testauth");
identity.AddClaim(new Claim(
"foo",
"bar"));
identity.AddClaim(new Claim(
"bar",
"bar"));
context.User = new ClaimsPrincipal(identity);
});

// act
var result = await server.PostAsync(new ClientQueryRequest { Query = "{ fooMultiple }" });

// assert
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
result.MatchSnapshot();
}

public class Query
{
[CustomAuthorize("foo", "bar", Apply = HotChocolate.Authorization.ApplyPolicy.BeforeResolver)]
public string? GetFoo() => "foo";

[CustomAuthorize("foo", "bar", Apply = HotChocolate.Authorization.ApplyPolicy.BeforeResolver)]
[CustomAuthorize("bar", "baz", Apply = HotChocolate.Authorization.ApplyPolicy.BeforeResolver)]
public string? GetFooMultiple() => "foo";
}

private class CustomAuthorizeAttribute(string type, string value)
: HotChocolate.Authorization.AuthorizeAttribute,
IAuthorizationRequirement,
IAuthorizationRequirementData
{
public string Type => type;

public string Value => value;

public IEnumerable<IAuthorizationRequirement> GetRequirements()
{
yield return this;
}

protected internal override void TryConfigure(
IDescriptorContext context,
IDescriptor descriptor,
ICustomAttributeProvider element)
{
if (descriptor is IObjectTypeDescriptor type)
{
type.Directive(CreateDirective());
}
else if (descriptor is IObjectFieldDescriptor field)
{
field.Directive(CreateDirective());
}
}

private HotChocolate.Authorization.AuthorizeDirective CreateDirective()
{
return new HotChocolate.Authorization.AuthorizeDirective(metadata: [this]);
}
}

private class CustomAuthorizationHandler : AuthorizationHandler<CustomAuthorizeAttribute>
{
protected override Task HandleRequirementAsync(
AuthorizationHandlerContext context,
CustomAuthorizeAttribute requirement)
{
if (context.User.HasClaim(requirement.Type, requirement.Value))
{
context.Succeed(requirement);
}
return Task.CompletedTask;
}
}

private TestServer CreateTestServer(
Action<IRequestExecutorBuilder> build,
Action<HttpContext> configureUser)
{
return ServerFactory.Create(
services =>
{
build(services
.AddRouting()
.AddGraphQLServer()
.AddHttpRequestInterceptor(
(context, requestExecutor, requestBuilder, cancellationToken) =>
{
configureUser(context);
return default;
}));
},
app =>
{
app.UseRouting();
app.UseEndpoints(b => b.MapGraphQL());
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,104 +77,6 @@ public void CreateInstance_Roles_RolesHasItems()
t => Assert.Equal("b", t));
}

[Fact]
public void CacheKey_Policy_NoRoles()
{
// arrange
var authorizeDirective = new AuthorizeDirective(policy: "policy");

// act
var cacheKey = authorizeDirective.GetPolicyCacheKey();

// assert
Assert.Equal("policy;", cacheKey);
}

[Fact]
public void CacheKey_NoPolicy_Roles()
{
// arrange
var authorizeDirective = new AuthorizeDirective(
policy: null,
roles: ["a", "b"]);

// act
var cacheKey = authorizeDirective.GetPolicyCacheKey();

// assert
Assert.Equal(";a,b", cacheKey);
}

[Fact]
public void CacheKey_Policy_And_Roles()
{
// arrange
var authorizeDirective = new AuthorizeDirective(
policy: "policy",
roles: ["a", "b"]);

// act
var cacheKey = authorizeDirective.GetPolicyCacheKey();

// assert
Assert.Equal("policy;a,b", cacheKey);
}

[Fact]
public void CacheKey_NoPolicy_NoRoles()
{
// arrange
var authorizeDirective = new AuthorizeDirective(
policy: null,
roles: null);

// act
var cacheKey = authorizeDirective.GetPolicyCacheKey();

// assert
Assert.Equal("", cacheKey);
}

[Fact]
public void CacheKey_Policy_And_Role_Naming_Does_Not_Conflict()
{
// arrange
var authorizeDirective1 = new AuthorizeDirective(
policy: "policy",
roles: null);

var authorizeDirective2 = new AuthorizeDirective(
policy: null,
roles: ["policy"]);

// act
var cacheKey1 = authorizeDirective1.GetPolicyCacheKey();
var cacheKey2 = authorizeDirective2.GetPolicyCacheKey();

// assert
Assert.NotEqual(cacheKey1, cacheKey2);
}

[Fact]
public void CacheKey_Same_Roles_Albeit_Sorted_Differently_Have_Same_Cache_Key()
{
// arrange
var authorizeDirective1 = new AuthorizeDirective(
policy: null,
roles: ["a", "c", "b"]);

var authorizeDirective2 = new AuthorizeDirective(
policy: null,
roles: ["c", "b", "a"]);

// act
var cacheKey1 = authorizeDirective1.GetPolicyCacheKey();
var cacheKey2 = authorizeDirective2.GetPolicyCacheKey();

// assert
Assert.Equal(cacheKey1, cacheKey2);
}

[Fact]
public void TypeAuth_DefaultPolicy()
{
Expand Down
Loading
Loading