Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions packages/cli/src/actions/db.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import {
} from './action-utils';
import { consolidateEnums, syncEnums, syncRelation, syncTable, type Relation } from './pull';
import { providers as pullProviders } from './pull/provider';
import { getDatasource, getDbName, getRelationFieldsKey, getRelationFkName, isDatabaseManagedAttribute } from './pull/utils';
import { getDatasource, getDbName, getRelationFieldsKey, getRelationFkName, getRelationName, isDatabaseManagedAttribute } from './pull/utils';
import type { DataSourceProviderType } from '@zenstackhq/schema';
import { CliError } from '../cli-error';

Expand Down Expand Up @@ -283,16 +283,9 @@ async function runPull(options: PullOptions) {
}

newDataModel.fields.forEach((f) => {
// Prioritized matching: exact db name > relation fields key > relation FK name > type reference
// Prioritized matching: exact db name > relation fields key > relation FK name > relation name > type reference
let originalFields = originalDataModel.fields.filter((d) => getDbName(d) === getDbName(f));

// If this is a back-reference relation field (has @relation but no `fields` arg), silently skip
const isRelationField =
f.$type === 'DataField' && !!(f as any).attributes?.some((a: any) => a?.decl?.ref?.name === '@relation');
if (originalFields.length === 0 && isRelationField && !getRelationFieldsKey(f as any)) {
return;
}

if (originalFields.length === 0) {
// Try matching by relation fields key (the `fields` attribute in @relation)
// This matches relation fields by their FK field references
Expand All @@ -315,10 +308,20 @@ async function runPull(options: PullOptions) {
}

if (originalFields.length === 0) {
// Try matching by type reference
// Try matching by relation name (the first positional arg in @relation)
// This is essential for back-reference fields that only have a relation name
const newRelName = getRelationName(f as any);
if (newRelName) {
originalFields = originalDataModel.fields.filter(
(d) => d.$type === 'DataField' && getRelationName(d as any) === newRelName,
);
}
}

if (originalFields.length === 0 && !getRelationName(f as any)) {
// Try matching by type reference (only for fields without a named relation)
// We need this because for relations that don't have @relation, we can only check if the original exists by the field type.
// Yes, in this case it can potentially result in multiple original fields, but we only want to ensure that at least one relation exists.
// In the future, we might implement some logic to detect how many of these types of relations we need and add/remove fields based on this.
// Fields with a named relation that didn't match above are genuinely new and should be added.
originalFields = originalDataModel.fields.filter(
(d) =>
f.$type === 'DataField' &&
Expand Down Expand Up @@ -499,7 +502,7 @@ async function runPull(options: PullOptions) {
});
originalDataModel.fields
.filter((f) => {
// Prioritized matching: exact db name > relation fields key > relation FK name > type reference
// Prioritized matching: exact db name > relation fields key > relation FK name > relation name > type reference
const matchByDbName = newDataModel.fields.find((d) => getDbName(d) === getDbName(f));
if (matchByDbName) return false;

Expand All @@ -520,6 +523,15 @@ async function runPull(options: PullOptions) {
);
if (matchByFkName) return false;

// Try matching by relation name (for named back-reference fields)
const originalRelName = getRelationName(f as any);
if (originalRelName) {
const matchByRelName = newDataModel.fields.find(
(d) => d.$type === 'DataField' && getRelationName(d as any) === originalRelName,
);
if (matchByRelName) return false;
}

const matchByTypeRef = newDataModel.fields.find(
(d) =>
f.$type === 'DataField' &&
Expand Down
14 changes: 14 additions & 0 deletions packages/cli/src/actions/pull/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,20 @@ export function getRelationFkName(decl: DataField): string | undefined {
return schemaAttrValue?.value;
}

/**
* Gets the relation name from the @relation attribute's first positional argument.
* e.g., @relation('myRelation', fields: [...], references: [...]) -> "myRelation"
* e.g., @relation(fields: [...], references: [...]) -> undefined
* e.g., @relation('backRef') -> "backRef"
*/
export function getRelationName(decl: DataField): string | undefined {
const relationAttr = decl?.attributes?.find((a) => a.decl?.ref?.name === '@relation');
if (!relationAttr) return undefined;
const firstPositionalArg = relationAttr.args.find((a) => !a.name);
if (!firstPositionalArg || firstPositionalArg.value?.$type !== 'StringLiteral') return undefined;
return (firstPositionalArg.value as StringLiteral).value;
}

/**
* Gets the FK field names from the @relation attribute's `fields` argument.
* Returns a sorted, comma-separated string of field names for comparison.
Expand Down
77 changes: 77 additions & 0 deletions packages/cli/test/db/pull.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,83 @@ model Tag {
expect(restoredSchema).toEqual(schema);
});

it('should restore opposite relation fields when multiple models have FKs to the same target', async () => {
const { workDir, schema } = await createProject(
`model Post {
id Int @id @default(autoincrement())
title String
postUserCreated User? @relation('Post_createdByToUser', fields: [createdBy], references: [id])
createdBy Int?
postUserUpdated User? @relation('Post_updatedByToUser', fields: [updatedBy], references: [id])
updatedBy Int?
}

model Comment {
id Int @id @default(autoincrement())
text String
commentUserCreated User? @relation('Comment_createdByToUser', fields: [createdBy], references: [id])
createdBy Int?
commentUserUpdated User? @relation('Comment_updatedByToUser', fields: [updatedBy], references: [id])
updatedBy Int?
}

model User {
id Int @id @default(autoincrement())
email String @unique
postUserCreatedToUsers Post[] @relation('Post_createdByToUser')
postUserUpdatedToUsers Post[] @relation('Post_updatedByToUser')
commentUserCreatedToUsers Comment[] @relation('Comment_createdByToUser')
commentUserUpdatedToUsers Comment[] @relation('Comment_updatedByToUser')
}`,
);
runCli('db push', workDir);

const schemaFile = path.join(workDir, 'zenstack/schema.zmodel');

fs.writeFileSync(schemaFile, getDefaultPrelude());
runCli('db pull --indent 4', workDir);

const restoredSchema = getSchema(workDir);
expect(restoredSchema).toEqual(schema);
});

it('should preserve opposite relation fields when multiple models have FKs to the same target', async () => {
const { workDir, schema } = await createProject(
`model Post {
id Int @id @default(autoincrement())
title String
postUserCreated User? @relation('Post_createdByToUser', fields: [createdBy], references: [id])
createdBy Int?
postUserUpdated User? @relation('Post_updatedByToUser', fields: [updatedBy], references: [id])
updatedBy Int?
}

model Comment {
id Int @id @default(autoincrement())
text String
commentUserCreated User? @relation('Comment_createdByToUser', fields: [createdBy], references: [id])
createdBy Int?
commentUserUpdated User? @relation('Comment_updatedByToUser', fields: [updatedBy], references: [id])
updatedBy Int?
}

model User {
id Int @id @default(autoincrement())
email String @unique
postUserCreatedToUsers Post[] @relation('Post_createdByToUser')
postUserUpdatedToUsers Post[] @relation('Post_updatedByToUser')
commentUserCreatedToUsers Comment[] @relation('Comment_createdByToUser')
commentUserUpdatedToUsers Comment[] @relation('Comment_updatedByToUser')
}`,
);
runCli('db push', workDir);

runCli('db pull --indent 4', workDir);

const restoredSchema = getSchema(workDir);
expect(restoredSchema).toEqual(schema);
});

it('should restore one-to-one relation when FK is the single-column primary key', async () => {
const { workDir, schema } = await createProject(
`model Profile {
Expand Down
Loading