Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/Access/Common/RowPolicyDefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,35 @@ const RowPolicyFilterTypeInfo & RowPolicyFilterTypeInfo::get(RowPolicyFilterType
throw Exception("Unknown type: " + std::to_string(static_cast<size_t>(type_)), ErrorCodes::LOGICAL_ERROR);
}

String toString(RowPolicyKind type)
{
return RowPolicyKindInfo::get(type).raw_name;
}

const RowPolicyKindInfo & RowPolicyKindInfo::get(RowPolicyKind kind_)
{
static constexpr auto make_info = [](const char * raw_name_)
{
String init_name = raw_name_;
boost::to_lower(init_name);
return RowPolicyKindInfo{raw_name_, std::move(init_name)};
};

switch (kind_)
{
case RowPolicyKind::PERMISSIVE:
{
static const auto info = make_info("PERMISSIVE");
return info;
}
case RowPolicyKind::RESTRICTIVE:
{
static const auto info = make_info("RESTRICTIVE");
return info;
}
case RowPolicyKind::MAX: break;
}
throw Exception("Unknown kind: " + std::to_string(static_cast<size_t>(kind_)), ErrorCodes::LOGICAL_ERROR);
}

}
21 changes: 21 additions & 0 deletions src/Access/Common/RowPolicyDefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,25 @@ struct RowPolicyFilterTypeInfo
static const RowPolicyFilterTypeInfo & get(RowPolicyFilterType type);
};


/// Kinds of row policies. It affects how row policies are applied.
/// A row is only accessible if at least one of the permissive policies passes,
/// in addition to all the restrictive policies.
enum class RowPolicyKind
{
PERMISSIVE,
RESTRICTIVE,

MAX,
};

String toString(RowPolicyKind kind);

struct RowPolicyKindInfo
{
const char * const raw_name;
const String name; /// Lowercased with underscores, e.g. "permissive".
static const RowPolicyKindInfo & get(RowPolicyKind kind);
};

}
40 changes: 40 additions & 0 deletions src/Access/RolesOrUsersSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,46 @@ std::vector<UUID> RolesOrUsersSet::getMatchingIDs(const AccessControl & access_c
}


bool RolesOrUsersSet::contains(const RolesOrUsersSet & other) const
{
if (all && other.all)
{
for (const auto & id : except_ids)
{
if (!other.except_ids.contains(id))
return false;
}
return true;
}
else if (all /* && !other.all */)
{
for (const auto & id : other.ids)
{
if (other.except_ids.contains(id))
continue;
if (except_ids.contains(id))
return false;
}
return true;
}
else if (other.all /* && !all */)
{
return false;
}
else /* !all && !other.all */
{
for (const auto & id : other.ids)
{
if (other.except_ids.contains(id))
continue;
if (!ids.contains(id) || except_ids.contains(id))
return false;
}
return true;
}
}


bool operator ==(const RolesOrUsersSet & lhs, const RolesOrUsersSet & rhs)
{
return (lhs.all == rhs.all) && (lhs.ids == rhs.ids) && (lhs.except_ids == rhs.except_ids);
Expand Down
3 changes: 3 additions & 0 deletions src/Access/RolesOrUsersSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ struct RolesOrUsersSet
/// Returns a list of matching users and roles.
std::vector<UUID> getMatchingIDs(const AccessControl & access_control) const;

/// Returns true if this set contains each element of another set.
bool contains(const RolesOrUsersSet & other) const;

friend bool operator ==(const RolesOrUsersSet & lhs, const RolesOrUsersSet & rhs);
friend bool operator !=(const RolesOrUsersSet & lhs, const RolesOrUsersSet & rhs) { return !(lhs == rhs); }

Expand Down
13 changes: 11 additions & 2 deletions src/Access/RowPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,18 @@ bool RowPolicy::equal(const IAccessEntity & other) const
{
if (!IAccessEntity::equal(other))
return false;

const auto & other_policy = typeid_cast<const RowPolicy &>(other);
return (full_name == other_policy.full_name) && boost::range::equal(filters, other_policy.filters)
&& restrictive == other_policy.restrictive && (to_roles == other_policy.to_roles);
if ((full_name != other_policy.full_name) || !boost::range::equal(filters, other_policy.filters)
|| (kind != other_policy.kind) || (to_set != other_policy.to_set))
return false;

if (kind == RowPolicyKind::PERMISSIVE)
{
if (of_set != other_policy.of_set)
return false;
}
return true;
}

}
33 changes: 19 additions & 14 deletions src/Access/RowPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,36 @@ struct RowPolicy : public IAccessEntity
/// for user or available for modification.
std::array<String, static_cast<size_t>(RowPolicyFilterType::MAX)> filters;

/// Sets that the policy is permissive.
/// A row is only accessible if at least one of the permissive policies passes,
/// in addition to all the restrictive policies.
void setPermissive(bool permissive_ = true) { setRestrictive(!permissive_); }
bool isPermissive() const { return !isRestrictive(); }

/// Sets that the policy is restrictive.
/// A row is only accessible if at least one of the permissive policies passes,
/// in addition to all the restrictive policies.
void setRestrictive(bool restrictive_ = true) { restrictive = restrictive_; }
bool isRestrictive() const { return restrictive; }
/// Sets the kind of the policy, it affects how row policies are applied.
void setKind(RowPolicyKind kind_) { kind = kind_; }
RowPolicyKind getKind() const { return kind; }

bool equal(const IAccessEntity & other) const override;
std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<RowPolicy>(); }
static constexpr const auto TYPE = AccessEntityType::ROW_POLICY;
AccessEntityType getType() const override { return TYPE; }

/// Which roles or users should use this row policy.
RolesOrUsersSet to_roles;
/// Users and roles written in the TO clause.
/// For each user in this set this row policy is applied,
/// and for each role in this set this row policy is applied for any user using that role.
RolesOrUsersSet to_set;

/// Users and roles written in the OF clause (used for permissive row policies only).
/// Contains a list of users this row policy affects.
/// There can be one of the three cases:
/// 1) If some user is in `to_set` set he will see filtered rows;
/// 2) If some user is not in `to_set` set but he's in `of_set` set
/// he won't see any rows unless other permissive row policies allow him
/// to see something;
/// 3) If some user is not in `to_set` set and not in `of_set` set
/// this row policy won't affect this user at all, but other row policies can.
RolesOrUsersSet of_set;

private:
void setName(const String &) override;

RowPolicyName full_name;
bool restrictive = false;
RowPolicyKind kind = RowPolicyKind::PERMISSIVE;
};

using RowPolicyPtr = std::shared_ptr<const RowPolicy>;
Expand Down
44 changes: 31 additions & 13 deletions src/Access/RowPolicyCache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,34 @@ namespace
class FiltersMixer
{
public:
void add(const ASTPtr & filter, bool is_restrictive)
void add(const ASTPtr & filter, RowPolicyKind kind)
{
if (is_restrictive)
restrictions.push_back(filter);
if (kind == RowPolicyKind::PERMISSIVE)
{
setPermissiveFiltersExist();
permissive_filters.push_back(filter);
}
else
permissions.push_back(filter);
{
restrictive_filters.push_back(filter);
}
}

void setPermissiveFiltersExist()
{
permissive_filters_exist = true;
}

ASTPtr getResult() &&
{
/// Process permissive filters.
restrictions.push_back(makeASTForLogicalOr(std::move(permissions)));
if (permissive_filters_exist)
{
/// Process permissive filters.
restrictive_filters.push_back(makeASTForLogicalOr(std::move(permissive_filters)));
}

/// Process restrictive filters.
auto result = makeASTForLogicalAnd(std::move(restrictions));
auto result = makeASTForLogicalAnd(std::move(restrictive_filters));

bool value;
if (tryGetLiteralBool(result.get(), value) && value)
Expand All @@ -44,16 +57,18 @@ namespace
}

private:
ASTs permissions;
ASTs restrictions;
ASTs permissive_filters;
bool permissive_filters_exist = false;
ASTs restrictive_filters;
};
}


void RowPolicyCache::PolicyInfo::setPolicy(const RowPolicyPtr & policy_)
{
policy = policy_;
roles = &policy->to_roles;
to_set = &policy->to_set;
of_set = &policy->of_set;
database_and_table_name = std::make_shared<std::pair<String, String>>(policy->getDatabase(), policy->getTableName());

for (auto filter_type : collections::range(0, RowPolicyFilterType::MAX))
Expand Down Expand Up @@ -211,7 +226,8 @@ void RowPolicyCache::mixFiltersFor(EnabledRowPolicies & enabled)
for (const auto & [policy_id, info] : all_policies)
{
const auto & policy = *info.policy;
bool match = info.roles->match(enabled.params.user_id, enabled.params.enabled_roles);
bool matches = info.to_set->match(enabled.params.user_id, enabled.params.enabled_roles);
bool affects = !matches && (policy.getKind() == RowPolicyKind::PERMISSIVE) && info.of_set->match(enabled.params.user_id, enabled.params.enabled_roles);
MixedFiltersKey key;
key.database = info.database_and_table_name->first;
key.table_name = info.database_and_table_name->second;
Expand All @@ -223,8 +239,10 @@ void RowPolicyCache::mixFiltersFor(EnabledRowPolicies & enabled)
key.filter_type = filter_type;
auto & mixer = mixers[key];
mixer.database_and_table_name = info.database_and_table_name;
if (match)
mixer.mixer.add(info.parsed_filters[filter_type_i], policy.isRestrictive());
if (matches)
mixer.mixer.add(info.parsed_filters[filter_type_i], policy.getKind());
else if (affects)
mixer.mixer.setPermissiveFiltersExist();
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/Access/RowPolicyCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class RowPolicyCache
void setPolicy(const RowPolicyPtr & policy_);

RowPolicyPtr policy;
const RolesOrUsersSet * roles = nullptr;
const RolesOrUsersSet * to_set = nullptr;
const RolesOrUsersSet * of_set = nullptr;
std::shared_ptr<const std::pair<String, String>> database_and_table_name;
ASTPtr parsed_filters[static_cast<size_t>(RowPolicyFilterType::MAX)];
};
Expand Down
3 changes: 2 additions & 1 deletion src/Access/UsersConfigAccessStorage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,8 @@ namespace
auto policy = std::make_shared<RowPolicy>();
policy->setFullName(user_name, database, table_name);
policy->filters[static_cast<size_t>(RowPolicyFilterType::SELECT_FILTER)] = filter;
policy->to_roles.add(generateID(AccessEntityType::USER, user_name));
policy->to_set.add(generateID(AccessEntityType::USER, user_name));
policy->of_set.all = true;
policies.push_back(policy);
}
}
Expand Down
46 changes: 33 additions & 13 deletions src/Interpreters/Access/InterpreterCreateRowPolicyQuery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ namespace
RowPolicy & policy,
const ASTCreateRowPolicyQuery & query,
const RowPolicyName & override_name,
const std::optional<RolesOrUsersSet> & override_to_roles)
const std::optional<RolesOrUsersSet> & override_to_set,
const std::optional<RolesOrUsersSet> & override_of_set)
{
if (!override_name.empty())
policy.setFullName(override_name);
Expand All @@ -29,16 +30,33 @@ namespace
else if (query.names->full_names.size() == 1)
policy.setFullName(query.names->full_names.front());

if (query.is_restrictive)
policy.setRestrictive(*query.is_restrictive);
auto old_kind = policy.getKind();
if (query.kind)
policy.setKind(*query.kind);
bool kind_changed = (policy.getKind() != old_kind) || !query.alter;

for (const auto & [filter_type, filter] : query.filters)
policy.filters[static_cast<size_t>(filter_type)] = filter ? serializeAST(*filter) : String{};

if (override_to_roles)
policy.to_roles = *override_to_roles;
else if (query.roles)
policy.to_roles = *query.roles;
if (override_to_set)
policy.to_set = *override_to_set;
else if (query.to_set)
policy.to_set = *query.to_set;

if (override_of_set)
policy.of_set = *override_of_set;
else if (query.of_set)
policy.of_set = *query.of_set;
else if ((policy.getKind() == RowPolicyKind::PERMISSIVE) && (override_to_set || query.to_set || kind_changed))
policy.of_set = RolesOrUsersSet::AllTag{}; /// By default permissive row policies are OF ALL.
else if (policy.getKind() == RowPolicyKind::RESTRICTIVE)
policy.of_set.clear();

if ((policy.getKind() == RowPolicyKind::PERMISSIVE) && !policy.of_set.contains(policy.to_set))
throw Exception("Users and roles in the TO clause must be a subset of ones in the OF clause", ErrorCodes::BAD_ARGUMENTS);

if ((policy.getKind() == RowPolicyKind::RESTRICTIVE) && !policy.of_set.empty())
throw Exception("OF clause can only be used with permissive row policies", ErrorCodes::BAD_ARGUMENTS);
}
}

Expand All @@ -60,16 +78,18 @@ BlockIO InterpreterCreateRowPolicyQuery::execute()

query.replaceEmptyDatabase(getContext()->getCurrentDatabase());

std::optional<RolesOrUsersSet> roles_from_query;
if (query.roles)
roles_from_query = RolesOrUsersSet{*query.roles, access_control, getContext()->getUserID()};
std::optional<RolesOrUsersSet> to_set_from_query, of_set_from_query;
if (query.to_set)
to_set_from_query = RolesOrUsersSet{*query.to_set, access_control, getContext()->getUserID()};
if (query.of_set)
of_set_from_query = RolesOrUsersSet{*query.of_set, access_control, getContext()->getUserID()};

if (query.alter)
{
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
{
auto updated_policy = typeid_cast<std::shared_ptr<RowPolicy>>(entity->clone());
updateRowPolicyFromQueryImpl(*updated_policy, query, {}, roles_from_query);
updateRowPolicyFromQueryImpl(*updated_policy, query, {}, to_set_from_query, of_set_from_query);
return updated_policy;
};
Strings names = query.names->toStrings();
Expand All @@ -87,7 +107,7 @@ BlockIO InterpreterCreateRowPolicyQuery::execute()
for (const auto & full_name : query.names->full_names)
{
auto new_policy = std::make_shared<RowPolicy>();
updateRowPolicyFromQueryImpl(*new_policy, query, full_name, roles_from_query);
updateRowPolicyFromQueryImpl(*new_policy, query, full_name, to_set_from_query, of_set_from_query);
new_policies.emplace_back(std::move(new_policy));
}

Expand All @@ -105,7 +125,7 @@ BlockIO InterpreterCreateRowPolicyQuery::execute()

void InterpreterCreateRowPolicyQuery::updateRowPolicyFromQuery(RowPolicy & policy, const ASTCreateRowPolicyQuery & query)
{
updateRowPolicyFromQueryImpl(policy, query, {}, {});
updateRowPolicyFromQueryImpl(policy, query, {}, {}, {});
}


Expand Down
Loading