diff --git a/README.md b/README.md index 64e0e3c8..4da4a3f4 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ ZenStack is a TypeScript database toolkit for developing full-stack or backend N - Automatic CRUD web APIs with adapters for popular frameworks (coming soon) - Automatic [TanStack Query](https://github.com/TanStack/query) hooks for easy CRUD from the frontend (coming soon) -# What's new with V3 +# What's New in V3 ZenStack V3 is a major rewrite of [V2](https://github.com/zenstackhq/zenstack). The biggest change is V3 doesn't have a runtime dependency to Prisma anymore. Instead of working as a big wrapper of Prisma as in V2, V3 made a bold move and implemented the entire ORM engine using [Kysely](https://github.com/kysely-org/kysely), while keeping the query API fully compatible with Prisma. @@ -49,7 +49,7 @@ Even without using advanced features, ZenStack offers the following benefits as > Although ZenStack v3's ORM runtime doesn't depend on Prisma anymore (specifically, `@prisma/client`), it still relies on Prisma to handle database migration. See [database migration](https://zenstack.dev/docs/3.x/orm/migration) for more details. -# Quick start +# Quick Start - [ORM](./samples/orm): A simple example demonstrating ZenStack ORM usage. - [Next.js + TanStack Query](./samples/next.js): A full-stack sample demonstrating using TanStack Query to consume ZenStack's automatic CRUD services in a Next.js app. @@ -72,7 +72,7 @@ Or, if you have an existing project, use the CLI to initialize it: npx @zenstackhq/cli@next init ``` -### 3. Manual setup +### 3. Setting up manually Alternatively, you can set it up manually: diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index 7df81364..269a0f9c 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -382,12 +382,14 @@ attribute @map(_ name: String) @@@prisma attribute @@map(_ name: String) @@@prisma /** - * Exclude a field from the Prisma Client (for example, a field that you do not want Prisma users to update). + * Exclude a field from the ORM Client (for example, a field that you do not want Prisma users to update). + * The field is still recognized by database schema migrations. */ attribute @ignore() @@@prisma /** - * Exclude a model from the Prisma Client (for example, a model that you do not want Prisma users to update). + * Exclude a model from the ORM Client (for example, a model that you do not want Prisma users to update). + * The model is still recognized by database schema migrations. */ attribute @@ignore() @@@prisma diff --git a/packages/orm/src/client/client-impl.ts b/packages/orm/src/client/client-impl.ts index 150e00df..f8331a81 100644 --- a/packages/orm/src/client/client-impl.ts +++ b/packages/orm/src/client/client-impl.ts @@ -29,7 +29,7 @@ import { FindOperationHandler } from './crud/operations/find'; import { GroupByOperationHandler } from './crud/operations/group-by'; import { UpdateOperationHandler } from './crud/operations/update'; import { InputValidator } from './crud/validator'; -import { NotFoundError, QueryError } from './errors'; +import { createConfigError, createNotFoundError } from './errors'; import { ZenStackDriver } from './executor/zenstack-driver'; import { ZenStackQueryExecutor } from './executor/zenstack-query-executor'; import * as BuiltinFunctions from './functions'; @@ -223,7 +223,7 @@ export class ClientImpl { private async handleProc(name: string, args: unknown[]) { if (!('procedures' in this.$options) || !this.$options || typeof this.$options.procedures !== 'object') { - throw new QueryError('Procedures are not configured for the client.'); + throw createConfigError('Procedures are not configured for the client.'); } const procOptions = this.$options.procedures as ProceduresOptions< @@ -389,7 +389,7 @@ function createModelCrudHandler { if (this.supportsDistinctOn) { result = result.distinctOn(distinct.map((f) => this.eb.ref(`${modelAlias}.${f}`))); } else { - throw new QueryError(`"distinct" is not supported by "${this.schema.provider.type}" provider`); + throw createNotSupportedError(`"distinct" is not supported by "${this.schema.provider.type}" provider`); } } @@ -482,7 +482,7 @@ export abstract class BaseCrudDialect { } default: { - throw new InternalError(`Invalid array filter key: ${key}`); + throw createInvalidInputError(`Invalid array filter key: ${key}`); } } } @@ -510,10 +510,10 @@ export abstract class BaseCrudDialect { .with('Bytes', () => this.buildBytesFilter(fieldRef, payload)) // TODO: JSON filters .with('Json', () => { - throw new InternalError('JSON filters are not supported yet'); + throw createNotSupportedError('JSON filters are not supported yet'); }) .with('Unsupported', () => { - throw new QueryError(`Unsupported field cannot be used in filters`); + throw createInvalidInputError(`Unsupported field cannot be used in filters`); }) .exhaustive() ); @@ -589,7 +589,7 @@ export abstract class BaseCrudDialect { }) .otherwise(() => { if (throwIfInvalid) { - throw new QueryError(`Invalid filter key: ${op}`); + throw createInvalidInputError(`Invalid filter key: ${op}`); } else { return undefined; } @@ -642,7 +642,7 @@ export abstract class BaseCrudDialect { : this.eb(fieldRef, 'like', sql.val(`%${value}`)), ) .otherwise(() => { - throw new QueryError(`Invalid string filter key: ${key}`); + throw createInvalidInputError(`Invalid string filter key: ${key}`); }); if (condition) { @@ -815,7 +815,7 @@ export abstract class BaseCrudDialect { if (fieldDef.array) { // order by to-many relation if (typeof value !== 'object') { - throw new QueryError(`invalid orderBy value for field "${field}"`); + throw createInvalidInputError(`invalid orderBy value for field "${field}"`); } if ('_count' in value) { invariant( @@ -1084,7 +1084,7 @@ export abstract class BaseCrudDialect { computer = computedFields?.[fieldDef.originModel ?? model]?.[field]; } if (!computer) { - throw new QueryError(`Computed field "${field}" implementation not provided for model "${model}"`); + throw createConfigError(`Computed field "${field}" implementation not provided for model "${model}"`); } return computer(this.eb, { modelAlias }); } diff --git a/packages/orm/src/client/crud/dialects/postgresql.ts b/packages/orm/src/client/crud/dialects/postgresql.ts index 82eeb9a8..a37e603b 100644 --- a/packages/orm/src/client/crud/dialects/postgresql.ts +++ b/packages/orm/src/client/crud/dialects/postgresql.ts @@ -12,7 +12,6 @@ import { match } from 'ts-pattern'; import type { BuiltinType, FieldDef, GetModels, SchemaDef } from '../../../schema'; import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants'; import type { FindArgs } from '../../crud-types'; -import { QueryError } from '../../errors'; import type { ClientOptions } from '../../options'; import { buildJoinPairs, @@ -24,6 +23,7 @@ import { requireModel, } from '../../query-utils'; import { BaseCrudDialect } from './base-dialect'; +import { createInternalError } from '../../errors'; export class PostgresCrudDialect extends BaseCrudDialect { constructor(schema: Schema, options: ClientOptions) { @@ -438,7 +438,7 @@ export class PostgresCrudDialect extends BaseCrudDiale override getFieldSqlType(fieldDef: FieldDef) { // TODO: respect `@db.x` attributes if (fieldDef.relation) { - throw new QueryError('Cannot get SQL type of a relation field'); + throw createInternalError('Cannot get SQL type of a relation field'); } let result: string; diff --git a/packages/orm/src/client/crud/dialects/sqlite.ts b/packages/orm/src/client/crud/dialects/sqlite.ts index e163f464..464e30dc 100644 --- a/packages/orm/src/client/crud/dialects/sqlite.ts +++ b/packages/orm/src/client/crud/dialects/sqlite.ts @@ -12,7 +12,7 @@ import { match } from 'ts-pattern'; import type { BuiltinType, FieldDef, GetModels, SchemaDef } from '../../../schema'; import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants'; import type { FindArgs } from '../../crud-types'; -import { QueryError } from '../../errors'; +import { createInternalError } from '../../errors'; import { getDelegateDescendantModels, getManyToManyRelation, @@ -121,7 +121,7 @@ export class SqliteCrudDialect extends BaseCrudDialect try { return JSON.parse(value); } catch (e) { - throw new QueryError('Invalid JSON returned', e); + throw createInternalError('Invalid JSON returned', undefined, { cause: e }); } } return value; @@ -376,10 +376,10 @@ export class SqliteCrudDialect extends BaseCrudDialect override getFieldSqlType(fieldDef: FieldDef) { // TODO: respect `@db.x` attributes if (fieldDef.relation) { - throw new QueryError('Cannot get SQL type of a relation field'); + throw createInternalError('Cannot get SQL type of a relation field'); } if (fieldDef.array) { - throw new QueryError('SQLite does not support scalar list type'); + throw createInternalError('SQLite does not support scalar list type'); } if (this.schema.enums?.[fieldDef.type]) { diff --git a/packages/orm/src/client/crud/operations/base.ts b/packages/orm/src/client/crud/operations/base.ts index a309b268..a51a2557 100644 --- a/packages/orm/src/client/crud/operations/base.ts +++ b/packages/orm/src/client/crud/operations/base.ts @@ -14,14 +14,21 @@ import { nanoid } from 'nanoid'; import { match } from 'ts-pattern'; import { ulid } from 'ulid'; import * as uuid from 'uuid'; -import type { ClientContract } from '../..'; import type { BuiltinType, Expression, FieldDef } from '../../../schema'; import { ExpressionUtils, type GetModels, type ModelDef, type SchemaDef } from '../../../schema'; import { extractFields, fieldsToSelectObject } from '../../../utils/object-utils'; import { NUMERIC_FIELD_TYPES } from '../../constants'; -import { TransactionIsolationLevel, type CRUD } from '../../contract'; +import { TransactionIsolationLevel, type ClientContract, type CRUD } from '../../contract'; import type { FindArgs, SelectIncludeOmit, WhereInput } from '../../crud-types'; -import { InternalError, NotFoundError, QueryError } from '../../errors'; +import { + createDBQueryError, + createInternalError, + createInvalidInputError, + createNotFoundError, + createNotSupportedError, + ORMError, + ORMErrorReason, +} from '../../errors'; import type { ToKysely } from '../../query-builder'; import { ensureArray, @@ -172,8 +179,7 @@ export abstract class BaseOperationHandler { const r = await kysely.getExecutor().executeQuery(compiled, queryId); result = r.rows; } catch (err) { - const message = `Failed to execute query: ${err}, sql: ${compiled.sql}`; - throw new QueryError(message, err); + throw createDBQueryError('Failed to execute query', err, compiled.sql, compiled.parameters); } return result; @@ -212,7 +218,7 @@ export abstract class BaseOperationHandler { result = this.dialect.buildSelectField(result, model, parentAlias, field); } else { if (!fieldDef.array && !fieldDef.optional && payload.where) { - throw new QueryError(`Field "${field}" doesn't support filtering`); + throw createInternalError(`Field "${field}" does not support filtering`, model); } if (fieldDef.originModel) { result = this.dialect.buildRelationSelection( @@ -253,7 +259,7 @@ export abstract class BaseOperationHandler { // additional validations if (modelDef.isDelegate && !creatingForDelegate) { - throw new QueryError(`Model "${this.model}" is a delegate and cannot be created directly.`); + throw createNotSupportedError(`Model "${model}" is a delegate and cannot be created directly.`); } let createFields: any = {}; @@ -442,7 +448,7 @@ export abstract class BaseOperationHandler { select: { [pair.pk]: true }, } as any); if (!extraRead) { - throw new QueryError(`Field "${pair.pk}" not found in parent created data`); + throw createInternalError(`Field "${pair.pk}" not found in parent created data`, model); } else { // update the parent entity Object.assign(entity, extraRead); @@ -560,7 +566,7 @@ export abstract class BaseOperationHandler { select: fieldsToSelectObject(referencedPkFields) as any, }); if (!relationEntity) { - throw new NotFoundError( + throw createNotFoundError( relationModel, `Could not find the entity to connect for the relation "${relationField.name}"`, ); @@ -584,7 +590,7 @@ export abstract class BaseOperationHandler { } default: - throw new QueryError(`Invalid relation action: ${action}`); + throw createInvalidInputError(`Invalid relation action: ${action}`); } } @@ -650,7 +656,7 @@ export abstract class BaseOperationHandler { } default: - throw new QueryError(`Invalid relation action: ${action}`); + throw createInvalidInputError(`Invalid relation action: ${action}`); } } } @@ -681,7 +687,7 @@ export abstract class BaseOperationHandler { fromRelation.field, ); if (ownedByModel) { - throw new QueryError('incorrect relation hierarchy for createMany'); + throw createInvalidInputError('incorrect relation hierarchy for createMany', model); } relationKeyPairs = keyPairs; } @@ -739,7 +745,7 @@ export abstract class BaseOperationHandler { if (modelDef.baseModel) { if (input.skipDuplicates) { // TODO: simulate createMany with create in this case - throw new QueryError('"skipDuplicates" options is not supported for polymorphic models'); + throw createNotSupportedError('"skipDuplicates" options is not supported for polymorphic models'); } // create base hierarchy const baseCreateResult = await this.processBaseModelCreateMany( @@ -906,7 +912,7 @@ export abstract class BaseOperationHandler { fieldsToReturn?: string[], ): Promise { if (!data || typeof data !== 'object') { - throw new InternalError('data must be an object'); + throw createInvalidInputError('data must be an object'); } const parentWhere: any = {}; @@ -982,7 +988,7 @@ export abstract class BaseOperationHandler { select: this.makeIdSelect(model), }); if (!readResult && throwIfNotFound) { - throw new NotFoundError(model); + throw createNotFoundError(model); } combinedWhere = readResult; } @@ -1010,13 +1016,13 @@ export abstract class BaseOperationHandler { updateFields[field] = this.processScalarFieldUpdateData(model, field, finalData); } else { if (!allowRelationUpdate) { - throw new QueryError(`Relation update not allowed for field "${field}"`); + throw createNotSupportedError(`Relation update not allowed for field "${field}"`); } if (!thisEntity) { thisEntity = await this.getEntityIds(kysely, model, combinedWhere); if (!thisEntity) { if (throwIfNotFound) { - throw new NotFoundError(model); + throw createNotFoundError(model); } else { return null; } @@ -1065,7 +1071,7 @@ export abstract class BaseOperationHandler { const updatedEntity = await this.executeQueryTakeFirst(kysely, query, 'update'); if (!updatedEntity) { if (throwIfNotFound) { - throw new NotFoundError(model); + throw createNotFoundError(model); } else { return null; } @@ -1163,7 +1169,7 @@ export abstract class BaseOperationHandler { .with('multiply', () => eb(fieldRef, '*', value)) .with('divide', () => eb(fieldRef, '/', value)) .otherwise(() => { - throw new InternalError(`Invalid incremental update operation: ${key}`); + throw createInvalidInputError(`Invalid incremental update operation: ${key}`); }); } @@ -1185,7 +1191,7 @@ export abstract class BaseOperationHandler { return eb(fieldRef, '||', eb.val(ensureArray(value))); }) .otherwise(() => { - throw new InternalError(`Invalid array update operation: ${key}`); + throw createInvalidInputError(`Invalid array update operation: ${key}`); }); } @@ -1212,7 +1218,7 @@ export abstract class BaseOperationHandler { fieldsToReturn?: string[], ): Promise { if (typeof data !== 'object') { - throw new InternalError('data must be an object'); + throw createInvalidInputError('data must be an object'); } if (Object.keys(data).length === 0) { @@ -1221,7 +1227,7 @@ export abstract class BaseOperationHandler { const modelDef = this.requireModel(model); if (modelDef.baseModel && limit !== undefined) { - throw new QueryError('Updating with a limit is not supported for polymorphic models'); + throw createNotSupportedError('Updating with a limit is not supported for polymorphic models'); } filterModel ??= model; @@ -1465,7 +1471,7 @@ export abstract class BaseOperationHandler { } default: { - throw new Error('Not implemented yet'); + throw createInvalidInputError(`Invalid relation update operation: ${key}`); } } } @@ -1493,7 +1499,7 @@ export abstract class BaseOperationHandler { for (const d of _data) { const ids = await this.getEntityIds(kysely, model, d); if (!ids) { - throw new NotFoundError(model); + throw createNotFoundError(model); } const r = await this.handleManyToManyRelation( kysely, @@ -1511,7 +1517,7 @@ export abstract class BaseOperationHandler { // validate connect result if (_data.length > results.filter((r) => !!r).length) { - throw new NotFoundError(model); + throw createNotFoundError(model); } } else { const { ownedByModel, keyPairs } = getRelationForeignKeyFieldPairs( @@ -1527,7 +1533,7 @@ export abstract class BaseOperationHandler { where: _data[0], }); if (!target) { - throw new NotFoundError(model); + throw createNotFoundError(model); } for (const { fk, pk } of keyPairs) { @@ -1577,7 +1583,7 @@ export abstract class BaseOperationHandler { // validate connect result if (!updateResult.numAffectedRows || _data.length > updateResult.numAffectedRows) { // some entities were not connected - throw new NotFoundError(model); + throw createNotFoundError(model); } } } @@ -1742,7 +1748,7 @@ export abstract class BaseOperationHandler { for (const d of _data) { const ids = await this.getEntityIds(kysely, model, d); if (!ids) { - throw new NotFoundError(model); + throw createNotFoundError(model); } results.push( await this.handleManyToManyRelation( @@ -1761,7 +1767,7 @@ export abstract class BaseOperationHandler { // validate connect result if (_data.length > results.filter((r) => !!r).length) { - throw new NotFoundError(model); + throw createNotFoundError(model); } } else { const { ownedByModel, keyPairs } = getRelationForeignKeyFieldPairs( @@ -1771,7 +1777,7 @@ export abstract class BaseOperationHandler { ); if (ownedByModel) { - throw new InternalError('relation can only be set from the non-owning side'); + throw createInternalError('relation can only be set from the non-owning side', fromRelation.model); } const fkConditions = keyPairs.reduce( @@ -1827,7 +1833,7 @@ export abstract class BaseOperationHandler { // validate result if (!r.numAffectedRows || _data.length > r.numAffectedRows) { // some entities were not connected - throw new NotFoundError(model); + throw createNotFoundError(model); } } } @@ -1894,7 +1900,7 @@ export abstract class BaseOperationHandler { where: fromRelation.ids, }); if (!fromEntity) { - throw new NotFoundError(fromRelation.model); + throw createNotFoundError(fromRelation.model); } const fieldDef = this.requireField(fromRelation.model, fromRelation.field); @@ -1924,7 +1930,7 @@ export abstract class BaseOperationHandler { // validate result if (throwForNotFound && expectedDeleteCount > deleteResult.rows.length) { // some entities were not deleted - throw new NotFoundError(deleteFromModel); + throw createNotFoundError(deleteFromModel); } } @@ -1948,7 +1954,7 @@ export abstract class BaseOperationHandler { if (modelDef.baseModel) { if (limit !== undefined) { - throw new QueryError('Deleting with a limit is not supported for polymorphic models'); + throw createNotSupportedError('Deleting with a limit is not supported for polymorphic models'); } // just delete base and it'll cascade back to this model return this.processBaseModelDelete(kysely, modelDef.baseModel, where, limit, filterModel); @@ -2013,7 +2019,7 @@ export abstract class BaseOperationHandler { const oppositeRelation = this.requireField(fieldDef.type, fieldDef.relation.opposite); if (oppositeModelDef.baseModel && oppositeRelation.relation?.onDelete === 'Cascade') { if (limit !== undefined) { - throw new QueryError('Deleting with a limit is not supported for polymorphic models'); + throw createNotSupportedError('Deleting with a limit is not supported for polymorphic models'); } // the deletion will propagate upward to the base model chain await this.delete( @@ -2134,7 +2140,7 @@ export abstract class BaseOperationHandler { protected async executeQueryTakeFirstOrThrow(kysely: ToKysely, query: Compilable, operation: string) { const result = await kysely.executeQuery(query.compile(), this.makeQueryId(operation)); if (result.rows.length === 0) { - throw new QueryError('No rows found'); + throw new ORMError(ORMErrorReason.NOT_FOUND, 'No rows found'); } return result.rows[0]; } diff --git a/packages/orm/src/client/crud/operations/create.ts b/packages/orm/src/client/crud/operations/create.ts index 98124838..1e1c9869 100644 --- a/packages/orm/src/client/crud/operations/create.ts +++ b/packages/orm/src/client/crud/operations/create.ts @@ -1,7 +1,7 @@ import { match } from 'ts-pattern'; import type { GetModels, SchemaDef } from '../../../schema'; import type { CreateArgs, CreateManyAndReturnArgs, CreateManyArgs, WhereInput } from '../../crud-types'; -import { RejectedByPolicyError, RejectedByPolicyReason } from '../../errors'; +import { createRejectedByPolicyError, RejectedByPolicyReason } from '../../errors'; import { getIdValues } from '../../query-utils'; import { BaseOperationHandler } from './base'; @@ -48,7 +48,7 @@ export class CreateOperationHandler extends BaseOperat }); if (!result && this.hasPolicyEnabled) { - throw new RejectedByPolicyError( + throw createRejectedByPolicyError( this.model, RejectedByPolicyReason.CANNOT_READ_BACK, `result is not allowed to be read back`, diff --git a/packages/orm/src/client/crud/operations/delete.ts b/packages/orm/src/client/crud/operations/delete.ts index e6fb3c2a..af9942a9 100644 --- a/packages/orm/src/client/crud/operations/delete.ts +++ b/packages/orm/src/client/crud/operations/delete.ts @@ -1,7 +1,7 @@ import { match } from 'ts-pattern'; import type { SchemaDef } from '../../../schema'; import type { DeleteArgs, DeleteManyArgs } from '../../crud-types'; -import { NotFoundError, RejectedByPolicyError, RejectedByPolicyReason } from '../../errors'; +import { createNotFoundError, createRejectedByPolicyError, RejectedByPolicyReason } from '../../errors'; import { BaseOperationHandler } from './base'; export class DeleteOperationHandler extends BaseOperationHandler { @@ -34,13 +34,13 @@ export class DeleteOperationHandler extends BaseOperat } const deleteResult = await this.delete(tx, this.model, args.where, undefined, undefined, selectedFields); if (deleteResult.rows.length === 0) { - throw new NotFoundError(this.model); + throw createNotFoundError(this.model); } return needReadBack ? preDeleteRead : deleteResult.rows[0]; }); if (!result && this.hasPolicyEnabled) { - throw new RejectedByPolicyError( + throw createRejectedByPolicyError( this.model, RejectedByPolicyReason.CANNOT_READ_BACK, 'result is not allowed to be read back', diff --git a/packages/orm/src/client/crud/operations/update.ts b/packages/orm/src/client/crud/operations/update.ts index 5d8d7b19..9c81d169 100644 --- a/packages/orm/src/client/crud/operations/update.ts +++ b/packages/orm/src/client/crud/operations/update.ts @@ -1,7 +1,7 @@ import { match } from 'ts-pattern'; import type { GetModels, SchemaDef } from '../../../schema'; import type { UpdateArgs, UpdateManyAndReturnArgs, UpdateManyArgs, UpsertArgs, WhereInput } from '../../crud-types'; -import { RejectedByPolicyError, RejectedByPolicyReason } from '../../errors'; +import { createRejectedByPolicyError, RejectedByPolicyReason } from '../../errors'; import { getIdValues } from '../../query-utils'; import { BaseOperationHandler } from './base'; @@ -61,7 +61,7 @@ export class UpdateOperationHandler extends BaseOperat // update succeeded but result cannot be read back if (this.hasPolicyEnabled) { // if access policy is enabled, we assume it's due to read violation (not guaranteed though) - throw new RejectedByPolicyError( + throw createRejectedByPolicyError( this.model, RejectedByPolicyReason.CANNOT_READ_BACK, 'result is not allowed to be read back', @@ -120,7 +120,7 @@ export class UpdateOperationHandler extends BaseOperat if (readBackResult.length < updateResult.length && this.hasPolicyEnabled) { // some of the updated entities cannot be read back - throw new RejectedByPolicyError( + throw createRejectedByPolicyError( this.model, RejectedByPolicyReason.CANNOT_READ_BACK, 'result is not allowed to be read back', @@ -168,7 +168,7 @@ export class UpdateOperationHandler extends BaseOperat }); if (!result && this.hasPolicyEnabled) { - throw new RejectedByPolicyError( + throw createRejectedByPolicyError( this.model, RejectedByPolicyReason.CANNOT_READ_BACK, 'result is not allowed to be read back', diff --git a/packages/orm/src/client/crud/validator/index.ts b/packages/orm/src/client/crud/validator/index.ts index b4c012ed..7d97c942 100644 --- a/packages/orm/src/client/crud/validator/index.ts +++ b/packages/orm/src/client/crud/validator/index.ts @@ -31,7 +31,6 @@ import { type UpdateManyArgs, type UpsertArgs, } from '../../crud-types'; -import { InputValidationError, InternalError } from '../../errors'; import { fieldHasDefaultValue, getDiscriminatorField, @@ -48,6 +47,7 @@ import { addNumberValidation, addStringValidation, } from './utils'; +import { createInternalError, createInvalidInputError } from '../../errors'; const schemaCache = new WeakMap>(); @@ -230,10 +230,12 @@ export class InputValidator { } const { error, data } = schema.safeParse(args); if (error) { - throw new InputValidationError( - model, + throw createInvalidInputError( `Invalid ${operation} args for model "${model}": ${formatError(error)}`, - error, + model, + { + cause: error, + }, ); } return data as T; @@ -471,7 +473,7 @@ export class InputValidator { // requires at least one unique field (field set) is required const uniqueFields = getUniqueFields(this.schema, model); if (uniqueFields.length === 0) { - throw new InternalError(`Model "${model}" has no unique fields`); + throw createInternalError(`Model "${model}" has no unique fields`); } if (uniqueFields.length === 1) { diff --git a/packages/orm/src/client/crud/validator/utils.ts b/packages/orm/src/client/crud/validator/utils.ts index 5024b07d..94f4baca 100644 --- a/packages/orm/src/client/crud/validator/utils.ts +++ b/packages/orm/src/client/crud/validator/utils.ts @@ -13,7 +13,7 @@ import { match, P } from 'ts-pattern'; import { z } from 'zod'; import { ZodIssueCode } from 'zod/v3'; import { ExpressionUtils } from '../../../schema'; -import { QueryError } from '../../errors'; +import { createNotSupportedError } from '../../errors'; function getArgValue(expr: Expression | undefined): T | undefined { if (!expr || !ExpressionUtils.isLiteral(expr)) { @@ -452,7 +452,7 @@ function evalCall(data: any, expr: CallExpression) { return fieldArg.length === 0; }) .otherwise(() => { - throw new QueryError(`Unknown function "${expr.function}"`); + throw createNotSupportedError(`Unsupported function "${expr.function}"`); }) ); } diff --git a/packages/orm/src/client/errors.ts b/packages/orm/src/client/errors.ts index 89ef02f2..9908a6b2 100644 --- a/packages/orm/src/client/errors.ts +++ b/packages/orm/src/client/errors.ts @@ -1,45 +1,43 @@ -/** - * Base for all ZenStack runtime errors. - */ -export class ZenStackError extends Error {} +import { getDbErrorCode } from './executor/error-processor'; /** - * Error thrown when input validation fails. + * Reason code for ORM errors. */ -export class InputValidationError extends ZenStackError { - constructor( - public readonly model: string, - message: string, - cause?: unknown, - ) { - super(message, { cause }); - } -} +export enum ORMErrorReason { + /** + * ORM client configuration error. + */ + CONFIG_ERROR = 'config-error', -/** - * Error thrown when a query fails. - */ -export class QueryError extends ZenStackError { - constructor(message: string, cause?: unknown) { - super(message, { cause }); - } -} + /** + * Invalid input error. + */ + INVALID_INPUT = 'invalid-input', -/** - * Error thrown when an internal error occurs. - */ -export class InternalError extends ZenStackError {} + /** + * The specified record was not found. + */ + NOT_FOUND = 'not-found', -/** - * Error thrown when an entity is not found. - */ -export class NotFoundError extends ZenStackError { - constructor( - public readonly model: string, - details?: string, - ) { - super(`Entity not found for model "${model}"${details ? `: ${details}` : ''}`); - } + /** + * Operation is rejected by access policy. + */ + REJECTED_BY_POLICY = 'rejected-by-policy', + + /** + * Error was thrown by the underlying database driver. + */ + DB_QUERY_ERROR = 'db-query-error', + + /** + * The requested operation is not supported. + */ + NOT_SUPPORTED = 'not-supported', + + /** + * An internal error occurred. + */ + INTERNAL_ERROR = 'internal-error', } /** @@ -63,14 +61,91 @@ export enum RejectedByPolicyReason { } /** - * Error thrown when an operation is rejected by access policy. + * ZenStack ORM error. */ -export class RejectedByPolicyError extends ZenStackError { +export class ORMError extends Error { constructor( - public readonly model: string | undefined, - public readonly reason: RejectedByPolicyReason = RejectedByPolicyReason.NO_ACCESS, + public reason: ORMErrorReason, message?: string, + options?: ErrorOptions, ) { - super(message ?? `Operation rejected by policy${model ? ': ' + model : ''}`); + super(message, options); } + + /** + * The name of the model that the error pertains to. + */ + public model?: string; + + /** + * The error code given by the underlying database driver. + */ + public dbErrorCode?: unknown; + + /** + * The error message given by the underlying database driver. + */ + public dbErrorMessage?: string; + + /** + * The reason code for policy rejection. Only available when `reason` is `REJECTED_BY_POLICY`. + */ + public rejectedByPolicyReason?: RejectedByPolicyReason; + + /** + * The SQL query that was executed. Only available when `reason` is `DB_QUERY_ERROR`. + */ + public sql?: string; + + /** + * The parameters used in the SQL query. Only available when `reason` is `DB_QUERY_ERROR`. + */ + public sqlParams?: readonly unknown[]; +} + +export function createConfigError(message: string, options?: ErrorOptions) { + return new ORMError(ORMErrorReason.CONFIG_ERROR, message, options); +} + +export function createNotFoundError(model: string, message?: string, options?: ErrorOptions) { + const error = new ORMError(ORMErrorReason.NOT_FOUND, message ?? 'Record not found', options); + error.model = model; + return error; +} + +export function createInvalidInputError(message: string, model?: string, options?: ErrorOptions) { + const error = new ORMError(ORMErrorReason.INVALID_INPUT, message, options); + error.model = model; + return error; +} + +export function createDBQueryError(message: string, dbError: unknown, sql: string, parameters: readonly unknown[]) { + const error = new ORMError(ORMErrorReason.DB_QUERY_ERROR, message, { cause: dbError }); + error.dbErrorCode = getDbErrorCode(dbError); + error.dbErrorMessage = dbError instanceof Error ? dbError.message : undefined; + error.sql = sql; + error.sqlParams = parameters; + return error; +} + +export function createRejectedByPolicyError( + model: string, + reason: RejectedByPolicyReason, + message: string, + options?: ErrorOptions, +) { + const error = new ORMError(ORMErrorReason.REJECTED_BY_POLICY, message, options); + error.model = model; + error.rejectedByPolicyReason = reason; + return error; +} + +export function createNotSupportedError(message: string, options?: ErrorOptions) { + return new ORMError(ORMErrorReason.NOT_SUPPORTED, message, options); +} + +export function createInternalError(message: string, model?: string, options?: ErrorOptions) { + const error = new ORMError(ORMErrorReason.INTERNAL_ERROR, message, options); + error.model = model; + return error; } diff --git a/packages/orm/src/client/executor/error-processor.ts b/packages/orm/src/client/executor/error-processor.ts new file mode 100644 index 00000000..e00b5b5c --- /dev/null +++ b/packages/orm/src/client/executor/error-processor.ts @@ -0,0 +1,12 @@ +/** + * Extracts database error code from an error thrown by the database driver. + * + * @todo currently assumes the error has a code field + */ +export function getDbErrorCode(error: unknown): unknown | undefined { + if (error instanceof Error && 'code' in error) { + return error.code; + } else { + return undefined; + } +} diff --git a/packages/orm/src/client/executor/zenstack-query-executor.ts b/packages/orm/src/client/executor/zenstack-query-executor.ts index a2005bff..622c3238 100644 --- a/packages/orm/src/client/executor/zenstack-query-executor.ts +++ b/packages/orm/src/client/executor/zenstack-query-executor.ts @@ -25,7 +25,7 @@ import { match } from 'ts-pattern'; import type { GetModels, ModelDef, SchemaDef, TypeDefDef } from '../../schema'; import { type ClientImpl } from '../client-impl'; import { TransactionIsolationLevel, type ClientContract } from '../contract'; -import { InternalError, QueryError, ZenStackError } from '../errors'; +import { createDBQueryError, createInternalError, ORMError } from '../errors'; import type { AfterEntityMutationCallback, OnKyselyQueryCallback } from '../plugin'; import { stripAlias } from '../query-utils'; import { QueryNameMapper } from './name-mapper'; @@ -108,12 +108,16 @@ export class ZenStackQueryExecutor extends DefaultQuer if (startedTx) { await this.driver.rollbackTransaction(connection); } - if (err instanceof ZenStackError) { + if (err instanceof ORMError) { throw err; } else { // wrap error - const message = `Failed to execute query: ${err}, sql: ${compiledQuery?.sql}`; - throw new QueryError(message, err); + throw createDBQueryError( + 'Failed to execute query', + err, + compiledQuery.sql, + compiledQuery.parameters, + ); } } }); @@ -361,7 +365,7 @@ export class ZenStackQueryExecutor extends DefaultQuer return tableNode.table.identifier.name; }) .otherwise((node) => { - throw new InternalError(`Invalid query node: ${node}`); + throw createInternalError(`Invalid query node: ${node}`); }) as GetModels; } diff --git a/packages/orm/src/client/index.ts b/packages/orm/src/client/index.ts index 225aeba5..6a320300 100644 --- a/packages/orm/src/client/index.ts +++ b/packages/orm/src/client/index.ts @@ -3,7 +3,7 @@ export * from './contract'; export type * from './crud-types'; export { getCrudDialect } from './crud/dialects'; export { BaseCrudDialect } from './crud/dialects/base-dialect'; -export * from './errors'; +export { ORMError, ORMErrorReason, RejectedByPolicyReason } from './errors'; export * from './options'; export * from './plugin'; export type { ZenStackPromise } from './promise'; diff --git a/packages/orm/src/client/query-utils.ts b/packages/orm/src/client/query-utils.ts index 6bd51435..9797584c 100644 --- a/packages/orm/src/client/query-utils.ts +++ b/packages/orm/src/client/query-utils.ts @@ -13,7 +13,7 @@ import { ExpressionUtils, type FieldDef, type GetModels, type ModelDef, type Sch import { extractFields } from '../utils/object-utils'; import type { AGGREGATE_OPERATORS } from './constants'; import type { OrderBy } from './crud-types'; -import { InternalError, QueryError } from './errors'; +import { createInternalError } from './errors'; export function hasModel(schema: SchemaDef, model: string) { return Object.keys(schema.models) @@ -32,7 +32,7 @@ export function getTypeDef(schema: SchemaDef, type: string) { export function requireModel(schema: SchemaDef, model: string) { const modelDef = getModel(schema, model); if (!modelDef) { - throw new QueryError(`Model "${model}" not found in schema`); + throw createInternalError(`Model "${model}" not found in schema`, model); } return modelDef; } @@ -46,7 +46,7 @@ export function requireField(schema: SchemaDef, modelOrType: string, field: stri const modelDef = getModel(schema, modelOrType); if (modelDef) { if (!modelDef.fields[field]) { - throw new QueryError(`Field "${field}" not found in model "${modelOrType}"`); + throw createInternalError(`Field "${field}" not found in model "${modelOrType}"`, modelOrType); } else { return modelDef.fields[field]; } @@ -54,12 +54,12 @@ export function requireField(schema: SchemaDef, modelOrType: string, field: stri const typeDef = getTypeDef(schema, modelOrType); if (typeDef) { if (!typeDef.fields[field]) { - throw new QueryError(`Field "${field}" not found in type "${modelOrType}"`); + throw createInternalError(`Field "${field}" not found in type "${modelOrType}"`, modelOrType); } else { return typeDef.fields[field]; } } - throw new QueryError(`Model or type "${modelOrType}" not found in schema`); + throw createInternalError(`Model or type "${modelOrType}" not found in schema`, modelOrType); } export function getIdFields(schema: SchemaDef, model: GetModels) { @@ -71,7 +71,7 @@ export function requireIdFields(schema: SchemaDef, model: string) { const modelDef = requireModel(schema, model); const result = modelDef?.idFields; if (!result) { - throw new InternalError(`Model "${model}" does not have ID field(s)`); + throw createInternalError(`Model "${model}" does not have ID field(s)`, model); } return result; } @@ -80,12 +80,12 @@ export function getRelationForeignKeyFieldPairs(schema: SchemaDef, model: string const fieldDef = requireField(schema, model, relationField); if (!fieldDef?.relation) { - throw new InternalError(`Field "${relationField}" is not a relation`); + throw createInternalError(`Field "${relationField}" is not a relation`, model); } if (fieldDef.relation.fields) { if (!fieldDef.relation.references) { - throw new InternalError(`Relation references not defined for field "${relationField}"`); + throw createInternalError(`Relation references not defined for field "${relationField}"`, model); } // this model owns the fk return { @@ -97,19 +97,19 @@ export function getRelationForeignKeyFieldPairs(schema: SchemaDef, model: string }; } else { if (!fieldDef.relation.opposite) { - throw new InternalError(`Opposite relation not defined for field "${relationField}"`); + throw createInternalError(`Opposite relation not defined for field "${relationField}"`, model); } const oppositeField = requireField(schema, fieldDef.type, fieldDef.relation.opposite); if (!oppositeField.relation) { - throw new InternalError(`Field "${fieldDef.relation.opposite}" is not a relation`); + throw createInternalError(`Field "${fieldDef.relation.opposite}" is not a relation`, model); } if (!oppositeField.relation.fields) { - throw new InternalError(`Relation fields not defined for field "${relationField}"`); + throw createInternalError(`Relation fields not defined for field "${relationField}"`, model); } if (!oppositeField.relation.references) { - throw new InternalError(`Relation references not defined for field "${relationField}"`); + throw createInternalError(`Relation references not defined for field "${relationField}"`, model); } // the opposite model owns the fk @@ -153,7 +153,7 @@ export function getUniqueFields(schema: SchemaDef, model: string) { > = []; for (const [key, value] of Object.entries(modelDef.uniqueFields)) { if (typeof value !== 'object') { - throw new InternalError(`Invalid unique field definition for "${key}"`); + throw createInternalError(`Invalid unique field definition for "${key}"`, model); } if (typeof value.type === 'string') { @@ -173,7 +173,7 @@ export function getUniqueFields(schema: SchemaDef, model: string) { export function getIdValues(schema: SchemaDef, model: string, data: any): Record { const idFields = getIdFields(schema, model); if (!idFields) { - throw new InternalError(`ID fields not defined for model "${model}"`); + throw createInternalError(`ID fields not defined for model "${model}"`, model); } return idFields.reduce((acc, field) => ({ ...acc, [field]: data[field] }), {}); } @@ -328,7 +328,7 @@ export function getDiscriminatorField(schema: SchemaDef, model: string) { } const discriminator = delegateAttr.args?.find((arg) => arg.name === 'discriminator'); if (!discriminator || !ExpressionUtils.isField(discriminator.value)) { - throw new InternalError(`Discriminator field not defined for model "${model}"`); + throw createInternalError(`Discriminator field not defined for model "${model}"`, model); } return discriminator.value.field; } diff --git a/packages/plugins/policy/src/expression-transformer.ts b/packages/plugins/policy/src/expression-transformer.ts index ad448805..a86f793f 100644 --- a/packages/plugins/policy/src/expression-transformer.ts +++ b/packages/plugins/policy/src/expression-transformer.ts @@ -1,13 +1,5 @@ import { invariant } from '@zenstackhq/common-helpers'; -import { - getCrudDialect, - InternalError, - QueryError, - QueryUtils, - type BaseCrudDialect, - type ClientContract, - type CRUD_EXT, -} from '@zenstackhq/orm'; +import { getCrudDialect, QueryUtils, type BaseCrudDialect, type ClientContract, type CRUD_EXT } from '@zenstackhq/orm'; import type { BinaryExpression, BinaryOperator, @@ -48,7 +40,15 @@ import { } from 'kysely'; import { match } from 'ts-pattern'; import { ExpressionEvaluator } from './expression-evaluator'; -import { conjunction, disjunction, falseNode, isBeforeInvocation, logicalNot, trueNode } from './utils'; +import { + conjunction, + createUnsupportedError, + disjunction, + falseNode, + isBeforeInvocation, + logicalNot, + trueNode, +} from './utils'; export type ExpressionTransformerContext = { model: GetModels; @@ -92,7 +92,7 @@ export class ExpressionTransformer { get authType() { if (!this.schema.authType) { - throw new InternalError('Schema does not have an "authType" specified'); + invariant(false, 'Schema does not have an "authType" specified'); } return this.schema.authType!; } @@ -298,7 +298,7 @@ export class ExpressionTransformer { private transformAuthBinary(expr: BinaryExpression, context: ExpressionTransformerContext) { if (expr.op !== '==' && expr.op !== '!=') { - throw new QueryError( + throw createUnsupportedError( `Unsupported operator for \`auth()\` in policy of model "${context.model}": ${expr.op}`, ); } @@ -318,7 +318,7 @@ export class ExpressionTransformer { } else { const authModel = QueryUtils.getModel(this.schema, this.authType); if (!authModel) { - throw new QueryError( + throw createUnsupportedError( `Unsupported use of \`auth()\` in policy of model "${context.model}", comparing with \`auth()\` is only possible when auth type is a model`, ); } @@ -387,7 +387,7 @@ export class ExpressionTransformer { private transformCall(expr: CallExpression, context: ExpressionTransformerContext) { const func = this.getFunctionImpl(expr.function); if (!func) { - throw new QueryError(`Function not implemented: ${expr.function}`); + throw createUnsupportedError(`Function not implemented: ${expr.function}`); } const eb = expressionBuilder(); return func( @@ -444,7 +444,7 @@ export class ExpressionTransformer { // if (Expression.isMember(arg)) { // } - throw new InternalError(`Unsupported argument expression: ${arg.kind}`); + throw createUnsupportedError(`Unsupported argument expression: ${arg.kind}`); } @expr('member') diff --git a/packages/plugins/policy/src/policy-handler.ts b/packages/plugins/policy/src/policy-handler.ts index a9473f31..36c2da21 100644 --- a/packages/plugins/policy/src/policy-handler.ts +++ b/packages/plugins/policy/src/policy-handler.ts @@ -1,15 +1,6 @@ import { invariant } from '@zenstackhq/common-helpers'; import type { BaseCrudDialect, ClientContract, ProceedKyselyQueryFunction } from '@zenstackhq/orm'; -import { - getCrudDialect, - InternalError, - QueryError, - QueryUtils, - RejectedByPolicyError, - RejectedByPolicyReason, - SchemaUtils, - type CRUD_EXT, -} from '@zenstackhq/orm'; +import { getCrudDialect, QueryUtils, RejectedByPolicyReason, SchemaUtils, type CRUD_EXT } from '@zenstackhq/orm'; import { ExpressionUtils, type BuiltinType, @@ -55,7 +46,17 @@ import { match } from 'ts-pattern'; import { ColumnCollector } from './column-collector'; import { ExpressionTransformer } from './expression-transformer'; import type { Policy, PolicyOperation } from './types'; -import { buildIsFalse, conjunction, disjunction, falseNode, getTableName, isBeforeInvocation, trueNode } from './utils'; +import { + buildIsFalse, + conjunction, + createRejectedByPolicyError, + createUnsupportedError, + disjunction, + falseNode, + getTableName, + isBeforeInvocation, + trueNode, +} from './utils'; export type CrudQueryNode = SelectQueryNode | InsertQueryNode | UpdateQueryNode | DeleteQueryNode; @@ -76,7 +77,7 @@ export class PolicyHandler extends OperationNodeTransf async handle(node: RootOperationNode, proceed: ProceedKyselyQueryFunction) { if (!this.isCrudQueryNode(node)) { // non-CRUD queries are not allowed - throw new RejectedByPolicyError( + throw createRejectedByPolicyError( undefined, RejectedByPolicyReason.OTHER, 'non-CRUD queries are not allowed', @@ -104,7 +105,7 @@ export class PolicyHandler extends OperationNodeTransf if (constCondition === true) { needCheckPreCreate = false; } else if (constCondition === false) { - throw new RejectedByPolicyError(mutationModel, RejectedByPolicyReason.NO_ACCESS); + throw createRejectedByPolicyError(mutationModel, RejectedByPolicyReason.NO_ACCESS); } } @@ -134,7 +135,9 @@ export class PolicyHandler extends OperationNodeTransf for (const postRow of result.rows) { const beforeRow = beforeUpdateInfo.rows.find((r) => idFields.every((f) => r[f] === postRow[f])); if (!beforeRow) { - throw new QueryError( + throw createRejectedByPolicyError( + mutationModel, + RejectedByPolicyReason.OTHER, 'Before-update and after-update rows do not match by id. If you have post-update policies on a model, updating id fields is not supported.', ); } @@ -194,7 +197,7 @@ export class PolicyHandler extends OperationNodeTransf const postUpdateResult = await proceed(postUpdateQuery.toOperationNode()); if (!postUpdateResult.rows[0]?.$condition) { - throw new RejectedByPolicyError( + throw createRejectedByPolicyError( mutationModel, RejectedByPolicyReason.NO_ACCESS, 'some or all updated rows failed to pass post-update policy check', @@ -210,7 +213,7 @@ export class PolicyHandler extends OperationNodeTransf } else { const readBackResult = await this.processReadBack(node, result, proceed); if (readBackResult.rows.length !== result.rows.length) { - throw new RejectedByPolicyError( + throw createRejectedByPolicyError( mutationModel, RejectedByPolicyReason.CANNOT_READ_BACK, 'result is not allowed to be read back', @@ -543,14 +546,14 @@ export class PolicyHandler extends OperationNodeTransf const result = await proceed(queryNode); if (!result.rows[0]?.$conditionA) { - throw new RejectedByPolicyError( + throw createRejectedByPolicyError( m2m.firstModel as GetModels, RejectedByPolicyReason.CANNOT_READ_BACK, `many-to-many relation participant model "${m2m.firstModel}" not updatable`, ); } if (!result.rows[0]?.$conditionB) { - throw new RejectedByPolicyError( + throw createRejectedByPolicyError( m2m.secondModel as GetModels, RejectedByPolicyReason.NO_ACCESS, `many-to-many relation participant model "${m2m.secondModel}" not updatable`, @@ -621,7 +624,7 @@ export class PolicyHandler extends OperationNodeTransf const result = await proceed(preCreateCheck); if (!result.rows[0]?.$condition) { - throw new RejectedByPolicyError(model, RejectedByPolicyReason.NO_ACCESS); + throw createRejectedByPolicyError(model, RejectedByPolicyReason.NO_ACCESS); } } @@ -636,7 +639,7 @@ export class PolicyHandler extends OperationNodeTransf } else if (PrimitiveValueListNode.is(node)) { return [this.unwrapCreateValueRow(node.values, model, fields, isManyToManyJoinTable)]; } else { - throw new InternalError(`Unexpected node kind: ${node.kind} for unwrapping create values`); + invariant(false, `Unexpected node kind: ${node.kind} for unwrapping create values`); } } @@ -762,21 +765,21 @@ export class PolicyHandler extends OperationNodeTransf })) .when(UpdateQueryNode.is, (node) => { if (!node.table) { - throw new QueryError('Update query must have a table'); + invariant(false, 'Update query must have a table'); } const r = this.extractTableName(node.table); return r ? { mutationModel: r.model, alias: r.alias } : undefined; }) .when(DeleteQueryNode.is, (node) => { if (node.from.froms.length !== 1) { - throw new QueryError('Only one from table is supported for delete'); + throw createUnsupportedError('Only one from table is supported for delete'); } const r = this.extractTableName(node.from.froms[0]!); return r ? { mutationModel: r.model, alias: r.alias } : undefined; }) .exhaustive(); if (!r) { - throw new InternalError(`Unable to get table name for query node: ${node}`); + invariant(false, `Unable to get table name for query node: ${node}`); } return r; } diff --git a/packages/plugins/policy/src/utils.ts b/packages/plugins/policy/src/utils.ts index 321a2191..e15e0ccd 100644 --- a/packages/plugins/policy/src/utils.ts +++ b/packages/plugins/policy/src/utils.ts @@ -1,4 +1,4 @@ -import type { BaseCrudDialect } from '@zenstackhq/orm'; +import { ORMError, ORMErrorReason, RejectedByPolicyReason, type BaseCrudDialect } from '@zenstackhq/orm'; import { ExpressionUtils, type Expression, type SchemaDef } from '@zenstackhq/orm/schema'; import type { OperationNode } from 'kysely'; import { @@ -158,3 +158,18 @@ export function getTableName(node: OperationNode | undefined) { export function isBeforeInvocation(expr: Expression) { return ExpressionUtils.isCall(expr) && expr.function === 'before'; } + +export function createRejectedByPolicyError( + model: string | undefined, + reason: RejectedByPolicyReason, + message?: string, +) { + const err = new ORMError(ORMErrorReason.REJECTED_BY_POLICY, message ?? 'operation is rejected by access policies'); + err.rejectedByPolicyReason = reason; + err.model = model; + return err; +} + +export function createUnsupportedError(message: string) { + return new ORMError(ORMErrorReason.NOT_SUPPORTED, message); +} diff --git a/packages/server/src/api/rest/index.ts b/packages/server/src/api/rest/index.ts index bd65470c..536a45a5 100644 --- a/packages/server/src/api/rest/index.ts +++ b/packages/server/src/api/rest/index.ts @@ -1,16 +1,10 @@ import { clone, enumerate, lowerCaseFirst, paramCase } from '@zenstackhq/common-helpers'; -import { - InputValidationError, - NotFoundError, - QueryError, - RejectedByPolicyError, - ZenStackError, - type ClientContract, -} from '@zenstackhq/orm'; +import { ORMError, ORMErrorReason, type ClientContract } from '@zenstackhq/orm'; import type { FieldDef, ModelDef, SchemaDef } from '@zenstackhq/orm/schema'; import { Decimal } from 'decimal.js'; import SuperJSON from 'superjson'; import tsjapi, { type Linker, type Paginator, type Relator, type Serializer, type SerializerOptions } from 'ts-japi'; +import { match } from 'ts-pattern'; import UrlPattern from 'url-pattern'; import z from 'zod'; import type { ApiHandler, LogConfig, RequestContext, Response } from '../../types'; @@ -467,8 +461,8 @@ export class RestApiHandler implements ApiHandler implements ApiHandler { + return this.makeError('validationError', err.message, 422); + }) + .with(ORMErrorReason.REJECTED_BY_POLICY, () => { + return this.makeError('forbidden', err.message, 403, { reason: err.rejectedByPolicyReason }); + }) + .with(ORMErrorReason.NOT_FOUND, () => { + return this.makeError('notFound', err.message, 404); + }) + .with(ORMErrorReason.DB_QUERY_ERROR, () => { + return this.makeError('queryError', err.message, 400, { + dbErrorCode: err.dbErrorCode, + }); + }) + .otherwise(() => { + return this.makeError('unknownError', err.message); + }); } - private makeError(code: keyof typeof this.errors, detail?: string, status?: number, reason?: string) { + private makeError( + code: keyof typeof this.errors, + detail?: string, + status?: number, + otherFields: Record = {}, + ) { status = status ?? this.errors[code]?.status ?? 500; const error: any = { status, @@ -2057,9 +2053,7 @@ export class RestApiHandler implements ApiHandler implements ApiHandler implements ApiHandler { + status = 404; + error.model = err.model; + }) + .with(ORMErrorReason.INVALID_INPUT, () => { + status = 422; + error.rejectedByValidation = true; + error.model = err.model; + }) + .with(ORMErrorReason.REJECTED_BY_POLICY, () => { + status = 403; + error.rejectedByPolicy = true; + error.rejectReason = err.rejectedByPolicyReason; + error.model = err.model; + }) + .with(ORMErrorReason.DB_QUERY_ERROR, () => { + status = 400; + error.dbErrorCode = err.dbErrorCode; + }) + .otherwise(() => {}); const resp = { status, body: { error } }; log(this.options.log, 'debug', () => `sending error response: ${safeJSONStringify(resp)}`); diff --git a/packages/server/test/api/rest.test.ts b/packages/server/test/api/rest.test.ts index b309b5ae..b40ff604 100644 --- a/packages/server/test/api/rest.test.ts +++ b/packages/server/test/api/rest.test.ts @@ -2544,7 +2544,6 @@ describe('REST server tests', () => { expect(r.status).toBe(422); expect(r.body.errors[0].code).toBe('validation-error'); expect(r.body.errors[0].detail).toContain('Invalid email'); - expect(r.body.errors[0].reason).toContain('Invalid email'); }); }); }); diff --git a/packages/testtools/src/vitest-ext.ts b/packages/testtools/src/vitest-ext.ts index ab01d47c..64d5684f 100644 --- a/packages/testtools/src/vitest-ext.ts +++ b/packages/testtools/src/vitest-ext.ts @@ -1,19 +1,19 @@ -import { InputValidationError, NotFoundError, RejectedByPolicyError } from '@zenstackhq/orm'; +import { ORMError, ORMErrorReason } from '@zenstackhq/orm'; import { expect } from 'vitest'; function isPromise(value: any) { return typeof value.then === 'function' && typeof value.catch === 'function'; } -function expectError(err: any, errorType: any) { - if (err instanceof errorType) { +function expectErrorReason(err: any, errorReason: ORMErrorReason) { + if (err instanceof ORMError && err.reason === errorReason) { return { message: () => '', pass: true, }; } else { return { - message: () => `expected ${errorType}, got ${err}`, + message: () => `expected ORMError of reason ${errorReason}, got ${err}`, pass: false, }; } @@ -80,7 +80,7 @@ expect.extend({ try { await received; } catch (err) { - return expectError(err, NotFoundError); + return expectErrorReason(err, ORMErrorReason.NOT_FOUND); } return { message: () => `expected NotFoundError, got no error`, @@ -95,13 +95,13 @@ expect.extend({ try { await received; } catch (err) { - if (expectedMessages && err instanceof RejectedByPolicyError) { + if (expectedMessages && err instanceof ORMError && err.reason === ORMErrorReason.REJECTED_BY_POLICY) { const r = expectErrorMessages(expectedMessages, err.message || ''); if (r) { return r; } } - return expectError(err, RejectedByPolicyError); + return expectErrorReason(err, ORMErrorReason.REJECTED_BY_POLICY); } return { message: () => `expected PolicyError, got no error`, @@ -116,13 +116,13 @@ expect.extend({ try { await received; } catch (err) { - if (expectedMessages && err instanceof InputValidationError) { + if (expectedMessages && err instanceof ORMError && err.reason === ORMErrorReason.INVALID_INPUT) { const r = expectErrorMessages(expectedMessages, err.message || ''); if (r) { return r; } } - return expectError(err, InputValidationError); + return expectErrorReason(err, ORMErrorReason.INVALID_INPUT); } return { message: () => `expected InputValidationError, got no error`, diff --git a/tests/e2e/orm/client-api/delegate.test.ts b/tests/e2e/orm/client-api/delegate.test.ts index 704f134d..1497f91b 100644 --- a/tests/e2e/orm/client-api/delegate.test.ts +++ b/tests/e2e/orm/client-api/delegate.test.ts @@ -182,7 +182,7 @@ describe('Delegate model tests ', () => { rating: 3, }, }), - ).rejects.toThrow('constraint'); + ).rejects.toSatisfy((e) => e.cause.message.toLowerCase().includes('constraint')); await expect(client.ratedVideo.findMany()).toResolveWithLength(1); await expect(client.video.findMany()).toResolveWithLength(1); diff --git a/tests/e2e/orm/client-api/error-handling.test.ts b/tests/e2e/orm/client-api/error-handling.test.ts new file mode 100644 index 00000000..1227e525 --- /dev/null +++ b/tests/e2e/orm/client-api/error-handling.test.ts @@ -0,0 +1,52 @@ +import { ORMError, ORMErrorReason, RejectedByPolicyReason } from '@zenstackhq/orm'; +import { createPolicyTestClient, createTestClient } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; + +describe('Error handling tests', () => { + const schema = ` +model User { + id String @id @default(cuid()) + name String? + email String @unique @email +} +`; + + it('throws invalid input errors', async () => { + const db: any = await createTestClient(schema); + await expect(db.user.create({ data: { name: 'user' } })).toBeRejectedByValidation(); + await expect(db.user.create({ data: { name: 'user', email: 'foo' } })).toBeRejectedByValidation([ + 'Invalid email', + ]); + }); + + it('throws not found errors', async () => { + const db: any = await createTestClient(schema); + await expect(db.user.findUniqueOrThrow({ where: { id: 'non-existent-id' } })).toBeRejectedNotFound(); + }); + + it('throws rejected by policy errors', async () => { + const db: any = await createPolicyTestClient(schema); + await expect(db.user.create({ data: { name: 'user', email: 'user@example.com' } })).rejects.toSatisfy( + (e) => + e instanceof ORMError && + e.reason === ORMErrorReason.REJECTED_BY_POLICY && + e.rejectedByPolicyReason === RejectedByPolicyReason.NO_ACCESS, + ); + }); + + it('throws db query errors', async () => { + const db: any = await createTestClient(schema); + await db.user.create({ data: { email: 'user1@example.com' } }); + + const provider = db.$schema.provider.type; + const expectedCode = provider === 'sqlite' ? 'SQLITE_CONSTRAINT_UNIQUE' : '23505'; + + await expect(db.user.create({ data: { email: 'user1@example.com' } })).rejects.toSatisfy( + (e) => + e instanceof ORMError && + e.reason === ORMErrorReason.DB_QUERY_ERROR && + e.dbErrorCode === expectedCode && + !!e.dbErrorMessage?.includes('constraint'), + ); + }); +}); diff --git a/tests/e2e/orm/client-api/find.test.ts b/tests/e2e/orm/client-api/find.test.ts index 7e2194d4..765492cc 100644 --- a/tests/e2e/orm/client-api/find.test.ts +++ b/tests/e2e/orm/client-api/find.test.ts @@ -1,8 +1,7 @@ -import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import type { ClientContract } from '@zenstackhq/orm'; -import { InputValidationError, NotFoundError } from '@zenstackhq/orm'; -import { schema } from '../schemas/basic'; import { createTestClient } from '@zenstackhq/testtools'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; +import { schema } from '../schemas/basic'; import { createPosts, createUser } from './utils'; describe('Client find tests ', () => { @@ -53,7 +52,7 @@ describe('Client find tests ', () => { await expect(client.user.findMany({ take: 4 })).resolves.toHaveLength(3); // findFirst's take must be 1 - await expect(client.user.findFirst({ take: 2 })).rejects.toThrow(InputValidationError); + await expect(client.user.findFirst({ take: 2 })).toBeRejectedByValidation(); await expect(client.user.findFirst({ take: 1 })).toResolveTruthy(); // skip @@ -389,7 +388,7 @@ describe('Client find tests ', () => { r = await client.user.findUnique({ where: { id: 'none' } }); expect(r).toBeNull(); - await expect(client.user.findUniqueOrThrow({ where: { id: 'none' } })).rejects.toThrow(NotFoundError); + await expect(client.user.findUniqueOrThrow({ where: { id: 'none' } })).toBeRejectedNotFound(); }); it('works with non-unique finds', async () => { @@ -403,7 +402,7 @@ describe('Client find tests ', () => { r = await client.user.findFirst({ where: { name: 'User2' } }); expect(r).toBeNull(); - await expect(client.user.findFirstOrThrow({ where: { name: 'User2' } })).rejects.toThrow(NotFoundError); + await expect(client.user.findFirstOrThrow({ where: { name: 'User2' } })).toBeRejectedNotFound(); }); it('works with boolean composition', async () => { diff --git a/tests/e2e/orm/client-api/mixin.test.ts b/tests/e2e/orm/client-api/mixin.test.ts index 1e6d0f41..e373b8fe 100644 --- a/tests/e2e/orm/client-api/mixin.test.ts +++ b/tests/e2e/orm/client-api/mixin.test.ts @@ -75,7 +75,7 @@ model Bar with CommonFields { description: 'Bar', }, }), - ).rejects.toThrow('constraint'); + ).rejects.toSatisfy((e) => e.cause.message.toLowerCase().includes('constraint')); }); it('supports multiple-level mixins', async () => { diff --git a/tests/e2e/orm/client-api/update.test.ts b/tests/e2e/orm/client-api/update.test.ts index 88001f36..c79396d7 100644 --- a/tests/e2e/orm/client-api/update.test.ts +++ b/tests/e2e/orm/client-api/update.test.ts @@ -1050,7 +1050,7 @@ describe('Client update tests', () => { }, }, }), - ).rejects.toThrow('constraint'); + ).rejects.toSatisfy((e) => e.cause.message.toLowerCase().includes('constraint')); // transaction fails as a whole await expect(client.comment.findUnique({ where: { id: '3' } })).resolves.toMatchObject({ content: 'Comment3', diff --git a/tests/e2e/orm/plugin-infra/on-kysely-query.test.ts b/tests/e2e/orm/plugin-infra/on-kysely-query.test.ts index c6216c4a..68602613 100644 --- a/tests/e2e/orm/plugin-infra/on-kysely-query.test.ts +++ b/tests/e2e/orm/plugin-infra/on-kysely-query.test.ts @@ -211,7 +211,7 @@ describe('On kysely query tests', () => { data: { email: 'u1@test.com', name: 'Marvin' }, }), ), - ).rejects.toThrow('test error'); + ).rejects.toSatisfy((e) => (e as any).cause.message === 'test error'); await expect(called1).toBe(true); await expect(called2).toBe(true); diff --git a/tests/e2e/orm/policy/crud/update.test.ts b/tests/e2e/orm/policy/crud/update.test.ts index 9a83577f..7f060c88 100644 --- a/tests/e2e/orm/policy/crud/update.test.ts +++ b/tests/e2e/orm/policy/crud/update.test.ts @@ -758,7 +758,7 @@ model Post { }, }, }), - ).rejects.toThrow('constraint'); + ).rejects.toSatisfy((e) => e.cause.message.toLowerCase().includes('constraint')); await db.$unuseAll().post.update({ where: { id: 1 }, data: { title: 'Bar Post' } }); // can update await expect( @@ -1124,7 +1124,7 @@ model Foo { // can't update, but create violates unique constraint await expect( db.foo.upsert({ where: { id: 1 }, create: { id: 1, x: 1 }, update: { x: 1 } }), - ).rejects.toThrow('constraint'); + ).rejects.toSatisfy((e) => e.cause.message.toLowerCase().includes('constraint')); await db.$unuseAll().foo.update({ where: { id: 1 }, data: { x: 2 } }); // can update now await expect( diff --git a/tests/e2e/orm/policy/migrated/auth.test.ts b/tests/e2e/orm/policy/migrated/auth.test.ts index fb8f30de..b3e49980 100644 --- a/tests/e2e/orm/policy/migrated/auth.test.ts +++ b/tests/e2e/orm/policy/migrated/auth.test.ts @@ -536,7 +536,9 @@ model Post { ); await expect(db.user.create({ data: { id: 'userId-1' } })).toResolveTruthy(); - await expect(db.post.create({ data: { title: 'title' } })).rejects.toThrow('constraint'); + await expect(db.post.create({ data: { title: 'title' } })).rejects.toSatisfy((e) => + e.cause.message.toLowerCase().includes('constraint'), + ); await expect(db.post.findMany({})).toResolveTruthy(); }); diff --git a/tests/e2e/orm/policy/migrated/deep-nested.test.ts b/tests/e2e/orm/policy/migrated/deep-nested.test.ts index 6bd38e1f..f8bcea93 100644 --- a/tests/e2e/orm/policy/migrated/deep-nested.test.ts +++ b/tests/e2e/orm/policy/migrated/deep-nested.test.ts @@ -482,7 +482,7 @@ describe('deep nested operations tests', () => { }, }, }), - ).rejects.toThrow('constraint'); + ).rejects.toSatisfy((e) => e.cause.message.toLowerCase().includes('constraint')); // createMany skip duplicate await db.m1.update({ diff --git a/tests/e2e/orm/policy/migrated/multi-field-unique.test.ts b/tests/e2e/orm/policy/migrated/multi-field-unique.test.ts index 2cc265bc..b64d33f9 100644 --- a/tests/e2e/orm/policy/migrated/multi-field-unique.test.ts +++ b/tests/e2e/orm/policy/migrated/multi-field-unique.test.ts @@ -1,6 +1,6 @@ -import { describe, expect, it } from 'vitest'; -import { QueryError } from '@zenstackhq/orm'; +import { ORMError, ORMErrorReason } from '@zenstackhq/orm'; import { createPolicyTestClient } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; describe('Policy tests multi-field unique', () => { it('toplevel crud test unnamed constraint', async () => { @@ -20,7 +20,9 @@ describe('Policy tests multi-field unique', () => { ); await expect(db.model.create({ data: { a: 'a1', b: 'b1', x: 1 } })).toResolveTruthy(); - await expect(db.model.create({ data: { a: 'a1', b: 'b1', x: 2 } })).rejects.toThrow(QueryError); + await expect(db.model.create({ data: { a: 'a1', b: 'b1', x: 2 } })).rejects.toSatisfy( + (e) => e instanceof ORMError && e.reason === ORMErrorReason.DB_QUERY_ERROR, + ); await expect(db.model.create({ data: { a: 'a2', b: 'b2', x: 0 } })).toBeRejectedByPolicy(); await expect(db.model.findUnique({ where: { a_b: { a: 'a1', b: 'b1' } } })).toResolveTruthy(); @@ -83,8 +85,8 @@ describe('Policy tests multi-field unique', () => { ); await expect(db.m1.create({ data: { id: '1', m2: { create: { a: 'a1', b: 'b1', x: 1 } } } })).toResolveTruthy(); - await expect(db.m1.create({ data: { id: '2', m2: { create: { a: 'a1', b: 'b1', x: 2 } } } })).rejects.toThrow( - QueryError, + await expect(db.m1.create({ data: { id: '2', m2: { create: { a: 'a1', b: 'b1', x: 2 } } } })).rejects.toSatisfy( + (e) => e instanceof ORMError && e.reason === ORMErrorReason.DB_QUERY_ERROR, ); await expect( db.m1.create({ data: { id: '3', m2: { create: { a: 'a1', b: 'b2', x: 0 } } } }), diff --git a/tests/e2e/orm/policy/migrated/todo-sample.test.ts b/tests/e2e/orm/policy/migrated/todo-sample.test.ts deleted file mode 100644 index 0a19065d..00000000 --- a/tests/e2e/orm/policy/migrated/todo-sample.test.ts +++ /dev/null @@ -1,502 +0,0 @@ -import { beforeEach, describe, expect, it } from 'vitest'; -import type { ClientContract } from '@zenstackhq/orm'; -import { schema, type SchemaType } from '../../schemas/todo/schema'; -import { createPolicyTestClient } from '@zenstackhq/testtools'; - -describe('Todo Policy Tests', () => { - let db: ClientContract; - - beforeEach(async () => { - db = await createPolicyTestClient(schema); - }); - - it('user', async () => { - const user1 = { - id: 'user1', - email: 'user1@zenstack.dev', - name: 'User 1', - }; - const user2 = { - id: 'user2', - email: 'user2@zenstack.dev', - name: 'User 2', - }; - - const anonDb = db; - const user1Db = db.$setAuth({ id: user1.id }); - const user2Db = db.$setAuth({ id: user2.id }); - - // create user1 - // create should succeed but result can be read back anonymously - await expect(anonDb.user.create({ data: user1 })).toBeRejectedByPolicy([ - 'result is not allowed to be read back', - ]); - await expect(user1Db.user.findUnique({ where: { id: user1.id } })).toResolveTruthy(); - await expect(user2Db.user.findUnique({ where: { id: user1.id } })).toResolveNull(); - - // create user2 - await expect(anonDb.user.create({ data: user2 })).toBeRejectedByPolicy(); - - // find with user1 should only get user1 - const r = await user1Db.user.findMany(); - expect(r).toHaveLength(1); - expect(r[0]).toEqual(expect.objectContaining(user1)); - - // get user2 as user1 - await expect(user1Db.user.findUnique({ where: { id: user2.id } })).toResolveNull(); - - // add both users into the same space - await expect( - user1Db.space.create({ - data: { - name: 'Space 1', - slug: 'space1', - owner: { connect: { id: user1.id } }, - members: { - create: [ - { - user: { connect: { id: user1.id } }, - role: 'ADMIN', - }, - { - user: { connect: { id: user2.id } }, - role: 'USER', - }, - ], - }, - }, - }), - ).toResolveTruthy(); - - // now both user1 and user2 should be visible - await expect(user1Db.user.findMany()).resolves.toHaveLength(2); - await expect(user2Db.user.findMany()).resolves.toHaveLength(2); - - // update user2 as user1 - await expect( - user2Db.user.update({ - where: { id: user1.id }, - data: { name: 'hello' }, - }), - ).toBeRejectedNotFound(); - - // update user1 as user1 - await expect( - user1Db.user.update({ - where: { id: user1.id }, - data: { name: 'hello' }, - }), - ).toResolveTruthy(); - - // delete user2 as user1 - await expect(user1Db.user.delete({ where: { id: user2.id } })).toBeRejectedNotFound(); - - // delete user1 as user1 - await expect(user1Db.user.delete({ where: { id: user1.id } })).toResolveTruthy(); - await expect(user1Db.user.findUnique({ where: { id: user1.id } })).toResolveNull(); - }); - - it('todo list', async () => { - await createSpaceAndUsers(db.$unuseAll()); - - const anonDb = db; - const emptyUIDDb = db.$setAuth({ id: '' }); - const user1Db = db.$setAuth({ id: user1.id }); - const user2Db = db.$setAuth({ id: user2.id }); - const user3Db = db.$setAuth({ id: user3.id }); - - await expect( - anonDb.list.create({ - data: { - id: 'list1', - title: 'List 1', - owner: { connect: { id: user1.id } }, - space: { connect: { id: space1.id } }, - }, - }), - ).toBeRejectedByPolicy(); - - await expect( - user1Db.list.create({ - data: { - id: 'list1', - title: 'List 1', - owner: { connect: { id: user1.id } }, - space: { connect: { id: space1.id } }, - }, - }), - ).toResolveTruthy(); - - await expect(user1Db.list.findMany()).resolves.toHaveLength(1); - await expect(anonDb.list.findMany()).resolves.toHaveLength(0); - await expect(emptyUIDDb.list.findMany()).resolves.toHaveLength(0); - await expect(anonDb.list.findUnique({ where: { id: 'list1' } })).toResolveNull(); - - // accessible to owner - await expect(user1Db.list.findUnique({ where: { id: 'list1' } })).resolves.toEqual( - expect.objectContaining({ id: 'list1', title: 'List 1' }), - ); - - // accessible to user in the space - await expect(user2Db.list.findUnique({ where: { id: 'list1' } })).toResolveTruthy(); - - // inaccessible to user not in the space - await expect(user3Db.list.findUnique({ where: { id: 'list1' } })).toResolveNull(); - - // make a private list - await user1Db.list.create({ - data: { - id: 'list2', - title: 'List 2', - private: true, - owner: { connect: { id: user1.id } }, - space: { connect: { id: space1.id } }, - }, - }); - - // accessible to owner - await expect(user1Db.list.findUnique({ where: { id: 'list2' } })).toResolveTruthy(); - - // inaccessible to other user in the space - await expect(user2Db.list.findUnique({ where: { id: 'list2' } })).toResolveNull(); - - // create a list which doesn't match credential should fail - await expect( - user1Db.list.create({ - data: { - id: 'list3', - title: 'List 3', - owner: { connect: { id: user2.id } }, - space: { connect: { id: space1.id } }, - }, - }), - ).toBeRejectedByPolicy(); - - // create a list which doesn't match credential's space should fail - await expect( - user1Db.list.create({ - data: { - id: 'list3', - title: 'List 3', - owner: { connect: { id: user1.id } }, - space: { connect: { id: space2.id } }, - }, - }), - ).toBeRejectedByPolicy(); - - // update list - await expect( - user1Db.list.update({ - where: { id: 'list1' }, - data: { - title: 'List 1 updated', - }, - }), - ).resolves.toEqual(expect.objectContaining({ title: 'List 1 updated' })); - - await expect( - user2Db.list.update({ - where: { id: 'list1' }, - data: { - title: 'List 1 updated', - }, - }), - ).toBeRejectedNotFound(); - - // delete list - await expect(user2Db.list.delete({ where: { id: 'list1' } })).toBeRejectedNotFound(); - await expect(user1Db.list.delete({ where: { id: 'list1' } })).toResolveTruthy(); - await expect(user1Db.list.findUnique({ where: { id: 'list1' } })).toResolveNull(); - }); - - it('todo', async () => { - await createSpaceAndUsers(db.$unuseAll()); - - const user1Db = db.$setAuth({ id: user1.id }); - const user2Db = db.$setAuth({ id: user2.id }); - - // create a public list - await user1Db.list.create({ - data: { - id: 'list1', - title: 'List 1', - owner: { connect: { id: user1.id } }, - space: { connect: { id: space1.id } }, - }, - }); - - // create - await expect( - user1Db.todo.create({ - data: { - id: 'todo1', - title: 'Todo 1', - owner: { connect: { id: user1.id } }, - list: { - connect: { id: 'list1' }, - }, - }, - }), - ).toResolveTruthy(); - - await expect( - user2Db.todo.create({ - data: { - id: 'todo2', - title: 'Todo 2', - owner: { connect: { id: user2.id } }, - list: { - connect: { id: 'list1' }, - }, - }, - }), - ).toResolveTruthy(); - - // read - await expect(user1Db.todo.findMany()).resolves.toHaveLength(2); - await expect(user2Db.todo.findMany()).resolves.toHaveLength(2); - - // update, user in the same space can freely update - await expect( - user1Db.todo.update({ - where: { id: 'todo1' }, - data: { - title: 'Todo 1 updated', - }, - }), - ).toResolveTruthy(); - await expect( - user1Db.todo.update({ - where: { id: 'todo2' }, - data: { - title: 'Todo 2 updated', - }, - }), - ).toResolveTruthy(); - - // create a private list - await user1Db.list.create({ - data: { - id: 'list2', - private: true, - title: 'List 2', - owner: { connect: { id: user1.id } }, - space: { connect: { id: space1.id } }, - }, - }); - - // create - await expect( - user1Db.todo.create({ - data: { - id: 'todo3', - title: 'Todo 3', - owner: { connect: { id: user1.id } }, - list: { - connect: { id: 'list2' }, - }, - }, - }), - ).toResolveTruthy(); - - // reject because list2 is private - await expect( - user2Db.todo.create({ - data: { - id: 'todo4', - title: 'Todo 4', - owner: { connect: { id: user2.id } }, - list: { - connect: { id: 'list2' }, - }, - }, - }), - ).toBeRejectedByPolicy(); - - // update, only owner can update todo in a private list - await expect( - user1Db.todo.update({ - where: { id: 'todo3' }, - data: { - title: 'Todo 3 updated', - }, - }), - ).toResolveTruthy(); - await expect( - user2Db.todo.update({ - where: { id: 'todo3' }, - data: { - title: 'Todo 3 updated', - }, - }), - ).toBeRejectedNotFound(); - }); - - it('relation query', async () => { - await createSpaceAndUsers(db.$unuseAll()); - - const user1Db = db.$setAuth({ id: user1.id }); - const user2Db = db.$setAuth({ id: user2.id }); - - await user1Db.list.create({ - data: { - id: 'list1', - title: 'List 1', - owner: { connect: { id: user1.id } }, - space: { connect: { id: space1.id } }, - }, - }); - - await user1Db.list.create({ - data: { - id: 'list2', - title: 'List 2', - private: true, - owner: { connect: { id: user1.id } }, - space: { connect: { id: space1.id } }, - }, - }); - - const r = await user1Db.space.findFirstOrThrow({ - where: { id: 'space1' }, - include: { lists: true }, - }); - expect(r.lists).toHaveLength(2); - - const r1 = await user2Db.space.findFirstOrThrow({ - where: { id: 'space1' }, - include: { lists: true }, - }); - expect(r1.lists).toHaveLength(1); - }); - - it('post-update checks', async () => { - await createSpaceAndUsers(db.$unuseAll()); - - const user1Db = db.$setAuth({ id: user1.id }); - - await user1Db.list.create({ - data: { - id: 'list1', - title: 'List 1', - owner: { connect: { id: user1.id } }, - space: { connect: { id: space1.id } }, - todos: { - create: { - id: 'todo1', - title: 'Todo 1', - owner: { connect: { id: user1.id } }, - }, - }, - }, - }); - - // change list's owner - await expect( - user1Db.list.update({ - where: { id: 'list1' }, - data: { - owner: { connect: { id: user2.id } }, - }, - }), - ).toBeRejectedByPolicy(); - - // change todo's owner - await expect( - user1Db.todo.update({ - where: { id: 'todo1' }, - data: { - owner: { connect: { id: user2.id } }, - }, - }), - ).toBeRejectedByPolicy(); - - // nested change todo's owner - await expect( - user1Db.list.update({ - where: { id: 'list1' }, - data: { - todos: { - update: { - where: { id: 'todo1' }, - data: { - owner: { connect: { id: user2.id } }, - }, - }, - }, - }, - }), - ).toBeRejectedByPolicy(); - }); -}); - -const user1 = { - id: 'user1', - email: 'user1@zenstack.dev', - name: 'User 1', -}; - -const user2 = { - id: 'user2', - email: 'user2@zenstack.dev', - name: 'User 2', -}; - -const user3 = { - id: 'user3', - email: 'user3@zenstack.dev', - name: 'User 3', -}; - -const space1 = { - id: 'space1', - name: 'Space 1', - slug: 'space1', -}; - -const space2 = { - id: 'space2', - name: 'Space 2', - slug: 'space2', -}; - -async function createSpaceAndUsers(db: ClientContract) { - // create users - await db.user.create({ data: user1 }); - await db.user.create({ data: user2 }); - await db.user.create({ data: user3 }); - - // add user1 and user2 into space1 - await db.space.create({ - data: { - ...space1, - members: { - create: [ - { - user: { connect: { id: user1.id } }, - role: 'ADMIN', - }, - { - user: { connect: { id: user2.id } }, - role: 'USER', - }, - ], - }, - }, - }); - - // add user3 to space2 - await db.space.create({ - data: { - ...space2, - members: { - create: [ - { - user: { connect: { id: user3.id } }, - role: 'ADMIN', - }, - ], - }, - }, - }); -}