Skip to content

Commit a5bb81b

Browse files
authored
Add ability to continue ability to continue thread (#297)
* Add ability to continue ability to continue thread * Remove unused fns
1 parent 571c775 commit a5bb81b

File tree

4 files changed

+274
-7
lines changed

4 files changed

+274
-7
lines changed

apps/src/package.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ export { applyTableDiffs } from "./common/utils";
33
export { getTableContextYAML, filterTablesByCatalog } from "./metabase/helpers/catalog";
44
export { getTableData, getDatabaseTablesAndModelsWithoutFields, getAllCardsAndModels as getAllCards, getAllCardsLegacy, getAllFields, getAllFieldsFiltered } from "./metabase/helpers/metabaseAPIHelpers";
55
export { fetchModelInfo } from "./metabase/helpers/metabaseAPI";
6-
export { getAllTemplateTagsInQuery } from "./metabase/helpers/sqlQuery";
6+
export { getAllTemplateTagsInQuery, applySQLEdits, type SQLEdits } from "./metabase/helpers/sqlQuery";
77
export { getModelsWithFields, getSelectedAndRelevantModels, modifySqlForMetabaseModels, replaceLLMFriendlyIdentifiersInSqlWithModels } from "./metabase/helpers/metabaseModels";
88
export { getCurrentQuery, getDashboardState } from "./metabase/helpers/metabaseStateAPI";
99
export { subscribeMB, onMBSubscription } from "./metabase/helpers/stateSubscriptions";

web/src/components/common/App.tsx

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import Auth from './Auth'
3333
import _, { attempt } from 'lodash'
3434
import { updateAppMode, setAnalystMode, setDRMode, setCurrentEmail } from '../../state/settings/reducer'
3535
import { DevToolsBox } from '../devtools';
36-
import { RootState } from '../../state/store'
36+
import { RootState, store } from '../../state/store'
3737
import { getPlatformShortcut } from '../../helpers/platformCustomization'
3838
import { getParsedIframeInfo } from '../../helpers/origin'
3939
import { getApp } from '../../helpers/app'
@@ -45,6 +45,7 @@ import { Markdown } from './Markdown'
4545
import { getMXToken, setMinusxMode, toggleMinusXRoot } from '../../app/rpc'
4646
import { configs } from '../../constants'
4747
import { abortPlan, startNewThread, updateThreadID } from '../../state/chat/reducer'
48+
import { intelligentThreadStart } from '../../helpers/threadHistory'
4849

4950
// Agent constants
5051
const AGENTS = {
@@ -191,9 +192,14 @@ const AppLoggedIn = forwardRef((_props, ref) => {
191192
// Update thread id on start
192193
useEffect(() => {
193194
// dispatch(updateThreadID())
194-
if (!configs.IS_DEV) {
195-
dispatch(startNewThread())
196-
}
195+
intelligentThreadStart(store.getState).then(result => {
196+
if (result.restored) {
197+
console.log('Restored thread context for SQL:', result.matchingSQL);
198+
// Show subtle notification that context was restored
199+
}
200+
}).catch(error => {
201+
console.error('Error in intelligent thread start:', error);
202+
});
197203
}, [])
198204

199205
useEffect(() => {
@@ -255,6 +261,7 @@ const AppLoggedIn = forwardRef((_props, ref) => {
255261
if (taskInProgress) {
256262
dispatch(abortPlan())
257263
}
264+
// For clear messages, always start a fresh thread (no intelligent restore)
258265
dispatch(startNewThread())
259266
}
260267

web/src/helpers/threadHistory.ts

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
/**
2+
* Thread History Management
3+
*
4+
* Functions for scanning thread history and restoring conversation context
5+
* based on SQL query matching.
6+
*/
7+
8+
import { ChatThread, ChatMessage, Action, startNewThread, cloneThreadFromHistory } from '../state/chat/reducer';
9+
import { applySQLEdits, SQLEdits, getCurrentQuery } from 'apps';
10+
import { dispatch } from '../state/dispatch';
11+
import { RootState } from '../state/store';
12+
import { queryURL } from '../app/rpc';
13+
14+
/**
15+
* Normalizes SQL for comparison by removing extra whitespace,
16+
* converting to lowercase, and standardizing formatting
17+
*/
18+
export function normalizeSQL(sql: string): string {
19+
if (!sql || typeof sql !== 'string') {
20+
return '';
21+
}
22+
23+
return sql
24+
.trim()
25+
.toLowerCase()
26+
.replace(/\s+/g, ' ') // Replace multiple spaces with single space
27+
.replace(/\(\s+/g, '(') // Remove spaces after opening parentheses
28+
.replace(/\s+\)/g, ')') // Remove spaces before closing parentheses
29+
.replace(/,\s+/g, ',') // Normalize comma spacing
30+
.replace(/;\s*$/, ''); // Remove trailing semicolon
31+
}
32+
33+
/**
34+
* Extracts SQL from tool call arguments, handling both ExecuteQuery and EditAndExecuteQuery
35+
*/
36+
function extractSQLFromAction(action: Action): string | null {
37+
try {
38+
const functionName = action.function.name;
39+
const args = JSON.parse(action.function.arguments);
40+
41+
if (functionName === 'ExecuteQuery') {
42+
return args.sql || null;
43+
}
44+
45+
if (functionName === 'EditAndExecuteQuery') {
46+
// For EditAndExecuteQuery, we need to reconstruct the final SQL
47+
// This is a simplified approach - in reality we'd need access to the base SQL
48+
const sql_edits = args.sql_edits as SQLEdits;
49+
// Note: We can't reconstruct without the original SQL, so return null for now
50+
// This could be enhanced to store the reconstructed SQL in the action results
51+
return null;
52+
}
53+
54+
return null;
55+
} catch (error) {
56+
console.warn('Error extracting SQL from action:', error);
57+
return null;
58+
}
59+
}
60+
61+
/**
62+
* Result of scanning threads for matching SQL
63+
*/
64+
export interface ThreadScanResult {
65+
threadIndex: number;
66+
messageIndex: number;
67+
matchingSQL: string;
68+
}
69+
70+
/**
71+
* Scans thread history for a matching SQL query
72+
* Returns the first match found (most recent threads searched first)
73+
*/
74+
export function scanThreadsForSQL(
75+
threads: ChatThread[],
76+
currentSQL: string
77+
): ThreadScanResult | null {
78+
if (!currentSQL || !threads || threads.length === 0) {
79+
return null;
80+
}
81+
82+
const normalizedCurrentSQL = normalizeSQL(currentSQL);
83+
if (!normalizedCurrentSQL) {
84+
return null;
85+
}
86+
87+
// Scan threads backwards (most recent first)
88+
for (let threadIndex = threads.length - 1; threadIndex >= 0; threadIndex--) {
89+
const thread = threads[threadIndex];
90+
if (!thread.messages) continue;
91+
92+
// Scan messages backwards within each thread
93+
for (let messageIndex = thread.messages.length - 1; messageIndex >= 0; messageIndex--) {
94+
const message = thread.messages[messageIndex];
95+
96+
// Only check tool messages with ExecuteQuery or EditAndExecuteQuery actions
97+
if (message.role === 'tool' && message.action) {
98+
const extractedSQL = extractSQLFromAction(message.action);
99+
if (extractedSQL) {
100+
const normalizedExtractedSQL = normalizeSQL(extractedSQL);
101+
if (normalizedExtractedSQL === normalizedCurrentSQL) {
102+
return {
103+
threadIndex,
104+
messageIndex,
105+
matchingSQL: extractedSQL
106+
};
107+
}
108+
}
109+
}
110+
}
111+
}
112+
113+
return null;
114+
}
115+
116+
117+
/**
118+
* Intelligent thread start function that checks for matching SQL in history
119+
* and restores context if found, otherwise starts a new thread
120+
*/
121+
export async function intelligentThreadStart(getState: () => RootState): Promise<{
122+
restored: boolean;
123+
matchingSQL?: string;
124+
}> {
125+
try {
126+
// Get current SQL from the page
127+
const currentURL = await queryURL()
128+
let currentSQL = ''
129+
try {
130+
const url = new URL(currentURL);
131+
const hash = url.hash;
132+
currentSQL = JSON.parse(atob(decodeURIComponent(hash.slice(1)))).dataset_query.native.query;
133+
} catch {
134+
console.warn('Failed to extract SQL from URL hash, using getCurrentQuery');
135+
}
136+
if (!currentSQL) {
137+
// No SQL on page, start new thread normally
138+
dispatch(startNewThread());
139+
return { restored: false };
140+
}
141+
142+
// Get current thread state
143+
const state = getState();
144+
const threads = state.chat.threads;
145+
146+
if (!threads || threads.length === 0) {
147+
// No threads to search, start new thread
148+
dispatch(startNewThread());
149+
return { restored: false };
150+
}
151+
152+
// Scan for matching SQL in thread history
153+
const matchResult = scanThreadsForSQL(threads, currentSQL);
154+
155+
if (matchResult) {
156+
// Found a match! Clone the thread up to that message
157+
dispatch(cloneThreadFromHistory({
158+
sourceThreadIndex: matchResult.threadIndex,
159+
upToMessageIndex: matchResult.messageIndex
160+
}));
161+
162+
return {
163+
restored: true,
164+
matchingSQL: matchResult.matchingSQL
165+
};
166+
} else {
167+
// No match found, start new thread
168+
dispatch(startNewThread());
169+
return { restored: false };
170+
}
171+
172+
} catch (error) {
173+
console.error('Error in intelligentThreadStart:', error);
174+
// Fallback to normal thread start on any error
175+
dispatch(startNewThread());
176+
return { restored: false };
177+
}
178+
}

web/src/state/chat/reducer.ts

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ export interface Task {
139139

140140
export type Tasks = Array<Task>
141141

142-
interface ChatThread {
142+
export interface ChatThread {
143143
index: number
144144
debugChatIndex: number
145145
messages: Array<ChatMessage>
@@ -536,11 +536,93 @@ export const chatSlice = createSlice({
536536
clearTasks: (state) => {
537537
const activeThread = getActiveThread(state)
538538
activeThread.tasks = []
539+
},
540+
cloneThreadFromHistory: (state, action: PayloadAction<{
541+
sourceThreadIndex: number,
542+
upToMessageIndex: number
543+
}>) => {
544+
try {
545+
const { sourceThreadIndex, upToMessageIndex } = action.payload
546+
547+
// Validate source thread exists
548+
if (sourceThreadIndex < 0 || sourceThreadIndex >= state.threads.length) {
549+
console.error('Invalid source thread index:', sourceThreadIndex)
550+
return
551+
}
552+
553+
const sourceThread = state.threads[sourceThreadIndex]
554+
if (!sourceThread || !sourceThread.messages) {
555+
console.error('Source thread or messages not found')
556+
return
557+
}
558+
559+
// Validate message index
560+
if (upToMessageIndex < 0 || upToMessageIndex >= sourceThread.messages.length) {
561+
console.error('Invalid message index:', upToMessageIndex)
562+
return
563+
}
564+
565+
// Handle thread limit (remove oldest if at max)
566+
if (state.threads.length >= MAX_THREADS) {
567+
const excessThreads = state.threads.length - MAX_THREADS + 1;
568+
state.threads.splice(0, excessThreads);
569+
570+
state.threads.forEach((thread, index) => {
571+
thread.index = index;
572+
});
573+
}
574+
575+
// Generate new thread ID
576+
const previousID = state.threads[state.threads.length - 1].id
577+
const newID = generateNextThreadID(previousID)
578+
579+
// Clone messages up to and including the assistant response after the tool call
580+
const endIndex = Math.min(upToMessageIndex + 1, sourceThread.messages.length - 1);
581+
const clonedMessages = sourceThread.messages
582+
.slice(0, endIndex + 1)
583+
.map((message, index) => ({
584+
...message,
585+
index,
586+
feedback: { reaction: 'unrated' as const },
587+
updatedAt: Date.now()
588+
}));
589+
590+
// Create new thread
591+
const newThread: ChatThread = {
592+
index: state.threads.length,
593+
debugChatIndex: -1,
594+
messages: clonedMessages,
595+
status: 'FINISHED',
596+
userConfirmation: {
597+
show: false,
598+
content: '',
599+
userInput: 'NULL'
600+
},
601+
clarification: {
602+
show: false,
603+
questions: [],
604+
answers: [],
605+
currentQuestionIndex: 0,
606+
isCompleted: false
607+
},
608+
interrupted: false,
609+
tasks: [],
610+
id: newID
611+
}
612+
613+
// Add thread and switch to it
614+
state.threads.push(newThread)
615+
state.activeThread = state.threads.length - 1
616+
617+
} catch (error) {
618+
console.error('Error cloning thread from history:', error)
619+
// Don't change state on error - let existing thread remain active
620+
}
539621
}
540622
},
541623
})
542624

543625
// Action creators are generated for each case reducer function
544-
export const { addUserMessage, deleteUserMessage, addActionPlanMessage, startAction, finishAction, interruptPlan, startNewThread, addReaction, removeReaction, updateDebugChatIndex, setActiveThreadStatus, toggleUserConfirmation, setUserConfirmationInput, toggleClarification, setClarificationAnswer, switchToThread, abortPlan, updateThreadID, updateLastWarmedOn, clearTasks } = chatSlice.actions
626+
export const { addUserMessage, deleteUserMessage, addActionPlanMessage, startAction, finishAction, interruptPlan, startNewThread, addReaction, removeReaction, updateDebugChatIndex, setActiveThreadStatus, toggleUserConfirmation, setUserConfirmationInput, toggleClarification, setClarificationAnswer, switchToThread, abortPlan, updateThreadID, updateLastWarmedOn, clearTasks, cloneThreadFromHistory } = chatSlice.actions
545627

546628
export default chatSlice.reducer

0 commit comments

Comments
 (0)