Skip to content

Add support for streaming to the generative module #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/lib/core/modules/generative/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export type GenerateTextParams = {
}
>

stream?: boolean
responseFormat?: {
type: 'json_object' | 'json_schema'
schema?: Record<string, any>
Expand Down
2 changes: 2 additions & 0 deletions src/lib/core/modules/generative/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const promptSchema = z.object({
.optional()
.default([])
.transform((data) => transformData(data)),
stream: z.boolean().optional().default(false),
})

const validationSchema = z.object({
Expand Down Expand Up @@ -107,6 +108,7 @@ const validationSchema = z.object({
.default({
type: 'text',
}),
stream: z.boolean().optional().default(false),
})

const schema = z.discriminatedUnion('type', [promptSchema, validationSchema])
Expand Down
14 changes: 10 additions & 4 deletions src/lib/core/services/generative/GenerativeService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {
GenerativeImageMessage,
GenerativeTextMessage,
} from 'src/lib/plugins-common/generative'
import { Stream } from 'stream'
import { Modules } from '../../modules'
import { GenerateTextParams } from '../../modules/generative/types'
import { Plugins } from '../../plugins'
Expand Down Expand Up @@ -66,15 +67,20 @@ export class GenerativeService {
? validated.responseFormat?.schema
: undefined,
},
stream: validated.stream,
}),
)

if (err) throw new Error(err.message)

return {
content: res.content,
usageMetadata: res.metadata.usage,
finishReason: res.metadata.finishReason,
if (res instanceof Stream.Readable) {
return res
} else {
return {
content: res.content,
usageMetadata: res.metadata.usage,
finishReason: res.metadata.finishReason,
}
}
}
}
14 changes: 13 additions & 1 deletion src/lib/plugins-common/generative/Generative.interface.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type * as stream from 'stream'
import type { z } from 'zod'
import { PluginContext } from '..'

Expand All @@ -19,7 +20,7 @@ export interface GenerativePlugin<
generateText: (
ctx: C,
params: GenerateTextParams,
) => Promise<GenerateTextResult>
) => Promise<GenerateTextResult | GenerateTextResultStream>
}

export type GetSupportedModelsParams = {}
Expand All @@ -33,6 +34,7 @@ export type GenerateTextParams<
messages: GenerativeMessage[]

options?: T
stream?: boolean
signal: AbortSignal
}

Expand All @@ -50,6 +52,16 @@ export type GenerateTextResult = {
}
}

export type GenerateTextResultStreamChunk = {
content: string | Record<string, any>
}

export type GenerateTextResultStreamPayload = GenerateTextResult & {
finished: true
}

export type GenerateTextResultStream = stream.Readable

export type ModelSpec = {
name: string
}
Expand Down
52 changes: 41 additions & 11 deletions src/modules/generative/controllers/Generative.controller.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { Body, Controller, HttpCode, Post } from '@nestjs/common'
import { Body, Controller, HttpCode, Post, Res } from '@nestjs/common'
import { Response } from 'express'
import { Stream } from 'stream'
import { GenerativeService } from '../services/Generative.service'

@Controller('')
Expand All @@ -7,35 +9,63 @@ export class GenerativeController {

@HttpCode(200)
@Post('/content/generative/chat/completions')
async _chatCompletion(@Body() body: any) {
return this.chatCompletion(body)
async _chatCompletion(@Body() body: any, @Res() res: Response) {
return this.chatCompletion(body, res)
}

@HttpCode(200)
@Post('/generative/chat/completions')
async chatCompletion(@Body() body: any) {
async chatCompletion(@Body() body: any, @Res() res: Response) {
// @TODO: handle cancellation
const signal = new AbortController().signal
const controller = new AbortController()

res.on('close', () => {
controller.abort()
})

const result = await this.generativeService.generateText({
params: body,
signal,
signal: controller.signal,
})

return result
if (result instanceof Stream.Readable) {
res.setHeader('Content-Type', 'text/event-stream')
res.setHeader('Cache-Control', 'no-cache')
res.setHeader('Connection', 'keep-alive')

result.pipe(res)
} else {
return res.send({
data: result,
})
}
}

@HttpCode(200)
@Post('/generative/generate/text')
async text(@Body() body: any) {
async text(@Body() body: any, @Res() res: Response) {
// @TODO: handle cancellation
const signal = new AbortController().signal
const controller = new AbortController()

res.on('close', () => {
controller.abort()
})

const result = await this.generativeService.generateText({
params: body,
signal,
signal: controller.signal,
})

return result
if (result instanceof Stream.Readable) {
res.setHeader('Content-Type', 'text/event-stream')
res.setHeader('Cache-Control', 'no-cache')
res.setHeader('Connection', 'keep-alive')

result.pipe(res)
} else {
return res.send({
data: result,
})
}
}
}
125 changes: 124 additions & 1 deletion src/plugins/plugin-generative-openai/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@ import {
GenerateTextOptionsBase,
GenerateTextParams,
GenerateTextResult,
GenerateTextResultStream,
GenerateTextResultStreamChunk,
GenerateTextResultStreamPayload,
GenerativePlugin,
GetSupportedModelsParams,
GetSupportedModelsResult,
} from 'src/lib/plugins-common/generative'
import * as stream from 'stream'
import { z } from 'zod'
import { Config, Context, Model } from './plugin.types'

Expand Down Expand Up @@ -165,7 +169,7 @@ export class GenerativeOpenAI implements PluginLifecycle, GenerativePlugin {
generateText = async (
ctx: Context,
params: GenerateTextParams<GenerateTextOptionsBase>,
): Promise<GenerateTextResult> => {
): Promise<GenerateTextResult | GenerateTextResultStream> => {
const options = params.options

const model = options?.model || (this.config.options?.model as Model)
Expand Down Expand Up @@ -214,6 +218,9 @@ export class GenerativeOpenAI implements PluginLifecycle, GenerativePlugin {
if (options?.maxTokens && options.maxTokens > config.maxTokens)
throw new Error(`'${model}' maxTokens limit is ${config.maxTokens}`)

if (params.stream && !config.streaming)
throw new Error(`'${model}' doesn't support streaming`)

const llm =
responseFormat === 'json_object'
? chat.withStructuredOutput(options?.schema || {}, {
Expand Down Expand Up @@ -254,6 +261,51 @@ export class GenerativeOpenAI implements PluginLifecycle, GenerativePlugin {
totalTokens: 0,
}

if (params.stream) {
const response = await llm.stream(messages, {
options: {
body: {
stream_options: {
include_usage: true,
},
},
signal: params.signal,
},
})

const format = params.options?.responseFormat || 'text'

const stream = new Stream(
response as any,
(chunk) => {
if (format === 'text')
return {
content: chunk.content,
}

return {
content: chunk,
}
},
(acc, chunk) => {
if (format === 'text')
return {
...acc,
content:
(typeof acc.content === 'string' ? acc.content : '') +
chunk.content,
}

return {
...acc,
content: chunk,
}
},
)

return stream
}

const response = await llm.invoke(messages, {
callbacks: [
{
Expand Down Expand Up @@ -283,3 +335,74 @@ export class GenerativeOpenAI implements PluginLifecycle, GenerativePlugin {
}
}
}

export class Stream extends stream.Readable {
private _firstChunk: boolean = true
private _finished: boolean = false

constructor(
private readonly _readable?: stream.Readable,
private readonly _readChunk?: (chunk: any) => GenerateTextResultStreamChunk,
private readonly _appendChunk?: (
acc: GenerateTextResultStreamPayload,
chunk: GenerateTextResultStreamChunk,
) => GenerateTextResultStreamPayload,
) {
super({ read: () => {} })

if (this._readable) {
let acc: null | GenerateTextResultStreamPayload = null
;(async () => {
if (this._readable)
for await (const chunk of this._readable) {
const data = this._readChunk ? this._readChunk(chunk) : chunk
this.pushChunk(data)

if (!acc) {
acc = {
content: null as any,
finished: true,
metadata: {
finishReason: '',
usage: {
inputTokens: 0,
outputTokens: 0,
totalTokens: 0,
},
},
}
}

if (this._appendChunk && acc && chunk) {
acc = this._appendChunk(acc, chunk)
}
}

if (acc) {
this.pushFinal(acc)
this.close()
}
})().catch((err) => {
this.emit('error', err)
this.close()
})
}
}

pushChunk = (chunk: GenerateTextResultStreamChunk) => {
let data = (!this._firstChunk ? '\n' : '') + JSON.stringify(chunk)
this._firstChunk = false

return super.emit('data', Buffer.from(data, 'utf-8'))
}

pushFinal = (payload: GenerateTextResultStreamPayload) => {
const data = '\n' + JSON.stringify({ ...payload, finished: true })
return super.emit('data', Buffer.from(data, 'utf-8'))
}

close() {
this.push(null)
this.emit('close')
}
}