Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions sql/load_sql_context.sql
Original file line number Diff line number Diff line change
Expand Up @@ -289,14 +289,16 @@ select
),
array[]::text[]
),
'is_unique', pi.indisunique and pi.indpred is null,
'is_primary_key', pi.indisprimary
)
'is_unique', pi.indisunique and pi.indpred is null,
'is_primary_key', pi.indisprimary,
'name', pc_ix.relname
)
from
pg_catalog.pg_index pi
where
pi.indrelid = pc.oid
)
from
pg_catalog.pg_index pi
join pg_catalog.pg_class pc_ix on pi.indexrelid = pc_ix.oid
where
pi.indrelid = pc.oid
),
jsonb_build_array()
),
Expand Down
123 changes: 121 additions & 2 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::parser_util::*;
use crate::sql_types::*;
use graphql_parser::query::*;
use serde::Serialize;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::hash::Hash;
use std::ops::Deref;
use std::str::FromStr;
Expand Down Expand Up @@ -51,6 +51,7 @@ pub enum AggregateSelection {
pub struct InsertBuilder {
// args
pub objects: Vec<InsertRowBuilder>,
pub on_conflict: Option<OnConflictBuilder>,

// metadata
pub table: Arc<Table>,
Expand All @@ -59,6 +60,13 @@ pub struct InsertBuilder {
pub selections: Vec<InsertSelection>,
}

#[derive(Clone, Debug)]
pub struct OnConflictBuilder {
pub constraint: Index,
pub update_fields: HashSet<Arc<Column>>,
pub filter: FilterBuilder,
}

#[derive(Clone, Debug)]
pub struct InsertRowBuilder {
// String is Column name
Expand Down Expand Up @@ -325,6 +333,107 @@ where
Ok(objects)
}

fn read_argument_on_conflict<'a, T>(
field: &__Field,
query_field: &graphql_parser::query::Field<'a, T>,
variables: &serde_json::Value,
variable_definitions: &Vec<VariableDefinition<'a, T>>,
table: &Arc<Table>,
schema: &Arc<__Schema>,
) -> GraphQLResult<Option<OnConflictBuilder>>
where
T: Text<'a> + Eq + AsRef<str>,
{
let has_on_conflict = query_field
.arguments
.iter()
.any(|(name, _)| name.as_ref() == "onConflict");

if !has_on_conflict {
return Ok(None);
}

let validated: gson::Value = read_argument(
"onConflict",
field,
query_field,
variables,
variable_definitions,
)?;

match validated {
gson::Value::Absent | gson::Value::Null => Ok(None),
gson::Value::Object(obj) => {
let constraint_val = obj
.get("constraint")
.ok_or(GraphQLError::validation("constraint is required"))?;
let constraint_name = match constraint_val {
gson::Value::String(s) => s,
_ => return Err(GraphQLError::validation("constraint must be a string")),
};

let constraint = table
.on_conflict_indexes()
.iter()
.find(|idx| &idx.name == constraint_name)
.ok_or(GraphQLError::validation("Invalid constraint name"))?;

let update_fields_val = obj
.get("updateFields")
.ok_or(GraphQLError::validation("updateFields is required"))?;
let update_fields_arr = match update_fields_val {
gson::Value::Array(arr) => arr,
_ => return Err(GraphQLError::validation("updateFields must be an array")),
};

let mut update_fields = HashSet::new();
for val in update_fields_arr {
match val {
gson::Value::String(graphql_col_name) => {
// Map GraphQL column name back to actual Column
let column = table
.columns
.iter()
.filter(|c| c.permissions.is_updatable && !c.is_generated && !c.is_serial)
.find(|c| schema.graphql_column_field_name(c).as_str() == graphql_col_name.as_str())
.ok_or_else(|| {
GraphQLError::validation(format!(
"Invalid column in updateFields: {}",
graphql_col_name
))
})?;
update_fields.insert(Arc::clone(column));
}
_ => {
return Err(GraphQLError::validation(
"updateFields elements must be strings",
))
}
}
}

let filter = if let Some(filter_val) = obj.get(args::FILTER) {
let filter_type = __Type::FilterEntity(FilterEntityType {
table: Arc::clone(table),
schema: Arc::clone(schema),
});
let filter_field_map = input_field_map(&filter_type);
let filters = create_filters(filter_val, &filter_field_map)?;
Some(FilterBuilder { elems: filters })
} else {
None
};

Ok(Some(OnConflictBuilder {
constraint: (*constraint).clone(),
update_fields,
filter: filter.unwrap_or(FilterBuilder { elems: vec![] }),
}))
}
_ => Err(GraphQLError::validation("Invalid onConflict argument")),
}
}

pub fn to_insert_builder<'a, T>(
field: &__Field,
query_field: &graphql_parser::query::Field<'a, T>,
Expand All @@ -345,11 +454,20 @@ where
match &type_ {
__Type::InsertResponse(xtype) => {
// Raise for disallowed arguments
restrict_allowed_arguments(&[args::OBJECTS], query_field)?;
restrict_allowed_arguments(&[args::OBJECTS, "onConflict"], query_field)?;

let objects: Vec<InsertRowBuilder> =
read_argument_objects(field, query_field, variables, variable_definitions)?;

let on_conflict = read_argument_on_conflict(
field,
query_field,
variables,
variable_definitions,
&xtype.table,
&xtype.schema,
)?;

let mut builder_fields: Vec<InsertSelection> = vec![];

let selection_fields = normalize_selection_set(
Expand Down Expand Up @@ -394,6 +512,7 @@ where
Ok(InsertBuilder {
table: Arc::clone(&xtype.table),
objects,
on_conflict,
selections: builder_fields,
})
}
Expand Down
Loading
Loading