diff --git a/packages/effects/request/src/request-client/modules/sse.test.ts b/packages/effects/request/src/request-client/modules/sse.test.ts index 5d630a87..4e8c6a9d 100644 --- a/packages/effects/request/src/request-client/modules/sse.test.ts +++ b/packages/effects/request/src/request-client/modules/sse.test.ts @@ -89,7 +89,8 @@ describe('sSE', () => { expect(onMessage).toHaveBeenCalledTimes(2); expect(messages.join('')).toBe('hello world'); - expect(onEnd).toHaveBeenCalledWith('hello world'); + // onEnd 不再带参数 + expect(onEnd).toHaveBeenCalled(); }); it('should apply request interceptors', async () => { @@ -101,20 +102,30 @@ describe('sSE', () => { fulfilled: interceptor, }); - vi.stubGlobal('fetch', createFetchMock(['data'])); - // 创建 fetch mock,并挂到全局 const fetchMock = createFetchMock(['data']); vi.stubGlobal('fetch', fetchMock); + await sse.requestSSE('/sse', undefined, {}); expect(interceptor).toHaveBeenCalled(); expect(fetchMock).toHaveBeenCalledWith( - 'http://localhost//sse', + 'http://localhost/sse', expect.objectContaining({ - headers: expect.objectContaining({ 'x-test': 'intercepted' }), + headers: expect.any(Headers), }), ); + + const calls = fetchMock.mock?.calls; + expect(calls).toBeDefined(); + expect(calls?.length).toBeGreaterThan(0); + + const init = calls?.[0]?.[1] as RequestInit; + expect(init).toBeDefined(); + + const headers = init?.headers as Headers; + expect(headers?.get('x-test')).toBe('intercepted'); + expect(headers?.get('accept')).toBe('text/event-stream'); }); it('should throw error when no reader', async () => { diff --git a/packages/effects/request/src/request-client/modules/sse.ts b/packages/effects/request/src/request-client/modules/sse.ts index 0edf2d4a..09d13017 100644 --- a/packages/effects/request/src/request-client/modules/sse.ts +++ b/packages/effects/request/src/request-client/modules/sse.ts @@ -36,9 +36,10 @@ class SSE { requestOptions?: SseRequestOptions, ) { const baseUrl = this.client.getBaseUrl() || ''; - const hasUrlSplit = baseUrl.endsWith('/') && url.startsWith('/'); - const axiosConfig: InternalAxiosRequestConfig = { + let axiosConfig: InternalAxiosRequestConfig = { + url, + method: (requestOptions?.method as any) ?? 'GET', headers: {} as AxiosRequestHeaders, }; const requestInterceptors = this.client.instance.interceptors @@ -48,25 +49,45 @@ class SSE { requestInterceptors.handlers.length > 0 ) { for (const handler of requestInterceptors.handlers) { - if (handler.fulfilled) { - await handler.fulfilled(axiosConfig); + if (typeof handler?.fulfilled === 'function') { + const next = await handler.fulfilled(axiosConfig as any); + if (next) axiosConfig = next as InternalAxiosRequestConfig; } } } + const merged = new Headers(); + Object.entries( + (axiosConfig.headers ?? {}) as Record, + ).forEach(([k, v]) => merged.set(k, String(v))); + if (requestOptions?.headers) { + new Headers(requestOptions.headers).forEach((v, k) => merged.set(k, v)); + } + if (!merged.has('accept')) { + merged.set('accept', 'text/event-stream'); + } + + let bodyInit = requestOptions?.body ?? data; + const ct = (merged.get('content-type') || '').toLowerCase(); + if ( + bodyInit && + typeof bodyInit === 'object' && + !ArrayBuffer.isView(bodyInit as any) && + !(bodyInit instanceof ArrayBuffer) && + !(bodyInit instanceof Blob) && + !(bodyInit instanceof FormData) && + ct.includes('application/json') + ) { + bodyInit = JSON.stringify(bodyInit); + } const requestInit: RequestInit = { ...requestOptions, - body: data, - headers: { - ...(axiosConfig.headers as Record), - ...requestOptions?.headers, - }, + method: axiosConfig.method, + headers: merged, + body: bodyInit, }; - const response = await fetch( - `${baseUrl}${hasUrlSplit ? '' : '/'}${url}`, - requestInit, - ); + const response = await fetch(safeJoinUrl(baseUrl, url), requestInit); if (!response.ok) { throw new Error(`HTTP error! status: ${response.status}`); } @@ -78,19 +99,38 @@ class SSE { throw new Error('No reader'); } let isEnd = false; - let allMessage = ''; while (!isEnd) { const { done, value } = await reader.read(); if (done) { isEnd = true; - requestOptions?.onEnd?.(allMessage); + decoder.decode(new Uint8Array(0), { stream: false }); + requestOptions?.onEnd?.(); + reader.releaseLock?.(); break; } const content = decoder.decode(value, { stream: true }); requestOptions?.onMessage?.(content); - allMessage += content; } } } +function safeJoinUrl(baseUrl: string | undefined, url: string): string { + if (!baseUrl) { + return url; // 没有 baseUrl,直接返回 url + } + + // 如果 url 本身就是绝对地址,直接返回 + if (/^https?:\/\//i.test(url)) { + return url; + } + + // 如果 baseUrl 是完整 URL,就用 new URL + if (/^https?:\/\//i.test(baseUrl)) { + return new URL(url, baseUrl).toString(); + } + + // 否则,当作路径拼接 + return `${baseUrl.replace(/\/+$/, '')}/${url.replace(/^\/+/, '')}`; +} + export { SSE }; diff --git a/packages/effects/request/src/request-client/types.ts b/packages/effects/request/src/request-client/types.ts index aa1e7811..d40ee8a5 100644 --- a/packages/effects/request/src/request-client/types.ts +++ b/packages/effects/request/src/request-client/types.ts @@ -46,7 +46,7 @@ type RequestClientOptions = CreateAxiosDefaults & ExtendOptions; */ interface SseRequestOptions extends RequestInit { onMessage?: (message: string) => void; - onEnd?: (message: string) => void; + onEnd?: () => void; } interface RequestInterceptorConfig {