Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -481,13 +481,14 @@ export class Client<
}

override async connect(transport: Transport, options?: RequestOptions): Promise<void> {
await super.connect(transport);
// When transport sessionId is already set this means we are trying to reconnect.
// In this case we don't need to initialize again.
if (transport.sessionId !== undefined) {
return;
}
try {
await super.connect(transport);
// When transport sessionId is already set this means we are trying to reconnect.
// In this case we don't need to initialize again.
if (transport.sessionId !== undefined) {
return;
}

const result = await this.request(
{
method: 'initialize',
Expand Down
10 changes: 8 additions & 2 deletions src/server/zod-compat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ export type ShapeOutput<Shape extends ZodRawShapeCompat> = {
// --- Runtime detection ---
export function isZ4Schema(s: AnySchema): s is z4.$ZodType {
// Present on Zod 4 (Classic & Mini) schemas; absent on Zod 3
const schema = s as unknown as ZodV4Internal;
return !!schema._zod;
const schema = s as unknown as ZodV4Internal | undefined;
return !!schema?._zod;
}

// --- Schema construction ---
Expand All @@ -79,6 +79,9 @@ export function safeParse<S extends AnySchema>(
schema: S,
data: unknown
): { success: true; data: SchemaOutput<S> } | { success: false; error: unknown } {
if ('safeParse' in schema && typeof schema.safeParse === 'function') {
return schema.safeParse(data);
}
if (isZ4Schema(schema)) {
// Mini exposes top-level safeParse
const result = z4mini.safeParse(schema, data);
Expand All @@ -93,6 +96,9 @@ export async function safeParseAsync<S extends AnySchema>(
schema: S,
data: unknown
): Promise<{ success: true; data: SchemaOutput<S> } | { success: false; error: unknown }> {
if ('safeParseAsync' in schema && typeof schema.safeParseAsync === 'function') {
return await schema.safeParseAsync(data);
}
if (isZ4Schema(schema)) {
// Mini exposes top-level safeParseAsync
const result = await z4mini.safeParseAsync(schema, data);
Expand Down
27 changes: 27 additions & 0 deletions test/client/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,33 @@ test('should reject unsupported protocol version', async () => {
expect(clientTransport.close).toHaveBeenCalled();
});

/***
* Test: Connection Initialization Failure Cleanup
*/
test('should close transport and clean up if connection initialization fails', async () => {
const clientTransport: Transport = {
start: vi.fn().mockRejectedValue(new Error('Transport start failed')),
close: vi.fn().mockResolvedValue(undefined),
send: vi.fn().mockResolvedValue(undefined)
};

const client = new Client(
{
name: 'test client',
version: '1.0'
},
{
capabilities: {}
}
);

await expect(client.connect(clientTransport)).rejects.toThrow('Transport start failed');

// The transport.close() method should be called to clean up resources
// because the client.connect() method caught the error and called this.close()
expect(clientTransport.close).toHaveBeenCalled();
});

/***
* Test: Connect New Client to Old Supported Server Version
*/
Expand Down
2 changes: 1 addition & 1 deletion test/integration-tests/processCleanup.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { LoggingMessageNotificationSchema } from '../../src/types.js';
const FIXTURES_DIR = path.resolve(__dirname, '../../src/__fixtures__');

describe('Process cleanup', () => {
vi.setConfig({ testTimeout: 5000 }); // 5 second timeout
vi.setConfig({ testTimeout: 10000 }); // 10 second timeout

it('server should exit cleanly after closing transport', async () => {
const server = new Server(
Expand Down
2 changes: 1 addition & 1 deletion test/integration-tests/taskLifecycle.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,7 @@ describe('Task Lifecycle Integration Tests', () => {
method: 'tasks/cancel',
params: { taskId }
},
z.object({ _meta: z.record(z.unknown()).optional() })
z.object({ _meta: z.record(z.string(), z.unknown()).optional() })
);

// Verify task is cancelled
Expand Down
10 changes: 5 additions & 5 deletions test/shared/protocol.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2086,7 +2086,7 @@ describe('Request Cancellation vs Task Cancellation', () => {
let wasAborted = false;
const TestRequestSchema = z.object({
method: z.literal('test/longRunning'),
params: z.optional(z.record(z.unknown()))
params: z.optional(z.record(z.string(), z.unknown()))
});
protocol.setRequestHandler(TestRequestSchema, async (_request, extra) => {
// Simulate a long-running operation
Expand Down Expand Up @@ -2301,7 +2301,7 @@ describe('Request Cancellation vs Task Cancellation', () => {
let requestCompleted = false;
const TestMethodSchema = z.object({
method: z.literal('test/method'),
params: z.optional(z.record(z.unknown()))
params: z.optional(z.record(z.string(), z.unknown()))
});
protocol.setRequestHandler(TestMethodSchema, async () => {
await new Promise(resolve => setTimeout(resolve, 50));
Expand Down Expand Up @@ -3681,7 +3681,7 @@ describe('Message Interception', () => {
method: z.literal('test/taskRequest'),
params: z
.object({
_meta: z.optional(z.record(z.unknown()))
_meta: z.optional(z.record(z.string(), z.unknown()))
})
.passthrough()
});
Expand Down Expand Up @@ -3728,7 +3728,7 @@ describe('Message Interception', () => {
method: z.literal('test/taskRequestError'),
params: z
.object({
_meta: z.optional(z.record(z.unknown()))
_meta: z.optional(z.record(z.string(), z.unknown()))
})
.passthrough()
});
Expand Down Expand Up @@ -3808,7 +3808,7 @@ describe('Message Interception', () => {
// Set up a request handler
const TestRequestSchema = z.object({
method: z.literal('test/normalRequest'),
params: z.optional(z.record(z.unknown()))
params: z.optional(z.record(z.string(), z.unknown()))
});

protocol.setRequestHandler(TestRequestSchema, async () => {
Expand Down
Loading