1- namespace GraphQL . EntityFramework ;
1+ using System . Security . Claims ;
2+
3+ namespace GraphQL . EntityFramework ;
24
35#region FiltersSignature
46
57public class Filters
68{
7- public delegate bool Filter < in TEntity > ( object userContext , TEntity input )
9+ public delegate bool Filter < in TEntity > ( object userContext , ClaimsPrincipal ? userPrincipal , TEntity input )
810 where TEntity : class ;
911
10- public delegate Task < bool > AsyncFilter < in TEntity > ( object userContext , TEntity input )
12+ public delegate Task < bool > AsyncFilter < in TEntity > ( object userContext , ClaimsPrincipal ? userPrincipal , TEntity input )
1113 where TEntity : class ;
1214
1315 #endregion
1416
1517 public void Add < TEntity > ( Filter < TEntity > filter )
1618 where TEntity : class =>
1719 funcs [ typeof ( TEntity ) ] =
18- ( context , item ) =>
20+ ( userContext , userPrincipal , item ) =>
1921 {
2022 try
2123 {
22- return Task . FromResult ( filter ( context , ( TEntity ) item ) ) ;
24+ return Task . FromResult ( filter ( userContext , userPrincipal , ( TEntity ) item ) ) ;
2325 }
2426 catch ( Exception exception )
2527 {
@@ -30,23 +32,23 @@ public void Add<TEntity>(Filter<TEntity> filter)
3032 public void Add < TEntity > ( AsyncFilter < TEntity > filter )
3133 where TEntity : class =>
3234 funcs [ typeof ( TEntity ) ] =
33- async ( context , item ) =>
35+ async ( userContext , userPrincipal , item ) =>
3436 {
3537 try
3638 {
37- return await filter ( context , ( TEntity ) item ) ;
39+ return await filter ( userContext , userPrincipal , ( TEntity ) item ) ;
3840 }
3941 catch ( Exception exception )
4042 {
4143 throw new ( $ "Failed to execute filter. { nameof ( TEntity ) } : { typeof ( TEntity ) } .", exception ) ;
4244 }
4345 } ;
4446
45- delegate Task < bool > Filter ( object userContext , object input ) ;
47+ delegate Task < bool > Filter ( object userContext , ClaimsPrincipal ? userPrincipal , object input ) ;
4648
4749 Dictionary < Type , Filter > funcs = new ( ) ;
4850
49- internal virtual async Task < IEnumerable < TEntity > > ApplyFilter < TEntity > ( IEnumerable < TEntity > result , object userContext )
51+ internal virtual async Task < IEnumerable < TEntity > > ApplyFilter < TEntity > ( IEnumerable < TEntity > result , object userContext , ClaimsPrincipal ? userPrincipal )
5052 where TEntity : class
5153 {
5254 if ( funcs . Count == 0 )
@@ -63,7 +65,7 @@ internal virtual async Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerab
6365 var list = new List < TEntity > ( ) ;
6466 foreach ( var item in result )
6567 {
66- if ( await ShouldInclude ( userContext , item , filters ) )
68+ if ( await ShouldInclude ( userContext , userPrincipal , item , filters ) )
6769 {
6870 list . Add ( item ) ;
6971 }
@@ -72,12 +74,12 @@ internal virtual async Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerab
7274 return list ;
7375 }
7476
75- static async Task < bool > ShouldInclude < TEntity > ( object userContext , TEntity item , List < AsyncFilter < TEntity > > filters )
77+ static async Task < bool > ShouldInclude < TEntity > ( object userContext , ClaimsPrincipal ? userPrincipal , TEntity item , List < AsyncFilter < TEntity > > filters )
7678 where TEntity : class
7779 {
7880 foreach ( var func in filters )
7981 {
80- if ( ! await func ( userContext , item ) )
82+ if ( ! await func ( userContext , userPrincipal , item ) )
8183 {
8284 return false ;
8385 }
@@ -86,7 +88,7 @@ static async Task<bool> ShouldInclude<TEntity>(object userContext, TEntity item,
8688 return true ;
8789 }
8890
89- internal virtual async Task < bool > ShouldInclude < TEntity > ( object userContext , TEntity ? item )
91+ internal virtual async Task < bool > ShouldInclude < TEntity > ( object userContext , ClaimsPrincipal ? userPrincipal , TEntity ? item )
9092 where TEntity : class
9193 {
9294 if ( item is null )
@@ -101,7 +103,7 @@ internal virtual async Task<bool> ShouldInclude<TEntity>(object userContext, TEn
101103
102104 foreach ( var func in FindFilters < TEntity > ( ) )
103105 {
104- if ( ! await func ( userContext , item ) )
106+ if ( ! await func ( userContext , userPrincipal , item ) )
105107 {
106108 return false ;
107109 }
@@ -116,7 +118,7 @@ IEnumerable<AsyncFilter<TEntity>> FindFilters<TEntity>()
116118 var type = typeof ( TEntity ) ;
117119 foreach ( var pair in funcs . Where ( x => x . Key . IsAssignableFrom ( type ) ) )
118120 {
119- yield return ( context , item ) => pair . Value ( context , item ) ;
121+ yield return ( userContext , userPrincipal , item ) => pair . Value ( userContext , userPrincipal , item ) ;
120122 }
121123 }
122124}
0 commit comments