@@ -2,11 +2,16 @@ import { DiscordSnowflake } from '@sapphire/snowflake';
2
2
import { configs , defaultSessionId , midjourneyBotConfigs } from './config' ;
3
3
import type {
4
4
MessageItem ,
5
+ MessageType ,
5
6
MessageTypeProps ,
6
7
MidjourneyProps ,
7
8
UpscaleProps ,
8
9
} from './interface' ;
9
- import { findMessageByPrompt , isInProgress } from './utils' ;
10
+ import {
11
+ findMessageByPrompt ,
12
+ getHashFromCustomId ,
13
+ isInProgress ,
14
+ } from './utils' ;
10
15
11
16
export class Midjourney {
12
17
protected readonly channelId : string ;
@@ -45,6 +50,17 @@ export class Midjourney {
45
50
}
46
51
}
47
52
53
+ async interactions ( payload : any ) {
54
+ return fetch ( `https://discord.com/api/v9/interactions` , {
55
+ method : 'POST' ,
56
+ body : JSON . stringify ( payload ) ,
57
+ headers : {
58
+ 'Content-Type' : 'application/json' ,
59
+ Authorization : this . token ,
60
+ } ,
61
+ } ) ;
62
+ }
63
+
48
64
async createImage ( prompt : string ) {
49
65
const payload = {
50
66
type : 2 ,
@@ -90,14 +106,7 @@ export class Midjourney {
90
106
nonce : DiscordSnowflake . generate ( ) . toString ( ) ,
91
107
} ;
92
108
93
- const res = await fetch ( `https://discord.com/api/v9/interactions` , {
94
- method : 'POST' ,
95
- body : JSON . stringify ( payload ) ,
96
- headers : {
97
- 'Content-Type' : 'application/json' ,
98
- Authorization : this . token ,
99
- } ,
100
- } ) ;
109
+ const res = await this . interactions ( payload ) ;
101
110
if ( res . status >= 400 ) {
102
111
let message = '' ;
103
112
try {
@@ -113,7 +122,10 @@ export class Midjourney {
113
122
}
114
123
}
115
124
116
- async createUpscale ( { messageId, index, hash, customId } : UpscaleProps ) {
125
+ async createUpscaleOrVariation (
126
+ type : Exclude < MessageType , 'imagine' > ,
127
+ { messageId, customId } : UpscaleProps
128
+ ) {
117
129
const payload = {
118
130
type : 3 ,
119
131
nonce : DiscordSnowflake . generate ( ) . toString ( ) ,
@@ -125,33 +137,26 @@ export class Midjourney {
125
137
session_id : defaultSessionId ,
126
138
data : {
127
139
component_type : 2 ,
128
- custom_id : customId || `MJ::JOB::upsample:: ${ index } :: ${ hash } ` ,
140
+ custom_id : customId ,
129
141
} ,
130
142
} ;
131
- const res = await fetch ( `https://discord.com/api/v9/interactions` , {
132
- method : 'POST' ,
133
- body : JSON . stringify ( payload ) ,
134
- headers : {
135
- 'Content-Type' : 'application/json' ,
136
- Authorization : this . token ,
137
- } ,
138
- } ) ;
143
+ const res = await this . interactions ( payload ) ;
139
144
if ( res . status >= 400 ) {
140
145
let message = '' ;
141
146
try {
142
147
const data = await res . json ( ) ;
143
148
if ( this . debugger ) {
144
- this . log ( ' Create upscale failed' , JSON . stringify ( data ) ) ;
149
+ this . log ( ` Create ${ type } failed` , JSON . stringify ( data ) ) ;
145
150
}
146
151
message = data ?. message ;
147
152
} catch ( e ) {
148
153
// catch JSON error
149
154
}
150
- throw new Error ( message || `Create upscale failed with ${ res . status } ` ) ;
155
+ throw new Error ( message || `Create ${ type } failed with ${ res . status } ` ) ;
151
156
}
152
157
}
153
158
154
- async getMessage ( prompt : string , options ? : MessageTypeProps ) {
159
+ async getMessage ( prompt : string , options : MessageTypeProps ) {
155
160
const res = await fetch (
156
161
`https://discord.com/api/v10/channels/${ this . channelId } /messages?limit=50` ,
157
162
{
@@ -170,7 +175,10 @@ export class Midjourney {
170
175
* Same with /imagine command
171
176
*/
172
177
async imagine ( prompt : string ) {
178
+ const timestamp = new Date ( ) . toISOString ( ) ;
179
+
173
180
await this . createImage ( prompt ) ;
181
+
174
182
const times = this . timeout / this . interval ;
175
183
let count = 0 ;
176
184
let result : MessageItem | undefined ;
@@ -179,7 +187,7 @@ export class Midjourney {
179
187
count += 1 ;
180
188
await new Promise ( ( res ) => setTimeout ( res , this . interval ) ) ;
181
189
this . log ( count , 'imagine' ) ;
182
- const message = await this . getMessage ( prompt ) ;
190
+ const message = await this . getMessage ( prompt , { timestamp } ) ;
183
191
if ( message && ! isInProgress ( message ) ) {
184
192
result = message ;
185
193
break ;
@@ -192,18 +200,63 @@ export class Midjourney {
192
200
}
193
201
194
202
async upscale ( { prompt, ...params } : UpscaleProps & { prompt : string } ) {
195
- await this . createUpscale ( params ) ;
203
+ const { index } = getHashFromCustomId ( 'upscale' , params . customId ) ;
196
204
const times = this . timeout / this . interval ;
197
205
let count = 0 ;
198
206
let result : MessageItem | undefined ;
207
+
208
+ if ( ! index ) {
209
+ throw new Error ( 'Create upscale failed with 400, unknown customId' ) ;
210
+ }
211
+
212
+ const timestamp = new Date ( ) . toISOString ( ) ;
213
+
214
+ await this . createUpscaleOrVariation ( 'upscale' , params ) ;
215
+
199
216
while ( count < times ) {
200
217
try {
201
218
count += 1 ;
202
219
await new Promise ( ( res ) => setTimeout ( res , this . interval ) ) ;
203
220
this . log ( count , 'upscale' ) ;
204
221
const message = await this . getMessage ( prompt , {
205
222
type : 'upscale' ,
206
- index : params . index ,
223
+ index,
224
+ timestamp,
225
+ } ) ;
226
+ if ( message && ! isInProgress ( message ) ) {
227
+ result = message ;
228
+ break ;
229
+ }
230
+ } catch {
231
+ continue ;
232
+ }
233
+ }
234
+ return result ;
235
+ }
236
+
237
+ async variation ( { prompt, ...params } : UpscaleProps & { prompt : string } ) {
238
+ const { index } = getHashFromCustomId ( 'variation' , params . customId ) ;
239
+ const times = this . timeout / this . interval ;
240
+ let count = 0 ;
241
+ let result : MessageItem | undefined ;
242
+
243
+ if ( ! index ) {
244
+ throw new Error ( 'Create variation failed with 400, unknown customId' ) ;
245
+ }
246
+
247
+ const timestamp = new Date ( ) . toISOString ( ) ;
248
+
249
+ await this . createUpscaleOrVariation ( 'variation' , params ) ;
250
+
251
+ while ( count < times ) {
252
+ try {
253
+ count += 1 ;
254
+ await new Promise ( ( res ) => setTimeout ( res , this . interval ) ) ;
255
+ this . log ( count , 'variation' ) ;
256
+ const message = await this . getMessage ( prompt , {
257
+ type : 'variation' ,
258
+ index,
259
+ timestamp,
207
260
} ) ;
208
261
if ( message && ! isInProgress ( message ) ) {
209
262
result = message ;
0 commit comments