diff --git a/packages/effects/request/src/request-client/modules/sse.test.ts b/packages/effects/request/src/request-client/modules/sse.test.ts new file mode 100644 index 00000000..5d630a87 --- /dev/null +++ b/packages/effects/request/src/request-client/modules/sse.test.ts @@ -0,0 +1,131 @@ +import type { RequestClient } from '../request-client'; + +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import { SSE } from './sse'; + +// 模拟 TextDecoder +const OriginalTextDecoder = globalThis.TextDecoder; + +beforeEach(() => { + vi.stubGlobal( + 'TextDecoder', + class { + private decoder = new OriginalTextDecoder(); + decode(value: Uint8Array, opts?: any) { + return this.decoder.decode(value, opts); + } + }, + ); +}); + +// 创建 fetch mock +const createFetchMock = (chunks: string[], ok = true) => { + const encoder = new TextEncoder(); + let index = 0; + return vi.fn().mockResolvedValue({ + ok, + status: ok ? 200 : 500, + body: { + getReader: () => ({ + read: async () => { + if (index < chunks.length) { + return { done: false, value: encoder.encode(chunks[index++]) }; + } + return { done: true, value: undefined }; + }, + }), + }, + }); +}; + +describe('sSE', () => { + let client: RequestClient; + let sse: SSE; + + beforeEach(() => { + vi.restoreAllMocks(); + client = { + getBaseUrl: () => 'http://localhost', + instance: { + interceptors: { + request: { + handlers: [], + }, + }, + }, + } as unknown as RequestClient; + sse = new SSE(client); + }); + + it('should call requestSSE when postSSE is used', async () => { + const spy = vi.spyOn(sse, 'requestSSE').mockResolvedValue(undefined); + await sse.postSSE('/test', { foo: 'bar' }, { headers: { a: '1' } }); + expect(spy).toHaveBeenCalledWith( + '/test', + { foo: 'bar' }, + { + headers: { a: '1' }, + method: 'POST', + }, + ); + }); + + it('should throw error if fetch response not ok', async () => { + vi.stubGlobal('fetch', createFetchMock([], false)); + await expect(sse.requestSSE('/bad')).rejects.toThrow( + 'HTTP error! status: 500', + ); + }); + + it('should trigger onMessage and onEnd callbacks', async () => { + const messages: string[] = []; + const onMessage = vi.fn((msg: string) => messages.push(msg)); + const onEnd = vi.fn(); + + vi.stubGlobal('fetch', createFetchMock(['hello', ' world'])); + + await sse.requestSSE('/sse', undefined, { onMessage, onEnd }); + + expect(onMessage).toHaveBeenCalledTimes(2); + expect(messages.join('')).toBe('hello world'); + expect(onEnd).toHaveBeenCalledWith('hello world'); + }); + + it('should apply request interceptors', async () => { + const interceptor = vi.fn(async (config) => { + config.headers['x-test'] = 'intercepted'; + return config; + }); + (client.instance.interceptors.request as any).handlers.push({ + fulfilled: interceptor, + }); + + vi.stubGlobal('fetch', createFetchMock(['data'])); + + // 创建 fetch mock,并挂到全局 + const fetchMock = createFetchMock(['data']); + vi.stubGlobal('fetch', fetchMock); + await sse.requestSSE('/sse', undefined, {}); + + expect(interceptor).toHaveBeenCalled(); + expect(fetchMock).toHaveBeenCalledWith( + 'http://localhost//sse', + expect.objectContaining({ + headers: expect.objectContaining({ 'x-test': 'intercepted' }), + }), + ); + }); + + it('should throw error when no reader', async () => { + vi.stubGlobal( + 'fetch', + vi.fn().mockResolvedValue({ + ok: true, + status: 200, + body: null, + }), + ); + await expect(sse.requestSSE('/sse')).rejects.toThrow('No reader'); + }); +}); diff --git a/packages/effects/request/src/request-client/modules/sse.ts b/packages/effects/request/src/request-client/modules/sse.ts new file mode 100644 index 00000000..0edf2d4a --- /dev/null +++ b/packages/effects/request/src/request-client/modules/sse.ts @@ -0,0 +1,96 @@ +import type { AxiosRequestHeaders, InternalAxiosRequestConfig } from 'axios'; + +import type { RequestClient } from '../request-client'; +import type { SseRequestOptions } from '../types'; + +/** + * SSE模块 + */ +class SSE { + private client: RequestClient; + + constructor(client: RequestClient) { + this.client = client; + } + + public async postSSE( + url: string, + data?: any, + requestOptions?: SseRequestOptions, + ) { + return this.requestSSE(url, data, { + ...requestOptions, + method: 'POST', + }); + } + + /** + * SSE请求方法 + * @param url - 请求URL + * @param data - 请求数据 + * @param requestOptions - SSE请求选项 + */ + public async requestSSE( + url: string, + data?: any, + requestOptions?: SseRequestOptions, + ) { + const baseUrl = this.client.getBaseUrl() || ''; + const hasUrlSplit = baseUrl.endsWith('/') && url.startsWith('/'); + + const axiosConfig: InternalAxiosRequestConfig = { + headers: {} as AxiosRequestHeaders, + }; + const requestInterceptors = this.client.instance.interceptors + .request as any; + if ( + requestInterceptors.handlers && + requestInterceptors.handlers.length > 0 + ) { + for (const handler of requestInterceptors.handlers) { + if (handler.fulfilled) { + await handler.fulfilled(axiosConfig); + } + } + } + + const requestInit: RequestInit = { + ...requestOptions, + body: data, + headers: { + ...(axiosConfig.headers as Record), + ...requestOptions?.headers, + }, + }; + + const response = await fetch( + `${baseUrl}${hasUrlSplit ? '' : '/'}${url}`, + requestInit, + ); + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + + const reader = response.body?.getReader(); + const decoder = new TextDecoder(); + + if (!reader) { + throw new Error('No reader'); + } + let isEnd = false; + let allMessage = ''; + while (!isEnd) { + const { done, value } = await reader.read(); + if (done) { + isEnd = true; + requestOptions?.onEnd?.(allMessage); + break; + } + const content = decoder.decode(value, { stream: true }); + requestOptions?.onMessage?.(content); + allMessage += content; + } + } +} + +export { SSE }; diff --git a/packages/effects/request/src/request-client/request-client.ts b/packages/effects/request/src/request-client/request-client.ts index e5811673..453913b2 100644 --- a/packages/effects/request/src/request-client/request-client.ts +++ b/packages/effects/request/src/request-client/request-client.ts @@ -9,6 +9,7 @@ import qs from 'qs'; import { FileDownloader } from './modules/downloader'; import { InterceptorManager } from './modules/interceptor'; +import { SSE } from './modules/sse'; import { FileUploader } from './modules/uploader'; function getParamsSerializer( @@ -41,12 +42,14 @@ class RequestClient { public addResponseInterceptor: InterceptorManager['addResponseInterceptor']; public download: FileDownloader['download']; + public readonly instance: AxiosInstance; // 是否正在刷新token public isRefreshing = false; + public postSSE: SSE['postSSE']; // 刷新token队列 public refreshTokenQueue: ((token: string) => void)[] = []; + public requestSSE: SSE['requestSSE']; public upload: FileUploader['upload']; - private readonly instance: AxiosInstance; /** * 构造函数,用于创建Axios实例 @@ -84,6 +87,10 @@ class RequestClient { // 实例化文件下载器 const fileDownloader = new FileDownloader(this); this.download = fileDownloader.download.bind(fileDownloader); + // 实例化SSE模块 + const sse = new SSE(this); + this.postSSE = sse.postSSE.bind(sse); + this.requestSSE = sse.requestSSE.bind(sse); } /** @@ -103,6 +110,13 @@ class RequestClient { return this.request(url, { ...config, method: 'GET' }); } + /** + * 获取基础URL + */ + public getBaseUrl() { + return this.instance.defaults.baseURL; + } + /** * POST请求方法 */ diff --git a/packages/effects/request/src/request-client/types.ts b/packages/effects/request/src/request-client/types.ts index 494741dc..aa1e7811 100644 --- a/packages/effects/request/src/request-client/types.ts +++ b/packages/effects/request/src/request-client/types.ts @@ -41,6 +41,14 @@ type RequestContentType = type RequestClientOptions = CreateAxiosDefaults & ExtendOptions; +/** + * SSE 请求选项 + */ +interface SseRequestOptions extends RequestInit { + onMessage?: (message: string) => void; + onEnd?: (message: string) => void; +} + interface RequestInterceptorConfig { fulfilled?: ( config: ExtendOptions & InternalAxiosRequestConfig, @@ -78,4 +86,5 @@ export type { RequestInterceptorConfig, RequestResponse, ResponseInterceptorConfig, + SseRequestOptions, };