diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index cb604c74a..e7f9124e6 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -104,6 +104,13 @@ function nanoid(length: Int?, format: String?): String { function ulid(format: String?): String { } @@@expressionContext([DefaultValue]) +/** + * Generates a custom identifier. The ORM client must be initialized with an + * implementation of this function. + */ +function customId(length: Int?): String { +} @@@expressionContext([DefaultValue]) + /** * Creates a sequence of integers in the underlying database and assign the incremented * values to the ID values of the created records based on the sequence. diff --git a/packages/language/src/validators/function-invocation-validator.ts b/packages/language/src/validators/function-invocation-validator.ts index 8ce1035e3..25106f1a6 100644 --- a/packages/language/src/validators/function-invocation-validator.ts +++ b/packages/language/src/validators/function-invocation-validator.ts @@ -237,6 +237,20 @@ export default class FunctionInvocationValidator implements AstValidator(lengthArg); + if (length !== undefined && length <= 0) { + accept('error', 'first argument must be a positive number', { + node: expr.args[0]!, + }); + } + } + } + @func('auth') private _checkAuth(expr: InvocationExpr, accept: ValidationAcceptor) { if (!expr.$resolvedType) { diff --git a/packages/language/test/function-invocation.test.ts b/packages/language/test/function-invocation.test.ts index ff6bb45ef..0fbd469a7 100644 --- a/packages/language/test/function-invocation.test.ts +++ b/packages/language/test/function-invocation.test.ts @@ -414,4 +414,36 @@ describe('Function Invocation Tests', () => { ); }); }); + + describe('customId() length validation', () => { + it('should reject non-positive lengths', async () => { + await loadSchemaWithError( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id String @id @default(customId(0)) + } + `, + 'first argument must be a positive number', + ); + + await loadSchemaWithError( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id String @id @default(customId(-1)) + } + `, + 'first argument must be a positive number', + ); + }); + }); }); diff --git a/packages/orm/src/client/crud/operations/base.ts b/packages/orm/src/client/crud/operations/base.ts index fc75cac9d..33c3b62aa 100644 --- a/packages/orm/src/client/crud/operations/base.ts +++ b/packages/orm/src/client/crud/operations/base.ts @@ -25,6 +25,7 @@ import { NUMERIC_FIELD_TYPES } from '../../constants'; import { TransactionIsolationLevel, type ClientContract, type CRUD } from '../../contract'; import type { FindArgs, SelectIncludeOmit, WhereInput } from '../../crud-types'; import { + createConfigError, createDBQueryError, createInternalError, createInvalidInputError, @@ -1044,7 +1045,7 @@ export abstract class BaseOperationHandler { } if (!(field in data)) { if (typeof fieldDef?.default === 'object' && 'kind' in fieldDef.default) { - const generated = this.evalGenerator(fieldDef.default); + const generated = this.evalGenerator(fieldDef.default, modelDef.name); if (generated !== undefined) { values[field] = this.dialect.transformInput( generated, @@ -1072,7 +1073,7 @@ export abstract class BaseOperationHandler { return values; } - private evalGenerator(defaultValue: Expression) { + private evalGenerator(defaultValue: Expression, model: string) { if (ExpressionUtils.isCall(defaultValue)) { const firstArgVal = defaultValue.args?.[0] && ExpressionUtils.isLiteral(defaultValue.args[0]) @@ -1095,6 +1096,21 @@ export abstract class BaseOperationHandler { return this.formatGeneratedValue(generated, defaultValue.args?.[1]); }) .with('ulid', () => this.formatGeneratedValue(ulid(), defaultValue.args?.[0])) + .with('customId', () => { + if (!this.client.$options.customId) { + throw createConfigError('"customId" implementation not provided'); + } + const length = typeof firstArgVal === 'number' ? firstArgVal : undefined; + const generated = this.client.$options.customId({ + client: this.client, + model: model as GetModels, + length, + }); + if (!generated || typeof generated !== 'string') { + throw createConfigError('"customId" must return a non-empty string'); + } + return generated; + }) .otherwise(() => undefined); } else if ( ExpressionUtils.isMember(defaultValue) && diff --git a/packages/orm/src/client/options.ts b/packages/orm/src/client/options.ts index 6439e3996..7d21a3d3d 100644 --- a/packages/orm/src/client/options.ts +++ b/packages/orm/src/client/options.ts @@ -40,6 +40,25 @@ export type ZModelFunction = ( context: ZModelFunctionContext, ) => Expression; +export type CustomIdFunctionContext = { + /** + * ZenStack client instance. + */ + client: ClientContract; + + /** + * The model for which the ID should be generated. + */ + model: GetModels; + + /** + * The length of the ID as requested by the schema. + */ + length?: number; +}; + +export type CustomIdFunction = (ctx: CustomIdFunctionContext) => string; + /** * ZenStack client options. */ @@ -82,6 +101,12 @@ export type ClientOptions = { */ validateInput?: boolean; + /** + * Implementation of a custom ID generation function, which is called from ZModel as + * `@default(customId())`. + */ + customId?: CustomIdFunction; + /** * Options for omitting fields in ORM query results. */ diff --git a/tests/e2e/orm/client-api/custom-id.test.ts b/tests/e2e/orm/client-api/custom-id.test.ts new file mode 100644 index 000000000..15a82b663 --- /dev/null +++ b/tests/e2e/orm/client-api/custom-id.test.ts @@ -0,0 +1,180 @@ +import { createTestClient } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; + +const schema = ` + model User { + uid String @id @default(customId()) + posts Post[] + } + + model Post { + pid String @id @default(customId()) + userId String? + user User? @relation(fields: [userId], references: [uid]) + comments Comment[] + } + + model Comment { + cid String @id @default(customId()) + postId String? + post Post? @relation(fields: [postId], references: [pid]) + } +`; + +describe('customId', () => { + it('works with no arguments', async () => { + let client = await createTestClient(schema, { + customId: ({ model, length, client }) => `${model}.${length ?? 16}.${client.$auth!['uid']}`, + }); + + client = client.$setAuth({ + uid: '1', + }); + + await expect(client.user.create({ data: {} })).resolves.toMatchObject({ + uid: 'User.16.1', + }); + + await expect(client.post.create({ data: {} })).resolves.toMatchObject({ + pid: 'Post.16.1', + }); + + await expect(client.comment.create({ data: {} })).resolves.toMatchObject({ + cid: 'Comment.16.1', + }); + }); + + it('works with arguments', async () => { + const schema = ` + model User { + uid String @id @default(customId(8)) + posts Post[] + } + + model Post { + pid String @id @default(customId(8)) + userId String? + user User? @relation(fields: [userId], references: [uid]) + comments Comment[] + } + + model Comment { + cid String @id @default(customId(8)) + postId String? + post Post? @relation(fields: [postId], references: [pid]) + } + `; + + let client = await createTestClient(schema, { + customId: ({ model, length, client }) => `${model}.${length}.${client.$auth!['uid']}`, + }); + + client = client.$setAuth({ + uid: '1', + }); + + await expect(client.user.create({ data: {} })).resolves.toMatchObject({ + uid: 'User.8.1', + }); + + await expect(client.post.create({ data: {} })).resolves.toMatchObject({ + pid: 'Post.8.1', + }); + + await expect(client.comment.create({ data: {} })).resolves.toMatchObject({ + cid: 'Comment.8.1', + }); + }); + + it('works with nested', async () => { + let client = await createTestClient(schema, { + customId: ({ model, length, client }) => `${model}.${length ?? 16}.${client.$auth!['uid']}`, + }); + + client = client.$setAuth({ + uid: '1', + }); + + await expect(client.user.create({ + data: { + posts: { + create: {}, + }, + }, + })).resolves.toMatchObject({ + uid: 'User.16.1', + }); + + await expect(client.post.findUnique({ + where: { + pid: 'Post.16.1', + } + })).resolves.toBeTruthy(); + }); + + it('works with deeply nested', async () => { + let client = await createTestClient(schema, { + customId: ({ model, length, client }) => `${model}.${length ?? 16}.${client.$auth!['uid']}`, + }); + + client = client.$setAuth({ + uid: '1', + }); + + await expect(client.user.create({ + data: { + posts: { + create: { + comments: { + create: {}, + }, + }, + }, + }, + })).resolves.toMatchObject({ + uid: 'User.16.1', + }); + + await expect(client.post.findUnique({ + where: { + pid: 'Post.16.1', + } + })).resolves.toBeTruthy(); + + await expect(client.comment.findUnique({ + where: { + cid: 'Comment.16.1', + } + })).resolves.toBeTruthy(); + }); + + it('rejects without an implementation', async () => { + const client = await createTestClient(schema); + await expect(client.user.create({ data: {} })).rejects.toThrowError('implementation not provided'); + }); + + it('rejects without a valid implementation (undefined)', async () => { + // @ts-expect-error + const client = await createTestClient(schema, { + customId: () => undefined, + }); + // @ts-expect-error + await expect(client.user.create({ data: {} })).rejects.toThrowError('non-empty string'); + }); + + it('rejects without a valid implementation (empty string)', async () => { + const client = await createTestClient(schema, { + customId: () => '', + }); + await expect(client.user.create({ data: {} })).rejects.toThrowError('non-empty string'); + }); + + it('rejects without a valid implementation (non-string)', async () => { + // @ts-expect-error + const client = await createTestClient(schema, { + customId: () => 1, + }); + // @ts-expect-error + await expect(client.user.create({ data: {} })).rejects.toThrowError('non-empty string'); + }); +});