From 9ebed93953c55e0c4161767ad4ec3d0ca48546ff Mon Sep 17 00:00:00 2001 From: sid293 Date: Wed, 18 Mar 2026 13:56:02 +0530 Subject: [PATCH] Refine Zod schema types for record keys, enhance client connection error handling with proper cleanup, and improve Zod compatibility utilities. --- src/client/index.ts | 13 ++++----- src/server/zod-compat.ts | 10 +++++-- test/client/index.test.ts | 27 +++++++++++++++++++ test/integration-tests/processCleanup.test.ts | 2 +- test/integration-tests/taskLifecycle.test.ts | 2 +- test/shared/protocol.test.ts | 10 +++---- 6 files changed, 49 insertions(+), 15 deletions(-) diff --git a/src/client/index.ts b/src/client/index.ts index 03a6b40b5..85ad859d5 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -481,13 +481,14 @@ export class Client< } override async connect(transport: Transport, options?: RequestOptions): Promise { - await super.connect(transport); - // When transport sessionId is already set this means we are trying to reconnect. - // In this case we don't need to initialize again. - if (transport.sessionId !== undefined) { - return; - } try { + await super.connect(transport); + // When transport sessionId is already set this means we are trying to reconnect. + // In this case we don't need to initialize again. + if (transport.sessionId !== undefined) { + return; + } + const result = await this.request( { method: 'initialize', diff --git a/src/server/zod-compat.ts b/src/server/zod-compat.ts index 9d25a5efc..0c9908aa5 100644 --- a/src/server/zod-compat.ts +++ b/src/server/zod-compat.ts @@ -56,8 +56,8 @@ export type ShapeOutput = { // --- Runtime detection --- export function isZ4Schema(s: AnySchema): s is z4.$ZodType { // Present on Zod 4 (Classic & Mini) schemas; absent on Zod 3 - const schema = s as unknown as ZodV4Internal; - return !!schema._zod; + const schema = s as unknown as ZodV4Internal | undefined; + return !!schema?._zod; } // --- Schema construction --- @@ -79,6 +79,9 @@ export function safeParse( schema: S, data: unknown ): { success: true; data: SchemaOutput } | { success: false; error: unknown } { + if ('safeParse' in schema && typeof schema.safeParse === 'function') { + return schema.safeParse(data); + } if (isZ4Schema(schema)) { // Mini exposes top-level safeParse const result = z4mini.safeParse(schema, data); @@ -93,6 +96,9 @@ export async function safeParseAsync( schema: S, data: unknown ): Promise<{ success: true; data: SchemaOutput } | { success: false; error: unknown }> { + if ('safeParseAsync' in schema && typeof schema.safeParseAsync === 'function') { + return await schema.safeParseAsync(data); + } if (isZ4Schema(schema)) { // Mini exposes top-level safeParseAsync const result = await z4mini.safeParseAsync(schema, data); diff --git a/test/client/index.test.ts b/test/client/index.test.ts index f5c6a348d..e9d51f9d4 100644 --- a/test/client/index.test.ts +++ b/test/client/index.test.ts @@ -343,6 +343,33 @@ test('should reject unsupported protocol version', async () => { expect(clientTransport.close).toHaveBeenCalled(); }); +/*** + * Test: Connection Initialization Failure Cleanup + */ +test('should close transport and clean up if connection initialization fails', async () => { + const clientTransport: Transport = { + start: vi.fn().mockRejectedValue(new Error('Transport start failed')), + close: vi.fn().mockResolvedValue(undefined), + send: vi.fn().mockResolvedValue(undefined) + }; + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: {} + } + ); + + await expect(client.connect(clientTransport)).rejects.toThrow('Transport start failed'); + + // The transport.close() method should be called to clean up resources + // because the client.connect() method caught the error and called this.close() + expect(clientTransport.close).toHaveBeenCalled(); +}); + /*** * Test: Connect New Client to Old Supported Server Version */ diff --git a/test/integration-tests/processCleanup.test.ts b/test/integration-tests/processCleanup.test.ts index 11940697b..bcd77c5bf 100644 --- a/test/integration-tests/processCleanup.test.ts +++ b/test/integration-tests/processCleanup.test.ts @@ -9,7 +9,7 @@ import { LoggingMessageNotificationSchema } from '../../src/types.js'; const FIXTURES_DIR = path.resolve(__dirname, '../../src/__fixtures__'); describe('Process cleanup', () => { - vi.setConfig({ testTimeout: 5000 }); // 5 second timeout + vi.setConfig({ testTimeout: 10000 }); // 10 second timeout it('server should exit cleanly after closing transport', async () => { const server = new Server( diff --git a/test/integration-tests/taskLifecycle.test.ts b/test/integration-tests/taskLifecycle.test.ts index 629a61b66..c11253662 100644 --- a/test/integration-tests/taskLifecycle.test.ts +++ b/test/integration-tests/taskLifecycle.test.ts @@ -1040,7 +1040,7 @@ describe('Task Lifecycle Integration Tests', () => { method: 'tasks/cancel', params: { taskId } }, - z.object({ _meta: z.record(z.unknown()).optional() }) + z.object({ _meta: z.record(z.string(), z.unknown()).optional() }) ); // Verify task is cancelled diff --git a/test/shared/protocol.test.ts b/test/shared/protocol.test.ts index 886dcbb21..8e2fdf4d7 100644 --- a/test/shared/protocol.test.ts +++ b/test/shared/protocol.test.ts @@ -2086,7 +2086,7 @@ describe('Request Cancellation vs Task Cancellation', () => { let wasAborted = false; const TestRequestSchema = z.object({ method: z.literal('test/longRunning'), - params: z.optional(z.record(z.unknown())) + params: z.optional(z.record(z.string(), z.unknown())) }); protocol.setRequestHandler(TestRequestSchema, async (_request, extra) => { // Simulate a long-running operation @@ -2301,7 +2301,7 @@ describe('Request Cancellation vs Task Cancellation', () => { let requestCompleted = false; const TestMethodSchema = z.object({ method: z.literal('test/method'), - params: z.optional(z.record(z.unknown())) + params: z.optional(z.record(z.string(), z.unknown())) }); protocol.setRequestHandler(TestMethodSchema, async () => { await new Promise(resolve => setTimeout(resolve, 50)); @@ -3681,7 +3681,7 @@ describe('Message Interception', () => { method: z.literal('test/taskRequest'), params: z .object({ - _meta: z.optional(z.record(z.unknown())) + _meta: z.optional(z.record(z.string(), z.unknown())) }) .passthrough() }); @@ -3728,7 +3728,7 @@ describe('Message Interception', () => { method: z.literal('test/taskRequestError'), params: z .object({ - _meta: z.optional(z.record(z.unknown())) + _meta: z.optional(z.record(z.string(), z.unknown())) }) .passthrough() }); @@ -3808,7 +3808,7 @@ describe('Message Interception', () => { // Set up a request handler const TestRequestSchema = z.object({ method: z.literal('test/normalRequest'), - params: z.optional(z.record(z.unknown())) + params: z.optional(z.record(z.string(), z.unknown())) }); protocol.setRequestHandler(TestRequestSchema, async () => {