Skip to content

Commit b847bcd

Browse files
committed
move store to tool context
1 parent f17aba0 commit b847bcd

File tree

8 files changed

+24
-26
lines changed

8 files changed

+24
-26
lines changed

frontend/src/components/chat/chat-panel.tsx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import { useChat } from "@ai-sdk/react";
55
import { storePrompt } from "@marimo-team/codemirror-ai";
66
import type { ReactCodeMirrorRef } from "@uiw/react-codemirror";
77
import { DefaultChatTransport, type ToolUIPart } from "ai";
8-
import { useAtom, useAtomValue, useSetAtom } from "jotai";
8+
import { useAtom, useAtomValue, useSetAtom, useStore } from "jotai";
99
import {
1010
AtSignIcon,
1111
BotMessageSquareIcon,
@@ -528,6 +528,7 @@ const ChatPanelBody = () => {
528528
const { invokeAiTool, sendRun } = useRequestClient();
529529

530530
const activeChatId = activeChat?.id;
531+
const store = useStore();
531532

532533
const { addStagedCell } = useStagedAICellsActions();
533534
const { createNewCell, prepareForRun } = useCellActions();
@@ -536,6 +537,7 @@ const ChatPanelBody = () => {
536537
createNewCell,
537538
prepareForRun,
538539
sendRun,
540+
store,
539541
};
540542

541543
const {

frontend/src/components/editor/ai/add-cell-with-ai.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ export const AddCellWithAI: React.FC<{
9292

9393
const { createNewCell, prepareForRun } = useCellActions();
9494
const toolContext: ToolNotebookContext = {
95+
store,
9596
addStagedCell,
9697
createNewCell,
9798
prepareForRun,

frontend/src/core/ai/tools/__tests__/edit-notebook-tool.test.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ describe("EditNotebookTool", () => {
7070
createNewCell: ReturnType<typeof vi.fn>;
7171
prepareForRun: ReturnType<typeof vi.fn>;
7272
sendRun: ReturnType<typeof vi.fn>;
73+
store: ReturnType<typeof getDefaultStore>;
7374
};
7475

7576
beforeEach(() => {
@@ -86,8 +87,9 @@ describe("EditNotebookTool", () => {
8687
createNewCell: vi.fn(),
8788
prepareForRun: vi.fn(),
8889
sendRun: vi.fn().mockResolvedValue(null),
90+
store,
8991
};
90-
tool = new EditNotebookTool(store);
92+
tool = new EditNotebookTool();
9193

9294
cellId1 = "cell-1" as CellId;
9395
cellId2 = "cell-2" as CellId;

frontend/src/core/ai/tools/__tests__/run-cells-tool.test.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ describe("RunStaleCellsTool", () => {
3131
createNewCell: ReturnType<typeof vi.fn>;
3232
prepareForRun: ReturnType<typeof vi.fn>;
3333
sendRun: ReturnType<typeof vi.fn>;
34+
store: ReturnType<typeof getDefaultStore>;
3435
};
3536

3637
beforeEach(() => {
@@ -40,9 +41,10 @@ describe("RunStaleCellsTool", () => {
4041
createNewCell: vi.fn(),
4142
prepareForRun: vi.fn(),
4243
sendRun: vi.fn().mockResolvedValue(null),
44+
store,
4345
};
4446

45-
tool = new RunStaleCellsTool(store);
47+
tool = new RunStaleCellsTool();
4648

4749
cellId1 = "cell-1" as CellId;
4850
cellId2 = "cell-2" as CellId;

frontend/src/core/ai/tools/base.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { z } from "zod";
44
import type { CreateNewCellAction } from "@/core/cells/cells";
55
import type { CellId } from "@/core/cells/ids";
66
import type { RunRequest } from "@/core/network/types";
7+
import type { JotaiStore } from "@/core/state/jotai";
78
import type { Edit } from "../staged-cells";
89
import type { CopilotMode } from "./registry";
910

@@ -101,6 +102,7 @@ export interface ToolDescription {
101102

102103
/** Utility functions for tools to interact with the notebook */
103104
export interface ToolNotebookContext {
105+
store: JotaiStore;
104106
addStagedCell: (payload: { cellId: CellId; edit: Edit }) => void;
105107
createNewCell: (payload: CreateNewCellAction) => void;
106108
prepareForRun: (payload: { cellId: CellId }) => void;

frontend/src/core/ai/tools/edit-notebook-tool.ts

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import {
1010
} from "@/core/cells/cells";
1111
import { CellId } from "@/core/cells/ids";
1212
import { updateEditorCodeFromPython } from "@/core/codemirror/language/utils";
13-
import type { JotaiStore } from "@/core/state/jotai";
1413
import type { CellColumnId } from "@/utils/id-tree";
1514
import {
1615
type AiTool,
@@ -87,28 +86,23 @@ export type EditType = EditOperation["type"];
8786
export class EditNotebookTool
8887
implements AiTool<EditNotebookInput, ToolOutputBase>
8988
{
90-
private store: JotaiStore;
9189
readonly name = "edit_notebook_tool";
9290
readonly description = description;
9391
readonly schema = editNotebookSchema;
9492
readonly outputSchema = toolOutputBaseSchema;
9593
readonly mode: CopilotMode[] = ["agent"];
9694

97-
constructor(store: JotaiStore) {
98-
this.store = store;
99-
}
100-
10195
handler = async (
10296
{ edit }: EditNotebookInput,
10397
toolContext: ToolNotebookContext,
10498
): Promise<ToolOutputBase> => {
105-
const { addStagedCell, createNewCell } = toolContext;
99+
const { addStagedCell, createNewCell, store } = toolContext;
106100

107101
switch (edit.type) {
108102
case "update_cell": {
109103
const { cellId, code } = edit;
110104

111-
const notebook = this.store.get(notebookAtom);
105+
const notebook = store.get(notebookAtom);
112106
this.validateCellIdExists(cellId, notebook);
113107
const editorView = this.getCellEditorView(cellId, notebook);
114108

@@ -133,7 +127,7 @@ export class EditNotebookTool
133127
const newCellId = CellId.create();
134128

135129
if (typeof position === "object") {
136-
const notebook = this.store.get(notebookAtom);
130+
const notebook = store.get(notebookAtom);
137131
if ("cellId" in position) {
138132
this.validateCellIdExists(position.cellId, notebook);
139133
notebookPosition = position.cellId;
@@ -164,7 +158,7 @@ export class EditNotebookTool
164158
case "delete_cell": {
165159
const { cellId } = edit;
166160

167-
const notebook = this.store.get(notebookAtom);
161+
const notebook = store.get(notebookAtom);
168162
this.validateCellIdExists(cellId, notebook);
169163

170164
const editorView = this.getCellEditorView(cellId, notebook);

frontend/src/core/ai/tools/registry.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import type { components } from "@marimo-team/marimo-api";
44
import { Memoize } from "typescript-memoize";
55
import { type ZodObject, z } from "zod";
6-
import { store } from "@/core/state/jotai";
76
import {
87
type AiTool,
98
ToolExecutionError,
@@ -131,6 +130,6 @@ export class FrontendToolRegistry {
131130
}
132131

133132
export const FRONTEND_TOOL_REGISTRY = new FrontendToolRegistry([
134-
new EditNotebookTool(store),
135-
new RunStaleCellsTool(store),
133+
new EditNotebookTool(),
134+
new RunStaleCellsTool(),
136135
]);

frontend/src/core/ai/tools/run-cells-tool.ts

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,19 +66,14 @@ export class RunStaleCellsTool
6666
.optional(),
6767
}) satisfies z.ZodType<RunStaleCellsOutput>;
6868
readonly mode: CopilotMode[] = ["agent"];
69-
private store: JotaiStore;
70-
71-
constructor(store: JotaiStore) {
72-
this.store = store;
73-
}
7469

7570
handler = async (
7671
_args: EmptyToolInput,
7772
toolContext: ToolNotebookContext,
7873
): Promise<RunStaleCellsOutput> => {
79-
const { prepareForRun, sendRun } = toolContext;
74+
const { prepareForRun, sendRun, store } = toolContext;
8075

81-
const notebook = this.store.get(notebookAtom);
76+
const notebook = store.get(notebookAtom);
8277
const staleCells = staleCellIds(notebook);
8378

8479
if (staleCells.length === 0) {
@@ -96,7 +91,7 @@ export class RunStaleCellsTool
9691
});
9792

9893
// Wait for all cells to finish executing
99-
const allCellsFinished = await this.waitForCellsToFinish(staleCells);
94+
const allCellsFinished = await this.waitForCellsToFinish(store, staleCells);
10095
if (!allCellsFinished) {
10196
return {
10297
status: "success",
@@ -106,7 +101,7 @@ export class RunStaleCellsTool
106101
}
107102

108103
// Get notebook state after cells have finished
109-
const updatedNotebook = this.store.get(notebookAtom);
104+
const updatedNotebook = store.get(notebookAtom);
110105

111106
const cellsToOutput = new Map<CellId, CellOutput | null>();
112107
let resultMessage = "";
@@ -192,6 +187,7 @@ export class RunStaleCellsTool
192187
* Returns true if all cells finished executing, false if the timeout was reached
193188
*/
194189
private async waitForCellsToFinish(
190+
store: JotaiStore,
195191
cellIds: CellId[],
196192
timeout = 30_000,
197193
): Promise<boolean> {
@@ -207,7 +203,7 @@ export class RunStaleCellsTool
207203
};
208204

209205
// If already finished, return immediately
210-
if (checkAllFinished(this.store.get(notebookAtom))) {
206+
if (checkAllFinished(store.get(notebookAtom))) {
211207
return true;
212208
}
213209

0 commit comments

Comments
 (0)